You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

fine_tune_good_unet.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. from pathlib import Path
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. # import cv2
  5. from collections import defaultdict
  6. import torchvision.transforms as transforms
  7. import torch
  8. from torch import nn
  9. import torch.nn.functional as F
  10. from segment_anything.utils.transforms import ResizeLongestSide
  11. import albumentations as A
  12. from albumentations.pytorch import ToTensorV2
  13. import numpy as np
  14. import random
  15. from tqdm import tqdm
  16. from time import sleep
  17. from data_load_group import *
  18. from time import time
  19. from PIL import Image
  20. from sklearn.model_selection import KFold
  21. from shutil import copyfile
  22. # import monai
  23. from utils import *
  24. from torch.autograd import Variable
  25. from Models.model_handler import panc_sam
  26. from Models.model_handler import UNet3D
  27. from Models.model_handler import Conv3DFilter
  28. from loss import *
  29. from args import get_arguments
  30. from segment_anything import SamPredictor, sam_model_registry
  31. from statistics import mean
  32. from copy import deepcopy
  33. from torch.nn.functional import threshold, normalize
  34. def calculate_recall(pred, target):
  35. smooth = 1
  36. batch_size = 1
  37. recall_scores = []
  38. binary_mask = pred>0
  39. true_positive = ((pred == 1) & (target == 1)).sum().item()
  40. false_negative = ((pred == 0) & (target == 1)).sum().item()
  41. recall = (true_positive + smooth) / ((true_positive + false_negative) + smooth)
  42. return recall
  43. def calculate_precision(pred, target):
  44. smooth = 1
  45. batch_size = 1
  46. recall_scores = []
  47. binary_mask = pred>0
  48. true_positive = ((pred == 1) & (target == 1)).sum().item()
  49. false_negative = ((pred == 1) & (target == 0)).sum().item()
  50. recall = (true_positive + smooth) / ((true_positive + false_negative) + smooth)
  51. return recall
  52. def calculate_jaccard(pred, target):
  53. smooth = 1
  54. batch_size = pred.shape[0]
  55. jaccard_scores = []
  56. binary_mask = pred>0
  57. true_positive = ((pred == 1) & (target == 1)).sum().item()
  58. false_negative = ((pred == 0) & (target == 1)).sum().item()
  59. true_positive = ((pred == 1) & (target == 1)).sum().item()
  60. false_positive = ((pred == 1) & (target == 0)).sum().item()
  61. jaccard = (true_positive + smooth) / (true_positive + false_positive + false_negative + smooth)
  62. # jaccard_scores.append(jaccard)
  63. # for i in range(batch_size):
  64. # true_positive = ((binary_mask[i] == 1) & (target[i] == 1)).sum().item()
  65. # false_positive = ((binary_mask[i] == 1) & (target[i] == 0)).sum().item()
  66. # false_negative = ((binary_mask[i] == 0) & (target[i] == 1)).sum().item()
  67. return jaccard
  68. def save_img(img, dir):
  69. img = img.clone().cpu().numpy() + 100
  70. if len(img.shape) == 3:
  71. img = rearrange(img, "c h w -> h w c")
  72. img_min = np.amin(img, axis=(0, 1), keepdims=True)
  73. img = img - img_min
  74. img_max = np.amax(img, axis=(0, 1), keepdims=True)
  75. img = (img / img_max * 255).astype(np.uint8)
  76. # grey_img = Image.fromarray(img[:, :, 0])
  77. img = Image.fromarray(img)
  78. else:
  79. img_min = img.min()
  80. img = img - img_min
  81. img_max = img.max()
  82. if img_max != 0:
  83. img = img / img_max * 255
  84. img = Image.fromarray(img).convert("L")
  85. img.save(dir)
  86. def seed_everything(seed: int):
  87. import random, os
  88. import numpy as np
  89. import torch
  90. random.seed(seed)
  91. os.environ['PYTHONHASHSEED'] = str(seed)
  92. np.random.seed(seed)
  93. torch.manual_seed(seed)
  94. torch.cuda.manual_seed(seed)
  95. torch.backends.cudnn.deterministic = True
  96. torch.backends.cudnn.benchmark = True
  97. seed_everything(2024)
  98. global optimizer
  99. args = get_arguments()
  100. exp_id = 0
  101. found = 0
  102. user_input = args.run_name
  103. while found == 0:
  104. try:
  105. os.makedirs(f"exps/{exp_id}-{user_input}")
  106. found = 1
  107. except:
  108. exp_id = exp_id + 1
  109. copyfile(os.path.realpath(__file__), f"exps/{exp_id}-{user_input}/code.py")
  110. augmentation = A.Compose(
  111. [
  112. A.Rotate(limit=100, p=0.7),
  113. A.RandomScale(scale_limit=0.3, p=0.5),
  114. ]
  115. )
  116. device = "cuda:0"
  117. panc_sam_instance = torch.load(args.model_path)
  118. panc_sam_instance.to(device)
  119. conv3d_instance = UNet3D()
  120. kernel_size = [(1, 5, 5), (5, 5, 5), (5, 5, 5), (5, 5, 5)]
  121. conv3d_instance = Conv3DFilter(
  122. 1,
  123. 5,
  124. kernel_size,
  125. np.array(kernel_size) // 2,
  126. custom_bias=args.custom_bias,
  127. )
  128. conv3d_instance.to(device)
  129. conv3d_instance.train()
  130. train_dataset = PanDataset(
  131. dirs=[f"{args.train_dir}/train"],
  132. datasets=[["NIH_PNG", 1]],
  133. target_image_size=args.image_size,
  134. slice_per_image=args.slice_per_image,
  135. train=True,
  136. val=False, # Enable validation data splitting
  137. augmentation=augmentation,
  138. )
  139. val_dataset = PanDataset(
  140. [f"{args.train_dir}/train"],
  141. [["NIH_PNG", 1]],
  142. args.image_size,
  143. slice_per_image=args.slice_per_image,
  144. val=True
  145. )
  146. test_dataset = PanDataset(
  147. [f"{args.test_dir}/test"],
  148. [["NIH_PNG", 1]],
  149. args.image_size,
  150. slice_per_image=args.slice_per_image,
  151. train=False,
  152. )
  153. train_loader = DataLoader(
  154. train_dataset,
  155. batch_size=args.batch_size,
  156. collate_fn=train_dataset.collate_fn,
  157. shuffle=True,
  158. drop_last=False,
  159. num_workers=args.num_workers,
  160. )
  161. val_loader = DataLoader(
  162. val_dataset,
  163. batch_size=args.batch_size,
  164. collate_fn=val_dataset.collate_fn,
  165. shuffle=False,
  166. drop_last=False,
  167. num_workers=args.num_workers,
  168. )
  169. test_loader = DataLoader(
  170. test_dataset,
  171. batch_size=args.batch_size,
  172. collate_fn=test_dataset.collate_fn,
  173. shuffle=False,
  174. drop_last=False,
  175. num_workers=args.num_workers,
  176. )
  177. # Set up the optimizer, hyperparameter tuning will improve performance here
  178. lr = args.lr
  179. max_lr = 3e-4
  180. wd = 5e-4
  181. all_parameters = list(conv3d_instance.parameters())
  182. optimizer = torch.optim.Adam(
  183. all_parameters,
  184. lr=lr,
  185. weight_decay=wd,
  186. )
  187. scheduler = torch.optim.lr_scheduler.StepLR(
  188. optimizer, step_size=10, gamma=0.7, verbose=True
  189. )
  190. loss_function = loss_fn(alpha=0.5, gamma=2.0)
  191. loss_function.to(device)
  192. from time import time
  193. import time as s_time
  194. log_file = open(f"exps/{exp_id}-{user_input}/log.txt", "a")
  195. def process_model(data_loader, train=0, save_output=0, epoch=None, scheduler=None):
  196. epoch_losses = []
  197. index = 0
  198. results = []
  199. dice_sam_lists = []
  200. dice_sam_prompt_lists = []
  201. dice_lists = []
  202. dice_prompt_lists = []
  203. num_samples = 0
  204. counterb = 0
  205. for batch in tqdm(data_loader, total=args.sample_size // args.batch_size):
  206. for batched_input in batch:
  207. num_samples += 1
  208. # raise ValueError(len(batched_input))
  209. low_res_masks = torch.zeros((1, 1, 0, 256, 256))
  210. # s_time.sleep(0.6)
  211. counterb += len(batch)
  212. index += 1
  213. label = []
  214. label = [i["label"] for i in batched_input]
  215. # Only correct if gray scale
  216. label = torch.cat(label, dim=1)
  217. # raise ValueError(la)
  218. label = label.float()
  219. true_indexes = torch.where((torch.amax(label, dim=(2, 3)) > 0).view(-1))[0]
  220. low_res_label = F.interpolate(label, low_res_masks.shape[-2:]).to("cuda:0")
  221. low_res_masks, low_res_masks_promtp = panc_sam_instance(
  222. batched_input, device
  223. )
  224. low_res_shape = low_res_masks.shape[-2:]
  225. low_res_label_prompt=low_res_label
  226. if train:
  227. transformed = augmentation(
  228. image=low_res_masks_promtp[0].permute(1, 2, 0).cpu().numpy(),
  229. mask=low_res_label[0].permute(1, 2, 0).cpu().numpy(),
  230. )
  231. low_res_masks_promtp = (
  232. torch.tensor(transformed["image"])
  233. .permute(2, 0, 1)
  234. .unsqueeze(0)
  235. .to(device)
  236. )
  237. low_res_label_prompt = (
  238. torch.tensor(transformed["mask"])
  239. .permute(2, 0, 1)
  240. .unsqueeze(0)
  241. .to(device)
  242. )
  243. transformed = augmentation(
  244. image=low_res_masks[0].permute(1, 2, 0).cpu().numpy(),
  245. mask=low_res_label[0].permute(1, 2, 0).cpu().numpy(),
  246. )
  247. low_res_masks = (
  248. torch.tensor(transformed["image"])
  249. .permute(2, 0, 1)
  250. .unsqueeze(0)
  251. .to(device)
  252. )
  253. low_res_label = (
  254. torch.tensor(transformed["mask"])
  255. .permute(2, 0, 1)
  256. .unsqueeze(0)
  257. .to(device)
  258. )
  259. low_res_masks = F.interpolate(low_res_masks, low_res_shape).to(device)
  260. low_res_label = F.interpolate(low_res_label, low_res_shape).to(device)
  261. low_res_masks = F.interpolate(low_res_masks, low_res_shape).to(device)
  262. low_res_label = F.interpolate(low_res_label, low_res_shape).to(device)
  263. low_res_masks_promtp = F.interpolate(
  264. low_res_masks_promtp, low_res_shape
  265. ).to(device)
  266. low_res_label_prompt = F.interpolate(
  267. low_res_label_prompt, low_res_shape
  268. ).to(device)
  269. low_res_masks = low_res_masks.detach()
  270. low_res_masks_promtp = low_res_masks_promtp.detach()
  271. dice_sam = dice_coefficient(low_res_masks , low_res_label).detach().cpu()
  272. dice_sam_prompt = (
  273. dice_coefficient(low_res_masks_promtp, low_res_label_prompt)
  274. .detach()
  275. .cpu()
  276. )
  277. low_res_masks_promtp = conv3d_instance(
  278. low_res_masks_promtp.detach().to(device)
  279. )
  280. loss = loss_function(low_res_masks_promtp, low_res_masks_promtp)
  281. loss /= args.accumulative_batch_size / args.batch_size
  282. binary_mask = low_res_masks > 0.5
  283. binary_mask_prompt = low_res_masks_promtp > 0.5
  284. dice = dice_coefficient(binary_mask, low_res_label).detach().cpu()
  285. dice_prompt = (
  286. dice_coefficient(binary_mask_prompt, low_res_label_prompt)
  287. .detach()
  288. .cpu()
  289. )
  290. dice_sam_lists.append(dice_sam)
  291. dice_sam_prompt_lists.append(dice_sam_prompt)
  292. dice_lists.append(dice)
  293. dice_prompt_lists.append(dice_prompt)
  294. log_file.flush()
  295. if train:
  296. loss.backward()
  297. if index % (args.accumulative_batch_size / args.batch_size) == 0:
  298. optimizer.step()
  299. # if epoch==40:
  300. # scheduler.step()
  301. optimizer.zero_grad()
  302. index = 0
  303. else:
  304. result = torch.cat(
  305. (
  306. low_res_masks[:, ::10].detach().cpu().reshape(1, -1, 256, 256),
  307. binary_mask[:, ::10].detach().cpu().reshape(1, -1, 256, 256),
  308. ),
  309. dim=0,
  310. )
  311. results.append(result)
  312. if index % (args.accumulative_batch_size / args.batch_size) == 0:
  313. epoch_losses.append(loss.item())
  314. if counterb == (args.sample_size // args.batch_size) and train:
  315. break
  316. return (
  317. epoch_losses,
  318. results,
  319. dice_lists,
  320. dice_prompt_lists,
  321. dice_sam_lists,
  322. dice_sam_prompt_lists,
  323. )
  324. def train_model(train_loader, val_loader,test_loader, K_fold=False, N_fold=7, epoch_num_start=7):
  325. global optimizer
  326. index=0
  327. if args.inference:
  328. with torch.no_grad():
  329. conv = torch.load(f'{args.conv_path}')
  330. recall_list=[]
  331. percision_list=[]
  332. jaccard_list=[]
  333. for input in tqdm(test_loader):
  334. low_res_masks_sam, low_res_masks_promtp_sam = panc_sam_instance(
  335. input[0], device
  336. )
  337. low_res_masks_sam = F.interpolate(low_res_masks_sam, 512).cpu()
  338. low_res_masks_promtp_sam = F.interpolate(low_res_masks_promtp_sam, 512).cpu()
  339. low_res_masks_promtp = conv(low_res_masks_promtp_sam.to(device)).detach().cpu()
  340. for slice_id,(batched_input,mask_sam,mask_prompt_sam,mask_prompt) in enumerate(zip(input[0],low_res_masks_sam[0],low_res_masks_promtp_sam[0],low_res_masks_promtp[0])):
  341. if not os.path.exists(f"ims/batch_{index}"):
  342. os.mkdir(f"ims/batch_{index}")
  343. image = batched_input["image"]
  344. label = batched_input["label"][0,0,::2,::2].to(bool)
  345. binary_mask_sam = (mask_sam > 0)
  346. binary_mask_prompt_sam = (mask_prompt_sam > 0)
  347. binary_mask_prompt = (mask_prompt > 0.5)
  348. recall = calculate_recall(label, binary_mask_prompt)
  349. percision = calculate_precision(label, binary_mask_prompt)
  350. jaccard = calculate_jaccard(label, binary_mask_prompt)
  351. percision_list.append(percision)
  352. recall_list.append(recall)
  353. jaccard_list.append(jaccard)
  354. image_mask = image.clone().to(torch.long)
  355. image_label = image.clone().to(torch.long)
  356. image_mask[binary_mask_sam]=255
  357. image_label[label]=255
  358. save_img(
  359. torch.stack((image_mask,image_label,image),dim=0),
  360. f"ims/batch_{index}/sam{slice_id}.png",
  361. )
  362. image_mask = image.clone().to(torch.long)
  363. image_mask[binary_mask_prompt_sam]=255
  364. save_img(
  365. torch.stack((image_mask,image_label,image),dim=0),
  366. f"ims/batch_{index}/sam_prompt{slice_id}.png",
  367. )
  368. image_mask = image.clone().to(torch.long)
  369. image_mask[binary_mask_prompt]=255
  370. save_img(
  371. torch.stack((image_mask,image_label,image),dim=0),
  372. f"ims/batch_{index}/prompt_{slice_id}.png",
  373. )
  374. print(f'Recall={np.mean(recall_list)}')
  375. print(f'Percision={np.mean(percision_list)}')
  376. print(f'Jaccard={np.mean(jaccard_list)}')
  377. index += 1
  378. print(f'Recall={np.mean(recall_list)}')
  379. print(f'Percision={np.mean(percision_list)}')
  380. print(f'Jaccard={np.mean(jaccard_list)}')
  381. else:
  382. print("Train model started.")
  383. train_losses = []
  384. train_epochs = []
  385. val_losses = []
  386. val_epochs = []
  387. dice = []
  388. dice_val = []
  389. results = []
  390. last_best_dice=0
  391. for epoch in range(args.num_epochs):
  392. print(f"=====================EPOCH: {epoch + 1}=====================")
  393. log_file.write(f"=====================EPOCH: {epoch + 1}===================\n")
  394. print("Training:")
  395. (
  396. train_epoch_losses,
  397. results,
  398. dice_list,
  399. dice_prompt_list,
  400. dice_sam_list,
  401. dice_sam_prompt_list,
  402. ) = process_model(train_loader, train=1, epoch=epoch, scheduler=scheduler)
  403. dice_mean = np.mean(dice_list)
  404. dice_prompt_mean = np.mean(dice_prompt_list)
  405. dice_sam_mean = np.mean(dice_sam_list)
  406. dice_sam_prompt_mean = np.mean(dice_sam_prompt_list)
  407. print("Validating:")
  408. (
  409. _,
  410. _,
  411. val_dice_list,
  412. val_dice_prompt_list,
  413. val_dice_sam_list,
  414. val_dice_sam_prompt_list,
  415. ) = process_model(val_loader)
  416. val_dice_mean = np.mean(val_dice_list)
  417. val_dice_prompt_mean = np.mean(val_dice_prompt_list)
  418. val_dice_sam_mean = np.mean(val_dice_sam_list)
  419. val_dice_sam_prompt_mean = np.mean(val_dice_sam_prompt_list)
  420. train_mean_losses = [mean(x) for x in train_losses]
  421. logs = ""
  422. logs += f"Train Dice_sam: {dice_sam_mean}\n"
  423. logs += f"Train Dice: {dice_mean}\n"
  424. logs += f"Train Dice_sam_prompt: {dice_sam_prompt_mean}\n"
  425. logs += f"Train Dice_prompt: {dice_prompt_mean}\n"
  426. logs += f"Mean train loss: {mean(train_epoch_losses)}\n"
  427. logs += f"val Dice_sam: {val_dice_sam_mean}\n"
  428. logs += f"val Dice: {val_dice_mean}\n"
  429. logs += f"val Dice_sam_prompt: {val_dice_sam_prompt_mean}\n"
  430. logs += f"val Dice_prompt: {val_dice_prompt_mean}\n"
  431. # plt.plot(val_epochs, val_mean_losses, train_epochs, train_mean_losses)
  432. if val_dice_prompt_mean > last_best_dice:
  433. torch.save(
  434. conv3d_instance,
  435. f"exps/{exp_id}-{user_input}/conv_save.pth",
  436. )
  437. print("Model saved")
  438. last_best_dice = val_dice_prompt_mean
  439. print(logs)
  440. log_file.write(logs)
  441. scheduler.step()
  442. ## training with k-fold cross validation:
  443. fff = time()
  444. train_model(train_loader, val_loader,test_loader)
  445. log_file.close()
  446. # train and also test the model