Interpretability of lung nodules segmentation
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

kvasir.ipynb 22KB

1
  1. {"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"pygments_lexer":"ipython3","nbconvert_exporter":"python","version":"3.6.4","file_extension":".py","codemirror_mode":{"name":"ipython","version":3},"name":"python","mimetype":"text/x-python"}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"%matplotlib inline\n%load_ext autoreload\n%autoreload 2\n\nimport os\nfrom collections import defaultdict\nimport shutil\nimport time\nimport copy\nimport math\nimport random\nfrom imutils import paths\nfrom collections import OrderedDict\n\nimport pandas as pd\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom numpy import unravel_index\n\nfrom torch.utils.data import Dataset, DataLoader\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\n\nprint(torch.cuda.is_available())\n\nfrom torchvision import transforms\nfrom torchvision import datasets\n\nfrom PIL import *\nimport albumentations as A\n\n# from torchsummary import summary\nimport segmentation_models_pytorch as smp\nimport captum","metadata":{"id":"SUEOkB6EoUAA","outputId":"d0151148-10bd-4b6b-98f3-27d3aec82203","execution":{"iopub.status.busy":"2021-11-19T08:12:31.530863Z","iopub.execute_input":"2021-11-19T08:12:31.531412Z","iopub.status.idle":"2021-11-19T08:12:40.925102Z","shell.execute_reply.started":"2021-11-19T08:12:31.531371Z","shell.execute_reply":"2021-11-19T08:12:40.924054Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"!pip install imutils\n!pip install segmentation_models_pytorch\n!pip install captum\n!pip install albumentations\n!pip install gdown \nimport gdown \nurl = 'https://drive.google.com/uc?id=1Jw1kwRLrXbE1OLGvAiwpz9gXTuEOy4Mc' \noutput = 'data.zip'\ngdown.download(url, output)\n# url = 'https://drive.google.com/uc?id=1Jw1kwRLrXbE1OLGvAiwpz9gXTuEOy4Mc' \n# output = 'best_model.pth'\n# gdown.download(url, output)\n!unzip data.zip","metadata":{"scrolled":true,"execution":{"iopub.status.busy":"2021-11-19T08:11:30.309863Z","iopub.execute_input":"2021-11-19T08:11:30.310334Z","iopub.status.idle":"2021-11-19T08:12:31.528468Z","shell.execute_reply.started":"2021-11-19T08:11:30.310248Z","shell.execute_reply":"2021-11-19T08:12:31.527632Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# a = np.random.permutation(range(1, 1001))\n# train, val, test = a[:800], a[800:900], a[900:]\n# for i in train:\n# file_name = '{:04d}.jpg'.format(i)\n# os.rename('/home/fazel/code/KVASIR/images/' + file_name, '/home/fazel/code/KVASIR/train/images/' + file_name)\n# os.rename('/home/fazel/code/KVASIR/masks/' + file_name, '/home/fazel/code/KVASIR/train/masks/' + file_name)\n# for i in val:\n# file_name = '{:04d}.jpg'.format(i)\n# os.rename('/home/fazel/code/KVASIR/images/' + file_name, '/home/fazel/code/KVASIR/val/images/' + file_name)\n# os.rename('/home/fazel/code/KVASIR/masks/' + file_name, '/home/fazel/code/KVASIR/val/masks/' + file_name)\n# for i in test:\n# file_name = '{:04d}.jpg'.format(i)\n# os.rename('/home/fazel/code/KVASIR/images/' + file_name, '/home/fazel/code/KVASIR/test/images/' + file_name)\n# os.rename('/home/fazel/code/KVASIR/masks/' + file_name, '/home/fazel/code/KVASIR/test/masks/' + file_name)","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"def visualize(**images):\n n_images = len(images)\n f, axarr = plt.subplots(1, n_images, figsize=(4 * n_images,4))\n for idx, (name, image) in enumerate(images.items()):\n if image.shape[0] == 3 or image.shape[0] == 2:\n axarr[idx].imshow(np.squeeze(image.permute(1, 2, 0)))\n else: \n axarr[idx].imshow(np.squeeze(image))\n axarr[idx].set_title(name.replace('_',' ').title(), fontsize=20)\n plt.show()\n \nclass EndoscopyDataset(Dataset):\n def __init__(self, images, masks, augmentations=None): \n self.input_images = images\n self.target_masks = masks\n self.augmentations = augmentations\n\n def __len__(self):\n return len(self.input_images)\n \n def __getitem__(self, idx): \n img = Image.open(os.path.join(self.input_images[idx])).convert('RGB')\n mask = Image.open(os.path.join(self.target_masks[idx])).convert('RGB')\n img = transforms.Compose([transforms.Resize((400, 400), interpolation=transforms.InterpolationMode.NEAREST), transforms.ToTensor()])(img)\n mask = transforms.Compose([transforms.Resize((400, 400), interpolation=transforms.InterpolationMode.NEAREST), transforms.Grayscale(), transforms.ToTensor()])(mask)\n img = img.permute((1, 2, 0))\n mask = mask.permute((1, 2, 0))\n img = img.cpu().detach().numpy()\n mask = mask.cpu().detach().numpy()\n \n if self.augmentations:\n augmented = self.augmentations(image=img, mask=mask)\n img = augmented['image']\n mask = augmented['mask']\n \n img = torch.tensor(img, dtype=torch.float)\n img = img.permute((2, 0, 1))\n mask = torch.tensor(mask, dtype=torch.float)\n mask = mask.permute((2, 0, 1))\n \n return [img, mask]\n \ntrain_batch_size = 8\nval_batch_size = 4\ntest_batch_size = 4\nnum_workers = 2\n\n# main_dir = '/media/external_3TB/3TB/rasekh/fazel/KVASIR/'\n# main_dir = '/content/drive/My Drive/KVASIR/'\n# main_dir = 'KVASIR/'\nmain_dir = './'\n\ntrain_images = sorted(list(paths.list_files(main_dir + 'train/images/', contains=\"jpg\")))\nval_images = sorted(list(paths.list_files(main_dir + 'val/images/', contains=\"jpg\")))\ntest_images = sorted(list(paths.list_files(main_dir + 'test/images/', contains=\"jpg\")))\n\ntrain_masks = sorted(list(paths.list_files(main_dir + 'train/masks/', contains=\"jpg\")))\nval_masks = sorted(list(paths.list_files(main_dir + 'val/masks/', contains=\"jpg\")))\ntest_masks = sorted(list(paths.list_files(main_dir + 'test/masks/', contains=\"jpg\")))\n\naugmentations = A.Compose({\n A.HorizontalFlip(p=0.5),\n A.Rotate(limit=(-90, 90)),\n A.VerticalFlip(p=0.5),\n A.Transpose(p=0.5),\n A.GaussianBlur(p=0.5),\n})\n\ndataset = {\n 'train': EndoscopyDataset(train_images, train_masks, augmentations), \n 'val': EndoscopyDataset(val_images, val_masks, None), \n 'test': EndoscopyDataset(test_images, test_masks, None)\n}\n\ndataloader = {\n 'train': DataLoader(dataset['train'], batch_size=train_batch_size, shuffle=True, num_workers=num_workers),\n 'val': DataLoader(dataset['val'], batch_size=val_batch_size, shuffle=True, num_workers=num_workers),\n 'test': DataLoader(dataset['test'], batch_size=test_batch_size, shuffle=False, num_workers=num_workers)\n}\n\nimage, mask = dataset['train'][random.randint(0, len(dataset['train'])-1)]\nprint(image.shape, image.min(), image.max())\nprint(mask.shape, mask.min(), mask.max())\nvisualize(\n original_image = image,\n grund_truth_mask = mask\n)","metadata":{"id":"dOKGH2G-oUAL","execution":{"iopub.status.busy":"2021-11-19T08:12:40.930576Z","iopub.execute_input":"2021-11-19T08:12:40.931148Z","iopub.status.idle":"2021-11-19T08:12:41.511706Z","shell.execute_reply.started":"2021-11-19T08:12:40.93111Z","shell.execute_reply":"2021-11-19T08:12:41.509263Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', in_channels=3, out_channels=1, init_features=32, pretrained=True)","metadata":{"execution":{"iopub.status.busy":"2021-11-19T08:12:50.564963Z","iopub.execute_input":"2021-11-19T08:12:50.565779Z","iopub.status.idle":"2021-11-19T08:12:56.213428Z","shell.execute_reply.started":"2021-11-19T08:12:50.565732Z","shell.execute_reply":"2021-11-19T08:12:56.212475Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"training = True\nepochs = 400\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n# model = UNet(init_features=64).to(device)\n\nloss = smp.utils.losses.DiceLoss()\n\nmetrics = [\n smp.utils.metrics.IoU(threshold=0.5),\n smp.utils.metrics.Fscore(threshold=0.5),\n smp.utils.metrics.Accuracy(threshold=0.5)\n]\n\noptimizer = torch.optim.Adam([ \n dict(params=model.parameters(), lr=0.0001),\n])\n\nlr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(\n optimizer, T_0=1, T_mult=2, eta_min=5e-5,\n)","metadata":{"id":"fIlVc7OapTuX","execution":{"iopub.status.busy":"2021-11-19T08:12:56.215424Z","iopub.execute_input":"2021-11-19T08:12:56.215714Z","iopub.status.idle":"2021-11-19T08:12:56.290211Z","shell.execute_reply.started":"2021-11-19T08:12:56.215675Z","shell.execute_reply":"2021-11-19T08:12:56.289446Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"train_epoch = smp.utils.train.TrainEpoch(\n model, \n loss=loss, \n metrics=metrics, \n optimizer=optimizer,\n device=device,\n verbose=True,\n)\n\nvalid_epoch = smp.utils.train.ValidEpoch(\n model, \n loss=loss, \n metrics=metrics, \n device=device,\n verbose=True,\n)","metadata":{"id":"hyb0lsmCpejc","execution":{"iopub.status.busy":"2021-11-19T08:12:59.932601Z","iopub.execute_input":"2021-11-19T08:12:59.933311Z","iopub.status.idle":"2021-11-19T08:13:04.723152Z","shell.execute_reply.started":"2021-11-19T08:12:59.933271Z","shell.execute_reply":"2021-11-19T08:13:04.722365Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"%%time\n\nif training:\n\n best_iou_score = 0.0\n train_logs_list, valid_logs_list = [], []\n\n for i in range(0, epochs):\n print('\\nEpoch: {}'.format(i))\n train_logs = train_epoch.run(dataloader['train'])\n valid_logs = valid_epoch.run(dataloader['val'])\n train_logs_list.append(train_logs)\n valid_logs_list.append(valid_logs)\n\n if best_iou_score < valid_logs['iou_score']:\n best_iou_score = valid_logs['iou_score']\n torch.save(model, main_dir + 'best_model.pth')\n print('Model saved!')","metadata":{"id":"CEVEzEUYpswo","outputId":"79025bfb-75ad-4336-9551-f36dea5218a7","scrolled":true,"execution":{"iopub.status.busy":"2021-11-19T08:13:04.724717Z","iopub.execute_input":"2021-11-19T08:13:04.724973Z","iopub.status.idle":"2021-11-19T11:32:41.794187Z","shell.execute_reply.started":"2021-11-19T08:13:04.724938Z","shell.execute_reply":"2021-11-19T11:32:41.793182Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"train_logs_df = pd.DataFrame(train_logs_list)\nvalid_logs_df = pd.DataFrame(valid_logs_list)","metadata":{"id":"_B-5qHn-rO6_","outputId":"8c60ef1d-897e-4e5b-a94e-15ed79d5c870","execution":{"iopub.status.busy":"2021-11-19T11:32:43.608127Z","iopub.execute_input":"2021-11-19T11:32:43.608444Z","iopub.status.idle":"2021-11-19T11:32:43.687767Z","shell.execute_reply.started":"2021-11-19T11:32:43.608405Z","shell.execute_reply":"2021-11-19T11:32:43.686767Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"plt.figure(figsize=(6,6))\nplt.plot(train_logs_df.index.tolist(), train_logs_df.dice_loss.tolist(), lw=3, label = 'Train')\nplt.plot(valid_logs_df.index.tolist(), valid_logs_df.dice_loss.tolist(), lw=3, label = 'Valid')\nplt.xlabel('Epochs', fontsize=20)\nplt.ylabel('Dice Loss', fontsize=20)\nplt.title('Dice Loss Plot', fontsize=20)\nplt.legend(loc='best', fontsize=16)\nplt.grid()\nplt.savefig('dice_loss_plot.png')\nplt.show()","metadata":{"id":"zr2PPzT9oUAS","outputId":"aa60af71-7353-4469-d4de-88627abe7819","execution":{"iopub.status.busy":"2021-11-19T11:32:43.691723Z","iopub.execute_input":"2021-11-19T11:32:43.691976Z","iopub.status.idle":"2021-11-19T11:32:44.092142Z","shell.execute_reply.started":"2021-11-19T11:32:43.691925Z","shell.execute_reply":"2021-11-19T11:32:44.090899Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"plt.figure(figsize=(6,6))\nplt.plot(train_logs_df.index.tolist(), train_logs_df.iou_score.tolist(), lw=3, label = 'Train')\nplt.plot(valid_logs_df.index.tolist(), valid_logs_df.iou_score.tolist(), lw=3, label = 'Valid')\nplt.xlabel('Epochs', fontsize=20)\nplt.ylabel('IoU Score', fontsize=20)\nplt.title('IoU Score Plot', fontsize=20)\nplt.legend(loc='best', fontsize=16)\nplt.grid()\nplt.savefig('iou_score_plot.png')\nplt.show()","metadata":{"id":"L5arr_lnoUAU","outputId":"1b6cba91-25a1-4335-c45a-4aecbe88480f","execution":{"iopub.status.busy":"2021-11-19T11:32:44.093858Z","iopub.execute_input":"2021-11-19T11:32:44.094416Z","iopub.status.idle":"2021-11-19T11:32:44.47155Z","shell.execute_reply.started":"2021-11-19T11:32:44.094378Z","shell.execute_reply":"2021-11-19T11:32:44.47084Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"plt.figure(figsize=(6,6))\nplt.plot(train_logs_df.index.tolist(), train_logs_df.fscore.tolist(), lw=3, label = 'Train')\nplt.plot(valid_logs_df.index.tolist(), valid_logs_df.fscore.tolist(), lw=3, label = 'Valid')\nplt.xlabel('Epochs', fontsize=20)\nplt.ylabel('F1 Score', fontsize=20)\nplt.title('F1 Score Plot', fontsize=20)\nplt.legend(loc='best', fontsize=16)\nplt.grid()\nplt.savefig('fscore_plot.png')\nplt.show()","metadata":{"id":"ZLrsVTU9oUAV","outputId":"1071d01d-b193-4909-d4d0-885c015bcc7d","execution":{"iopub.status.busy":"2021-11-19T11:32:44.472955Z","iopub.execute_input":"2021-11-19T11:32:44.473434Z","iopub.status.idle":"2021-11-19T11:32:44.840665Z","shell.execute_reply.started":"2021-11-19T11:32:44.473396Z","shell.execute_reply":"2021-11-19T11:32:44.839873Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"plt.figure(figsize=(6,6))\nplt.plot(train_logs_df.index.tolist(), train_logs_df.accuracy.tolist(), lw=3, label = 'Train')\nplt.plot(valid_logs_df.index.tolist(), valid_logs_df.accuracy.tolist(), lw=3, label = 'Valid')\nplt.xlabel('Epochs', fontsize=20)\nplt.ylabel('Accuracy Score', fontsize=20)\nplt.title('Accuracy Score Plot', fontsize=20)\nplt.legend(loc='best', fontsize=16)\nplt.grid()\nplt.savefig('accuracy_plot.png')\nplt.show()","metadata":{"execution":{"iopub.status.busy":"2021-11-19T11:32:44.842134Z","iopub.execute_input":"2021-11-19T11:32:44.842428Z","iopub.status.idle":"2021-11-19T11:32:45.226551Z","shell.execute_reply.started":"2021-11-19T11:32:44.842388Z","shell.execute_reply":"2021-11-19T11:32:45.225831Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"model = torch.load('./best_model.pth')","metadata":{"execution":{"iopub.status.busy":"2021-11-19T11:33:31.638823Z","iopub.execute_input":"2021-11-19T11:33:31.639482Z","iopub.status.idle":"2021-11-19T11:33:31.748126Z","shell.execute_reply.started":"2021-11-19T11:33:31.639438Z","shell.execute_reply":"2021-11-19T11:33:31.747372Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"%matplotlib inline\n\nmodel.eval()\n\nIOUs = []\nF1s = []\nAccuracies = []\npredictions = []\n\nwith torch.no_grad():\n for i, (inputs, labels) in enumerate(dataloader['test']):\n inputs = inputs.to(device)\n labels = labels.to(device)\n\n pred_mask = model(inputs)\n\n for i in range(len(inputs)):\n test_image = inputs[i]\n test_mask = labels[i]\n predMask = pred_mask[i]\n \n iou = smp.utils.functional.iou(predMask, test_mask, threshold=0.5)\n IOUs.append(iou.cpu().detach())\n\n f1 = smp.utils.functional.f_score(predMask, test_mask, threshold=0.5)\n F1s.append(f1.cpu().detach())\n \n accuracy = smp.utils.functional.accuracy(predMask, test_mask, threshold=0.5)\n Accuracies.append(accuracy.cpu().detach())\n \n predictions.append(predMask)\n\n visualize(\n original_image = test_image.cpu(),\n ground_truth_mask = test_mask.cpu(),\n predicted_mask = predMask.cpu(),\n )","metadata":{"id":"pj0hVHuioUAW","outputId":"aae25f2b-f76f-40e0-9825-4f61117fa483","scrolled":true,"execution":{"iopub.status.busy":"2021-11-19T11:33:36.876163Z","iopub.execute_input":"2021-11-19T11:33:36.876808Z","iopub.status.idle":"2021-11-19T11:34:26.787217Z","shell.execute_reply.started":"2021-11-19T11:33:36.876772Z","shell.execute_reply":"2021-11-19T11:34:26.786347Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"print('Test IOU: ' + str(np.mean(IOUs)))\nprint('Test F1: ' + str(np.mean(F1s)))\nprint('Test Accuracy: ' + str(np.mean(Accuracies)))","metadata":{"id":"8IfIxNS8oUAX","outputId":"18e27cec-5791-4717-d384-c8f852533409","execution":{"iopub.status.busy":"2021-11-19T11:34:26.789143Z","iopub.execute_input":"2021-11-19T11:34:26.789428Z","iopub.status.idle":"2021-11-19T11:34:26.86127Z","shell.execute_reply.started":"2021-11-19T11:34:26.78939Z","shell.execute_reply":"2021-11-19T11:34:26.860428Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"print(np.min(IOUs), np.min(F1s), np.min(Accuracies))\nprint(np.max(IOUs), np.max(F1s), np.max(Accuracies))","metadata":{"execution":{"iopub.status.busy":"2021-11-19T11:43:23.007739Z","iopub.execute_input":"2021-11-19T11:43:23.008046Z","iopub.status.idle":"2021-11-19T11:43:23.08843Z","shell.execute_reply.started":"2021-11-19T11:43:23.008005Z","shell.execute_reply":"2021-11-19T11:43:23.087624Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"predictions = torch.cat(predictions).cpu().detach().numpy()","metadata":{"id":"073CMCO-lHE9","execution":{"iopub.status.busy":"2021-11-19T11:34:50.737975Z","iopub.execute_input":"2021-11-19T11:34:50.738593Z","iopub.status.idle":"2021-11-19T11:34:50.854514Z","shell.execute_reply.started":"2021-11-19T11:34:50.738554Z","shell.execute_reply":"2021-11-19T11:34:50.853731Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"index = np.argmax(IOUs)\nimg, mask = dataset['test'][index]\npred_mask = predictions[index]\npred_mask = torch.Tensor(pred_mask.reshape((1, pred_mask.shape[0], pred_mask.shape[1]))).cpu().detach()\nf, axarr = plt.subplots(1, 3, figsize=(12, 4))\naxarr[0].imshow(np.squeeze(img.permute(1, 2, 0)))\naxarr[1].imshow(np.squeeze(mask.permute(1, 2, 0)))\naxarr[2].imshow(np.squeeze(pred_mask.permute(1, 2, 0)))\nplt.show()","metadata":{"id":"IdGAgqiUlWHY","outputId":"828ce917-6a60-454f-f035-a34095810f67","execution":{"iopub.status.busy":"2021-11-19T11:49:15.61783Z","iopub.execute_input":"2021-11-19T11:49:15.618423Z","iopub.status.idle":"2021-11-19T11:49:16.111533Z","shell.execute_reply.started":"2021-11-19T11:49:15.618384Z","shell.execute_reply":"2021-11-19T11:49:16.110839Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"from captum.attr import visualization as viz\nfrom captum.attr import LayerGradCam, FeatureAblation, LayerActivation, LayerAttribution","metadata":{"execution":{"iopub.status.busy":"2021-11-19T11:59:28.307925Z","iopub.execute_input":"2021-11-19T11:59:28.30822Z","iopub.status.idle":"2021-11-19T11:59:28.415082Z","shell.execute_reply.started":"2021-11-19T11:59:28.308189Z","shell.execute_reply":"2021-11-19T11:59:28.414382Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"\"\"\"\nThis wrapper computes the segmentation model output and sums the pixel scores for\nall pixels predicted as each class, returning a tensor with a single value for\neach class. This makes it easier to attribute with respect to a single output\nscalar, as opposed to an individual pixel output attribution.\n\"\"\"\ndef agg_segmentation_wrapper(inp):\n model_out = fcn(inp)['out']\n # Creates binary matrix with 1 for original argmax class for each pixel\n # and 0 otherwise. Note that this may change when the input is ablated\n # so we use the original argmax predicted above, out_max.\n selected_inds = torch.zeros_like(model_out[0:1]).scatter_(1, out_max, 1)\n return (model_out * selected_inds).sum(dim=(2,3))","metadata":{"execution":{"iopub.status.busy":"2021-11-19T12:00:27.329619Z","iopub.execute_input":"2021-11-19T12:00:27.329881Z","iopub.status.idle":"2021-11-19T12:00:27.396823Z","shell.execute_reply.started":"2021-11-19T12:00:27.32985Z","shell.execute_reply":"2021-11-19T12:00:27.396153Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"lgc = LayerGradCam(agg_segmentation_wrapper, model.encoder1)","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"gc_attr = lgc.attribute(normalized_inp, target=6)","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"la = LayerActivation(agg_segmentation_wrapper, model.encoder1)\nactivation = la.attribute(normalized_inp)\nprint(\"Input Shape:\", normalized_inp.shape)\nprint(\"Layer Activation Shape:\", activation.shape)\nprint(\"Layer GradCAM Shape:\", gc_attr.shape)","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"viz.visualize_image_attr(gc_attr[0].cpu().permute(1,2,0).detach().numpy(),sign=\"all\")","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"upsampled_gc_attr = LayerAttribution.interpolate(gc_attr,normalized_inp.shape[2:])\nprint(\"Upsampled Shape:\",upsampled_gc_attr.shape)","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"viz.visualize_image_attr_multiple(upsampled_gc_attr[0].cpu().permute(1,2,0).detach().numpy(),original_image=preproc_img.permute(1,2,0).numpy(),signs=[\"all\", \"positive\", \"negative\"],methods=[\"original_image\", \"blended_heat_map\",\"blended_heat_map\"])","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"img_without_train = (1 - (out_max == 19).float())[0].cpu() * preproc_img\nplt.imshow(img_without_train.permute(1,2,0))","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"fa = FeatureAblation(agg_segmentation_wrapper)\nfa_attr = fa.attribute(normalized_inp, feature_mask=out_max, perturbations_per_eval=2, target=6)","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"viz.visualize_image_attr(fa_attr[0].cpu().detach().permute(1,2,0).numpy(),sign=\"all\")","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"fa_attr_without_max = (1 - (out_max == 6).float())[0] * fa_attr","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"viz.visualize_image_attr(fa_attr_without_max[0].cpu().detach().permute(1,2,0).numpy(),sign=\"all\")","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]}]}