You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

SAM_without_prompt.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. from pathlib import Path
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. # import cv2
  5. from collections import defaultdict
  6. import torchvision.transforms as transforms
  7. import torch
  8. from torch import nn
  9. import torch.nn.functional as F
  10. from segment_anything.utils.transforms import ResizeLongestSide
  11. import albumentations as A
  12. from albumentations.pytorch import ToTensorV2
  13. import numpy as np
  14. from einops import rearrange
  15. import random
  16. from tqdm import tqdm
  17. from time import sleep
  18. from data import *
  19. from time import time
  20. from PIL import Image
  21. from sklearn.model_selection import KFold
  22. from shutil import copyfile
  23. from args import get_arguments
  24. # import wandb_handler
  25. args = get_arguments()
  26. def save_img(img, dir):
  27. img = img.clone().cpu().numpy() + 100
  28. if len(img.shape) == 3:
  29. img = rearrange(img, "c h w -> h w c")
  30. img_min = np.amin(img, axis=(0, 1), keepdims=True)
  31. img = img - img_min
  32. img_max = np.amax(img, axis=(0, 1), keepdims=True)
  33. img = (img / img_max * 255).astype(np.uint8)
  34. grey_img = Image.fromarray(img[:, :, 0])
  35. img = Image.fromarray(img)
  36. else:
  37. img_min = img.min()
  38. img = img - img_min
  39. img_max = img.max()
  40. if img_max != 0:
  41. img = img / img_max * 255
  42. img = Image.fromarray(img).convert("L")
  43. img.save(dir)
  44. class loss_fn(torch.nn.Module):
  45. def __init__(self, alpha=0.7, gamma=2.0, epsilon=1e-5):
  46. super(loss_fn, self).__init__()
  47. self.alpha = alpha
  48. self.gamma = gamma
  49. self.epsilon = epsilon
  50. def dice_loss(self, logits, gt, eps=1):
  51. # Convert logits to probabilities
  52. # Flatten the tensorsx
  53. probs = torch.sigmoid(logits)
  54. probs = probs.view(-1)
  55. gt = gt.view(-1)
  56. # Compute Dice coefficient
  57. intersection = (probs * gt).sum()
  58. dice_coeff = (2.0 * intersection + eps) / (probs.sum() + gt.sum() + eps)
  59. # Compute Dice Los[s
  60. loss = 1 - dice_coeff
  61. return loss
  62. def focal_loss(self, logits, gt, gamma=2):
  63. logits = logits.reshape(-1, 1)
  64. gt = gt.reshape(-1, 1)
  65. logits = torch.cat((1 - logits, logits), dim=1)
  66. probs = torch.sigmoid(logits)
  67. pt = probs.gather(1, gt.long())
  68. modulating_factor = (1 - pt) ** gamma
  69. # pt_false= pt<=0.5
  70. # modulating_factor[pt_false] *= 2
  71. focal_loss = -modulating_factor * torch.log(pt + 1e-12)
  72. # Compute the mean focal loss
  73. loss = focal_loss.mean()
  74. return loss # Store as a Python number to save memory
  75. def forward(self, logits, target):
  76. logits = logits.squeeze(1)
  77. target = target.squeeze(1)
  78. # Dice Loss
  79. # prob = F.softmax(logits, dim=1)[:, 1, ...]
  80. dice_loss = self.dice_loss(logits, target)
  81. # Focal Loss
  82. focal_loss = self.focal_loss(logits, target.squeeze(-1))
  83. alpha = 20.0
  84. # Combined Loss
  85. combined_loss = alpha * focal_loss + dice_loss
  86. return combined_loss
  87. def img_enhance(img2, coef=0.2):
  88. img_mean = np.mean(img2)
  89. img_max = np.max(img2)
  90. val = (img_max - img_mean) * coef + img_mean
  91. img2[img2 < img_mean * 0.7] = img_mean * 0.7
  92. img2[img2 > val] = val
  93. return img2
  94. def dice_coefficient(logits, gt):
  95. eps=1
  96. binary_mask = logits>0
  97. intersection = (binary_mask * gt).sum(dim=(-2,-1))
  98. dice_scores = (2.0 * intersection + eps) / (binary_mask.sum(dim=(-2,-1)) + gt.sum(dim=(-2,-1)) + eps)
  99. return dice_scores.mean()
  100. def what_the_f(low_res_masks,label):
  101. low_res_label = F.interpolate(label, low_res_masks.shape[-2:])
  102. dice = dice_coefficient(
  103. low_res_masks, low_res_label
  104. )
  105. return dice
  106. accumaltive_batch_size = 8
  107. batch_size = 1
  108. num_workers = 4
  109. slice_per_image = 1
  110. num_epochs = 80
  111. sample_size = 2000
  112. image_size = 1024
  113. exp_id = 0
  114. found=0
  115. debug = 0
  116. if debug:
  117. user_input='debug'
  118. else:
  119. user_input = input("Related changes: ")
  120. while found == 0:
  121. try:
  122. os.makedirs(f"exps/{exp_id}-{user_input}/")
  123. found = 1
  124. except:
  125. exp_id = exp_id + 1
  126. copyfile(os.path.realpath(__file__), f"exps/{exp_id}-{user_input}/code.py")
  127. device = "cuda:1"
  128. from segment_anything import SamPredictor, sam_model_registry
  129. # //////////////////
  130. class panc_sam(nn.Module):
  131. def __init__(self, *args, **kwargs) -> None:
  132. super().__init__(*args, **kwargs)
  133. sam=sam_model_registry[args.model_type](args.checkpoint)
  134. # self.sam = torch.load('exps/sam_tuned_save.pth').sam
  135. self.prompt_encoder = self.sam.prompt_encoder
  136. for param in self.prompt_encoder.parameters():
  137. param.requires_grad = False
  138. def forward(self, image ,box):
  139. with torch.no_grad():
  140. image_embedding = self.sam.image_encoder(image).detach()
  141. outputs_prompt = []
  142. for curr_embedding in image_embedding:
  143. with torch.no_grad():
  144. sparse_embeddings, dense_embeddings = self.sam.prompt_encoder(
  145. points=None,
  146. boxes=None,
  147. masks=None,
  148. )
  149. low_res_masks, _ = self.sam.mask_decoder(
  150. image_embeddings=curr_embedding,
  151. image_pe=self.sam.prompt_encoder.get_dense_pe().detach(),
  152. sparse_prompt_embeddings=sparse_embeddings.detach(),
  153. dense_prompt_embeddings=dense_embeddings.detach(),
  154. multimask_output=False,
  155. )
  156. outputs_prompt.append(low_res_masks)
  157. low_res_masks_promtp = torch.cat(outputs_prompt, dim=0)
  158. # raise ValueError(low_res_masks_promtp)
  159. return low_res_masks_promtp
  160. # ///////////////
  161. augmentation = A.Compose(
  162. [
  163. A.Rotate(limit=90, p=0.5),
  164. A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1),
  165. A.RandomResizedCrop(1024, 1024, scale=(0.9, 1.0), p=1),
  166. A.HorizontalFlip(p=0.5),
  167. A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.5),
  168. A.CoarseDropout(max_holes=8, max_height=16, max_width=16, min_height=8, min_width=8, fill_value=0, p=0.5),
  169. A.RandomScale(scale_limit=0.3, p=0.5),
  170. A.GaussNoise(var_limit=(10.0, 50.0), p=0.5),
  171. A.GridDistortion(p=0.5),
  172. ]
  173. )
  174. panc_sam_instance=panc_sam()
  175. panc_sam_instance.to(device)
  176. panc_sam_instance.train()
  177. train_dataset = PanDataset(
  178. [args.train_dir],
  179. [args.train_labels_dir],
  180. [["NIH_PNG",1]],
  181. image_size,
  182. slice_per_image=slice_per_image,
  183. train=True,
  184. augmentation=augmentation,
  185. )
  186. val_dataset = PanDataset(
  187. [args.val_dir],
  188. [args.val_labels_dir],
  189. [["NIH_PNG",1]],
  190. image_size,
  191. slice_per_image=slice_per_image,
  192. train=False,
  193. )
  194. train_loader = DataLoader(
  195. train_dataset,
  196. batch_size=batch_size,
  197. collate_fn=train_dataset.collate_fn,
  198. shuffle=True,
  199. drop_last=False,
  200. num_workers=num_workers,
  201. )
  202. val_loader = DataLoader(
  203. val_dataset,
  204. batch_size=batch_size,
  205. collate_fn=val_dataset.collate_fn,
  206. shuffle=False,
  207. drop_last=False,
  208. num_workers=num_workers,
  209. )
  210. # Set up the optimizer, hyperparameter tuning will improve performance here
  211. lr = 1e-4
  212. max_lr = 5e-5
  213. wd = 5e-4
  214. optimizer = torch.optim.Adam(
  215. # parameters,
  216. list(panc_sam_instance.sam.mask_decoder.parameters()),
  217. lr=lr, weight_decay=wd
  218. )
  219. scheduler = torch.optim.lr_scheduler.OneCycleLR(
  220. optimizer,
  221. max_lr=max_lr,
  222. epochs=num_epochs,
  223. steps_per_epoch=sample_size // (accumaltive_batch_size // batch_size),
  224. )
  225. from statistics import mean
  226. from tqdm import tqdm
  227. from torch.nn.functional import threshold, normalize
  228. loss_function = loss_fn(alpha=0.5, gamma=2.0)
  229. loss_function.to(device)
  230. from time import time
  231. import time as s_time
  232. log_file = open(f"exps/{exp_id}-{user_input}/log.txt", "a")
  233. def process_model(data_loader, train=0, save_output=0):
  234. epoch_losses = []
  235. index = 0
  236. results = torch.zeros((2, 0, 256, 256))
  237. total_dice = 0.0
  238. num_samples = 0
  239. counterb = 0
  240. for image, label in tqdm(data_loader, total=sample_size):
  241. counterb += 1
  242. num_samples += 1
  243. index += 1
  244. image = image.to(device)
  245. label = label.to(device).float()
  246. input_size = (1024, 1024)
  247. box = torch.tensor([[200, 200, 750, 800]]).to(device)
  248. low_res_masks = panc_sam_instance(image,box)
  249. low_res_label = F.interpolate(label, low_res_masks.shape[-2:])
  250. dice = what_the_f(low_res_masks,low_res_label)
  251. binary_mask = normalize(threshold(low_res_masks, 0.0, 0))
  252. total_dice += dice
  253. average_dice = total_dice / num_samples
  254. log_file.write(str(average_dice) + "\n")
  255. log_file.flush()
  256. loss = loss_function.forward(low_res_masks, low_res_label)
  257. loss /= accumaltive_batch_size / batch_size
  258. if train:
  259. loss.backward()
  260. if index % (accumaltive_batch_size / batch_size) == 0:
  261. # print(loss)
  262. optimizer.step()
  263. scheduler.step()
  264. optimizer.zero_grad()
  265. index = 0
  266. else:
  267. pass
  268. if index % (accumaltive_batch_size / batch_size) == 0:
  269. epoch_losses.append(loss.item())
  270. if counterb == sample_size and train:
  271. break
  272. elif counterb == sample_size / 10 and not train:
  273. break
  274. return epoch_losses, results, average_dice
  275. def train_model(train_loader, val_loader, K_fold=False, N_fold=7, epoch_num_start=7):
  276. print("Train model started.")
  277. train_losses = []
  278. train_epochs = []
  279. val_losses = []
  280. val_epochs = []
  281. dice = []
  282. dice_val = []
  283. results = []
  284. index = 0
  285. last_best_dice = 0
  286. for epoch in range(num_epochs):
  287. print(f"=====================EPOCH: {epoch + 1}=====================")
  288. log_file.write(
  289. f"=====================EPOCH: {epoch + 1}===================\n"
  290. )
  291. print("Training:")
  292. train_epoch_losses, epoch_results, average_dice = process_model(
  293. train_loader, train=1
  294. )
  295. dice.append(average_dice)
  296. train_losses.append(train_epoch_losses)
  297. if (average_dice) > 0.5:
  298. print("valing:")
  299. val_epoch_losses, epoch_results, average_dice_val = process_model(
  300. val_loader
  301. )
  302. val_losses.append(val_epoch_losses)
  303. for i in tqdm(range(len(epoch_results[0]))):
  304. if not os.path.exists(f"ims/batch_{i}"):
  305. os.mkdir(f"ims/batch_{i}")
  306. save_img(epoch_results[0, i].clone(), f"ims/batch_{i}/prob_epoch_{epoch}.png")
  307. save_img(epoch_results[1, i].clone(), f"ims/batch_{i}/pred_epoch_{epoch}.png")
  308. train_mean_losses = [mean(x) for x in train_losses]
  309. # raise ValueError(average_dice)
  310. val_mean_losses = [mean(x) for x in val_losses]
  311. np.save("train_losses.npy", train_mean_losses)
  312. np.save("val_losses.npy", val_mean_losses)
  313. print(f"Train Dice: {average_dice}")
  314. print(f"Mean train loss: {mean(train_epoch_losses)}")
  315. try:
  316. dice_val.append(average_dice_val)
  317. print(f"val Dice : {average_dice_val}")
  318. print(f"Mean val loss: {mean(val_epoch_losses)}")
  319. results.append(epoch_results)
  320. val_epochs.append(epoch)
  321. train_epochs.append(epoch)
  322. plt.plot(val_epochs, val_mean_losses, train_epochs, train_mean_losses)
  323. print(last_best_dice)
  324. log_file.write(f'bestwieght:{last_best_dice}')
  325. if average_dice_val > last_best_dice:
  326. torch.save(panc_sam_instance, f"exps/{exp_id}-{user_input}/sam_tuned_save.pth")
  327. last_best_dice = average_dice_val
  328. del epoch_results
  329. del average_dice_val
  330. except:
  331. train_epochs.append(epoch)
  332. plt.plot(train_epochs, train_mean_losses)
  333. print(f"=================End of EPOCH: {epoch}==================\n")
  334. plt.yscale("log")
  335. plt.title("Mean epoch loss")
  336. plt.xlabel("Epoch Number")
  337. plt.ylabel("Loss")
  338. plt.savefig("result")
  339. return train_losses, val_losses, results
  340. train_losses, val_losses, results = train_model(train_loader, val_loader)
  341. log_file.close()
  342. # train and also val the model