In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import os
from collections import defaultdict
import shutil
import time
import copy
import math
import random
from imutils import paths
from collections import OrderedDict

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from numpy import unravel_index

from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler

print(torch.cuda.is_available())

from torchvision import transforms
from torchvision import datasets

from PIL import *
import albumentations as A

# from torchsummary import summary
import segmentation_models_pytorch as smp
import captum

In [None]:
!pip install imutils
!pip install segmentation_models_pytorch
!pip install captum
!pip install albumentations
!pip install gdown 
import gdown 
url = 'https://drive.google.com/uc?id=1Jw1kwRLrXbE1OLGvAiwpz9gXTuEOy4Mc' 
output = 'data.zip'
gdown.download(url, output)
# url = 'https://drive.google.com/uc?id=1Jw1kwRLrXbE1OLGvAiwpz9gXTuEOy4Mc' 
# output = 'best_model.pth'
# gdown.download(url, output)
!unzip data.zip

In [None]:
# a = np.random.permutation(range(1, 1001))
# train, val, test = a[:800], a[800:900], a[900:]
# for i in train:
#     file_name = '{:04d}.jpg'.format(i)
#     os.rename('/home/fazel/code/KVASIR/images/' + file_name, '/home/fazel/code/KVASIR/train/images/' + file_name)
#     os.rename('/home/fazel/code/KVASIR/masks/' + file_name, '/home/fazel/code/KVASIR/train/masks/' + file_name)
# for i in val:
#     file_name = '{:04d}.jpg'.format(i)
#     os.rename('/home/fazel/code/KVASIR/images/' + file_name, '/home/fazel/code/KVASIR/val/images/' + file_name)
#     os.rename('/home/fazel/code/KVASIR/masks/' + file_name, '/home/fazel/code/KVASIR/val/masks/' + file_name)
# for i in test:
#     file_name = '{:04d}.jpg'.format(i)
#     os.rename('/home/fazel/code/KVASIR/images/' + file_name, '/home/fazel/code/KVASIR/test/images/' + file_name)
#     os.rename('/home/fazel/code/KVASIR/masks/' + file_name, '/home/fazel/code/KVASIR/test/masks/' + file_name)

In [None]:
def visualize(**images):
    n_images = len(images)
    f, axarr = plt.subplots(1, n_images, figsize=(4 * n_images,4))
    for idx, (name, image) in enumerate(images.items()):
        if image.shape[0] == 3 or image.shape[0] == 2:
            axarr[idx].imshow(np.squeeze(image.permute(1, 2, 0)))
        else: 
            axarr[idx].imshow(np.squeeze(image))
        axarr[idx].set_title(name.replace('_',' ').title(), fontsize=20)
    plt.show()
    
class EndoscopyDataset(Dataset):
    def __init__(self, images, masks, augmentations=None):   
        self.input_images = images
        self.target_masks = masks
        self.augmentations = augmentations

    def __len__(self):
        return len(self.input_images)
    
    def __getitem__(self, idx): 
        img = Image.open(os.path.join(self.input_images[idx])).convert('RGB')
        mask = Image.open(os.path.join(self.target_masks[idx])).convert('RGB')
        img = transforms.Compose([transforms.Resize((400, 400), interpolation=transforms.InterpolationMode.NEAREST), transforms.ToTensor()])(img)
        mask = transforms.Compose([transforms.Resize((400, 400), interpolation=transforms.InterpolationMode.NEAREST), transforms.Grayscale(), transforms.ToTensor()])(mask)
        img = img.permute((1, 2, 0))
        mask = mask.permute((1, 2, 0))
        img = img.cpu().detach().numpy()
        mask = mask.cpu().detach().numpy()
        
        if self.augmentations:
            augmented = self.augmentations(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']
        
        img = torch.tensor(img, dtype=torch.float)
        img = img.permute((2, 0, 1))
        mask = torch.tensor(mask, dtype=torch.float)
        mask = mask.permute((2, 0, 1))
        
        return [img, mask]
    
train_batch_size = 8
val_batch_size = 4
test_batch_size = 4
num_workers = 2

# main_dir = '/media/external_3TB/3TB/rasekh/fazel/KVASIR/'
# main_dir = '/content/drive/My Drive/KVASIR/'
# main_dir = 'KVASIR/'
main_dir = './'

train_images = sorted(list(paths.list_files(main_dir + 'train/images/', contains="jpg")))
val_images = sorted(list(paths.list_files(main_dir + 'val/images/', contains="jpg")))
test_images = sorted(list(paths.list_files(main_dir + 'test/images/', contains="jpg")))

train_masks = sorted(list(paths.list_files(main_dir + 'train/masks/', contains="jpg")))
val_masks = sorted(list(paths.list_files(main_dir + 'val/masks/', contains="jpg")))
test_masks = sorted(list(paths.list_files(main_dir + 'test/masks/', contains="jpg")))

augmentations = A.Compose({
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=(-90, 90)),
        A.VerticalFlip(p=0.5),
        A.Transpose(p=0.5),
        A.GaussianBlur(p=0.5),
})

dataset = {
    'train': EndoscopyDataset(train_images, train_masks, augmentations), 
    'val': EndoscopyDataset(val_images, val_masks, None), 
    'test': EndoscopyDataset(test_images, test_masks, None)
}

dataloader = {
    'train': DataLoader(dataset['train'], batch_size=train_batch_size, shuffle=True, num_workers=num_workers),
    'val': DataLoader(dataset['val'], batch_size=val_batch_size, shuffle=True, num_workers=num_workers),
    'test': DataLoader(dataset['test'], batch_size=test_batch_size, shuffle=False, num_workers=num_workers)
}

image, mask = dataset['train'][random.randint(0, len(dataset['train'])-1)]
print(image.shape, image.min(), image.max())
print(mask.shape, mask.min(), mask.max())
visualize(
    original_image = image,
    grund_truth_mask = mask
)

In [None]:
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', in_channels=3, out_channels=1, init_features=32, pretrained=True)

In [None]:
training = True
epochs = 400
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = UNet(init_features=64).to(device)

loss = smp.utils.losses.DiceLoss()

metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
    smp.utils.metrics.Fscore(threshold=0.5),
    smp.utils.metrics.Accuracy(threshold=0.5)
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=1, T_mult=2, eta_min=5e-5,
)

In [None]:
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=device,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=device,
    verbose=True,
)

In [None]:
%%time

if training:

    best_iou_score = 0.0
    train_logs_list, valid_logs_list = [], []

    for i in range(0, epochs):
        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(dataloader['train'])
        valid_logs = valid_epoch.run(dataloader['val'])
        train_logs_list.append(train_logs)
        valid_logs_list.append(valid_logs)

        if best_iou_score < valid_logs['iou_score']:
            best_iou_score = valid_logs['iou_score']
            torch.save(model, main_dir + 'best_model.pth')
            print('Model saved!')

In [None]:
train_logs_df = pd.DataFrame(train_logs_list)
valid_logs_df = pd.DataFrame(valid_logs_list)

In [None]:
plt.figure(figsize=(6,6))
plt.plot(train_logs_df.index.tolist(), train_logs_df.dice_loss.tolist(), lw=3, label = 'Train')
plt.plot(valid_logs_df.index.tolist(), valid_logs_df.dice_loss.tolist(), lw=3, label = 'Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('Dice Loss', fontsize=20)
plt.title('Dice Loss Plot', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
plt.savefig('dice_loss_plot.png')
plt.show()

In [None]:
plt.figure(figsize=(6,6))
plt.plot(train_logs_df.index.tolist(), train_logs_df.iou_score.tolist(), lw=3, label = 'Train')
plt.plot(valid_logs_df.index.tolist(), valid_logs_df.iou_score.tolist(), lw=3, label = 'Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('IoU Score', fontsize=20)
plt.title('IoU Score Plot', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
plt.savefig('iou_score_plot.png')
plt.show()

In [None]:
plt.figure(figsize=(6,6))
plt.plot(train_logs_df.index.tolist(), train_logs_df.fscore.tolist(), lw=3, label = 'Train')
plt.plot(valid_logs_df.index.tolist(), valid_logs_df.fscore.tolist(), lw=3, label = 'Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('F1 Score', fontsize=20)
plt.title('F1 Score Plot', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
plt.savefig('fscore_plot.png')
plt.show()

In [None]:
plt.figure(figsize=(6,6))
plt.plot(train_logs_df.index.tolist(), train_logs_df.accuracy.tolist(), lw=3, label = 'Train')
plt.plot(valid_logs_df.index.tolist(), valid_logs_df.accuracy.tolist(), lw=3, label = 'Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('Accuracy Score', fontsize=20)
plt.title('Accuracy Score Plot', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
plt.savefig('accuracy_plot.png')
plt.show()

In [None]:
model = torch.load('./best_model.pth')

In [None]:
%matplotlib inline

model.eval()

IOUs = []
F1s = []
Accuracies = []
predictions = []

with torch.no_grad():
    for i, (inputs, labels) in enumerate(dataloader['test']):
        inputs = inputs.to(device)
        labels = labels.to(device)

        pred_mask = model(inputs)

        for i in range(len(inputs)):
            test_image = inputs[i]
            test_mask = labels[i]
            predMask = pred_mask[i]
            
            iou = smp.utils.functional.iou(predMask, test_mask, threshold=0.5)
            IOUs.append(iou.cpu().detach())

            f1 = smp.utils.functional.f_score(predMask, test_mask, threshold=0.5)
            F1s.append(f1.cpu().detach())
            
            accuracy = smp.utils.functional.accuracy(predMask, test_mask, threshold=0.5)
            Accuracies.append(accuracy.cpu().detach())
            
            predictions.append(predMask)

            visualize(
                original_image = test_image.cpu(),
                ground_truth_mask = test_mask.cpu(),
                predicted_mask = predMask.cpu(),
            )

In [None]:
print('Test IOU: ' + str(np.mean(IOUs)))
print('Test F1: ' + str(np.mean(F1s)))
print('Test Accuracy: ' + str(np.mean(Accuracies)))

In [None]:
print(np.min(IOUs), np.min(F1s), np.min(Accuracies))
print(np.max(IOUs), np.max(F1s), np.max(Accuracies))

In [None]:
predictions = torch.cat(predictions).cpu().detach().numpy()

In [None]:
index = np.argmax(IOUs)
img, mask = dataset['test'][index]
pred_mask = predictions[index]
pred_mask = torch.Tensor(pred_mask.reshape((1, pred_mask.shape[0], pred_mask.shape[1]))).cpu().detach()
f, axarr = plt.subplots(1, 3, figsize=(12, 4))
axarr[0].imshow(np.squeeze(img.permute(1, 2, 0)))
axarr[1].imshow(np.squeeze(mask.permute(1, 2, 0)))
axarr[2].imshow(np.squeeze(pred_mask.permute(1, 2, 0)))
plt.show()

In [None]:
from captum.attr import visualization as viz
from captum.attr import LayerGradCam, FeatureAblation, LayerActivation, LayerAttribution

In [None]:
"""
This wrapper computes the segmentation model output and sums the pixel scores for
all pixels predicted as each class, returning a tensor with a single value for
each class. This makes it easier to attribute with respect to a single output
scalar, as opposed to an individual pixel output attribution.
"""
def agg_segmentation_wrapper(inp):
    model_out = fcn(inp)['out']
    # Creates binary matrix with 1 for original argmax class for each pixel
    # and 0 otherwise. Note that this may change when the input is ablated
    # so we use the original argmax predicted above, out_max.
    selected_inds = torch.zeros_like(model_out[0:1]).scatter_(1, out_max, 1)
    return (model_out * selected_inds).sum(dim=(2,3))

In [None]:
lgc = LayerGradCam(agg_segmentation_wrapper, model.encoder1)

In [None]:
gc_attr = lgc.attribute(normalized_inp, target=6)

In [None]:
la = LayerActivation(agg_segmentation_wrapper, model.encoder1)
activation = la.attribute(normalized_inp)
print("Input Shape:", normalized_inp.shape)
print("Layer Activation Shape:", activation.shape)
print("Layer GradCAM Shape:", gc_attr.shape)

In [None]:
viz.visualize_image_attr(gc_attr[0].cpu().permute(1,2,0).detach().numpy(),sign="all")

In [None]:
upsampled_gc_attr = LayerAttribution.interpolate(gc_attr,normalized_inp.shape[2:])
print("Upsampled Shape:",upsampled_gc_attr.shape)

In [None]:
viz.visualize_image_attr_multiple(upsampled_gc_attr[0].cpu().permute(1,2,0).detach().numpy(),original_image=preproc_img.permute(1,2,0).numpy(),signs=["all", "positive", "negative"],methods=["original_image", "blended_heat_map","blended_heat_map"])

In [None]:
img_without_train = (1 - (out_max == 19).float())[0].cpu() * preproc_img
plt.imshow(img_without_train.permute(1,2,0))

In [None]:
fa = FeatureAblation(agg_segmentation_wrapper)
fa_attr = fa.attribute(normalized_inp, feature_mask=out_max, perturbations_per_eval=2, target=6)

In [None]:
viz.visualize_image_attr(fa_attr[0].cpu().detach().permute(1,2,0).numpy(),sign="all")

In [None]:
fa_attr_without_max = (1 - (out_max == 6).float())[0] * fa_attr

In [None]:
viz.visualize_image_attr(fa_attr_without_max[0].cpu().detach().permute(1,2,0).numpy(),sign="all")