You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

SAM_with_prompt.py 20KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662
  1. debug = 0
  2. from pathlib import Path
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. # import cv2
  6. from collections import defaultdict
  7. import torchvision.transforms as transforms
  8. import torch
  9. from torch import nn
  10. import torch.nn.functional as F
  11. from segment_anything.utils.transforms import ResizeLongestSide
  12. import albumentations as A
  13. from albumentations.pytorch import ToTensorV2
  14. import numpy as np
  15. from einops import rearrange
  16. import random
  17. from tqdm import tqdm
  18. from time import sleep
  19. from data import *
  20. from time import time
  21. from PIL import Image
  22. from sklearn.model_selection import KFold
  23. from shutil import copyfile
  24. # import monai
  25. from tqdm import tqdm
  26. from utils import main_prompt,main_prompt_for_ground_true
  27. from torch.autograd import Variable
  28. from args import get_arguments
  29. # import wandb_handler
  30. args = get_arguments()
  31. def save_img(img, dir):
  32. img = img.clone().cpu().numpy() + 100
  33. if len(img.shape) == 3:
  34. img = rearrange(img, "c h w -> h w c")
  35. img_min = np.amin(img, axis=(0, 1), keepdims=True)
  36. img = img - img_min
  37. img_max = np.amax(img, axis=(0, 1), keepdims=True)
  38. img = (img / img_max * 255).astype(np.uint8)
  39. grey_img = Image.fromarray(img[:, :, 0])
  40. img = Image.fromarray(img)
  41. else:
  42. img_min = img.min()
  43. img = img - img_min
  44. img_max = img.max()
  45. if img_max != 0:
  46. img = img / img_max * 255
  47. img = Image.fromarray(img).convert("L")
  48. img.save(dir)
  49. class FocalLoss(nn.Module):
  50. def __init__(self, gamma=2.0, alpha=0.25):
  51. super(FocalLoss, self).__init__()
  52. self.gamma = gamma
  53. self.alpha = alpha
  54. def dice_loss(self, logits, gt, eps=1):
  55. # Convert logits to probabilities
  56. # Flatten the tensors
  57. # probs = probs.view(-1)
  58. # gt = gt.view(-1)
  59. probs = torch.sigmoid(logits)
  60. # Compute Dice coefficient
  61. intersection = (probs * gt).sum()
  62. dice_coeff = (2.0 * intersection + eps) / (probs.sum() + gt.sum() + eps)
  63. # Compute Dice Los[s
  64. loss = 1 - dice_coeff
  65. return loss
  66. def focal_loss(self, pred, mask):
  67. """
  68. pred: [B, 1, H, W]
  69. mask: [B, 1, H, W]
  70. """
  71. # pred=pred.reshape(-1,1)
  72. # mask = mask.reshape(-1,1)
  73. # assert pred.shape == mask.shape, "pred and mask should have the same shape."
  74. p = torch.sigmoid(pred)
  75. num_pos = torch.sum(mask)
  76. num_neg = mask.numel() - num_pos
  77. w_pos = (1 - p) ** self.gamma
  78. w_neg = p**self.gamma
  79. loss_pos = -self.alpha * mask * w_pos * torch.log(p + 1e-12)
  80. loss_neg = -(1 - self.alpha) * (1 - mask) * w_neg * torch.log(1 - p + 1e-12)
  81. loss = (torch.sum(loss_pos) + torch.sum(loss_neg)) / (num_pos + num_neg + 1e-12)
  82. return loss
  83. def forward(self, logits, target):
  84. logits = logits.squeeze(1)
  85. target = target.squeeze(1)
  86. # Dice Loss
  87. # prob = F.softmax(logits, dim=1)[:, 1, ...]
  88. dice_loss = self.dice_loss(logits, target)
  89. # Focal Loss
  90. focal_loss = self.focal_loss(logits, target.squeeze(-1))
  91. alpha = 20.0
  92. # Combined Loss
  93. combined_loss = alpha * focal_loss + dice_loss
  94. return combined_loss
  95. class loss_fn(torch.nn.Module):
  96. def __init__(self, alpha=0.7, gamma=2.0, epsilon=1e-5):
  97. super(loss_fn, self).__init__()
  98. self.alpha = alpha
  99. self.gamma = gamma
  100. self.epsilon = epsilon
  101. def tversky_loss(self, y_pred, y_true, alpha=0.8, beta=0.2, smooth=1e-2):
  102. y_pred = torch.sigmoid(y_pred)
  103. # raise ValueError(y_pred)
  104. y_true_pos = torch.flatten(y_true)
  105. y_pred_pos = torch.flatten(y_pred)
  106. true_pos = torch.sum(y_true_pos * y_pred_pos)
  107. false_neg = torch.sum(y_true_pos * (1 - y_pred_pos))
  108. false_pos = torch.sum((1 - y_true_pos) * y_pred_pos)
  109. tversky_index = (true_pos + smooth) / (
  110. true_pos + alpha * false_neg + beta * false_pos + smooth
  111. )
  112. return 1 - tversky_index
  113. def focal_tversky(self, y_pred, y_true, gamma=0.75):
  114. pt_1 = self.tversky_loss(y_pred, y_true)
  115. return torch.pow((1 - pt_1), gamma)
  116. def dice_loss(self, logits, gt, eps=1):
  117. # Convert logits to probabilities
  118. # Flatten the tensorsx
  119. probs = torch.sigmoid(logits)
  120. probs = probs.view(-1)
  121. gt = gt.view(-1)
  122. # Compute Dice coefficient
  123. intersection = (probs * gt).sum()
  124. dice_coeff = (2.0 * intersection + eps) / (probs.sum() + gt.sum() + eps)
  125. # Compute Dice Los[s
  126. loss = 1 - dice_coeff
  127. return loss
  128. def focal_loss(self, logits, gt, gamma=2):
  129. logits = logits.reshape(-1, 1)
  130. gt = gt.reshape(-1, 1)
  131. logits = torch.cat((1 - logits, logits), dim=1)
  132. probs = torch.sigmoid(logits)
  133. pt = probs.gather(1, gt.long())
  134. modulating_factor = (1 - pt) ** gamma
  135. # pt_false= pt<=0.5
  136. # modulating_factor[pt_false] *= 2
  137. focal_loss = -modulating_factor * torch.log(pt + 1e-12)
  138. # Compute the mean focal loss
  139. loss = focal_loss.mean()
  140. return loss # Store as a Python number to save memory
  141. def forward(self, logits, target):
  142. logits = logits.squeeze(1)
  143. target = target.squeeze(1)
  144. # Dice Loss
  145. # prob = F.softmax(logits, dim=1)[:, 1, ...]
  146. dice_loss = self.dice_loss(logits, target)
  147. tversky_loss = self.tversky_loss(logits, target)
  148. # Focal Loss
  149. focal_loss = self.focal_loss(logits, target.squeeze(-1))
  150. alpha = 20.0
  151. # Combined Loss
  152. combined_loss = alpha * focal_loss + dice_loss
  153. return combined_loss
  154. def img_enhance(img2, coef=0.2):
  155. img_mean = np.mean(img2)
  156. img_max = np.max(img2)
  157. val = (img_max - img_mean) * coef + img_mean
  158. img2[img2 < img_mean * 0.7] = img_mean * 0.7
  159. img2[img2 > val] = val
  160. return img2
  161. def dice_coefficient(pred, target):
  162. smooth = 1 # Smoothing constant to avoid division by zero
  163. dice = 0
  164. pred_index = pred
  165. target_index = target
  166. intersection = (pred_index * target_index).sum()
  167. union = pred_index.sum() + target_index.sum()
  168. dice += (2.0 * intersection + smooth) / (union + smooth)
  169. return dice.item()
  170. num_workers = 4
  171. slice_per_image = 1
  172. num_epochs = 80
  173. sample_size = 2000
  174. # image_size=sam_model.image_encoder.img_size
  175. image_size = 1024
  176. exp_id = 0
  177. found = 0
  178. if debug:
  179. user_input = "debug"
  180. else:
  181. user_input = input("Related changes: ")
  182. while found == 0:
  183. try:
  184. os.makedirs(f"exps/{exp_id}-{user_input}")
  185. found = 1
  186. except:
  187. exp_id = exp_id + 1
  188. copyfile(os.path.realpath(__file__), f"exps/{exp_id}-{user_input}/code.py")
  189. layer_n = 4
  190. L = layer_n
  191. a = np.full(L, layer_n)
  192. params = {"M": 255, "a": a, "p": 0.35}
  193. device = "cuda:0"
  194. from segment_anything import SamPredictor, sam_model_registry
  195. # //////////////////
  196. class panc_sam(nn.Module):
  197. def __init__(self, *args, **kwargs) -> None:
  198. super().__init__(*args, **kwargs)
  199. sam=sam_model_registry[args.model_type](args.checkpoint)
  200. def forward(self, batched_input):
  201. # with torch.no_grad():
  202. # raise ValueError(10)
  203. input_images = torch.stack([x["image"] for x in batched_input], dim=0)
  204. with torch.no_grad():
  205. image_embeddings = self.sam.image_encoder(input_images).detach()
  206. outputs = []
  207. for image_record, curr_embedding in zip(batched_input, image_embeddings):
  208. if "point_coords" in image_record:
  209. points = (image_record["point_coords"].unsqueeze(0), image_record["point_labels"].unsqueeze(0))
  210. # raise ValueError(points)
  211. else:
  212. raise ValueError('what the f?')
  213. points = None
  214. # raise ValueError(image_record["point_coords"].shape)
  215. with torch.no_grad():
  216. sparse_embeddings, dense_embeddings = self.sam.prompt_encoder(
  217. points=points,
  218. boxes=image_record.get("boxes", None),
  219. masks=image_record.get("mask_inputs", None),
  220. )
  221. low_res_masks, _ = self.sam.mask_decoder(
  222. image_embeddings=curr_embedding.unsqueeze(0),
  223. image_pe=self.sam.prompt_encoder.get_dense_pe().detach(),
  224. sparse_prompt_embeddings=sparse_embeddings.detach(),
  225. dense_prompt_embeddings=dense_embeddings.detach(),
  226. multimask_output=False,
  227. )
  228. outputs.append(
  229. {
  230. "low_res_logits": low_res_masks,
  231. }
  232. )
  233. low_res_masks = torch.stack([x["low_res_logits"] for x in outputs], dim=0)
  234. return low_res_masks.squeeze(1)
  235. # ///////////////
  236. augmentation = A.Compose(
  237. [
  238. A.Rotate(limit=90, p=0.5),
  239. A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1),
  240. A.RandomResizedCrop(1024, 1024, scale=(0.9, 1.0), p=1),
  241. A.HorizontalFlip(p=0.5),
  242. A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.5),
  243. A.CoarseDropout(
  244. max_holes=8,
  245. max_height=16,
  246. max_width=16,
  247. min_height=8,
  248. min_width=8,
  249. fill_value=0,
  250. p=0.5,
  251. ),
  252. A.RandomScale(scale_limit=0.1, p=0.5),
  253. ]
  254. )
  255. panc_sam_instance = panc_sam()
  256. panc_sam_instance.to(device)
  257. panc_sam_instance.train()
  258. train_dataset = PanDataset(
  259. [args.train_dir],
  260. [args.train_labels_dir],
  261. [["NIH_PNG",1]],
  262. image_size,
  263. slice_per_image=slice_per_image,
  264. train=True,
  265. augmentation=augmentation,
  266. )
  267. val_dataset = PanDataset(
  268. [args.val_dir],
  269. [args.train_dir],
  270. [["NIH_PNG",1]],
  271. image_size,
  272. slice_per_image=slice_per_image,
  273. train=False,
  274. )
  275. train_loader = DataLoader(
  276. train_dataset,
  277. batch_size=args.batch_size,
  278. collate_fn=train_dataset.collate_fn,
  279. shuffle=True,
  280. drop_last=False,
  281. num_workers=num_workers,
  282. )
  283. val_loader = DataLoader(
  284. val_dataset,
  285. batch_size=args.batch_size,
  286. collate_fn=val_dataset.collate_fn,
  287. shuffle=False,
  288. drop_last=False,
  289. num_workers=num_workers,
  290. )
  291. # Set up the optimizer, hyperparameter tuning will improve performance here
  292. lr = 1e-4
  293. max_lr = 5e-5
  294. wd = 5e-4
  295. optimizer = torch.optim.Adam(
  296. # parameters,
  297. list(panc_sam_instance.sam.mask_decoder.parameters()),
  298. # list(panc_sam_instance.mask_decoder.parameters()),
  299. lr=lr,
  300. weight_decay=wd,
  301. )
  302. scheduler = torch.optim.lr_scheduler.OneCycleLR(
  303. optimizer,
  304. max_lr=max_lr,
  305. epochs=num_epochs,
  306. steps_per_epoch=sample_size // (args.accumulative_batch_size // args.batch_size),
  307. )
  308. from statistics import mean
  309. from tqdm import tqdm
  310. from torch.nn.functional import threshold, normalize
  311. loss_function = loss_fn(alpha=0.5, gamma=2.0)
  312. loss_function.to(device)
  313. from time import time
  314. import time as s_time
  315. log_file = open(f"exps/{exp_id}-{user_input}/log.txt", "a")
  316. def process_model(data_loader, train=0, save_output=0):
  317. epoch_losses = []
  318. index = 0
  319. results = torch.zeros((2, 0, 256, 256))
  320. total_dice = 0.0
  321. num_samples = 0
  322. counterb = 0
  323. for image, label in tqdm(data_loader, total=sample_size):
  324. s_time.sleep(0.6)
  325. counterb += 1
  326. index += 1
  327. image = image.to(device)
  328. label = label.to(device).float()
  329. input_size = (1024, 1024)
  330. box = torch.tensor([[200, 200, 750, 800]]).to(device)
  331. points, point_labels = main_prompt_for_ground_true(label)
  332. # raise ValueError(points)
  333. batched_input = []
  334. for ibatch in range(args.batch_size):
  335. batched_input.append(
  336. {
  337. "image": image[ibatch],
  338. "point_coords": points[ibatch],
  339. "point_labels": point_labels[ibatch],
  340. "original_size": (1024, 1024)
  341. # 'original_size': image1.shape[:2]
  342. },
  343. )
  344. # raise ValueError(batched_input)
  345. low_res_masks = panc_sam_instance(batched_input)
  346. low_res_label = F.interpolate(label, low_res_masks.shape[-2:])
  347. binary_mask = normalize(threshold(low_res_masks, 0.0,0))
  348. loss = loss_function(low_res_masks, low_res_label)
  349. loss /= (args.accumulative_batch_size / args.batch_size)
  350. opened_binary_mask = torch.zeros_like(binary_mask).cpu()
  351. for j, mask in enumerate(binary_mask[:, 0]):
  352. numpy_mask = mask.detach().cpu().numpy().astype(np.uint8)
  353. opened_binary_mask[j][0] = torch.from_numpy(numpy_mask)
  354. dice = dice_coefficient(
  355. opened_binary_mask.numpy(), low_res_label.cpu().detach().numpy()
  356. )
  357. # print(dice)
  358. total_dice += dice
  359. num_samples += 1
  360. average_dice = total_dice / num_samples
  361. log_file.write(str(average_dice) + "\n")
  362. log_file.flush()
  363. if train:
  364. loss.backward()
  365. if index % (args.accumulative_batch_size / args.batch_size) == 0:
  366. # print(loss)
  367. optimizer.step()
  368. scheduler.step()
  369. optimizer.zero_grad()
  370. index = 0
  371. else:
  372. result = torch.cat(
  373. (
  374. low_res_masks[0].detach().cpu().reshape(1, 1, 256, 256),
  375. opened_binary_mask[0].reshape(1, 1, 256, 256),
  376. ),
  377. dim=0,
  378. )
  379. results = torch.cat((results, result), dim=1)
  380. if index % (args.accumulative_batch_size / args.batch_size) == 0:
  381. epoch_losses.append(loss.item())
  382. if counterb == sample_size and train:
  383. break
  384. elif counterb == sample_size // 5 and not train:
  385. break
  386. return epoch_losses, results, average_dice
  387. def train_model(train_loader, val_loader, K_fold=False, N_fold=7, epoch_num_start=7):
  388. print("Train model started.")
  389. train_losses = []
  390. train_epochs = []
  391. val_losses = []
  392. val_epochs = []
  393. dice = []
  394. dice_val = []
  395. results = []
  396. if debug==0:
  397. index = 0
  398. ## training with k-fold cross validation:
  399. last_best_dice = 0
  400. for epoch in range(num_epochs):
  401. if epoch > epoch_num_start:
  402. kf = KFold(n_splits=N_fold, shuffle=True)
  403. for i, (train_index, val_index) in enumerate(kf.split(train_loader)):
  404. print(
  405. f"=====================EPOCH: {epoch} fold: {i}====================="
  406. )
  407. print("Training:")
  408. x_train, x_val = (
  409. train_loader[train_index],
  410. train_loader[val_index],
  411. )
  412. train_epoch_losses, epoch_results, average_dice = process_model(
  413. x_train, train=1
  414. )
  415. dice.append(average_dice)
  416. train_losses.append(train_epoch_losses)
  417. if (average_dice) > 0.6:
  418. print("validating:")
  419. (
  420. val_epoch_losses,
  421. epoch_results,
  422. average_dice_val,
  423. ) = process_model(x_val)
  424. val_losses.append(val_epoch_losses)
  425. for i in tqdm(range(len(epoch_results[0]))):
  426. if not os.path.exists(f"ims/batch_{i}"):
  427. os.mkdir(f"ims/batch_{i}")
  428. save_img(
  429. epoch_results[0, i].clone(),
  430. f"ims/batch_{i}/prob_epoch_{epoch}.png",
  431. )
  432. save_img(
  433. epoch_results[1, i].clone(),
  434. f"ims/batch_{i}/pred_epoch_{epoch}.png",
  435. )
  436. train_mean_losses = [mean(x) for x in train_losses]
  437. val_mean_losses = [mean(x) for x in val_losses]
  438. np.save("train_losses.npy", train_mean_losses)
  439. np.save("val_losses.npy", val_mean_losses)
  440. print(f"Train Dice: {average_dice}")
  441. print(f"Mean train loss: {mean(train_epoch_losses)}")
  442. try:
  443. dice_val.append(average_dice_val)
  444. print(f"val Dice : {average_dice_val}")
  445. print(f"Mean val loss: {mean(val_epoch_losses)}")
  446. results.append(epoch_results)
  447. val_epochs.append(epoch)
  448. train_epochs.append(epoch)
  449. plt.plot(
  450. val_epochs,
  451. val_mean_losses,
  452. train_epochs,
  453. train_mean_losses,
  454. )
  455. if average_dice_val > last_best_dice:
  456. torch.save(
  457. panc_sam_instance,
  458. f"exps/{exp_id}-{user_input}/sam_tuned_save.pth",
  459. )
  460. last_best_dice = average_dice_val
  461. del epoch_results
  462. del average_dice_val
  463. except:
  464. train_epochs.append(epoch)
  465. plt.plot(train_epochs, train_mean_losses)
  466. print(
  467. f"=================End of EPOCH: {epoch} Fold :{i}==================\n"
  468. )
  469. plt.yscale("log")
  470. plt.title("Mean epoch loss")
  471. plt.xlabel("Epoch Number")
  472. plt.ylabel("Loss")
  473. plt.savefig("result")
  474. else:
  475. print(f"=====================EPOCH: {epoch}=====================")
  476. last_best_dice = 0
  477. print("Training:")
  478. train_epoch_losses, epoch_results, average_dice = process_model(
  479. train_loader, train=1
  480. )
  481. dice.append(average_dice)
  482. train_losses.append(train_epoch_losses)
  483. if (average_dice) > 0.6:
  484. print("validating:")
  485. val_epoch_losses, epoch_results, average_dice_val = process_model(
  486. val_loader
  487. )
  488. val_losses.append(val_epoch_losses)
  489. # for i in tqdm(range(len(epoch_results[0]))):
  490. # if not os.path.exists(f"ims/batch_{i}"):
  491. # os.mkdir(f"ims/batch_{i}")
  492. # save_img(
  493. # epoch_results[0, i].clone(),
  494. # f"ims/batch_{i}/prob_epoch_{epoch}.png",
  495. # )
  496. # save_img(
  497. # epoch_results[1, i].clone(),
  498. # f"ims/batch_{i}/pred_epoch_{epoch}.png",
  499. # )
  500. train_mean_losses = [mean(x) for x in train_losses]
  501. val_mean_losses = [mean(x) for x in val_losses]
  502. np.save("train_losses.npy", train_mean_losses)
  503. np.save("val_losses.npy", val_mean_losses)
  504. print(f"Train Dice: {average_dice}")
  505. print(f"Mean train loss: {mean(train_epoch_losses)}")
  506. try:
  507. dice_val.append(average_dice_val)
  508. print(f"val Dice : {average_dice_val}")
  509. print(f"Mean val loss: {mean(val_epoch_losses)}")
  510. results.append(epoch_results)
  511. val_epochs.append(epoch)
  512. train_epochs.append(epoch)
  513. plt.plot(
  514. val_epochs, val_mean_losses, train_epochs, train_mean_losses
  515. )
  516. if average_dice_val > last_best_dice:
  517. torch.save(
  518. panc_sam_instance,
  519. f"exps/{exp_id}-{user_input}/sam_tuned_save.pth",
  520. )
  521. last_best_dice = average_dice_val
  522. del epoch_results
  523. del average_dice_val
  524. except:
  525. train_epochs.append(epoch)
  526. plt.plot(train_epochs, train_mean_losses)
  527. print(f"=================End of EPOCH: {epoch}==================\n")
  528. plt.yscale("log")
  529. plt.title("Mean epoch loss")
  530. plt.xlabel("Epoch Number")
  531. plt.ylabel("Loss")
  532. plt.savefig("result")
  533. return train_losses, val_losses, results
  534. train_losses, val_losses, results = train_model(train_loader, val_loader)
  535. log_file.close()
  536. # train and also test the model