You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_inference.py 15KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. debug = 1
  2. from pathlib import Path
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. # import cv2
  6. from collections import defaultdict
  7. import torchvision.transforms as transforms
  8. import torch
  9. from torch import nn
  10. import torch.nn.functional as F
  11. from segment_anything.utils.transforms import ResizeLongestSide
  12. import albumentations as A
  13. from albumentations.pytorch import ToTensorV2
  14. import numpy as np
  15. from einops import rearrange
  16. import random
  17. from tqdm import tqdm
  18. from time import sleep
  19. from data import *
  20. from time import time
  21. from PIL import Image
  22. from sklearn.model_selection import KFold
  23. from shutil import copyfile
  24. import monai
  25. from tqdm import tqdm
  26. from utils import sample_prompt
  27. from torch.autograd import Variable
  28. from args import get_arguments
  29. # import wandb_handler
  30. args = get_arguments()
  31. def save_img(img, dir):
  32. img = img.clone().cpu().numpy() + 100
  33. if len(img.shape) == 3:
  34. img = rearrange(img, "c h w -> h w c")
  35. img_min = np.amin(img, axis=(0, 1), keepdims=True)
  36. img = img - img_min
  37. img_max = np.amax(img, axis=(0, 1), keepdims=True)
  38. img = (img / img_max * 255).astype(np.uint8)
  39. grey_img = Image.fromarray(img[:, :, 0])
  40. img = Image.fromarray(img)
  41. else:
  42. img_min = img.min()
  43. img = img - img_min
  44. img_max = img.max()
  45. if img_max != 0:
  46. img = img / img_max * 255
  47. img = Image.fromarray(img).convert("L")
  48. img.save(dir)
  49. class FocalLoss(nn.Module):
  50. def __init__(self, gamma=2.0, alpha=0.25):
  51. super(FocalLoss, self).__init__()
  52. self.gamma = gamma
  53. self.alpha = alpha
  54. def dice_loss(self, logits, gt, eps=1):
  55. # Convert logits to probabilities
  56. # Flatten the tensors
  57. # probs = probs.view(-1)
  58. # gt = gt.view(-1)
  59. probs = torch.sigmoid(logits)
  60. # Compute Dice coefficient
  61. intersection = (probs * gt).sum()
  62. dice_coeff = (2.0 * intersection + eps) / (probs.sum() + gt.sum() + eps)
  63. # Compute Dice Los[s
  64. loss = 1 - dice_coeff
  65. return loss
  66. def focal_loss(self, pred, mask):
  67. """
  68. pred: [B, 1, H, W]
  69. mask: [B, 1, H, W]
  70. """
  71. # pred=pred.reshape(-1,1)
  72. # mask = mask.reshape(-1,1)
  73. # assert pred.shape == mask.shape, "pred and mask should have the same shape."
  74. p = torch.sigmoid(pred)
  75. num_pos = torch.sum(mask)
  76. num_neg = mask.numel() - num_pos
  77. w_pos = (1 - p) ** self.gamma
  78. w_neg = p**self.gamma
  79. loss_pos = -self.alpha * mask * w_pos * torch.log(p + 1e-12)
  80. loss_neg = -(1 - self.alpha) * (1 - mask) * w_neg * torch.log(1 - p + 1e-12)
  81. loss = (torch.sum(loss_pos) + torch.sum(loss_neg)) / (num_pos + num_neg + 1e-12)
  82. return loss
  83. def forward(self, logits, target):
  84. logits = logits.squeeze(1)
  85. target = target.squeeze(1)
  86. # Dice Loss
  87. # prob = F.softmax(logits, dim=1)[:, 1, ...]
  88. dice_loss = self.dice_loss(logits, target)
  89. # Focal Loss
  90. focal_loss = self.focal_loss(logits, target.squeeze(-1))
  91. alpha = 20.0
  92. # Combined Loss
  93. combined_loss = alpha * focal_loss + dice_loss
  94. return combined_loss
  95. class loss_fn(torch.nn.Module):
  96. def __init__(self, alpha=0.7, gamma=2.0, epsilon=1e-5):
  97. super(loss_fn, self).__init__()
  98. self.alpha = alpha
  99. self.gamma = gamma
  100. self.epsilon = epsilon
  101. def dice_loss(self, logits, gt, eps=1):
  102. # Convert logits to probabilities
  103. # Flatten the tensorsx
  104. probs = torch.sigmoid(logits)
  105. probs = probs.view(-1)
  106. gt = gt.view(-1)
  107. # Compute Dice coefficient
  108. intersection = (probs * gt).sum()
  109. dice_coeff = (2.0 * intersection + eps) / (probs.sum() + gt.sum() + eps)
  110. # Compute Dice Los[s
  111. loss = 1 - dice_coeff
  112. return loss
  113. def focal_loss(self, logits, gt, gamma=4):
  114. logits = logits.reshape(-1, 1)
  115. gt = gt.reshape(-1, 1)
  116. logits = torch.cat((1 - logits, logits), dim=1)
  117. probs = torch.sigmoid(logits)
  118. pt = probs.gather(1, gt.long())
  119. modulating_factor = (1 - pt) ** gamma
  120. # pt_false= pt<=0.5
  121. # modulating_factor[pt_false] *= 2
  122. focal_loss = -modulating_factor * torch.log(pt + 1e-12)
  123. # Compute the mean focal loss
  124. loss = focal_loss.mean()
  125. return loss # Store as a Python number to save memory
  126. def forward(self, logits, target):
  127. logits = logits.squeeze(1)
  128. target = target.squeeze(1)
  129. # Dice Loss
  130. # prob = F.softmax(logits, dim=1)[:, 1, ...]
  131. dice_loss = self.dice_loss(logits, target)
  132. # Focal Loss
  133. focal_loss = self.focal_loss(logits, target.squeeze(-1))
  134. alpha = 20.0
  135. # Combined Loss
  136. combined_loss = alpha * focal_loss + dice_loss
  137. return combined_loss
  138. def img_enhance(img2, coef=0.2):
  139. img_mean = np.mean(img2)
  140. img_max = np.max(img2)
  141. val = (img_max - img_mean) * coef + img_mean
  142. img2[img2 < img_mean * 0.7] = img_mean * 0.7
  143. img2[img2 > val] = val
  144. return img2
  145. def dice_coefficient(pred, target):
  146. smooth = 1 # Smoothing constant to avoid division by zero
  147. dice = 0
  148. pred_index = pred
  149. target_index = target
  150. intersection = (pred_index * target_index).sum()
  151. union = pred_index.sum() + target_index.sum()
  152. dice += (2.0 * intersection + smooth) / (union + smooth)
  153. return dice.item()
  154. def calculate_accuracy(pred, target):
  155. correct = (pred == target).sum().item()
  156. total = target.numel()
  157. return correct / total
  158. def calculate_sensitivity(pred, target):
  159. smooth = 1
  160. # Also known as recall
  161. true_positive = ((pred == 1) & (target == 1)).sum().item()
  162. false_negative = ((pred == 0) & (target == 1)).sum().item()
  163. return (true_positive + smooth) / ((true_positive + false_negative) + smooth)
  164. def calculate_specificity(pred, target):
  165. smooth = 1
  166. true_negative = ((pred == 0) & (target == 0)).sum().item()
  167. false_positive = ((pred == 1) & (target == 0)).sum().item()
  168. return (true_negative + smooth) / ((true_negative + false_positive ) + smooth)
  169. # def calculate_recall(pred, target): # Same as sensitivity
  170. # return calculate_sensitivity(pred, target)
  171. accumaltive_batch_size = 512
  172. batch_size = 2
  173. num_workers = 4
  174. slice_per_image = 1
  175. num_epochs = 40
  176. sample_size = 1800
  177. # image_size=sam_model.image_encoder.img_size
  178. image_size = 1024
  179. exp_id = 0
  180. found = 0
  181. # if debug:
  182. # user_input = "debug"
  183. # else:
  184. # user_input = input("Related changes: ")
  185. # while found == 0:
  186. # try:
  187. # os.makedirs(f"exps/{exp_id}-{user_input}")
  188. # found = 1
  189. # except:
  190. # exp_id = exp_id + 1
  191. # copyfile(os.path.realpath(__file__), f"exps/{exp_id}-{user_input}/code.py")
  192. layer_n = 4
  193. L = layer_n
  194. a = np.full(L, layer_n)
  195. params = {"M": 255, "a": a, "p": 0.35}
  196. model_type = "vit_h"
  197. checkpoint = "checkpoints/sam_vit_h_4b8939.pth"
  198. device = "cuda:0"
  199. from segment_anything import SamPredictor, sam_model_registry
  200. # //////////////////
  201. class panc_sam(nn.Module):
  202. def __init__(self, *args, **kwargs) -> None:
  203. super().__init__(*args, **kwargs)
  204. # self.sam = sam_model_registry[model_type](checkpoint=checkpoint)
  205. self.sam = torch.load(
  206. "sam_tuned_save.pth"
  207. ).sam
  208. def forward(self, batched_input):
  209. # with torch.no_grad():
  210. input_images = torch.stack([x["image"] for x in batched_input], dim=0)
  211. with torch.no_grad():
  212. image_embeddings = self.sam.image_encoder(input_images).detach()
  213. outputs = []
  214. for image_record, curr_embedding in zip(batched_input, image_embeddings):
  215. if "point_coords" in image_record:
  216. points = (image_record["point_coords"].unsqueeze(0), image_record["point_labels"].unsqueeze(0))
  217. else:
  218. raise ValueError('what the f?')
  219. points = None
  220. # raise ValueError(image_record["point_coords"].shape)
  221. with torch.no_grad():
  222. sparse_embeddings, dense_embeddings = self.sam.prompt_encoder(
  223. points=points,
  224. boxes=image_record.get("boxes", None),
  225. masks=image_record.get("mask_inputs", None),
  226. )
  227. #sparse_embeddings, dense_embeddings = self.sam.prompt_encoder(
  228. # points=None,
  229. # boxes=None,
  230. # masks=None,
  231. #)
  232. sparse_embeddings = sparse_embeddings / 5
  233. dense_embeddings = dense_embeddings / 5
  234. # raise ValueError(image_embeddings.shape)
  235. low_res_masks, _ = self.sam.mask_decoder(
  236. image_embeddings=curr_embedding.unsqueeze(0),
  237. image_pe=self.sam.prompt_encoder.get_dense_pe().detach(),
  238. sparse_prompt_embeddings=sparse_embeddings.detach(),
  239. dense_prompt_embeddings=dense_embeddings.detach(),
  240. multimask_output=False,
  241. )
  242. outputs.append(
  243. {
  244. "low_res_logits": low_res_masks,
  245. }
  246. )
  247. low_res_masks = torch.stack([x["low_res_logits"] for x in outputs], dim=0)
  248. return low_res_masks.squeeze(1)
  249. panc_sam_instance = panc_sam()
  250. panc_sam_instance.to(device)
  251. panc_sam_instance.eval() # Set the model to evaluation mode
  252. test_dataset = PanDataset(
  253. [args.test_dir],
  254. [args.test_labels_dir],
  255. [["NIH_PNG",1]],
  256. image_size,
  257. slice_per_image=slice_per_image,
  258. train=False,
  259. )
  260. test_loader = DataLoader(
  261. test_dataset,
  262. batch_size=batch_size,
  263. collate_fn=test_dataset.collate_fn,
  264. shuffle=False,
  265. drop_last=False,
  266. num_workers=num_workers,)
  267. lr = 1e-4
  268. max_lr = 1e-3
  269. wd = 5e-4
  270. optimizer = torch.optim.Adam(
  271. # parameters,
  272. list(panc_sam_instance.sam.mask_decoder.parameters()),
  273. lr=lr,
  274. weight_decay=wd,
  275. )
  276. scheduler = torch.optim.lr_scheduler.OneCycleLR(
  277. optimizer,
  278. max_lr=max_lr,
  279. epochs=num_epochs,
  280. steps_per_epoch=sample_size // (accumaltive_batch_size // batch_size),
  281. )
  282. loss_function = loss_fn(alpha=0.5, gamma=2.0)
  283. loss_function.to(device)
  284. from statistics import mean
  285. from tqdm import tqdm
  286. from torch.nn.functional import threshold, normalize
  287. def process_model(data_loader, train=False, save_output=0):
  288. epoch_losses = []
  289. index = 0
  290. results = torch.zeros((2, 0, 256, 256))
  291. total_dice = 0.0
  292. total_accuracy = 0.0
  293. total_sensitivity = 0.0
  294. total_specificity = 0.0
  295. num_samples = 0
  296. counterb = 0
  297. for image, label in tqdm(data_loader, total=sample_size):
  298. counterb += 1
  299. index += 1
  300. image = image.to(device)
  301. label = label.to(device).float()
  302. points, point_labels = sample_prompt(label)
  303. batched_input = []
  304. for ibatch in range(batch_size):
  305. batched_input.append(
  306. {
  307. "image": image[ibatch],
  308. "point_coords": points[ibatch],
  309. "point_labels": point_labels[ibatch],
  310. "original_size": (1024, 1024)
  311. },
  312. )
  313. low_res_masks = panc_sam_instance(batched_input)
  314. low_res_label = F.interpolate(label, low_res_masks.shape[-2:])
  315. binary_mask = normalize(threshold(low_res_masks, 0.0,0))
  316. loss = loss_function(low_res_masks, low_res_label)
  317. loss /= (accumaltive_batch_size / batch_size)
  318. opened_binary_mask = torch.zeros_like(binary_mask).cpu()
  319. for j, mask in enumerate(binary_mask[:, 0]):
  320. numpy_mask = mask.detach().cpu().numpy().astype(np.uint8)
  321. opened_binary_mask[j][0] = torch.from_numpy(numpy_mask)
  322. dice = dice_coefficient(
  323. opened_binary_mask.numpy(), low_res_label.cpu().detach().numpy()
  324. )
  325. accuracy = calculate_accuracy(binary_mask, low_res_label)
  326. sensitivity = calculate_sensitivity(binary_mask, low_res_label)
  327. specificity = calculate_specificity(binary_mask, low_res_label)
  328. total_accuracy += accuracy
  329. total_sensitivity += sensitivity
  330. total_specificity += specificity
  331. total_dice += dice
  332. num_samples += 1
  333. average_dice = total_dice / num_samples
  334. average_accuracy = total_accuracy / num_samples
  335. average_sensitivity = total_sensitivity / num_samples
  336. average_specificity = total_specificity / num_samples
  337. if train:
  338. loss.backward()
  339. if index % (accumaltive_batch_size / batch_size) == 0:
  340. optimizer.step()
  341. scheduler.step()
  342. optimizer.zero_grad()
  343. index = 0
  344. else:
  345. result = torch.cat(
  346. (
  347. low_res_masks[0].detach().cpu().reshape(1, 1, 256, 256),
  348. opened_binary_mask[0].reshape(1, 1, 256, 256),
  349. ),
  350. dim=0,
  351. )
  352. results = torch.cat((results, result), dim=1)
  353. if index % (accumaltive_batch_size / batch_size) == 0:
  354. epoch_losses.append(loss.item())
  355. if counterb == sample_size:
  356. break
  357. return epoch_losses, results, average_dice , average_accuracy , average_sensitivity ,average_specificity
  358. def test_model(test_loader):
  359. print("Testing started.")
  360. test_losses = []
  361. dice_test = []
  362. results = []
  363. test_epoch_losses, epoch_results, average_dice_test, average_accuracy_test, average_sensitivity_test, average_specificity_test = process_model(test_loader)
  364. test_losses.append(test_epoch_losses)
  365. dice_test.append(average_dice_test)
  366. print(dice_test)
  367. # Handling the results as needed
  368. return test_losses, results
  369. test_losses, results = test_model(test_loader)
  370. import torch
  371. import torch.nn as nn
  372. def double_conv_3d(in_channels, out_channels):
  373. return nn.Sequential(
  374. nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
  375. nn.ReLU(inplace=True),
  376. nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
  377. nn.ReLU(inplace=True)
  378. )
  379. class UNet3D(nn.Module):
  380. def __init__(self):
  381. super(UNet3D, self).__init__()
  382. self.dconv_down1 = double_conv_3d(3, 64)
  383. self.dconv_down2 = double_conv_3d(64, 128)
  384. self.dconv_down3 = double_conv_3d(128, 256)
  385. self.dconv_down4 = double_conv_3d(256, 512)
  386. self.maxpool = nn.MaxPool3d(2)
  387. self.upsample = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
  388. self.dconv_up3 = double_conv_3d(256 + 512, 256)
  389. self.dconv_up2 = double_conv_3d(128 + 256, 128)
  390. self.dconv_up1 = double_conv_3d(128 + 64, 64)
  391. self.conv_last = nn.Conv3d(64, 1, kernel_size=1)
  392. def forward(self, x):
  393. conv1 = self.dconv_down1(x)
  394. x = self.maxpool(conv1)
  395. conv2 = self.dconv_down2(x)
  396. x = self.maxpool(conv2)
  397. conv3 = self.dconv_down3(x)
  398. x = self.maxpool(conv3)
  399. x = self.dconv_down4(x)
  400. x = self.upsample(x)
  401. x = torch.cat([x, conv3], dim=1)
  402. x = self.dconv_up3(x)
  403. x = self.upsample(x)
  404. x = torch.cat([x, conv2], dim=1)
  405. x = self.dconv_up2(x)
  406. x = self.upsample(x)
  407. x = torch.cat([x, conv1], dim=1)
  408. x = self.dconv_up1(x)
  409. out = self.conv_last(x)
  410. return out