| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434 |
- from pathlib import Path
- import numpy as np
- import matplotlib.pyplot as plt
- from utils import main_prompt , sample_prompt
- from collections import defaultdict
- import torchvision.transforms as transforms
- import torch
- from torch import nn
- import torch.nn.functional as F
- from segment_anything.utils.transforms import ResizeLongestSide
- import albumentations as A
- from albumentations.pytorch import ToTensorV2
- from einops import rearrange
- import random
- from tqdm import tqdm
- from time import sleep
- from data import *
- from time import time
- from PIL import Image
- from sklearn.model_selection import KFold
- from shutil import copyfile
- from torch.nn.functional import threshold, normalize
- import torchvision.transforms.functional as TF
- from args import get_arguments
-
- args = get_arguments()
-
- def dice_coefficient(logits, gt):
-
- eps=1
- binary_mask = logits>0
-
- intersection = (binary_mask * gt).sum(dim=(-2,-1))
- dice_scores = (2.0 * intersection + eps) / (binary_mask.sum(dim=(-2,-1)) + gt.sum(dim=(-2,-1)) + eps)
-
- return dice_scores.mean()
-
-
- def calculate_recall(pred, target):
- smooth = 1
- batch_size = pred.shape[0]
- recall_scores = []
- binary_mask = pred>0
-
- for i in range(batch_size):
- true_positive = ((binary_mask[i] == 1) & (target[i] == 1)).sum().item()
- false_negative = ((binary_mask[i] == 0) & (target[i] == 1)).sum().item()
- recall = (true_positive + smooth) / ((true_positive + false_negative) + smooth)
- recall_scores.append(recall)
-
- return sum(recall_scores) / len(recall_scores)
-
- def calculate_precision(pred, target):
- smooth = 1
- batch_size = pred.shape[0]
- precision_scores = []
- binary_mask = pred>0
-
- for i in range(batch_size):
- true_positive = ((binary_mask[i] == 1) & (target[i] == 1)).sum().item()
- false_positive = ((binary_mask[i] == 1) & (target[i] == 0)).sum().item()
- precision = (true_positive + smooth) / ((true_positive + false_positive) + smooth)
- precision_scores.append(precision)
-
- return sum(precision_scores) / len(precision_scores)
-
- def calculate_jaccard(pred, target):
- smooth = 1
- batch_size = pred.shape[0]
- jaccard_scores = []
- binary_mask = pred>0
-
-
- for i in range(batch_size):
- true_positive = ((binary_mask[i] == 1) & (target[i] == 1)).sum().item()
- false_positive = ((binary_mask[i] == 1) & (target[i] == 0)).sum().item()
- false_negative = ((binary_mask[i] == 0) & (target[i] == 1)).sum().item()
- jaccard = (true_positive + smooth) / (true_positive + false_positive + false_negative + smooth)
- jaccard_scores.append(jaccard)
-
- return sum(jaccard_scores) / len(jaccard_scores)
-
- def calculate_specificity(pred, target):
- smooth = 1
- batch_size = pred.shape[0]
- specificity_scores = []
- binary_mask = pred>0
-
-
- for i in range(batch_size):
- true_negative = ((binary_mask[i] == 0) & (target[i] == 0)).sum().item()
- false_positive = ((binary_mask[i] == 1) & (target[i] == 0)).sum().item()
- specificity = (true_negative + smooth) / (true_negative + false_positive + smooth)
- specificity_scores.append(specificity)
-
- return sum(specificity_scores) / len(specificity_scores)
-
- def what_the_f(low_res_masks,label):
-
- low_res_label = F.interpolate(label, low_res_masks.shape[-2:])
- dice = dice_coefficient(
- low_res_masks, low_res_label
- )
- recall=calculate_recall(low_res_masks, low_res_label)
- precision =calculate_precision(low_res_masks, low_res_label)
- jaccard = calculate_jaccard(low_res_masks, low_res_label)
-
- return dice , precision , recall , jaccard
- def save_img(img, dir):
- img = img.clone().cpu().numpy()
-
- if len(img.shape) == 3:
- img = rearrange(img, "c h w -> h w c")
- img_min = np.amin(img, axis=(0, 1), keepdims=True)
- img = img - img_min
-
- img_max = np.amax(img, axis=(0, 1), keepdims=True)
- img = (img / img_max * 255).astype(np.uint8)
- img = Image.fromarray(img)
-
- else:
- img_min = img.min()
- img = img - img_min
- img_max = img.max()
- if img_max != 0:
- img = img / img_max * 255
- img = Image.fromarray(img).convert("L")
-
- img.save(dir)
-
- slice_per_image = 1
- image_size = 1024
- class panc_sam(nn.Module):
- def __init__(self, *args, **kwargs) -> None:
- super().__init__(*args, **kwargs)
-
- #Promptless
- sam = torch.load("sam_tuned_save.pth").sam
-
- self.prompt_encoder = sam.prompt_encoder
-
- self.mask_decoder = sam.mask_decoder
- for param in self.prompt_encoder.parameters():
- param.requires_grad = False
-
- for param in self.mask_decoder.parameters():
- param.requires_grad = False
-
- #with Prompt
- sam=sam_model_registry[model_type](checkpoint=checkpoint)
-
- # sam = torch.load(
- # "sam_tuned_save.pth"
- # ).sam
- self.image_encoder = sam.image_encoder
- self.prompt_encoder2 = sam.prompt_encoder
- self.mask_decoder2 = sam.mask_decoder
-
- for param in self.image_encoder.parameters():
- param.requires_grad = False
-
- for param in self.prompt_encoder2.parameters():
- param.requires_grad = False
-
-
-
-
-
- def forward(self, input_images,box=None):
-
-
- # input_images = torch.stack([x["image"] for x in batched_input], dim=0)
- with torch.no_grad():
- image_embeddings = self.image_encoder(input_images).detach()
-
-
- outputs_prompt = []
- outputs = []
-
- for curr_embedding in image_embeddings:
-
- with torch.no_grad():
- sparse_embeddings, dense_embeddings = self.prompt_encoder(
- points=None,
- boxes=None,
- masks=None,
- )
-
- low_res_masks, _ = self.mask_decoder(
- image_embeddings=curr_embedding,
- image_pe=self.prompt_encoder.get_dense_pe().detach(),
- sparse_prompt_embeddings=sparse_embeddings.detach(),
- dense_prompt_embeddings=dense_embeddings.detach(),
- multimask_output=False,
- )
- outputs_prompt.append(low_res_masks)
-
- # points, point_labels = sample_prompt((low_res_masks > 0).float())
- points, point_labels = main_prompt(low_res_masks)
-
-
- points = points * 4
- points = (points, point_labels)
-
- with torch.no_grad():
- sparse_embeddings, dense_embeddings = self.prompt_encoder2(
- points=points,
- boxes=None,
- masks=None,
- )
-
- low_res_masks, _ = self.mask_decoder2(
- image_embeddings=curr_embedding,
- image_pe=self.prompt_encoder2.get_dense_pe().detach(),
- sparse_prompt_embeddings=sparse_embeddings.detach(),
- dense_prompt_embeddings=dense_embeddings.detach(),
- multimask_output=False,
- )
-
- outputs.append(low_res_masks)
- low_res_masks_promtp = torch.cat(outputs_prompt, dim=0)
- low_res_masks = torch.cat(outputs, dim=0)
-
-
- return low_res_masks, low_res_masks_promtp
-
- test_dataset = PanDataset(
- [
-
- args.test_dir
- ],
- [
- args.test_labels_dir
- ],
-
- [["NIH_PNG",1]],
-
- image_size,
-
- slice_per_image=slice_per_image,
- train=False,
- )
- device = "cuda:0"
-
- x = torch.load('sam_tuned_save.pth', map_location='cpu')
- # raise ValueError(x)
- x.to(device)
- # num_samples = 0
- # counterb=0
- # index=0
- # for image, label in tqdm(test_dataset, total=len(test_dataset)):
- # num_samples += 1
- # counterb += 1
- # index += 1
- # image = image.to(device)
- # label = label.to(device).float()
-
- # ############################model and dice########################################
- # box = torch.tensor([[200, 200, 750, 800]]).to(device)
- # low_res_masks_main,low_res_masks_prompt = x(image,box)
-
- def process_model(main_model , test_dataset, train=0, save_output=0):
- epoch_losses = []
- results=[]
- results_prompt=[]
- index = 0
- results = torch.zeros((2, 0, 256, 256))
- results_prompt = torch.zeros((2, 0, 256, 256))
-
- #############################
- total_dice = 0.0
- total_precision = 0.0
- total_recall =0.0
- total_jaccard = 0.0
- #############################
- num_samples = 0
- #############################
- total_dice_main =0.0
- total_precision_main = 0.0
- total_recall_main =0.0
- total_jaccard_main = 0.0
- counterb = 0
- for image, label , raw_data in tqdm(test_dataset, total=len(test_dataset)):
- # raise ValueError(image.shape , label.shape , raw_data.shape)
- num_samples += 1
- counterb += 1
- index += 1
- image = image.to(device)
- label = label.to(device).float()
-
- ############################model and dice########################################
- box = torch.tensor([[200, 200, 750, 800]]).to(device)
- low_res_masks_main,low_res_masks_prompt = main_model(image,box)
-
- low_res_label = F.interpolate(label, low_res_masks_main.shape[-2:])
-
- dice_prompt, precisio_prompt , recall_prompt , jaccard_prompt = what_the_f(low_res_masks_prompt,low_res_label)
- dice_main , precision_main , recall_main , jaccard_main = what_the_f(low_res_masks_main,low_res_label)
-
- # binary_mask = normalize(threshold(low_res_masks_main, 0.0,0))
- ##############prompt###############
- total_dice += dice_prompt
- total_precision += precisio_prompt
- total_recall += recall_prompt
- total_jaccard += jaccard_prompt
- average_dice = total_dice / num_samples
- average_precision = total_precision /num_samples
- average_recall = total_recall /num_samples
- average_jaccard = total_jaccard /num_samples
-
- ##############main##################
- total_dice_main+=dice_main
- total_precision_main +=precision_main
- total_recall_main +=recall_main
- total_jaccard_main += jaccard_main
-
-
- average_dice_main = total_dice_main / num_samples
- average_precision_main = total_precision_main /num_samples
- average_recall_main = total_recall_main /num_samples
- average_jaccard_main = total_jaccard_main /num_samples
- binary_mask = normalize(threshold(low_res_masks_main, 0.0,0))
- binary_mask_mask = normalize(threshold(low_res_masks_prompt, 0.0,0))
-
-
-
- ###################################
-
- result = torch.cat(
- (
- low_res_masks_main[0].detach().cpu().reshape(1, 1, 256, 256),
- binary_mask[0].detach().cpu().reshape(1, 1, 256, 256),
- ),
- dim=0,
- )
- results = torch.cat((results, result), dim=1)
-
- result_prompt = torch.cat(
- (
- low_res_masks_prompt[0].detach().cpu().reshape(1, 1, 256, 256),
- binary_mask_mask[0].detach().cpu().reshape(1, 1, 256, 256),
- ),
- dim=0,
- )
- results_prompt = torch.cat((results_prompt, result_prompt), dim=1)
- # if counterb == len(test_dataset)-5:
- # break
- if counterb == 200:
- break
- # elif counterb == sample_size and not train:
- # break
-
- return epoch_losses, results,results_prompt, average_dice,average_precision ,average_recall, average_jaccard,average_dice_main,average_precision_main,average_recall_main,average_jaccard_main
-
-
- print("Testing:")
- test_epoch_losses, epoch_results , results_prompt, average_dice_test,average_precision ,average_recall, average_jaccard,average_dice_test_main,average_precision_main,average_recall_main,average_jaccard_main = process_model(
- x,test_dataset
- )
- # raise ValueError(len(epoch_results[0]))
- train_losses = []
- train_epochs = []
- test_losses = []
- test_epochs = []
- dice = []
- dice_main = []
- dice_test = []
- dice_test_main =[]
- results = []
- index = 0
- ##############################save image#########################################
- for image, label , raw_data in tqdm(test_dataset):
-
- if index < 200:
- if not os.path.exists(f"result_img/batch_{index}"):
- os.mkdir(f"result_img/batch_{index}")
-
- save_img(
- image[0],
- f"result_img/batch_{index}/img.png",
- )
- tensor_raw = torch.tensor(raw_data)
- save_img(
- tensor_raw.T,
- f"result_img/batch_{index}/raw_img.png",
- )
-
- model_result_resized = TF.resize(epoch_results, size=(1024, 1024))
- result_canvas = torch.zeros_like(image[0])
- result_canvas[1] = label[0][0]
- result_canvas[0] = model_result_resized[1, index]
- blended_result = 0.2 * image[0] + 0.5 * result_canvas
-
- ###################################################################
-
- model_result_resized_prompt = TF.resize(results_prompt, size=(1024, 1024))
- result_canvas_prompt = torch.zeros_like(image[0])
- result_canvas_prompt[1] = label[0][0]
- # raise ValueError(model_result_resized_prompt.shape ,model_result_resized.shape )
- result_canvas_prompt[0] = model_result_resized_prompt[1, index]
- blended_result_prompt = 0.2 * image[0] + 0.5 * result_canvas_prompt
-
-
-
- save_img(blended_result, f"result_img/batch_{index}/comb.png")
- save_img(blended_result_prompt, f"result_img/batch_{index}/comb_prompt.png")
-
-
- save_img(
- epoch_results[1, index].clone(),f"result_img/batch_{index}/modelresult.png",
- )
- save_img(
- epoch_results[0, index].clone(),f"result_img/batch_{index}/prob_epoch_{index}.png",)
-
-
- index += 1
- if index == 200:
- break
-
-
- dice_test.append(average_dice_test)
- dice_test_main.append(average_dice_test_main)
- print("######################Prompt##########################")
- print(f"Test Dice : {average_dice_test}")
- print(f"Test presision : {average_precision}")
- print(f"Test recall : {average_recall}")
- print(f"Test jaccard : {average_jaccard}")
-
- print("######################Main##########################")
- print(f"Test Dice main : {average_dice_test_main}")
- print(f"Test presision main : {average_precision_main}")
- print(f"Test recall main : {average_recall_main}")
- print(f"Test jaccard main : {average_jaccard_main}")
|