| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564 |
- from pathlib import Path
- import numpy as np
- import matplotlib.pyplot as plt
-
- # import cv2
- from collections import defaultdict
- import torchvision.transforms as transforms
- import torch
- from torch import nn
-
- import torch.nn.functional as F
- from segment_anything.utils.transforms import ResizeLongestSide
- import albumentations as A
- from albumentations.pytorch import ToTensorV2
- import numpy as np
-
- import random
- from tqdm import tqdm
- from time import sleep
- from data_load_group import *
- from time import time
- from PIL import Image
- from sklearn.model_selection import KFold
- from shutil import copyfile
- # import monai
- from utils import *
- from torch.autograd import Variable
- from Models.model_handler import panc_sam
- from Models.model_handler import UNet3D
- from Models.model_handler import Conv3DFilter
- from loss import *
- from args import get_arguments
- from segment_anything import SamPredictor, sam_model_registry
- from statistics import mean
- from copy import deepcopy
- from torch.nn.functional import threshold, normalize
- def calculate_recall(pred, target):
- smooth = 1
- batch_size = 1
- recall_scores = []
-
- binary_mask = pred>0
-
- true_positive = ((pred == 1) & (target == 1)).sum().item()
- false_negative = ((pred == 0) & (target == 1)).sum().item()
- recall = (true_positive + smooth) / ((true_positive + false_negative) + smooth)
-
-
-
- return recall
-
- def calculate_precision(pred, target):
- smooth = 1
- batch_size = 1
- recall_scores = []
-
- binary_mask = pred>0
-
- true_positive = ((pred == 1) & (target == 1)).sum().item()
- false_negative = ((pred == 1) & (target == 0)).sum().item()
- recall = (true_positive + smooth) / ((true_positive + false_negative) + smooth)
-
-
-
- return recall
-
- def calculate_jaccard(pred, target):
- smooth = 1
- batch_size = pred.shape[0]
- jaccard_scores = []
- binary_mask = pred>0
- true_positive = ((pred == 1) & (target == 1)).sum().item()
- false_negative = ((pred == 0) & (target == 1)).sum().item()
- true_positive = ((pred == 1) & (target == 1)).sum().item()
- false_positive = ((pred == 1) & (target == 0)).sum().item()
- jaccard = (true_positive + smooth) / (true_positive + false_positive + false_negative + smooth)
- # jaccard_scores.append(jaccard)
- # for i in range(batch_size):
- # true_positive = ((binary_mask[i] == 1) & (target[i] == 1)).sum().item()
- # false_positive = ((binary_mask[i] == 1) & (target[i] == 0)).sum().item()
- # false_negative = ((binary_mask[i] == 0) & (target[i] == 1)).sum().item()
-
-
- return jaccard
-
- def save_img(img, dir):
- img = img.clone().cpu().numpy() + 100
-
- if len(img.shape) == 3:
- img = rearrange(img, "c h w -> h w c")
- img_min = np.amin(img, axis=(0, 1), keepdims=True)
- img = img - img_min
-
- img_max = np.amax(img, axis=(0, 1), keepdims=True)
- img = (img / img_max * 255).astype(np.uint8)
- # grey_img = Image.fromarray(img[:, :, 0])
-
- img = Image.fromarray(img)
-
- else:
- img_min = img.min()
- img = img - img_min
- img_max = img.max()
- if img_max != 0:
- img = img / img_max * 255
-
- img = Image.fromarray(img).convert("L")
-
- img.save(dir)
-
- def seed_everything(seed: int):
- import random, os
- import numpy as np
- import torch
-
- random.seed(seed)
- os.environ['PYTHONHASHSEED'] = str(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = True
-
- seed_everything(2024)
-
-
- global optimizer
- args = get_arguments()
-
-
- exp_id = 0
- found = 0
- user_input = args.run_name
- while found == 0:
- try:
- os.makedirs(f"exps/{exp_id}-{user_input}")
- found = 1
- except:
- exp_id = exp_id + 1
- copyfile(os.path.realpath(__file__), f"exps/{exp_id}-{user_input}/code.py")
-
- augmentation = A.Compose(
- [
- A.Rotate(limit=100, p=0.7),
- A.RandomScale(scale_limit=0.3, p=0.5),
- ]
- )
- device = "cuda:0"
-
-
-
- panc_sam_instance = torch.load(args.model_path)
- panc_sam_instance.to(device)
-
- conv3d_instance = UNet3D()
- kernel_size = [(1, 5, 5), (5, 5, 5), (5, 5, 5), (5, 5, 5)]
-
- conv3d_instance = Conv3DFilter(
- 1,
- 5,
- kernel_size,
- np.array(kernel_size) // 2,
- custom_bias=args.custom_bias,
- )
-
- conv3d_instance.to(device)
- conv3d_instance.train()
-
- train_dataset = PanDataset(
- dirs=[f"{args.train_dir}/train"],
- datasets=[["NIH_PNG", 1]],
- target_image_size=args.image_size,
- slice_per_image=args.slice_per_image,
- train=True,
- val=False, # Enable validation data splitting
- augmentation=augmentation,
- )
-
- val_dataset = PanDataset(
- [f"{args.train_dir}/train"],
- [["NIH_PNG", 1]],
- args.image_size,
- slice_per_image=args.slice_per_image,
- val=True
- )
-
- test_dataset = PanDataset(
- [f"{args.test_dir}/test"],
- [["NIH_PNG", 1]],
- args.image_size,
- slice_per_image=args.slice_per_image,
- train=False,
- )
- train_loader = DataLoader(
- train_dataset,
- batch_size=args.batch_size,
- collate_fn=train_dataset.collate_fn,
- shuffle=True,
- drop_last=False,
- num_workers=args.num_workers,
- )
-
-
- val_loader = DataLoader(
- val_dataset,
- batch_size=args.batch_size,
- collate_fn=val_dataset.collate_fn,
- shuffle=False,
- drop_last=False,
- num_workers=args.num_workers,
- )
-
- test_loader = DataLoader(
- test_dataset,
- batch_size=args.batch_size,
- collate_fn=test_dataset.collate_fn,
- shuffle=False,
- drop_last=False,
- num_workers=args.num_workers,
- )
-
-
- # Set up the optimizer, hyperparameter tuning will improve performance here
- lr = args.lr
- max_lr = 3e-4
- wd = 5e-4
-
-
- all_parameters = list(conv3d_instance.parameters())
-
- optimizer = torch.optim.Adam(
- all_parameters,
- lr=lr,
- weight_decay=wd,
- )
-
-
- scheduler = torch.optim.lr_scheduler.StepLR(
- optimizer, step_size=10, gamma=0.7, verbose=True
- )
-
- loss_function = loss_fn(alpha=0.5, gamma=2.0)
- loss_function.to(device)
-
- from time import time
- import time as s_time
-
- log_file = open(f"exps/{exp_id}-{user_input}/log.txt", "a")
-
-
- def process_model(data_loader, train=0, save_output=0, epoch=None, scheduler=None):
- epoch_losses = []
-
- index = 0
- results = []
- dice_sam_lists = []
- dice_sam_prompt_lists = []
- dice_lists = []
- dice_prompt_lists = []
-
- num_samples = 0
-
- counterb = 0
-
- for batch in tqdm(data_loader, total=args.sample_size // args.batch_size):
-
- for batched_input in batch:
-
- num_samples += 1
-
- # raise ValueError(len(batched_input))
- low_res_masks = torch.zeros((1, 1, 0, 256, 256))
- # s_time.sleep(0.6)
- counterb += len(batch)
-
- index += 1
- label = []
- label = [i["label"] for i in batched_input]
-
- # Only correct if gray scale
- label = torch.cat(label, dim=1)
- # raise ValueError(la)
- label = label.float()
-
- true_indexes = torch.where((torch.amax(label, dim=(2, 3)) > 0).view(-1))[0]
-
- low_res_label = F.interpolate(label, low_res_masks.shape[-2:]).to("cuda:0")
- low_res_masks, low_res_masks_promtp = panc_sam_instance(
- batched_input, device
- )
-
- low_res_shape = low_res_masks.shape[-2:]
- low_res_label_prompt=low_res_label
- if train:
- transformed = augmentation(
- image=low_res_masks_promtp[0].permute(1, 2, 0).cpu().numpy(),
- mask=low_res_label[0].permute(1, 2, 0).cpu().numpy(),
- )
-
- low_res_masks_promtp = (
- torch.tensor(transformed["image"])
- .permute(2, 0, 1)
- .unsqueeze(0)
- .to(device)
- )
-
- low_res_label_prompt = (
- torch.tensor(transformed["mask"])
- .permute(2, 0, 1)
- .unsqueeze(0)
- .to(device)
- )
-
- transformed = augmentation(
- image=low_res_masks[0].permute(1, 2, 0).cpu().numpy(),
- mask=low_res_label[0].permute(1, 2, 0).cpu().numpy(),
- )
-
- low_res_masks = (
- torch.tensor(transformed["image"])
- .permute(2, 0, 1)
- .unsqueeze(0)
- .to(device)
- )
-
- low_res_label = (
- torch.tensor(transformed["mask"])
- .permute(2, 0, 1)
- .unsqueeze(0)
- .to(device)
- )
- low_res_masks = F.interpolate(low_res_masks, low_res_shape).to(device)
- low_res_label = F.interpolate(low_res_label, low_res_shape).to(device)
-
- low_res_masks = F.interpolate(low_res_masks, low_res_shape).to(device)
- low_res_label = F.interpolate(low_res_label, low_res_shape).to(device)
- low_res_masks_promtp = F.interpolate(
- low_res_masks_promtp, low_res_shape
- ).to(device)
- low_res_label_prompt = F.interpolate(
- low_res_label_prompt, low_res_shape
- ).to(device)
- low_res_masks = low_res_masks.detach()
- low_res_masks_promtp = low_res_masks_promtp.detach()
-
- dice_sam = dice_coefficient(low_res_masks , low_res_label).detach().cpu()
- dice_sam_prompt = (
- dice_coefficient(low_res_masks_promtp, low_res_label_prompt)
- .detach()
- .cpu()
- )
- low_res_masks_promtp = conv3d_instance(
- low_res_masks_promtp.detach().to(device)
- )
-
- loss = loss_function(low_res_masks_promtp, low_res_masks_promtp)
- loss /= args.accumulative_batch_size / args.batch_size
-
- binary_mask = low_res_masks > 0.5
- binary_mask_prompt = low_res_masks_promtp > 0.5
-
- dice = dice_coefficient(binary_mask, low_res_label).detach().cpu()
- dice_prompt = (
- dice_coefficient(binary_mask_prompt, low_res_label_prompt)
- .detach()
- .cpu()
- )
-
- dice_sam_lists.append(dice_sam)
- dice_sam_prompt_lists.append(dice_sam_prompt)
- dice_lists.append(dice)
- dice_prompt_lists.append(dice_prompt)
-
- log_file.flush()
- if train:
- loss.backward()
- if index % (args.accumulative_batch_size / args.batch_size) == 0:
-
- optimizer.step()
- # if epoch==40:
- # scheduler.step()
- optimizer.zero_grad()
- index = 0
-
- else:
-
- result = torch.cat(
- (
- low_res_masks[:, ::10].detach().cpu().reshape(1, -1, 256, 256),
- binary_mask[:, ::10].detach().cpu().reshape(1, -1, 256, 256),
- ),
- dim=0,
- )
- results.append(result)
-
- if index % (args.accumulative_batch_size / args.batch_size) == 0:
- epoch_losses.append(loss.item())
- if counterb == (args.sample_size // args.batch_size) and train:
- break
-
-
- return (
- epoch_losses,
- results,
- dice_lists,
- dice_prompt_lists,
- dice_sam_lists,
- dice_sam_prompt_lists,
- )
-
-
- def train_model(train_loader, val_loader,test_loader, K_fold=False, N_fold=7, epoch_num_start=7):
- global optimizer
- index=0
- if args.inference:
- with torch.no_grad():
- conv = torch.load(f'{args.conv_path}')
-
- recall_list=[]
- percision_list=[]
- jaccard_list=[]
-
- for input in tqdm(test_loader):
-
-
- low_res_masks_sam, low_res_masks_promtp_sam = panc_sam_instance(
- input[0], device
- )
- low_res_masks_sam = F.interpolate(low_res_masks_sam, 512).cpu()
- low_res_masks_promtp_sam = F.interpolate(low_res_masks_promtp_sam, 512).cpu()
- low_res_masks_promtp = conv(low_res_masks_promtp_sam.to(device)).detach().cpu()
-
- for slice_id,(batched_input,mask_sam,mask_prompt_sam,mask_prompt) in enumerate(zip(input[0],low_res_masks_sam[0],low_res_masks_promtp_sam[0],low_res_masks_promtp[0])):
-
- if not os.path.exists(f"ims/batch_{index}"):
- os.mkdir(f"ims/batch_{index}")
- image = batched_input["image"]
-
-
- label = batched_input["label"][0,0,::2,::2].to(bool)
- binary_mask_sam = (mask_sam > 0)
- binary_mask_prompt_sam = (mask_prompt_sam > 0)
- binary_mask_prompt = (mask_prompt > 0.5)
- recall = calculate_recall(label, binary_mask_prompt)
- percision = calculate_precision(label, binary_mask_prompt)
- jaccard = calculate_jaccard(label, binary_mask_prompt)
- percision_list.append(percision)
- recall_list.append(recall)
- jaccard_list.append(jaccard)
- image_mask = image.clone().to(torch.long)
- image_label = image.clone().to(torch.long)
- image_mask[binary_mask_sam]=255
- image_label[label]=255
- save_img(
- torch.stack((image_mask,image_label,image),dim=0),
- f"ims/batch_{index}/sam{slice_id}.png",
- )
-
- image_mask = image.clone().to(torch.long)
- image_mask[binary_mask_prompt_sam]=255
-
- save_img(
- torch.stack((image_mask,image_label,image),dim=0),
- f"ims/batch_{index}/sam_prompt{slice_id}.png",
- )
-
- image_mask = image.clone().to(torch.long)
- image_mask[binary_mask_prompt]=255
- save_img(
- torch.stack((image_mask,image_label,image),dim=0),
- f"ims/batch_{index}/prompt_{slice_id}.png",
- )
- print(f'Recall={np.mean(recall_list)}')
- print(f'Percision={np.mean(percision_list)}')
- print(f'Jaccard={np.mean(jaccard_list)}')
- index += 1
- print(f'Recall={np.mean(recall_list)}')
- print(f'Percision={np.mean(percision_list)}')
- print(f'Jaccard={np.mean(jaccard_list)}')
- else:
-
- print("Train model started.")
-
- train_losses = []
- train_epochs = []
- val_losses = []
- val_epochs = []
- dice = []
- dice_val = []
- results = []
- last_best_dice=0
- for epoch in range(args.num_epochs):
-
- print(f"=====================EPOCH: {epoch + 1}=====================")
- log_file.write(f"=====================EPOCH: {epoch + 1}===================\n")
-
- print("Training:")
- (
- train_epoch_losses,
- results,
- dice_list,
- dice_prompt_list,
- dice_sam_list,
- dice_sam_prompt_list,
- ) = process_model(train_loader, train=1, epoch=epoch, scheduler=scheduler)
-
- dice_mean = np.mean(dice_list)
- dice_prompt_mean = np.mean(dice_prompt_list)
- dice_sam_mean = np.mean(dice_sam_list)
- dice_sam_prompt_mean = np.mean(dice_sam_prompt_list)
-
- print("Validating:")
- (
- _,
- _,
- val_dice_list,
- val_dice_prompt_list,
- val_dice_sam_list,
- val_dice_sam_prompt_list,
- ) = process_model(val_loader)
- val_dice_mean = np.mean(val_dice_list)
- val_dice_prompt_mean = np.mean(val_dice_prompt_list)
- val_dice_sam_mean = np.mean(val_dice_sam_list)
- val_dice_sam_prompt_mean = np.mean(val_dice_sam_prompt_list)
-
- train_mean_losses = [mean(x) for x in train_losses]
-
-
- logs = ""
-
- logs += f"Train Dice_sam: {dice_sam_mean}\n"
- logs += f"Train Dice: {dice_mean}\n"
- logs += f"Train Dice_sam_prompt: {dice_sam_prompt_mean}\n"
- logs += f"Train Dice_prompt: {dice_prompt_mean}\n"
- logs += f"Mean train loss: {mean(train_epoch_losses)}\n"
-
-
-
- logs += f"val Dice_sam: {val_dice_sam_mean}\n"
- logs += f"val Dice: {val_dice_mean}\n"
- logs += f"val Dice_sam_prompt: {val_dice_sam_prompt_mean}\n"
- logs += f"val Dice_prompt: {val_dice_prompt_mean}\n"
-
- # plt.plot(val_epochs, val_mean_losses, train_epochs, train_mean_losses)
- if val_dice_prompt_mean > last_best_dice:
- torch.save(
- conv3d_instance,
- f"exps/{exp_id}-{user_input}/conv_save.pth",
- )
- print("Model saved")
- last_best_dice = val_dice_prompt_mean
-
-
- print(logs)
- log_file.write(logs)
- scheduler.step()
- ## training with k-fold cross validation:
-
-
- fff = time()
- train_model(train_loader, val_loader,test_loader)
- log_file.close()
-
- # train and also test the model
|