| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541 |
- debug = 0
- from pathlib import Path
- import numpy as np
- import matplotlib.pyplot as plt
- from utils import sample_prompt , main_prompt
- from collections import defaultdict
- import torchvision.transforms as transforms
- import torch
- from torch import nn
- import torch.nn.functional as F
- from segment_anything.utils.transforms import ResizeLongestSide
- import albumentations as A
- from albumentations.pytorch import ToTensorV2
- import numpy as np
- from einops import rearrange
- import random
- from tqdm import tqdm
- from time import sleep
- from data import *
- from time import time
- from PIL import Image
- from sklearn.model_selection import KFold
- from shutil import copyfile
- from args import get_arguments
-
- args = get_arguments()
- def save_img(img, dir):
- img = img.clone().cpu().numpy() + 100
-
- if len(img.shape) == 3:
- img = rearrange(img, "c h w -> h w c")
- img_min = np.amin(img, axis=(0, 1), keepdims=True)
- img = img - img_min
-
- img_max = np.amax(img, axis=(0, 1), keepdims=True)
- img = (img / img_max * 255).astype(np.uint8)
- # grey_img = Image.fromarray(img[:, :, 0])
- img = Image.fromarray(img)
-
- else:
- img_min = img.min()
- img = img - img_min
- img_max = img.max()
- if img_max != 0:
- img = img / img_max * 255
- img = Image.fromarray(img).convert("L")
-
- img.save(dir)
-
-
-
- class loss_fn(torch.nn.Module):
- def __init__(self, alpha=0.7, gamma=2.0, epsilon=1e-5):
- super(loss_fn, self).__init__()
- self.alpha = alpha
- self.gamma = gamma
- self.epsilon = epsilon
-
- def tversky_loss(self, y_pred, y_true, alpha=0.8, beta=0.2, smooth=1e-2):
- y_pred = torch.sigmoid(y_pred)
- y_true_pos = torch.flatten(y_true)
- y_pred_pos = torch.flatten(y_pred)
- true_pos = torch.sum(y_true_pos * y_pred_pos)
- false_neg = torch.sum(y_true_pos * (1 - y_pred_pos))
- false_pos = torch.sum((1 - y_true_pos) * y_pred_pos)
- tversky_index = (true_pos + smooth) / (
- true_pos + alpha * false_neg + beta * false_pos + smooth
- )
- return 1 - tversky_index
-
- def focal_tversky(self, y_pred, y_true, gamma=0.75):
- pt_1 = self.tversky_loss(y_pred, y_true)
- return torch.pow((1 - pt_1), gamma)
-
- def dice_loss(self, logits, gt, eps=1):
- # Convert logits to probabilities
- # Flatten the tensorsx
- probs = torch.sigmoid(logits)
-
- probs = probs.view(-1)
- gt = gt.view(-1)
-
- # Compute Dice coefficient
- intersection = (probs * gt).sum()
-
- dice_coeff = (2.0 * intersection + eps) / (probs.sum() + gt.sum() + eps)
-
- # Compute Dice Los[s
- loss = 1 - dice_coeff
- return loss
-
- def focal_loss(self, logits, gt, gamma=2):
- logits = logits.reshape(-1, 1)
- gt = gt.reshape(-1, 1)
- logits = torch.cat((1 - logits, logits), dim=1)
-
- probs = torch.sigmoid(logits)
- pt = probs.gather(1, gt.long())
-
- modulating_factor = (1 - pt) ** gamma
- # pt_false= pt<=0.5
- # modulating_factor[pt_false] *= 2
- focal_loss = -modulating_factor * torch.log(pt + 1e-12)
-
- # Compute the mean focal loss
- loss = focal_loss.mean()
- return loss # Store as a Python number to save memory
-
- def forward(self, logits, target):
- logits = logits.squeeze(1)
- target = target.squeeze(1)
- # Dice Loss
- # prob = F.softmax(logits, dim=1)[:, 1, ...]
-
- dice_loss = self.dice_loss(logits, target)
- tversky_loss = self.tversky_loss(logits, target)
-
- # Focal Loss
- focal_loss = self.focal_loss(logits, target.squeeze(-1))
- alpha = 20.0
- # Combined Loss
- combined_loss = alpha * focal_loss + dice_loss
- return combined_loss
-
-
- def img_enhance(img2, coef=0.2):
- img_mean = np.mean(img2)
- img_max = np.max(img2)
- val = (img_max - img_mean) * coef + img_mean
- img2[img2 < img_mean * 0.7] = img_mean * 0.7
- img2[img2 > val] = val
- return img2
-
-
- def dice_coefficient(logits, gt):
- eps=1
- binary_mask = logits>0
- intersection = (binary_mask * gt).sum(dim=(-2,-1))
- dice_scores = (2.0 * intersection + eps) / (binary_mask.sum(dim=(-2,-1)) + gt.sum(dim=(-2,-1)) + eps)
-
- return dice_scores.mean()
-
- def what_the_f(low_res_masks,label):
-
- low_res_label = F.interpolate(label, low_res_masks.shape[-2:])
- dice = dice_coefficient(
- low_res_masks, low_res_label
- )
- return dice
-
-
- image_size = args.image_size
- exp_id = 0
- found = 0
- if debug:
- user_input='debug'
- else:
- user_input = input("Related changes: ")
- while found == 0:
- try:
- os.makedirs(f"exps/{exp_id}-{user_input}")
- found = 1
- except:
- exp_id = exp_id + 1
- copyfile(os.path.realpath(__file__), f"exps/{exp_id}-{user_input}/code.py")
- layer_n = 4
- L = layer_n
- a = np.full(L, layer_n)
- params = {"M": 255, "a": a, "p": 0.35}
-
- device = "cuda:1"
- from segment_anything import SamPredictor, sam_model_registry
- class panc_sam(nn.Module):
- def __init__(self, *args, **kwargs) -> None:
- super().__init__(*args, **kwargs)
-
- #Promptless
- sam = torch.load(args.promptprovider)
-
- self.prompt_encoder = sam.prompt_encoder
-
- self.mask_decoder = sam.mask_decoder
- for param in self.prompt_encoder.parameters():
- param.requires_grad = False
-
- for param in self.mask_decoder.parameters():
- param.requires_grad = False
-
- #with Prompt
- sam=sam_model_registry[args.model_type](args.checkpoint)
- self.image_encoder = sam.image_encoder
- self.prompt_encoder2 = sam.prompt_encoder
- self.mask_decoder2 = sam.mask_decoder
-
- for param in self.image_encoder.parameters():
- param.requires_grad = False
-
- for param in self.prompt_encoder2.parameters():
- param.requires_grad = False
-
- def forward(self, input_images,box=None):
-
-
- # input_images = torch.stack([x["image"] for x in batched_input], dim=0)
- with torch.no_grad():
- image_embeddings = self.image_encoder(input_images).detach()
-
-
- outputs_prompt = []
- outputs = []
-
- for curr_embedding in image_embeddings:
-
- with torch.no_grad():
- sparse_embeddings, dense_embeddings = self.prompt_encoder(
- points=None,
- boxes=None,
- masks=None,
- )
-
- low_res_masks, _ = self.mask_decoder(
- image_embeddings=curr_embedding,
- image_pe=self.prompt_encoder.get_dense_pe().detach(),
- sparse_prompt_embeddings=sparse_embeddings.detach(),
- dense_prompt_embeddings=dense_embeddings.detach(),
- multimask_output=False,
- )
- outputs_prompt.append(low_res_masks)
-
- points, point_labels = main_prompt(low_res_masks)
- points_sconed, point_labels_sconed = sample_prompt(low_res_masks)
- points = torch.cat([points, points_sconed], dim=1) # Adjust dimensions as necessary
- point_labels = torch.cat([point_labels, point_labels_sconed], dim=1)
-
-
-
- points = points * 4
- points = (points, point_labels)
-
- with torch.no_grad():
- sparse_embeddings, dense_embeddings = self.prompt_encoder2(
- points=points,
- boxes=None,
- masks=None,
- )
-
- low_res_masks, _ = self.mask_decoder2(
- image_embeddings=curr_embedding,
- image_pe=self.prompt_encoder2.get_dense_pe().detach(),
- sparse_prompt_embeddings=sparse_embeddings.detach(),
- dense_prompt_embeddings=dense_embeddings.detach(),
- multimask_output=False,
- )
-
- outputs.append(low_res_masks)
- low_res_masks_promtp = torch.cat(outputs_prompt, dim=0)
- low_res_masks = torch.cat(outputs, dim=0)
-
-
- return low_res_masks, low_res_masks_promtp
-
- augmentation = A.Compose(
- [
- A.Rotate(limit=30, p=0.5),
- A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1),
- A.RandomResizedCrop(1024, 1024, scale=(0.9, 1.0), p=1),
- A.HorizontalFlip(p=0.5),
- A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.5),
- A.CoarseDropout(
- max_holes=8,
- max_height=16,
- max_width=16,
- min_height=8,
- min_width=8,
- fill_value=0,
- p=0.5,
- ),
- A.RandomScale(scale_limit=0.3, p=0.5),
-
- ]
- )
- panc_sam_instance = panc_sam()
- panc_sam_instance.to(device)
- panc_sam_instance.train()
- train_dataset = PanDataset(
- [args.dir_train],
- [args.dir_labels],
-
- [["NIH_PNG",1]],
-
- args.image_size,
-
- slice_per_image=args.slice_per_image,
- train=True,
- augmentation=augmentation,
- )
- val_dataset = PanDataset(
- [args.dir_train],
- [args.dir_labels],
-
- [["NIH_PNG",1]],
-
- image_size,
-
- slice_per_image=args.slice_per_image,
- train=False,
- )
- train_loader = DataLoader(
- train_dataset,
- batch_size=args.batch_size,
- collate_fn=train_dataset.collate_fn,
- shuffle=True,
- drop_last=False,
- num_workers=args.num_workers,
- )
- val_loader = DataLoader(
- val_dataset,
- batch_size=args.batch_size,
- collate_fn=val_dataset.collate_fn,
- shuffle=True,
- drop_last=False,
- num_workers=args.num_workers,
- )
-
- lr = 1e-4
- #1e-3
- max_lr = 3e-4
- wd = 5e-4
- optimizer_main = torch.optim.Adam(
- # parameters,
- list(panc_sam_instance.mask_decoder2.parameters()),
-
- lr=lr,
- weight_decay=wd,
- )
- scheduler_main = torch.optim.lr_scheduler.OneCycleLR(
- optimizer_main,
- max_lr=max_lr,
- epochs=args.num_epochs,
- steps_per_epoch=args.sample_size // (args.accumulative_batch_size // args.batch_size),
- )
- #####################################################
-
- from statistics import mean
-
- from tqdm import tqdm
- from torch.nn.functional import threshold, normalize
-
- loss_function = loss_fn(alpha=0.5, gamma=4.0)
- loss_function.to(device)
-
- from time import time
- import time as s_time
-
- log_file = open(f"exps/{exp_id}-{user_input}/log.txt", "a")
-
-
- def process_model(main_model , data_loader, train=0, save_output=0):
- epoch_losses = []
- index = 0
- results = torch.zeros((2, 0, 256, 256))
- total_dice = 0.0
- num_samples = 0
- total_dice_main =0.0
-
- counterb = 0
- for image, label,raw_data in tqdm(data_loader, total=args.sample_size):
- num_samples += 1
- counterb += 1
- index += 1
- image = image.to(device)
- label = label.to(device).float()
-
- ############################promt########################################
- box = torch.tensor([[200, 200, 750, 800]]).to(device)
- low_res_masks_main,low_res_masks_prompt = main_model(image,box)
-
- low_res_label = F.interpolate(label, low_res_masks_main.shape[-2:])
-
- dice_prompt = what_the_f(low_res_masks_prompt,low_res_label)
- dice_main = what_the_f(low_res_masks_main,low_res_label)
-
- binary_mask = normalize(threshold(low_res_masks_main, 0.0,0))
- total_dice += dice_prompt
- total_dice_main+=dice_main
-
- average_dice = total_dice / num_samples
- average_dice_main = total_dice_main / num_samples
- log_file.write(str(average_dice) + "\n")
- log_file.flush()
- loss = loss_function.forward(low_res_masks_main, low_res_label)
- loss /= args.accumulative_batch_size / args.batch_size
-
- if train:
- loss.backward()
-
-
- if index % (args.accumulative_batch_size / args.batch_size) == 0:
- # print(loss)
- optimizer_main.step()
- scheduler_main.step()
- optimizer_main.zero_grad()
- index = 0
-
- else:
- result = torch.cat(
- (
- low_res_masks_main[0].detach().cpu().reshape(1, 1, 256, 256),
- binary_mask[0].detach().cpu().reshape(1, 1, 256, 256),
- ),
- dim=0,
- )
- results = torch.cat((results, result), dim=1)
- if index % (args.accumulative_batch_size / args.batch_size) == 0:
- epoch_losses.append(loss.item())
- if counterb == args.sample_size and train:
- break
- elif counterb == args.sample_size //2 and not train:
- break
-
- return epoch_losses, results, average_dice ,average_dice_main
-
-
- def train_model(train_loader, val_loader, K_fold=False, N_fold=7, epoch_num_start=7):
- print("Train model started.")
-
- train_losses = []
- train_epochs = []
- val_losses = []
- val_epochs = []
- dice = []
- dice_main = []
- dice_val = []
- dice_val_main =[]
- results = []
- #########################save image##################################
- index = 0
- if debug==0:
- for image, label , raw_data in tqdm(val_loader):
- if index < 100:
- if not os.path.exists(f"ims/batch_{index}"):
- os.mkdir(f"ims/batch_{index}")
-
- save_img(
- image[0],
- f"ims/batch_{index}/img_0.png",
- )
- save_img(0.2 * image[0][0] + label[0][0], f"ims/batch_{index}/gt_0.png")
-
- index += 1
- if index == 100:
- break
-
- # In each epoch we will train the model and the val it
- # training without k_fold cross validation:
- last_best_dice = 0
- if K_fold == False:
- for epoch in range(args.num_epochs):
- print(f"=====================EPOCH: {epoch + 1}=====================")
- log_file.write(
- f"=====================EPOCH: {epoch + 1}===================\n"
- )
- print("Training:")
- train_epoch_losses, epoch_results, average_dice ,average_dice_main = process_model(panc_sam_instance,train_loader, train=1)
-
- dice.append(average_dice)
- dice_main.append(average_dice_main)
- train_losses.append(train_epoch_losses)
- if (average_dice) > 0.5:
- print("validating:")
- val_epoch_losses, epoch_results, average_dice_val,average_dice_val_main = process_model(
- panc_sam_instance,val_loader
- )
-
- val_losses.append(val_epoch_losses)
- for i in tqdm(range(len(epoch_results[0]))):
- if not os.path.exists(f"ims/batch_{i}"):
- os.mkdir(f"ims/batch_{i}")
-
- save_img(
- epoch_results[0, i].clone(),
- f"ims/batch_{i}/prob_epoch_{epoch}.png",
- )
- save_img(
- epoch_results[1, i].clone(),
- f"ims/batch_{i}/pred_epoch_{epoch}.png",
- )
-
- train_mean_losses = [mean(x) for x in train_losses]
- val_mean_losses = [mean(x) for x in val_losses]
- np.save("train_losses.npy", train_mean_losses)
- np.save("val_losses.npy", val_mean_losses)
-
- print(f"Train Dice: {average_dice}")
- print(f"Train Dice main: {average_dice_main}")
-
-
- print(f"Mean train loss: {mean(train_epoch_losses)}")
-
- try:
- dice_val.append(average_dice_val)
- dice_val_main.append(average_dice_val_main)
- print(f"val Dice : {average_dice_val}")
- print(f"val Dice main : {average_dice_val_main}")
-
- print(f"Mean val loss: {mean(val_epoch_losses)}")
-
- results.append(epoch_results)
- val_epochs.append(epoch)
- train_epochs.append(epoch)
- plt.plot(val_epochs, val_mean_losses, train_epochs, train_mean_losses)
- if average_dice_val_main > last_best_dice:
- torch.save(
- panc_sam_instance,
- f"exps/{exp_id}-{user_input}/sam_tuned_save.pth",
- )
-
- last_best_dice = average_dice_val
- del epoch_results
- del average_dice_val
- except:
- train_epochs.append(epoch)
- plt.plot(train_epochs, train_mean_losses)
- print(f"=================End of EPOCH: {epoch}==================\n")
-
- plt.yscale("log")
- plt.title("Mean epoch loss")
- plt.xlabel("Epoch Number")
- plt.ylabel("Loss")
- plt.savefig("result")
-
- ## training with k-fold cross validation:
-
-
- return train_losses, val_losses, results
-
-
- train_losses, val_losses, results = train_model(train_loader, val_loader)
- log_file.close()
-
|