You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

fine_tune_good.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. debug = 0
  2. from pathlib import Path
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from utils import sample_prompt , main_prompt
  6. from collections import defaultdict
  7. import torchvision.transforms as transforms
  8. import torch
  9. from torch import nn
  10. import torch.nn.functional as F
  11. from segment_anything.utils.transforms import ResizeLongestSide
  12. import albumentations as A
  13. from albumentations.pytorch import ToTensorV2
  14. import numpy as np
  15. from einops import rearrange
  16. import random
  17. from tqdm import tqdm
  18. from time import sleep
  19. from data import *
  20. from time import time
  21. from PIL import Image
  22. from sklearn.model_selection import KFold
  23. from shutil import copyfile
  24. from args import get_arguments
  25. args = get_arguments()
  26. def save_img(img, dir):
  27. img = img.clone().cpu().numpy() + 100
  28. if len(img.shape) == 3:
  29. img = rearrange(img, "c h w -> h w c")
  30. img_min = np.amin(img, axis=(0, 1), keepdims=True)
  31. img = img - img_min
  32. img_max = np.amax(img, axis=(0, 1), keepdims=True)
  33. img = (img / img_max * 255).astype(np.uint8)
  34. # grey_img = Image.fromarray(img[:, :, 0])
  35. img = Image.fromarray(img)
  36. else:
  37. img_min = img.min()
  38. img = img - img_min
  39. img_max = img.max()
  40. if img_max != 0:
  41. img = img / img_max * 255
  42. img = Image.fromarray(img).convert("L")
  43. img.save(dir)
  44. class loss_fn(torch.nn.Module):
  45. def __init__(self, alpha=0.7, gamma=2.0, epsilon=1e-5):
  46. super(loss_fn, self).__init__()
  47. self.alpha = alpha
  48. self.gamma = gamma
  49. self.epsilon = epsilon
  50. def tversky_loss(self, y_pred, y_true, alpha=0.8, beta=0.2, smooth=1e-2):
  51. y_pred = torch.sigmoid(y_pred)
  52. y_true_pos = torch.flatten(y_true)
  53. y_pred_pos = torch.flatten(y_pred)
  54. true_pos = torch.sum(y_true_pos * y_pred_pos)
  55. false_neg = torch.sum(y_true_pos * (1 - y_pred_pos))
  56. false_pos = torch.sum((1 - y_true_pos) * y_pred_pos)
  57. tversky_index = (true_pos + smooth) / (
  58. true_pos + alpha * false_neg + beta * false_pos + smooth
  59. )
  60. return 1 - tversky_index
  61. def focal_tversky(self, y_pred, y_true, gamma=0.75):
  62. pt_1 = self.tversky_loss(y_pred, y_true)
  63. return torch.pow((1 - pt_1), gamma)
  64. def dice_loss(self, logits, gt, eps=1):
  65. # Convert logits to probabilities
  66. # Flatten the tensorsx
  67. probs = torch.sigmoid(logits)
  68. probs = probs.view(-1)
  69. gt = gt.view(-1)
  70. # Compute Dice coefficient
  71. intersection = (probs * gt).sum()
  72. dice_coeff = (2.0 * intersection + eps) / (probs.sum() + gt.sum() + eps)
  73. # Compute Dice Los[s
  74. loss = 1 - dice_coeff
  75. return loss
  76. def focal_loss(self, logits, gt, gamma=2):
  77. logits = logits.reshape(-1, 1)
  78. gt = gt.reshape(-1, 1)
  79. logits = torch.cat((1 - logits, logits), dim=1)
  80. probs = torch.sigmoid(logits)
  81. pt = probs.gather(1, gt.long())
  82. modulating_factor = (1 - pt) ** gamma
  83. # pt_false= pt<=0.5
  84. # modulating_factor[pt_false] *= 2
  85. focal_loss = -modulating_factor * torch.log(pt + 1e-12)
  86. # Compute the mean focal loss
  87. loss = focal_loss.mean()
  88. return loss # Store as a Python number to save memory
  89. def forward(self, logits, target):
  90. logits = logits.squeeze(1)
  91. target = target.squeeze(1)
  92. # Dice Loss
  93. # prob = F.softmax(logits, dim=1)[:, 1, ...]
  94. dice_loss = self.dice_loss(logits, target)
  95. tversky_loss = self.tversky_loss(logits, target)
  96. # Focal Loss
  97. focal_loss = self.focal_loss(logits, target.squeeze(-1))
  98. alpha = 20.0
  99. # Combined Loss
  100. combined_loss = alpha * focal_loss + dice_loss
  101. return combined_loss
  102. def img_enhance(img2, coef=0.2):
  103. img_mean = np.mean(img2)
  104. img_max = np.max(img2)
  105. val = (img_max - img_mean) * coef + img_mean
  106. img2[img2 < img_mean * 0.7] = img_mean * 0.7
  107. img2[img2 > val] = val
  108. return img2
  109. def dice_coefficient(logits, gt):
  110. eps=1
  111. binary_mask = logits>0
  112. intersection = (binary_mask * gt).sum(dim=(-2,-1))
  113. dice_scores = (2.0 * intersection + eps) / (binary_mask.sum(dim=(-2,-1)) + gt.sum(dim=(-2,-1)) + eps)
  114. return dice_scores.mean()
  115. def what_the_f(low_res_masks,label):
  116. low_res_label = F.interpolate(label, low_res_masks.shape[-2:])
  117. dice = dice_coefficient(
  118. low_res_masks, low_res_label
  119. )
  120. return dice
  121. image_size = args.image_size
  122. exp_id = 0
  123. found = 0
  124. if debug:
  125. user_input='debug'
  126. else:
  127. user_input = input("Related changes: ")
  128. while found == 0:
  129. try:
  130. os.makedirs(f"exps/{exp_id}-{user_input}")
  131. found = 1
  132. except:
  133. exp_id = exp_id + 1
  134. copyfile(os.path.realpath(__file__), f"exps/{exp_id}-{user_input}/code.py")
  135. layer_n = 4
  136. L = layer_n
  137. a = np.full(L, layer_n)
  138. params = {"M": 255, "a": a, "p": 0.35}
  139. device = "cuda:1"
  140. from segment_anything import SamPredictor, sam_model_registry
  141. class panc_sam(nn.Module):
  142. def __init__(self, *args, **kwargs) -> None:
  143. super().__init__(*args, **kwargs)
  144. #Promptless
  145. sam = torch.load(args.promptprovider)
  146. self.prompt_encoder = sam.prompt_encoder
  147. self.mask_decoder = sam.mask_decoder
  148. for param in self.prompt_encoder.parameters():
  149. param.requires_grad = False
  150. for param in self.mask_decoder.parameters():
  151. param.requires_grad = False
  152. #with Prompt
  153. sam=sam_model_registry[args.model_type](args.checkpoint)
  154. self.image_encoder = sam.image_encoder
  155. self.prompt_encoder2 = sam.prompt_encoder
  156. self.mask_decoder2 = sam.mask_decoder
  157. for param in self.image_encoder.parameters():
  158. param.requires_grad = False
  159. for param in self.prompt_encoder2.parameters():
  160. param.requires_grad = False
  161. def forward(self, input_images,box=None):
  162. # input_images = torch.stack([x["image"] for x in batched_input], dim=0)
  163. with torch.no_grad():
  164. image_embeddings = self.image_encoder(input_images).detach()
  165. outputs_prompt = []
  166. outputs = []
  167. for curr_embedding in image_embeddings:
  168. with torch.no_grad():
  169. sparse_embeddings, dense_embeddings = self.prompt_encoder(
  170. points=None,
  171. boxes=None,
  172. masks=None,
  173. )
  174. low_res_masks, _ = self.mask_decoder(
  175. image_embeddings=curr_embedding,
  176. image_pe=self.prompt_encoder.get_dense_pe().detach(),
  177. sparse_prompt_embeddings=sparse_embeddings.detach(),
  178. dense_prompt_embeddings=dense_embeddings.detach(),
  179. multimask_output=False,
  180. )
  181. outputs_prompt.append(low_res_masks)
  182. points, point_labels = main_prompt(low_res_masks)
  183. points_sconed, point_labels_sconed = sample_prompt(low_res_masks)
  184. points = torch.cat([points, points_sconed], dim=1) # Adjust dimensions as necessary
  185. point_labels = torch.cat([point_labels, point_labels_sconed], dim=1)
  186. points = points * 4
  187. points = (points, point_labels)
  188. with torch.no_grad():
  189. sparse_embeddings, dense_embeddings = self.prompt_encoder2(
  190. points=points,
  191. boxes=None,
  192. masks=None,
  193. )
  194. low_res_masks, _ = self.mask_decoder2(
  195. image_embeddings=curr_embedding,
  196. image_pe=self.prompt_encoder2.get_dense_pe().detach(),
  197. sparse_prompt_embeddings=sparse_embeddings.detach(),
  198. dense_prompt_embeddings=dense_embeddings.detach(),
  199. multimask_output=False,
  200. )
  201. outputs.append(low_res_masks)
  202. low_res_masks_promtp = torch.cat(outputs_prompt, dim=0)
  203. low_res_masks = torch.cat(outputs, dim=0)
  204. return low_res_masks, low_res_masks_promtp
  205. augmentation = A.Compose(
  206. [
  207. A.Rotate(limit=30, p=0.5),
  208. A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1),
  209. A.RandomResizedCrop(1024, 1024, scale=(0.9, 1.0), p=1),
  210. A.HorizontalFlip(p=0.5),
  211. A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.5),
  212. A.CoarseDropout(
  213. max_holes=8,
  214. max_height=16,
  215. max_width=16,
  216. min_height=8,
  217. min_width=8,
  218. fill_value=0,
  219. p=0.5,
  220. ),
  221. A.RandomScale(scale_limit=0.3, p=0.5),
  222. ]
  223. )
  224. panc_sam_instance = panc_sam()
  225. panc_sam_instance.to(device)
  226. panc_sam_instance.train()
  227. train_dataset = PanDataset(
  228. [args.dir_train],
  229. [args.dir_labels],
  230. [["NIH_PNG",1]],
  231. args.image_size,
  232. slice_per_image=args.slice_per_image,
  233. train=True,
  234. augmentation=augmentation,
  235. )
  236. val_dataset = PanDataset(
  237. [args.dir_train],
  238. [args.dir_labels],
  239. [["NIH_PNG",1]],
  240. image_size,
  241. slice_per_image=args.slice_per_image,
  242. train=False,
  243. )
  244. train_loader = DataLoader(
  245. train_dataset,
  246. batch_size=args.batch_size,
  247. collate_fn=train_dataset.collate_fn,
  248. shuffle=True,
  249. drop_last=False,
  250. num_workers=args.num_workers,
  251. )
  252. val_loader = DataLoader(
  253. val_dataset,
  254. batch_size=args.batch_size,
  255. collate_fn=val_dataset.collate_fn,
  256. shuffle=True,
  257. drop_last=False,
  258. num_workers=args.num_workers,
  259. )
  260. lr = 1e-4
  261. #1e-3
  262. max_lr = 3e-4
  263. wd = 5e-4
  264. optimizer_main = torch.optim.Adam(
  265. # parameters,
  266. list(panc_sam_instance.mask_decoder2.parameters()),
  267. lr=lr,
  268. weight_decay=wd,
  269. )
  270. scheduler_main = torch.optim.lr_scheduler.OneCycleLR(
  271. optimizer_main,
  272. max_lr=max_lr,
  273. epochs=args.num_epochs,
  274. steps_per_epoch=args.sample_size // (args.accumulative_batch_size // args.batch_size),
  275. )
  276. #####################################################
  277. from statistics import mean
  278. from tqdm import tqdm
  279. from torch.nn.functional import threshold, normalize
  280. loss_function = loss_fn(alpha=0.5, gamma=4.0)
  281. loss_function.to(device)
  282. from time import time
  283. import time as s_time
  284. log_file = open(f"exps/{exp_id}-{user_input}/log.txt", "a")
  285. def process_model(main_model , data_loader, train=0, save_output=0):
  286. epoch_losses = []
  287. index = 0
  288. results = torch.zeros((2, 0, 256, 256))
  289. total_dice = 0.0
  290. num_samples = 0
  291. total_dice_main =0.0
  292. counterb = 0
  293. for image, label,raw_data in tqdm(data_loader, total=args.sample_size):
  294. num_samples += 1
  295. counterb += 1
  296. index += 1
  297. image = image.to(device)
  298. label = label.to(device).float()
  299. ############################promt########################################
  300. box = torch.tensor([[200, 200, 750, 800]]).to(device)
  301. low_res_masks_main,low_res_masks_prompt = main_model(image,box)
  302. low_res_label = F.interpolate(label, low_res_masks_main.shape[-2:])
  303. dice_prompt = what_the_f(low_res_masks_prompt,low_res_label)
  304. dice_main = what_the_f(low_res_masks_main,low_res_label)
  305. binary_mask = normalize(threshold(low_res_masks_main, 0.0,0))
  306. total_dice += dice_prompt
  307. total_dice_main+=dice_main
  308. average_dice = total_dice / num_samples
  309. average_dice_main = total_dice_main / num_samples
  310. log_file.write(str(average_dice) + "\n")
  311. log_file.flush()
  312. loss = loss_function.forward(low_res_masks_main, low_res_label)
  313. loss /= args.accumulative_batch_size / args.batch_size
  314. if train:
  315. loss.backward()
  316. if index % (args.accumulative_batch_size / args.batch_size) == 0:
  317. # print(loss)
  318. optimizer_main.step()
  319. scheduler_main.step()
  320. optimizer_main.zero_grad()
  321. index = 0
  322. else:
  323. result = torch.cat(
  324. (
  325. low_res_masks_main[0].detach().cpu().reshape(1, 1, 256, 256),
  326. binary_mask[0].detach().cpu().reshape(1, 1, 256, 256),
  327. ),
  328. dim=0,
  329. )
  330. results = torch.cat((results, result), dim=1)
  331. if index % (args.accumulative_batch_size / args.batch_size) == 0:
  332. epoch_losses.append(loss.item())
  333. if counterb == args.sample_size and train:
  334. break
  335. elif counterb == args.sample_size //2 and not train:
  336. break
  337. return epoch_losses, results, average_dice ,average_dice_main
  338. def train_model(train_loader, val_loader, K_fold=False, N_fold=7, epoch_num_start=7):
  339. print("Train model started.")
  340. train_losses = []
  341. train_epochs = []
  342. val_losses = []
  343. val_epochs = []
  344. dice = []
  345. dice_main = []
  346. dice_val = []
  347. dice_val_main =[]
  348. results = []
  349. #########################save image##################################
  350. index = 0
  351. if debug==0:
  352. for image, label , raw_data in tqdm(val_loader):
  353. if index < 100:
  354. if not os.path.exists(f"ims/batch_{index}"):
  355. os.mkdir(f"ims/batch_{index}")
  356. save_img(
  357. image[0],
  358. f"ims/batch_{index}/img_0.png",
  359. )
  360. save_img(0.2 * image[0][0] + label[0][0], f"ims/batch_{index}/gt_0.png")
  361. index += 1
  362. if index == 100:
  363. break
  364. # In each epoch we will train the model and the val it
  365. # training without k_fold cross validation:
  366. last_best_dice = 0
  367. if K_fold == False:
  368. for epoch in range(args.num_epochs):
  369. print(f"=====================EPOCH: {epoch + 1}=====================")
  370. log_file.write(
  371. f"=====================EPOCH: {epoch + 1}===================\n"
  372. )
  373. print("Training:")
  374. train_epoch_losses, epoch_results, average_dice ,average_dice_main = process_model(panc_sam_instance,train_loader, train=1)
  375. dice.append(average_dice)
  376. dice_main.append(average_dice_main)
  377. train_losses.append(train_epoch_losses)
  378. if (average_dice) > 0.5:
  379. print("validating:")
  380. val_epoch_losses, epoch_results, average_dice_val,average_dice_val_main = process_model(
  381. panc_sam_instance,val_loader
  382. )
  383. val_losses.append(val_epoch_losses)
  384. for i in tqdm(range(len(epoch_results[0]))):
  385. if not os.path.exists(f"ims/batch_{i}"):
  386. os.mkdir(f"ims/batch_{i}")
  387. save_img(
  388. epoch_results[0, i].clone(),
  389. f"ims/batch_{i}/prob_epoch_{epoch}.png",
  390. )
  391. save_img(
  392. epoch_results[1, i].clone(),
  393. f"ims/batch_{i}/pred_epoch_{epoch}.png",
  394. )
  395. train_mean_losses = [mean(x) for x in train_losses]
  396. val_mean_losses = [mean(x) for x in val_losses]
  397. np.save("train_losses.npy", train_mean_losses)
  398. np.save("val_losses.npy", val_mean_losses)
  399. print(f"Train Dice: {average_dice}")
  400. print(f"Train Dice main: {average_dice_main}")
  401. print(f"Mean train loss: {mean(train_epoch_losses)}")
  402. try:
  403. dice_val.append(average_dice_val)
  404. dice_val_main.append(average_dice_val_main)
  405. print(f"val Dice : {average_dice_val}")
  406. print(f"val Dice main : {average_dice_val_main}")
  407. print(f"Mean val loss: {mean(val_epoch_losses)}")
  408. results.append(epoch_results)
  409. val_epochs.append(epoch)
  410. train_epochs.append(epoch)
  411. plt.plot(val_epochs, val_mean_losses, train_epochs, train_mean_losses)
  412. if average_dice_val_main > last_best_dice:
  413. torch.save(
  414. panc_sam_instance,
  415. f"exps/{exp_id}-{user_input}/sam_tuned_save.pth",
  416. )
  417. last_best_dice = average_dice_val
  418. del epoch_results
  419. del average_dice_val
  420. except:
  421. train_epochs.append(epoch)
  422. plt.plot(train_epochs, train_mean_losses)
  423. print(f"=================End of EPOCH: {epoch}==================\n")
  424. plt.yscale("log")
  425. plt.title("Mean epoch loss")
  426. plt.xlabel("Epoch Number")
  427. plt.ylabel("Loss")
  428. plt.savefig("result")
  429. ## training with k-fold cross validation:
  430. return train_losses, val_losses, results
  431. train_losses, val_losses, results = train_model(train_loader, val_loader)
  432. log_file.close()