You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

double_decoder_infrence.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. from pathlib import Path
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from utils import main_prompt , sample_prompt
  5. from collections import defaultdict
  6. import torchvision.transforms as transforms
  7. import torch
  8. from torch import nn
  9. import torch.nn.functional as F
  10. from segment_anything.utils.transforms import ResizeLongestSide
  11. import albumentations as A
  12. from albumentations.pytorch import ToTensorV2
  13. from einops import rearrange
  14. import random
  15. from tqdm import tqdm
  16. from time import sleep
  17. from data import *
  18. from time import time
  19. from PIL import Image
  20. from sklearn.model_selection import KFold
  21. from shutil import copyfile
  22. from torch.nn.functional import threshold, normalize
  23. import torchvision.transforms.functional as TF
  24. from args import get_arguments
  25. args = get_arguments()
  26. def dice_coefficient(logits, gt):
  27. eps=1
  28. binary_mask = logits>0
  29. intersection = (binary_mask * gt).sum(dim=(-2,-1))
  30. dice_scores = (2.0 * intersection + eps) / (binary_mask.sum(dim=(-2,-1)) + gt.sum(dim=(-2,-1)) + eps)
  31. return dice_scores.mean()
  32. def calculate_recall(pred, target):
  33. smooth = 1
  34. batch_size = pred.shape[0]
  35. recall_scores = []
  36. binary_mask = pred>0
  37. for i in range(batch_size):
  38. true_positive = ((binary_mask[i] == 1) & (target[i] == 1)).sum().item()
  39. false_negative = ((binary_mask[i] == 0) & (target[i] == 1)).sum().item()
  40. recall = (true_positive + smooth) / ((true_positive + false_negative) + smooth)
  41. recall_scores.append(recall)
  42. return sum(recall_scores) / len(recall_scores)
  43. def calculate_precision(pred, target):
  44. smooth = 1
  45. batch_size = pred.shape[0]
  46. precision_scores = []
  47. binary_mask = pred>0
  48. for i in range(batch_size):
  49. true_positive = ((binary_mask[i] == 1) & (target[i] == 1)).sum().item()
  50. false_positive = ((binary_mask[i] == 1) & (target[i] == 0)).sum().item()
  51. precision = (true_positive + smooth) / ((true_positive + false_positive) + smooth)
  52. precision_scores.append(precision)
  53. return sum(precision_scores) / len(precision_scores)
  54. def calculate_jaccard(pred, target):
  55. smooth = 1
  56. batch_size = pred.shape[0]
  57. jaccard_scores = []
  58. binary_mask = pred>0
  59. for i in range(batch_size):
  60. true_positive = ((binary_mask[i] == 1) & (target[i] == 1)).sum().item()
  61. false_positive = ((binary_mask[i] == 1) & (target[i] == 0)).sum().item()
  62. false_negative = ((binary_mask[i] == 0) & (target[i] == 1)).sum().item()
  63. jaccard = (true_positive + smooth) / (true_positive + false_positive + false_negative + smooth)
  64. jaccard_scores.append(jaccard)
  65. return sum(jaccard_scores) / len(jaccard_scores)
  66. def calculate_specificity(pred, target):
  67. smooth = 1
  68. batch_size = pred.shape[0]
  69. specificity_scores = []
  70. binary_mask = pred>0
  71. for i in range(batch_size):
  72. true_negative = ((binary_mask[i] == 0) & (target[i] == 0)).sum().item()
  73. false_positive = ((binary_mask[i] == 1) & (target[i] == 0)).sum().item()
  74. specificity = (true_negative + smooth) / (true_negative + false_positive + smooth)
  75. specificity_scores.append(specificity)
  76. return sum(specificity_scores) / len(specificity_scores)
  77. def what_the_f(low_res_masks,label):
  78. low_res_label = F.interpolate(label, low_res_masks.shape[-2:])
  79. dice = dice_coefficient(
  80. low_res_masks, low_res_label
  81. )
  82. recall=calculate_recall(low_res_masks, low_res_label)
  83. precision =calculate_precision(low_res_masks, low_res_label)
  84. jaccard = calculate_jaccard(low_res_masks, low_res_label)
  85. return dice , precision , recall , jaccard
  86. def save_img(img, dir):
  87. img = img.clone().cpu().numpy()
  88. if len(img.shape) == 3:
  89. img = rearrange(img, "c h w -> h w c")
  90. img_min = np.amin(img, axis=(0, 1), keepdims=True)
  91. img = img - img_min
  92. img_max = np.amax(img, axis=(0, 1), keepdims=True)
  93. img = (img / img_max * 255).astype(np.uint8)
  94. img = Image.fromarray(img)
  95. else:
  96. img_min = img.min()
  97. img = img - img_min
  98. img_max = img.max()
  99. if img_max != 0:
  100. img = img / img_max * 255
  101. img = Image.fromarray(img).convert("L")
  102. img.save(dir)
  103. slice_per_image = 1
  104. image_size = 1024
  105. class panc_sam(nn.Module):
  106. def __init__(self, *args, **kwargs) -> None:
  107. super().__init__(*args, **kwargs)
  108. #Promptless
  109. sam = torch.load("sam_tuned_save.pth").sam
  110. self.prompt_encoder = sam.prompt_encoder
  111. self.mask_decoder = sam.mask_decoder
  112. for param in self.prompt_encoder.parameters():
  113. param.requires_grad = False
  114. for param in self.mask_decoder.parameters():
  115. param.requires_grad = False
  116. #with Prompt
  117. sam=sam_model_registry[model_type](checkpoint=checkpoint)
  118. # sam = torch.load(
  119. # "sam_tuned_save.pth"
  120. # ).sam
  121. self.image_encoder = sam.image_encoder
  122. self.prompt_encoder2 = sam.prompt_encoder
  123. self.mask_decoder2 = sam.mask_decoder
  124. for param in self.image_encoder.parameters():
  125. param.requires_grad = False
  126. for param in self.prompt_encoder2.parameters():
  127. param.requires_grad = False
  128. def forward(self, input_images,box=None):
  129. # input_images = torch.stack([x["image"] for x in batched_input], dim=0)
  130. with torch.no_grad():
  131. image_embeddings = self.image_encoder(input_images).detach()
  132. outputs_prompt = []
  133. outputs = []
  134. for curr_embedding in image_embeddings:
  135. with torch.no_grad():
  136. sparse_embeddings, dense_embeddings = self.prompt_encoder(
  137. points=None,
  138. boxes=None,
  139. masks=None,
  140. )
  141. low_res_masks, _ = self.mask_decoder(
  142. image_embeddings=curr_embedding,
  143. image_pe=self.prompt_encoder.get_dense_pe().detach(),
  144. sparse_prompt_embeddings=sparse_embeddings.detach(),
  145. dense_prompt_embeddings=dense_embeddings.detach(),
  146. multimask_output=False,
  147. )
  148. outputs_prompt.append(low_res_masks)
  149. # points, point_labels = sample_prompt((low_res_masks > 0).float())
  150. points, point_labels = main_prompt(low_res_masks)
  151. points = points * 4
  152. points = (points, point_labels)
  153. with torch.no_grad():
  154. sparse_embeddings, dense_embeddings = self.prompt_encoder2(
  155. points=points,
  156. boxes=None,
  157. masks=None,
  158. )
  159. low_res_masks, _ = self.mask_decoder2(
  160. image_embeddings=curr_embedding,
  161. image_pe=self.prompt_encoder2.get_dense_pe().detach(),
  162. sparse_prompt_embeddings=sparse_embeddings.detach(),
  163. dense_prompt_embeddings=dense_embeddings.detach(),
  164. multimask_output=False,
  165. )
  166. outputs.append(low_res_masks)
  167. low_res_masks_promtp = torch.cat(outputs_prompt, dim=0)
  168. low_res_masks = torch.cat(outputs, dim=0)
  169. return low_res_masks, low_res_masks_promtp
  170. test_dataset = PanDataset(
  171. [
  172. args.test_dir
  173. ],
  174. [
  175. args.test_labels_dir
  176. ],
  177. [["NIH_PNG",1]],
  178. image_size,
  179. slice_per_image=slice_per_image,
  180. train=False,
  181. )
  182. device = "cuda:0"
  183. x = torch.load('sam_tuned_save.pth', map_location='cpu')
  184. # raise ValueError(x)
  185. x.to(device)
  186. # num_samples = 0
  187. # counterb=0
  188. # index=0
  189. # for image, label in tqdm(test_dataset, total=len(test_dataset)):
  190. # num_samples += 1
  191. # counterb += 1
  192. # index += 1
  193. # image = image.to(device)
  194. # label = label.to(device).float()
  195. # ############################model and dice########################################
  196. # box = torch.tensor([[200, 200, 750, 800]]).to(device)
  197. # low_res_masks_main,low_res_masks_prompt = x(image,box)
  198. def process_model(main_model , test_dataset, train=0, save_output=0):
  199. epoch_losses = []
  200. results=[]
  201. results_prompt=[]
  202. index = 0
  203. results = torch.zeros((2, 0, 256, 256))
  204. results_prompt = torch.zeros((2, 0, 256, 256))
  205. #############################
  206. total_dice = 0.0
  207. total_precision = 0.0
  208. total_recall =0.0
  209. total_jaccard = 0.0
  210. #############################
  211. num_samples = 0
  212. #############################
  213. total_dice_main =0.0
  214. total_precision_main = 0.0
  215. total_recall_main =0.0
  216. total_jaccard_main = 0.0
  217. counterb = 0
  218. for image, label , raw_data in tqdm(test_dataset, total=len(test_dataset)):
  219. # raise ValueError(image.shape , label.shape , raw_data.shape)
  220. num_samples += 1
  221. counterb += 1
  222. index += 1
  223. image = image.to(device)
  224. label = label.to(device).float()
  225. ############################model and dice########################################
  226. box = torch.tensor([[200, 200, 750, 800]]).to(device)
  227. low_res_masks_main,low_res_masks_prompt = main_model(image,box)
  228. low_res_label = F.interpolate(label, low_res_masks_main.shape[-2:])
  229. dice_prompt, precisio_prompt , recall_prompt , jaccard_prompt = what_the_f(low_res_masks_prompt,low_res_label)
  230. dice_main , precision_main , recall_main , jaccard_main = what_the_f(low_res_masks_main,low_res_label)
  231. # binary_mask = normalize(threshold(low_res_masks_main, 0.0,0))
  232. ##############prompt###############
  233. total_dice += dice_prompt
  234. total_precision += precisio_prompt
  235. total_recall += recall_prompt
  236. total_jaccard += jaccard_prompt
  237. average_dice = total_dice / num_samples
  238. average_precision = total_precision /num_samples
  239. average_recall = total_recall /num_samples
  240. average_jaccard = total_jaccard /num_samples
  241. ##############main##################
  242. total_dice_main+=dice_main
  243. total_precision_main +=precision_main
  244. total_recall_main +=recall_main
  245. total_jaccard_main += jaccard_main
  246. average_dice_main = total_dice_main / num_samples
  247. average_precision_main = total_precision_main /num_samples
  248. average_recall_main = total_recall_main /num_samples
  249. average_jaccard_main = total_jaccard_main /num_samples
  250. binary_mask = normalize(threshold(low_res_masks_main, 0.0,0))
  251. binary_mask_mask = normalize(threshold(low_res_masks_prompt, 0.0,0))
  252. ###################################
  253. result = torch.cat(
  254. (
  255. low_res_masks_main[0].detach().cpu().reshape(1, 1, 256, 256),
  256. binary_mask[0].detach().cpu().reshape(1, 1, 256, 256),
  257. ),
  258. dim=0,
  259. )
  260. results = torch.cat((results, result), dim=1)
  261. result_prompt = torch.cat(
  262. (
  263. low_res_masks_prompt[0].detach().cpu().reshape(1, 1, 256, 256),
  264. binary_mask_mask[0].detach().cpu().reshape(1, 1, 256, 256),
  265. ),
  266. dim=0,
  267. )
  268. results_prompt = torch.cat((results_prompt, result_prompt), dim=1)
  269. # if counterb == len(test_dataset)-5:
  270. # break
  271. if counterb == 200:
  272. break
  273. # elif counterb == sample_size and not train:
  274. # break
  275. return epoch_losses, results,results_prompt, average_dice,average_precision ,average_recall, average_jaccard,average_dice_main,average_precision_main,average_recall_main,average_jaccard_main
  276. print("Testing:")
  277. test_epoch_losses, epoch_results , results_prompt, average_dice_test,average_precision ,average_recall, average_jaccard,average_dice_test_main,average_precision_main,average_recall_main,average_jaccard_main = process_model(
  278. x,test_dataset
  279. )
  280. # raise ValueError(len(epoch_results[0]))
  281. train_losses = []
  282. train_epochs = []
  283. test_losses = []
  284. test_epochs = []
  285. dice = []
  286. dice_main = []
  287. dice_test = []
  288. dice_test_main =[]
  289. results = []
  290. index = 0
  291. ##############################save image#########################################
  292. for image, label , raw_data in tqdm(test_dataset):
  293. if index < 200:
  294. if not os.path.exists(f"result_img/batch_{index}"):
  295. os.mkdir(f"result_img/batch_{index}")
  296. save_img(
  297. image[0],
  298. f"result_img/batch_{index}/img.png",
  299. )
  300. tensor_raw = torch.tensor(raw_data)
  301. save_img(
  302. tensor_raw.T,
  303. f"result_img/batch_{index}/raw_img.png",
  304. )
  305. model_result_resized = TF.resize(epoch_results, size=(1024, 1024))
  306. result_canvas = torch.zeros_like(image[0])
  307. result_canvas[1] = label[0][0]
  308. result_canvas[0] = model_result_resized[1, index]
  309. blended_result = 0.2 * image[0] + 0.5 * result_canvas
  310. ###################################################################
  311. model_result_resized_prompt = TF.resize(results_prompt, size=(1024, 1024))
  312. result_canvas_prompt = torch.zeros_like(image[0])
  313. result_canvas_prompt[1] = label[0][0]
  314. # raise ValueError(model_result_resized_prompt.shape ,model_result_resized.shape )
  315. result_canvas_prompt[0] = model_result_resized_prompt[1, index]
  316. blended_result_prompt = 0.2 * image[0] + 0.5 * result_canvas_prompt
  317. save_img(blended_result, f"result_img/batch_{index}/comb.png")
  318. save_img(blended_result_prompt, f"result_img/batch_{index}/comb_prompt.png")
  319. save_img(
  320. epoch_results[1, index].clone(),f"result_img/batch_{index}/modelresult.png",
  321. )
  322. save_img(
  323. epoch_results[0, index].clone(),f"result_img/batch_{index}/prob_epoch_{index}.png",)
  324. index += 1
  325. if index == 200:
  326. break
  327. dice_test.append(average_dice_test)
  328. dice_test_main.append(average_dice_test_main)
  329. print("######################Prompt##########################")
  330. print(f"Test Dice : {average_dice_test}")
  331. print(f"Test presision : {average_precision}")
  332. print(f"Test recall : {average_recall}")
  333. print(f"Test jaccard : {average_jaccard}")
  334. print("######################Main##########################")
  335. print(f"Test Dice main : {average_dice_test_main}")
  336. print(f"Test presision main : {average_precision_main}")
  337. print(f"Test recall main : {average_recall_main}")
  338. print(f"Test jaccard main : {average_jaccard_main}")