You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Inference_individually.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. from pathlib import Path
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from utils import sample_prompt
  5. from collections import defaultdict
  6. import torchvision.transforms as transforms
  7. import torch
  8. from torch import nn
  9. import torch.nn.functional as F
  10. from segment_anything.utils.transforms import ResizeLongestSide
  11. import albumentations as A
  12. from albumentations.pytorch import ToTensorV2
  13. from einops import rearrange
  14. import random
  15. from tqdm import tqdm
  16. from time import sleep
  17. from data import *
  18. from time import time
  19. from PIL import Image
  20. from sklearn.model_selection import KFold
  21. from shutil import copyfile
  22. from args import get_arguments
  23. # import wandb_handler
  24. args = get_arguments()
  25. def save_img(img, dir):
  26. img = img.clone().cpu().numpy() + 100
  27. if len(img.shape) == 3:
  28. img = rearrange(img, "c h w -> h w c")
  29. img_min = np.amin(img, axis=(0, 1), keepdims=True)
  30. img = img - img_min
  31. img_max = np.amax(img, axis=(0, 1), keepdims=True)
  32. img = (img / img_max * 255).astype(np.uint8)
  33. img = Image.fromarray(img)
  34. else:
  35. img_min = img.min()
  36. img = img - img_min
  37. img_max = img.max()
  38. if img_max != 0:
  39. img = img / img_max * 255
  40. img = Image.fromarray(img).convert("L")
  41. img.save(dir)
  42. class loss_fn(torch.nn.Module):
  43. def __init__(self, alpha=0.7, gamma=2.0, epsilon=1e-5):
  44. super(loss_fn, self).__init__()
  45. self.alpha = alpha
  46. self.gamma = gamma
  47. self.epsilon = epsilon
  48. def dice_loss(self, logits, gt, eps=1):
  49. probs = torch.sigmoid(logits)
  50. probs = probs.view(-1)
  51. gt = gt.view(-1)
  52. intersection = (probs * gt).sum()
  53. dice_coeff = (2.0 * intersection + eps) / (probs.sum() + gt.sum() + eps)
  54. loss = 1 - dice_coeff
  55. return loss
  56. def focal_loss(self, logits, gt, gamma=2):
  57. logits = logits.reshape(-1, 1)
  58. gt = gt.reshape(-1, 1)
  59. logits = torch.cat((1 - logits, logits), dim=1)
  60. probs = torch.sigmoid(logits)
  61. pt = probs.gather(1, gt.long())
  62. modulating_factor = (1 - pt) ** gamma
  63. focal_loss = -modulating_factor * torch.log(pt + 1e-12)
  64. loss = focal_loss.mean()
  65. return loss # Store as a Python number to save memory
  66. def forward(self, logits, target):
  67. logits = logits.squeeze(1)
  68. target = target.squeeze(1)
  69. # Dice Loss
  70. # prob = F.softmax(logits, dim=1)[:, 1, ...]
  71. dice_loss = self.dice_loss(logits, target)
  72. # Focal Loss
  73. focal_loss = self.focal_loss(logits, target.squeeze(-1))
  74. alpha = 20.0
  75. # Combined Loss
  76. combined_loss = alpha * focal_loss + dice_loss
  77. return combined_loss
  78. def img_enhance(img2, coef=0.2):
  79. img_mean = np.mean(img2)
  80. img_max = np.max(img2)
  81. val = (img_max - img_mean) * coef + img_mean
  82. img2[img2 < img_mean * 0.7] = img_mean * 0.7
  83. img2[img2 > val] = val
  84. return img2
  85. def dice_coefficient(logits, gt):
  86. eps=1
  87. binary_mask = logits>0
  88. intersection = (binary_mask * gt).sum(dim=(-2,-1))
  89. dice_scores = (2.0 * intersection + eps) / (binary_mask.sum(dim=(-2,-1)) + gt.sum(dim=(-2,-1)) + eps)
  90. return dice_scores.mean()
  91. def calculate_recall(pred, target):
  92. smooth = 1
  93. batch_size = pred.shape[0]
  94. recall_scores = []
  95. binary_mask = pred>0
  96. for i in range(batch_size):
  97. true_positive = ((binary_mask[i] == 1) & (target[i] == 1)).sum().item()
  98. false_negative = ((binary_mask[i] == 0) & (target[i] == 1)).sum().item()
  99. recall = (true_positive + smooth) / ((true_positive + false_negative) + smooth)
  100. recall_scores.append(recall)
  101. return sum(recall_scores) / len(recall_scores)
  102. def calculate_precision(pred, target):
  103. smooth = 1
  104. batch_size = pred.shape[0]
  105. precision_scores = []
  106. binary_mask = pred>0
  107. for i in range(batch_size):
  108. true_positive = ((binary_mask[i] == 1) & (target[i] == 1)).sum().item()
  109. false_positive = ((binary_mask[i] == 1) & (target[i] == 0)).sum().item()
  110. precision = (true_positive + smooth) / ((true_positive + false_positive) + smooth)
  111. precision_scores.append(precision)
  112. return sum(precision_scores) / len(precision_scores)
  113. def calculate_jaccard(pred, target):
  114. smooth = 1
  115. batch_size = pred.shape[0]
  116. jaccard_scores = []
  117. binary_mask = pred>0
  118. for i in range(batch_size):
  119. true_positive = ((binary_mask[i] == 1) & (target[i] == 1)).sum().item()
  120. false_positive = ((binary_mask[i] == 1) & (target[i] == 0)).sum().item()
  121. false_negative = ((binary_mask[i] == 0) & (target[i] == 1)).sum().item()
  122. jaccard = (true_positive + smooth) / (true_positive + false_positive + false_negative + smooth)
  123. jaccard_scores.append(jaccard)
  124. return sum(jaccard_scores) / len(jaccard_scores)
  125. def calculate_specificity(pred, target):
  126. smooth = 1
  127. batch_size = pred.shape[0]
  128. specificity_scores = []
  129. binary_mask = pred>0
  130. for i in range(batch_size):
  131. true_negative = ((binary_mask[i] == 0) & (target[i] == 0)).sum().item()
  132. false_positive = ((binary_mask[i] == 1) & (target[i] == 0)).sum().item()
  133. specificity = (true_negative + smooth) / (true_negative + false_positive + smooth)
  134. specificity_scores.append(specificity)
  135. return sum(specificity_scores) / len(specificity_scores)
  136. def what_the_f(low_res_masks,label):
  137. low_res_label = F.interpolate(label, low_res_masks.shape[-2:])
  138. dice = dice_coefficient(
  139. low_res_masks, low_res_label
  140. )
  141. recall=calculate_recall(low_res_masks, low_res_label)
  142. precision =calculate_precision(low_res_masks, low_res_label)
  143. jaccard = calculate_jaccard(low_res_masks, low_res_label)
  144. return dice , precision , recall , jaccard
  145. accumaltive_batch_size = 8
  146. batch_size = 1
  147. num_workers = 2
  148. slice_per_image = 1
  149. num_epochs = 40
  150. sample_size = 3660
  151. # sample_size = 43300
  152. # image_size=sam_model.image_encoder.img_size
  153. image_size = 1024
  154. exp_id = 0
  155. found = 0
  156. layer_n = 4
  157. L = layer_n
  158. a = np.full(L, layer_n)
  159. params = {"M": 255, "a": a, "p": 0.35}
  160. model_type = "vit_h"
  161. checkpoint = "checkpoints/sam_vit_h_4b8939.pth"
  162. device = "cuda:0"
  163. from segment_anything import SamPredictor, sam_model_registry
  164. ##################################main model#######################################
  165. class panc_sam(nn.Module):
  166. def __init__(self, *args, **kwargs) -> None:
  167. super().__init__(*args, **kwargs)
  168. #Promptless
  169. sam = torch.load(args.pointbasemodel).sam
  170. self.prompt_encoder = sam.prompt_encoder
  171. self.mask_decoder = sam.mask_decoder
  172. for param in self.prompt_encoder.parameters():
  173. param.requires_grad = False
  174. for param in self.mask_decoder.parameters():
  175. param.requires_grad = False
  176. #with Prompt
  177. sam = torch.load(
  178. args.promptprovider
  179. ).sam
  180. self.image_encoder = sam.image_encoder
  181. self.prompt_encoder2 = sam.prompt_encoder
  182. self.mask_decoder2 = sam.mask_decoder
  183. for param in self.image_encoder.parameters():
  184. param.requires_grad = False
  185. for param in self.prompt_encoder2.parameters():
  186. param.requires_grad = False
  187. def forward(self, input_images,box=None):
  188. # input_images = torch.stack([x["image"] for x in batched_input], dim=0)
  189. # raise ValueError(input_images.shape)
  190. with torch.no_grad():
  191. image_embeddings = self.image_encoder(input_images).detach()
  192. outputs_prompt = []
  193. outputs = []
  194. for curr_embedding in image_embeddings:
  195. with torch.no_grad():
  196. sparse_embeddings, dense_embeddings = self.prompt_encoder(
  197. points=None,
  198. boxes=None,
  199. masks=None,
  200. )
  201. low_res_masks, _ = self.mask_decoder(
  202. image_embeddings=curr_embedding,
  203. image_pe=self.prompt_encoder.get_dense_pe().detach(),
  204. sparse_prompt_embeddings=sparse_embeddings.detach(),
  205. dense_prompt_embeddings=dense_embeddings.detach(),
  206. multimask_output=False,
  207. )
  208. outputs_prompt.append(low_res_masks)
  209. # raise ValueError(low_res_masks)
  210. # points, point_labels = sample_prompt((low_res_masks > 0).float())
  211. points, point_labels = sample_prompt(low_res_masks)
  212. points = points * 4
  213. points = (points, point_labels)
  214. with torch.no_grad():
  215. sparse_embeddings, dense_embeddings = self.prompt_encoder2(
  216. points=points,
  217. boxes=None,
  218. masks=None,
  219. )
  220. low_res_masks, _ = self.mask_decoder2(
  221. image_embeddings=curr_embedding,
  222. image_pe=self.prompt_encoder2.get_dense_pe().detach(),
  223. sparse_prompt_embeddings=sparse_embeddings.detach(),
  224. dense_prompt_embeddings=dense_embeddings.detach(),
  225. multimask_output=False,
  226. )
  227. outputs.append(low_res_masks)
  228. low_res_masks_promtp = torch.cat(outputs_prompt, dim=0)
  229. low_res_masks = torch.cat(outputs, dim=0)
  230. return low_res_masks, low_res_masks_promtp
  231. ##################################end#######################################
  232. ##################################Augmentation#######################################
  233. augmentation = A.Compose(
  234. [
  235. A.Rotate(limit=30, p=0.5),
  236. A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1),
  237. A.RandomResizedCrop(1024, 1024, scale=(0.9, 1.0), p=1),
  238. A.HorizontalFlip(p=0.5),
  239. A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.5),
  240. A.CoarseDropout(
  241. max_holes=8,
  242. max_height=16,
  243. max_width=16,
  244. min_height=8,
  245. min_width=8,
  246. fill_value=0,
  247. p=0.5,
  248. ),
  249. A.RandomScale(scale_limit=0.3, p=0.5),
  250. # A.GaussNoise(var_limit=(10.0, 50.0), p=0.5),
  251. # A.GridDistortion(p=0.5),
  252. ]
  253. )
  254. ##################model load#####################
  255. panc_sam_instance = panc_sam()
  256. # for param in panc_sam_instance_point.parameters():
  257. # param.requires_grad = False
  258. panc_sam_instance.to(device)
  259. panc_sam_instance.train()
  260. ##################load data#######################
  261. test_dataset = PanDataset(
  262. [args.test_dir],
  263. [args.test_labels_dir],
  264. [["NIH_PNG",1]],
  265. image_size,
  266. slice_per_image=slice_per_image,
  267. train=False,
  268. )
  269. test_loader = DataLoader(
  270. test_dataset,
  271. batch_size=batch_size,
  272. collate_fn=test_dataset.collate_fn,
  273. shuffle=False,
  274. drop_last=False,
  275. num_workers=num_workers,
  276. )
  277. ##################end load data#######################
  278. lr = 1e-4
  279. max_lr = 5e-5
  280. wd = 5e-4
  281. optimizer_main = torch.optim.Adam(
  282. # parameters,
  283. list(panc_sam_instance.mask_decoder2.parameters()),
  284. lr=lr,
  285. weight_decay=wd,
  286. )
  287. scheduler_main = torch.optim.lr_scheduler.OneCycleLR(
  288. optimizer_main,
  289. max_lr=max_lr,
  290. epochs=num_epochs,
  291. steps_per_epoch=sample_size // (accumaltive_batch_size // batch_size),
  292. )
  293. #####################################################
  294. from statistics import mean
  295. from tqdm import tqdm
  296. from torch.nn.functional import threshold, normalize
  297. loss_function = loss_fn(alpha=0.5, gamma=2.0)
  298. loss_function.to(device)
  299. from time import time
  300. import time as s_time
  301. def process_model(main_model , data_loader, train=0, save_output=0):
  302. epoch_losses = []
  303. results=[]
  304. index = 0
  305. results = torch.zeros((2, 0, 256, 256))
  306. #############################
  307. total_dice = 0.0
  308. total_precision = 0.0
  309. total_recall =0.0
  310. total_jaccard = 0.0
  311. #############################
  312. num_samples = 0
  313. #############################
  314. total_dice_main =0.0
  315. total_precision_main = 0.0
  316. total_recall_main =0.0
  317. total_jaccard_main = 0.0
  318. counterb = 0
  319. for image, label in tqdm(data_loader, total=sample_size):
  320. num_samples += 1
  321. counterb += 1
  322. index += 1
  323. image = image.to(device)
  324. label = label.to(device).float()
  325. ############################model and dice########################################
  326. box = torch.tensor([[200, 200, 750, 800]]).to(device)
  327. low_res_masks_main,low_res_masks_prompt = main_model(image,box)
  328. low_res_label = F.interpolate(label, low_res_masks_main.shape[-2:])
  329. dice_prompt, precisio_prompt , recall_prompt , jaccard_prompt = what_the_f(low_res_masks_prompt,low_res_label)
  330. dice_main , precision_main , recall_main , jaccard_main = what_the_f(low_res_masks_main,low_res_label)
  331. binary_mask = normalize(threshold(low_res_masks_main, 0.0,0))
  332. ##############prompt###############
  333. total_dice += dice_prompt
  334. total_precision += precisio_prompt
  335. total_recall += recall_prompt
  336. total_jaccard += jaccard_prompt
  337. average_dice = total_dice / num_samples
  338. average_precision = total_precision /num_samples
  339. average_recall = total_recall /num_samples
  340. average_jaccard = total_jaccard /num_samples
  341. ##############main##################
  342. total_dice_main+=dice_main
  343. total_precision_main +=precision_main
  344. total_recall_main +=recall_main
  345. total_jaccard_main += jaccard_main
  346. average_dice_main = total_dice_main / num_samples
  347. average_precision_main = total_precision_main /num_samples
  348. average_recall_main = total_recall_main /num_samples
  349. average_jaccard_main = total_jaccard_main /num_samples
  350. ###################################
  351. # result = torch.cat(
  352. # (
  353. # # low_res_masks_main[0].detach().cpu().reshape(1, 1, 256, 256),
  354. # binary_mask[0].detach().cpu().reshape(1, 1, 256, 256),
  355. # ),
  356. # dim=0,
  357. # )
  358. # results = torch.cat((results, result), dim=1)
  359. if counterb == sample_size and train:
  360. break
  361. elif counterb == sample_size and not train:
  362. break
  363. return epoch_losses, results, average_dice,average_precision ,average_recall, average_jaccard,average_dice_main,average_precision_main,average_recall_main,average_jaccard_main
  364. def train_model( test_loader, K_fold=False, N_fold=7, epoch_num_start=7):
  365. print("Train model started.")
  366. test_losses = []
  367. test_epochs = []
  368. dice = []
  369. dice_main = []
  370. dice_test = []
  371. dice_test_main =[]
  372. results = []
  373. index = 0
  374. print("Testing:")
  375. test_epoch_losses, epoch_results, average_dice_test,average_precision ,average_recall, average_jaccard,average_dice_test_main,average_precision_main,average_recall_main,average_jaccard_main = process_model(
  376. panc_sam_instance,test_loader
  377. )
  378. import torchvision.transforms.functional as TF
  379. dice_test.append(average_dice_test)
  380. dice_test_main.append(average_dice_test_main)
  381. print("######################Prompt##########################")
  382. print(f"Test Dice : {average_dice_test}")
  383. print(f"Test presision : {average_precision}")
  384. print(f"Test recall : {average_recall}")
  385. print(f"Test jaccard : {average_jaccard}")
  386. print("######################Main##########################")
  387. print(f"Test Dice main : {average_dice_test_main}")
  388. print(f"Test presision main : {average_precision_main}")
  389. print(f"Test recall main : {average_recall_main}")
  390. print(f"Test jaccard main : {average_jaccard_main}")
  391. # results.append(epoch_results)
  392. # del epoch_results
  393. del average_dice_test
  394. # return train_losses, results
  395. train_model(test_loader)