# Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,pycharm+all,macos | |||||
# Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode,pycharm+all,macos | |||||
### macOS ### | |||||
# General | |||||
.DS_Store | |||||
.AppleDouble | |||||
.LSOverride | |||||
# Icon must end with two \r | |||||
Icon | |||||
# Thumbnails | |||||
._* | |||||
# Files that might appear in the root of a volume | |||||
.DocumentRevisions-V100 | |||||
.fseventsd | |||||
.Spotlight-V100 | |||||
.TemporaryItems | |||||
.Trashes | |||||
.VolumeIcon.icns | |||||
.com.apple.timemachine.donotpresent | |||||
# Directories potentially created on remote AFP share | |||||
.AppleDB | |||||
.AppleDesktop | |||||
Network Trash Folder | |||||
Temporary Items | |||||
.apdisk | |||||
### macOS Patch ### | |||||
# iCloud generated files | |||||
*.icloud | |||||
### PyCharm+all ### | |||||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider | |||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 | |||||
# User-specific stuff | |||||
.idea/**/workspace.xml | |||||
.idea/**/tasks.xml | |||||
.idea/**/usage.statistics.xml | |||||
.idea/**/dictionaries | |||||
.idea/**/shelf | |||||
# AWS User-specific | |||||
.idea/**/aws.xml | |||||
# Generated files | |||||
.idea/**/contentModel.xml | |||||
# Sensitive or high-churn files | |||||
.idea/**/dataSources/ | |||||
.idea/**/dataSources.ids | |||||
.idea/**/dataSources.local.xml | |||||
.idea/**/sqlDataSources.xml | |||||
.idea/**/dynamic.xml | |||||
.idea/**/uiDesigner.xml | |||||
.idea/**/dbnavigator.xml | |||||
# Gradle | |||||
.idea/**/gradle.xml | |||||
.idea/**/libraries | |||||
# Gradle and Maven with auto-import | |||||
# When using Gradle or Maven with auto-import, you should exclude module files, | |||||
# since they will be recreated, and may cause churn. Uncomment if using | |||||
# auto-import. | |||||
# .idea/artifacts | |||||
# .idea/compiler.xml | |||||
# .idea/jarRepositories.xml | |||||
# .idea/modules.xml | |||||
# .idea/*.iml | |||||
# .idea/modules | |||||
# *.iml | |||||
# *.ipr | |||||
# CMake | |||||
cmake-build-*/ | |||||
# Mongo Explorer plugin | |||||
.idea/**/mongoSettings.xml | |||||
# File-based project format | |||||
*.iws | |||||
# IntelliJ | |||||
out/ | |||||
# mpeltonen/sbt-idea plugin | |||||
.idea_modules/ | |||||
# JIRA plugin | |||||
atlassian-ide-plugin.xml | |||||
# Cursive Clojure plugin | |||||
.idea/replstate.xml | |||||
# SonarLint plugin | |||||
.idea/sonarlint/ | |||||
# Crashlytics plugin (for Android Studio and IntelliJ) | |||||
com_crashlytics_export_strings.xml | |||||
crashlytics.properties | |||||
crashlytics-build.properties | |||||
fabric.properties | |||||
# Editor-based Rest Client | |||||
.idea/httpRequests | |||||
# Android studio 3.1+ serialized cache file | |||||
.idea/caches/build_file_checksums.ser | |||||
### PyCharm+all Patch ### | |||||
# Ignore everything but code style settings and run configurations | |||||
# that are supposed to be shared within teams. | |||||
.idea/* | |||||
!.idea/codeStyles | |||||
!.idea/runConfigurations | |||||
### Python ### | |||||
# Byte-compiled / optimized / DLL files | |||||
__pycache__/ | |||||
*.py[cod] | |||||
*$py.class | |||||
# C extensions | |||||
*.so | |||||
# Distribution / packaging | |||||
.Python | |||||
build/ | |||||
develop-eggs/ | |||||
dist/ | |||||
downloads/ | |||||
eggs/ | |||||
.eggs/ | |||||
lib/ | |||||
lib64/ | |||||
parts/ | |||||
sdist/ | |||||
var/ | |||||
wheels/ | |||||
share/python-wheels/ | |||||
*.egg-info/ | |||||
.installed.cfg | |||||
*.egg | |||||
MANIFEST | |||||
# PyInstaller | |||||
# Usually these files are written by a python script from a template | |||||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | |||||
*.manifest | |||||
*.spec | |||||
# Installer logs | |||||
pip-log.txt | |||||
pip-delete-this-directory.txt | |||||
# Unit test / coverage reports | |||||
htmlcov/ | |||||
.tox/ | |||||
.nox/ | |||||
.coverage | |||||
.coverage.* | |||||
.cache | |||||
nosetests.xml | |||||
coverage.xml | |||||
*.cover | |||||
*.py,cover | |||||
.hypothesis/ | |||||
.pytest_cache/ | |||||
cover/ | |||||
# Translations | |||||
*.mo | |||||
*.pot | |||||
# Django stuff: | |||||
*.log | |||||
local_settings.py | |||||
db.sqlite3 | |||||
db.sqlite3-journal | |||||
# Flask stuff: | |||||
instance/ | |||||
.webassets-cache | |||||
# Scrapy stuff: | |||||
.scrapy | |||||
# Sphinx documentation | |||||
docs/_build/ | |||||
# PyBuilder | |||||
.pybuilder/ | |||||
target/ | |||||
# Jupyter Notebook | |||||
.ipynb_checkpoints | |||||
# IPython | |||||
profile_default/ | |||||
ipython_config.py | |||||
# pyenv | |||||
# For a library or package, you might want to ignore these files since the code is | |||||
# intended to run in multiple environments; otherwise, check them in: | |||||
# .python-version | |||||
# pipenv | |||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | |||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | |||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | |||||
# install all needed dependencies. | |||||
#Pipfile.lock | |||||
# poetry | |||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. | |||||
# This is especially recommended for binary packages to ensure reproducibility, and is more | |||||
# commonly ignored for libraries. | |||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control | |||||
#poetry.lock | |||||
# pdm | |||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. | |||||
#pdm.lock | |||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it | |||||
# in version control. | |||||
# https://pdm.fming.dev/#use-with-ide | |||||
.pdm.toml | |||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm | |||||
__pypackages__/ | |||||
# Celery stuff | |||||
celerybeat-schedule | |||||
celerybeat.pid | |||||
# SageMath parsed files | |||||
*.sage.py | |||||
# Environments | |||||
.env | |||||
.venv | |||||
env/ | |||||
venv/ | |||||
ENV/ | |||||
env.bak/ | |||||
venv.bak/ | |||||
# Spyder project settings | |||||
.spyderproject | |||||
.spyproject | |||||
# Rope project settings | |||||
.ropeproject | |||||
# mkdocs documentation | |||||
/site | |||||
# mypy | |||||
.mypy_cache/ | |||||
.dmypy.json | |||||
dmypy.json | |||||
# Pyre type checker | |||||
.pyre/ | |||||
# pytype static type analyzer | |||||
.pytype/ | |||||
# Cython debug symbols | |||||
cython_debug/ | |||||
# PyCharm | |||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can | |||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore | |||||
# and can be added to the global gitignore or merged into this file. For a more nuclear | |||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder. | |||||
#.idea/ | |||||
### Python Patch ### | |||||
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration | |||||
poetry.toml | |||||
# ruff | |||||
.ruff_cache/ | |||||
### VisualStudioCode ### | |||||
.vscode/* | |||||
!.vscode/settings.json | |||||
!.vscode/tasks.json | |||||
!.vscode/launch.json | |||||
!.vscode/extensions.json | |||||
!.vscode/*.code-snippets | |||||
# Local History for Visual Studio Code | |||||
.history/ | |||||
# Built Visual Studio Code Extensions | |||||
*.vsix | |||||
### VisualStudioCode Patch ### | |||||
# Ignore all local history of files | |||||
.history | |||||
.ionide | |||||
# End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,pycharm+all,macos |
# LAP | |||||
An Attention-Based Module for Faithful Interpretation and Knowledge Injection in Convolutional Neural Networks | |||||
In this supplementary material, we have provided a pdf containing more details for some of the sections mentioned in the paper, and the code of our work. | |||||
# Data Preparation | |||||
## Metadata | |||||
First, download the metadata from [here](https://mega.nz/file/MqhXiLgC#NK1vd9ZLGU-3a182x-BXfvGCSuqHgq9hflI6tddh4Wc) | |||||
## RSNA | |||||
## CelebA | |||||
1. Download the dataset from [here](https://www.kaggle.com/jessicali9530/celeba-dataset) | |||||
2. Extract the data | |||||
```bash | |||||
unzip -q -j celeba-dataset.zip '**/*.jpg' -d data/celeba | |||||
``` | |||||
3. Extract `celeba.tsv` from [this](#metadata) metadata file to `dataset_metadata/celeba.tsv` | |||||
## Imagenet | |||||
1. Download the ImageNet ILSVRC2012 dataset from [here](https://www.image-net.org/challenges/LSVRC/index.php). | |||||
1. Extract `ILSVRC2012_img_train_t3.tar` to `data/imagenet/tars`. | |||||
1. Run the following command to extract the train images per class: | |||||
```bash | |||||
python data_preparation/imagenet/extract_train.py -s data/imagenet/tars/ -n <num-threads> | |||||
``` | |||||
1. Run the following command to extract the validation images per class (replace directory with the directory containing `ILSVRC2012_img_val.tar`, and extract `imagenet_val_maps.pklz` from [this](#metadata) metadata file): | |||||
```bash | |||||
python data_preparation/imagenet/extract_val.py -s <directory> -m <imagenet_val_maps.pklz> | |||||
``` | |||||
1. download the localizations bounding-boxes from [here](https://www.kaggle.com/c/imagenet-object-localization-challenge/data). Put `LOC_train_solution.csv` and `LOC_val_solution.csv` in `data/imagenet/extracted`. | |||||
# Model checkpoints | |||||
Download all the checkpoints from [here](https://mega.nz/file/16pWWA5L#b1SwLEv-YO5lL9wLVTcgykn2LmUBHm1exPFNB0JBoZo). | |||||
# Run | |||||
## Evaluate the checkpoints | |||||
### RSNA | |||||
#### Vanilla ResNet 18 | |||||
##### Evaluate the model performance on the test set | |||||
```bash | |||||
python evaluate.py rsna.org_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.org_resnet.pth | |||||
``` | |||||
#### Vanilla Inception V3 | |||||
##### Evaluate the model performance on the test set | |||||
```bash | |||||
python evaluate.py rsna.org_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.org_inception.pth | |||||
``` | |||||
#### WS ResNet 18 | |||||
##### Evaluate the model performance on the test set | |||||
```bash | |||||
python evaluate.py rsna.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_resnet.pth | |||||
``` | |||||
##### Generate the LAP interpretations for positive samples | |||||
```bash | |||||
python interpret.py rsna.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir rsna/ws_resnet_interpretations | |||||
``` | |||||
##### Generate the LAP interpretations for negative samples | |||||
```bash | |||||
python interpret.py rsna.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir rsna/ws_resnet_interpretations | |||||
``` | |||||
#### WS Inception V3 | |||||
##### Evaluate the model performance on the test set | |||||
```bash | |||||
python evaluate.py rsna.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_inception.pth | |||||
``` | |||||
##### Generate the LAP interpretations for positive samples | |||||
```bash | |||||
python interpret.py rsna.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir rsna/ws_inception_interpretations | |||||
``` | |||||
##### Generate the LAP interpretations for negative samples | |||||
```bash | |||||
python interpret.py rsna.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir rsna/ws_inception_interpretations | |||||
``` | |||||
#### BB ResNet 18 | |||||
##### Evaluate the model performance on the test set | |||||
```bash | |||||
python evaluate.py rsna.bb_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.bb_resnet.pth | |||||
``` | |||||
##### Generate the LAP interpretations for positive samples | |||||
```bash | |||||
python interpret.py rsna.bb_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.bb_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir rsna/bb_resnet_interpretations | |||||
``` | |||||
##### Generate the LAP interpretations for negative samples | |||||
```bash | |||||
python interpret.py rsna.bb_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.bb_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir rsna/bb_resnet_interpretations | |||||
``` | |||||
#### BB Inception V3 | |||||
##### Evaluate the model performance on the test set | |||||
```bash | |||||
python evaluate.py rsna.bb_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.bb_inception.pth | |||||
``` | |||||
##### Generate the LAP interpretations for positive samples | |||||
```bash | |||||
python interpret.py rsna.bb_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.bb_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir rsna/bb_inception_interpretations | |||||
``` | |||||
##### Generate the LAP interpretations for negative samples | |||||
```bash | |||||
python interpret.py rsna.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir rsna/ws_inception_interpretations | |||||
``` | |||||
### CelebA | |||||
#### Vanilla ResNet 18 | |||||
##### Evaluate the model performance on the test set | |||||
```bash | |||||
python evaluate.py celeba.org_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.org_resnet.pth | |||||
``` | |||||
#### Vanilla Inception V3 | |||||
##### Evaluate the model performance on the test set | |||||
```bash | |||||
python evaluate.py celeba.org_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.org_inception.pth | |||||
``` | |||||
#### WS ResNet 18 | |||||
##### Evaluate the model performance on the test set | |||||
```bash | |||||
python evaluate.py celeba.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_resnet.pth | |||||
``` | |||||
##### Generate the LAP interpretations for positive samples | |||||
```bash | |||||
python interpret.py celeba.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir celeba/ws_resnet_interpretations | |||||
``` | |||||
##### Generate the LAP interpretations for negative samples | |||||
```bash | |||||
python interpret.py celeba.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir celeba/ws_resnet_interpretations | |||||
``` | |||||
#### WS Inception V3 | |||||
##### Evaluate the model performance on the test set | |||||
```bash | |||||
python evaluate.py celeba.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_inception.pth | |||||
``` | |||||
##### Generate the LAP interpretations for positive samples | |||||
```bash | |||||
python interpret.py celeba.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir celeba/ws_inception_interpretations | |||||
``` | |||||
##### Generate the LAP interpretations for negative samples | |||||
```bash | |||||
python interpret.py celeba.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir celeba/ws_inception_interpretations | |||||
``` | |||||
### ImageNet | |||||
#### WS ResNet 50 (Fine-tuned) | |||||
##### Evaluate the model performance on the validation set | |||||
```bash | |||||
python evaluate.py imagenet.lap_resnet50_ft --device cuda:0 --samples-dir val --final-model-dir checkpoints/imagenet.ws_resnet_ft.pth | |||||
``` |
import argparse | |||||
import glob | |||||
import os | |||||
import tarfile | |||||
from multiprocessing import Pool | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('-s', dest='source', help='Class tars directory', required=True) | |||||
parser.add_argument('-t', dest='target', help='train set directory', default='data/imagenet/extracted') | |||||
parser.add_argument('-n', dest='num_threads', help='number of threads', default=1, type=int) | |||||
args = parser.parse_args() | |||||
class_tars = glob.glob(os.path.join(args.source, '*.tar')) | |||||
assert len(class_tars) == 1000, f"class_tars length: {len(class_tars)}" | |||||
def extract(class_tar): | |||||
filename = os.path.basename(class_tar) | |||||
class_name = filename.replace('.tar', '') | |||||
print('Extract ' + os.path.basename(class_tar)) | |||||
class_fname = os.path.join(args.target, class_name) | |||||
os.makedirs(class_fname, exist_ok=True) | |||||
with tarfile.open(class_tar) as f: | |||||
f.extractall(class_fname) | |||||
pool = Pool(args.num_threads) | |||||
pool.map(extract, class_tars) | |||||
pool.close() |
"""Prepare the ImageNet dataset""" | |||||
# import torch | |||||
import os | |||||
import argparse | |||||
import tarfile | |||||
import pickle | |||||
import gzip | |||||
# import subprocess | |||||
# from tqdm import tqdm | |||||
# from mxnet.gluon.utils import check_sha1 | |||||
# from gluoncv.utils import download, makedirs | |||||
_VAL_TAR = 'ILSVRC2012_img_val.tar' | |||||
def parse_args(): | |||||
parser = argparse.ArgumentParser( | |||||
description='Setup the ImageNet dataset.', | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||||
parser.add_argument('-s', dest='download_dir', required=True, | |||||
help="The directory that contains downloaded tar files") | |||||
parser.add_argument('-m', dest='mapping', required=True, | |||||
help="The mapping file for validation set") | |||||
parser.add_argument('--target-dir', default='data/imagenet/extracted', | |||||
help="The directory to store extracted images") | |||||
args = parser.parse_args() | |||||
return args | |||||
def check_file(filename): | |||||
if not os.path.exists(filename): | |||||
raise ValueError('File not found: '+filename) | |||||
def extract_val(tar_fname, target_dir, val_maps_file): | |||||
os.makedirs(target_dir) | |||||
print('Extracting ' + tar_fname) | |||||
with tarfile.open(tar_fname) as tar: | |||||
tar.extractall(target_dir) | |||||
# build rec file before images are moved into subfolders | |||||
# move images to proper subfolders | |||||
with gzip.open(val_maps_file, 'rb') as f: | |||||
dirs, mappings = pickle.load(f) | |||||
for d in dirs: | |||||
os.makedirs(os.path.join(target_dir, d)) | |||||
for m in mappings: | |||||
os.rename(os.path.join(target_dir, m[0]), os.path.join(target_dir, m[1], m[0])) | |||||
def main(): | |||||
args = parse_args() | |||||
target_dir = os.path.expanduser(args.target_dir) | |||||
if os.path.exists(target_dir): | |||||
raise ValueError('Target dir ['+target_dir+'] exists. Remove it first') | |||||
download_dir = os.path.expanduser(args.download_dir) | |||||
val_tar_fname = os.path.join(download_dir, _VAL_TAR) | |||||
check_file(val_tar_fname) | |||||
extract_val(val_tar_fname, os.path.join(target_dir, 'val'), args.mapping) | |||||
if __name__ == '__main__': | |||||
main() |
from typing import List, Tuple | |||||
import argparse | |||||
from os import makedirs, path, listdir | |||||
import numpy as np | |||||
import cv2 | |||||
from multiprocessing import Pool | |||||
def resize_image(img_dir: str, img_save_dir: str, res: int) -> None: | |||||
img = cv2.imread(img_dir) | |||||
img = cv2.resize(img, dsize=(res, res)) | |||||
cv2.imwrite(img_save_dir, img) | |||||
if __name__ == '__main__': | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('kaggle_dataset_dir', type=str, help='download the dataset from https://www.kaggle.com/c/rsna-pneumonia-detection-challenge/data, extract it, and pass its path is this argument') | |||||
parser.add_argument('resolution', type=int, help='The resolution required for your model, 224 for resnet, 299 for vanilla inception, 256 for modified inception') | |||||
parser.add_argument('cores', type=int, help='The number of cores for multiprocessing.') | |||||
args = parser.parse_args() | |||||
save_dir = f'data/RSNA-Kaggle_R{args.resolution}' | |||||
makedirs(save_dir, exist_ok=True) | |||||
kaggle_dataset_dir = args.kaggle_dataset_dir | |||||
assert path.exists(kaggle_dataset_dir), f'{kaggle_dataset_dir} does not exist!' | |||||
# reading rsna images names | |||||
rsna_imgs_path = path.join(kaggle_dataset_dir, 'stage_2_train_images') | |||||
assert path.exists(rsna_imgs_path), 'Make sure there is a folder named stage_2_train_images in the passed kaggle_directory!' | |||||
imgs_names = np.asarray(listdir(rsna_imgs_path)) | |||||
imgs_src = np.vectorize(lambda x: path.join(rsna_imgs_path, x))(imgs_names) | |||||
imgs_dst = np.vectorize(lambda x: path.join(save_dir, x))(imgs_names) | |||||
pool = Pool(args.cores) | |||||
pool.starmap(resize_image, zip(imgs_src, imgs_dst, np.full((len(imgs_src),), args.resolution))) | |||||
pool.close() | |||||
import traceback | |||||
from torchlap.experiments._entry_loader import get_entry, PhaseType | |||||
from torchlap.data.data_loader import DataLoader, RunType | |||||
from torchlap.experiments._model_loading import load_model | |||||
from torchlap.utils.hooker import Hooker | |||||
def main() -> None: | |||||
ep = get_entry(PhaseType.EVAL) | |||||
conf = ep.conf | |||||
model = ep.model | |||||
model = load_model(model, conf) | |||||
model.eval() | |||||
for test_group_info in conf.samples_dir.split(','): | |||||
try: | |||||
print('') | |||||
print('>> Running evaluations for %s' % test_group_info, flush=True) | |||||
test_data_loader = DataLoader(conf, test_group_info, RunType.TEST) | |||||
evaluator = conf.evaluator_cls(model, test_data_loader, conf) | |||||
with Hooker(*conf.hooks_by_phase[conf.phase_type]): | |||||
evaluator.evaluate() | |||||
except Exception: | |||||
print('Problem in %s' % test_group_info, flush=True) | |||||
track = traceback.format_exc() | |||||
print(track, flush=True) | |||||
if __name__ == '__main__': | |||||
main() |
from torchlap.experiments._entry_loader import get_entry, PhaseType | |||||
from torchlap.experiments._model_loading import load_model | |||||
from torchlap.interpreting.interpretation_evaluation_runner import InterpretingEvalRunner | |||||
from torchlap.utils.hooker import Hooker | |||||
def main() -> None: | |||||
ep = get_entry(PhaseType.EVALINTERPRET) | |||||
conf = ep.conf | |||||
model = ep.model | |||||
model = load_model(model, conf) | |||||
model.eval() | |||||
runner = InterpretingEvalRunner(conf, model) | |||||
with Hooker(*conf.hooks_by_phase[conf.phase_type]): | |||||
runner.evaluate() | |||||
if __name__ == '__main__': | |||||
main() |
from torchlap.experiments._entry_loader import get_entry, PhaseType | |||||
from torchlap.experiments._model_loading import load_model | |||||
from torchlap.interpreting.interpreting_runner import InterpretingRunner | |||||
from torchlap.utils.hooker import Hooker | |||||
def main() -> None: | |||||
ep = get_entry(PhaseType.INTERPRET) | |||||
conf = ep.conf | |||||
model = ep.model | |||||
model = load_model(model, conf) | |||||
model.eval() | |||||
runner = InterpretingRunner(conf, model) | |||||
with Hooker(*conf.hooks_by_phase[conf.phase_type]): | |||||
runner.interpret() | |||||
if __name__ == '__main__': | |||||
main() |
numpy | |||||
pandas | |||||
torch>=1.7.0 | |||||
torchvision>=0.8.1 | |||||
scikit-image>=0.14.3 | |||||
scikit-learn>=0.19.1 | |||||
captum>=0.3.0 | |||||
opencv-python>=4.4.0.46 |
__version__ = 'master' |
from typing import TYPE_CHECKING, Union, Dict, Tuple, Type, List, Callable, Iterator, Optional | |||||
import os | |||||
import random | |||||
import numpy as np | |||||
from sys import stderr | |||||
from enum import Enum | |||||
import torch | |||||
from ..interpreting.interpreter_maker import InterpretationType | |||||
from ..data.batch_choosing.class_balanced_shuffled_sequential import ClassBalancedShuffledSequentialBatchChooser | |||||
from ..data.batch_choosing.sequential import SequentialBatchChooser | |||||
from ..utils.hooker import HookPair | |||||
if TYPE_CHECKING: | |||||
from ..model_evaluation.evaluator import Evaluator | |||||
from ..data.batch_choosing.batch_chooser import BatchChooser | |||||
from ..data.content_loaders.content_loader import ContentLoader | |||||
class PhaseType(Enum): | |||||
TRAIN = 'train' | |||||
EVAL = 'eval' | |||||
INTERPRET = 'interpret' | |||||
EVALINTERPRET = 'evalinterpret' | |||||
class BaseConfig: | |||||
def __init__(self, | |||||
try_name: str, try_num: int, data_separation: str, | |||||
input_size: int, | |||||
phase_type: PhaseType, | |||||
content_loader_cls: Type['ContentLoader'], | |||||
evaluator_cls: Type['Evaluator']) -> None: | |||||
# -------------------- ESSENTIAL CONFIG -------------------- | |||||
self.data_separation = data_separation | |||||
self.try_name = try_name | |||||
self.try_num = try_num | |||||
self.phase_type = phase_type | |||||
# classes | |||||
self.evaluator_cls = evaluator_cls | |||||
self.content_loader_cls = content_loader_cls | |||||
self.train_batch_chooser_cls: Type['BatchChooser'] = ClassBalancedShuffledSequentialBatchChooser | |||||
self.eval_batch_chooser_cls: Type['BatchChooser'] = SequentialBatchChooser | |||||
# Setups | |||||
self.input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = input_size | |||||
self.batch_size = 64 | |||||
self.random_seed: int = 17 | |||||
self.samples_dir: Union[str, None] = None | |||||
# Saving and loading directories! | |||||
self.save_dir: Union[str, None] = None | |||||
self.report_dir: Union[str, None] = None | |||||
self.final_model_dir: Union[str, None] = None | |||||
self.epoch: Union[str, None] = None | |||||
# Device setting | |||||
self.dev_name = None | |||||
self.device = None | |||||
# -------------------- ESSENTIAL CONFIG -------------------- | |||||
# -------------------- TRAINING CONFIG -------------------- | |||||
self.pretrained_model_file = None | |||||
self.big_batch_size: int = 1 | |||||
self.max_epochs: int = 10 | |||||
self.iters_per_epoch: Union[int, None] = None | |||||
self.val_iters_per_epoch: Union[int, None] = None | |||||
# cleaning log while training | |||||
self.keep_best_and_last_epochs_only = False | |||||
# auto-stopping the train | |||||
self.n_unsuccessful_epochs_to_stop: Union[None, int] = 100 | |||||
self.different_augmentation_per_batch = True | |||||
self.augmentations_dict: Union[Dict[str, torch.nn.Module], None] = None | |||||
self.title_of_reference_metric_to_choose_best_epoch = 'Loss' | |||||
self.operator_to_decide_on_improvement_of_val_reference_metric = '<=' | |||||
# for dataloader | |||||
self.mapped_labels_to_use: Union[List[int], None] = None # None means all, a list of ints= the ones to consider | |||||
# hooking | |||||
self.hooks_by_phase: Dict[PhaseType, List[HookPair]] = { | |||||
phasetype: [] for phasetype in PhaseType | |||||
} | |||||
# -------------------- TRAINING CONFIG -------------------- | |||||
# ------------------ OPTIMIZATION CONFIG ------------------ | |||||
self.init_lr: float = 1e-4 | |||||
self.lr_decay: float = 1e-6 | |||||
self.freezing_regexes: Union[None, List[str]] = None | |||||
self.optimizer_creator: Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer] = get_adam_creator(self.init_lr, self.lr_decay) | |||||
# ------------------ OPTIMIZATION CONFIG ------------------ | |||||
# ----------------- INTERPRETATION CONFIG ----------------- | |||||
self.save_by_file_name = False | |||||
self.interpretation_method = InterpretationType.GuidedBackprop | |||||
self.class_label_for_interpretation: Union[int, None] = None | |||||
self.interpret_predictions_vs_gt: bool = True | |||||
# saving interpretations: | |||||
self.skip_overlay = False | |||||
self.skip_raw = False | |||||
# FOR INTERPRETATION EVALUATION | |||||
# A threshold to cut all interpretations lower than | |||||
self.cut_threshold: float = 0 | |||||
# whether the threshold is applied to un-normalized interpretation map | |||||
self.global_threshold: bool = False | |||||
# whether to use dynamic threshold | |||||
self.dynamic_threshold: bool = False | |||||
self.prediction_key_in_model_output_dict = 'positive_class_probability' | |||||
# for when interpretation masks as continuous objects are compared with GT | |||||
self.acceptable_min_intersection_threshold: float = 0.5 | |||||
# the key for interpretation to be evaluated, None = the first one available | |||||
self.interpretation_tag_to_evaluate = None | |||||
self.n_interpretation_samples: Optional[int] = None | |||||
# ----------------- INTERPRETATION CONFIG ----------------- | |||||
def __repr__(self): | |||||
""" | |||||
Represent config as string | |||||
""" | |||||
return str(self.__dict__) | |||||
def _update_based_on_args(self, args) -> None: | |||||
""" Updates the values based on the received arguments from the input! """ | |||||
args_dict = args.__dict__ | |||||
for arg in args_dict: | |||||
if arg in self.__dict__ and args_dict[arg] is not None: | |||||
setattr(self, arg, args_dict[arg]) | |||||
def update(self, args): | |||||
""" | |||||
updates the config from args | |||||
""" | |||||
# dirty piece of code for repetition | |||||
self._update_based_on_args(args) | |||||
self._update_based_on_args(args) | |||||
self.dev_name, self.device = _get_device(args) | |||||
self._set_random_seeds() | |||||
self._set_save_dir() | |||||
self._update_report_dir() | |||||
def _set_random_seeds(self): | |||||
# Set the seed for hash based operations in python | |||||
os.environ['PYTHONHASHSEED'] = '0' | |||||
torch.manual_seed(self.random_seed) | |||||
torch.cuda.manual_seed_all(self.random_seed) | |||||
np.random.seed(self.random_seed) | |||||
random.seed(self.random_seed) | |||||
if self.phase_type != PhaseType.TRAIN: | |||||
torch.backends.cudnn.deterministic = True | |||||
torch.backends.cudnn.benchmark = False | |||||
else: | |||||
torch.backends.cudnn.deterministic = False | |||||
torch.backends.cudnn.benchmark = True | |||||
# Set the numpy seed | |||||
np.random.seed(self.random_seed) | |||||
def _set_save_dir(self): | |||||
if self.save_dir is None: | |||||
self.save_dir = f"../Results/{self.try_num}_{self.try_name}" | |||||
os.makedirs(self.save_dir, exist_ok=True) | |||||
def _update_report_dir(self): | |||||
if self.report_dir is None: | |||||
self.report_dir = self.save_dir + '/' + self.phase_type.value | |||||
def get_sample_group_specific_report_dir(self, samples_specification, extra_subdir=None) -> str: | |||||
if extra_subdir is None: | |||||
extra_subdir = '' | |||||
else: | |||||
extra_subdir = '/' + extra_subdir | |||||
# Making sure of keeping version information! | |||||
if self.final_model_dir is not None: | |||||
report_dir = '%s/%s%s' % ( | |||||
self.report_dir, | |||||
self.final_model_dir[self.final_model_dir.rfind('/') + 1:self.final_model_dir.rfind('.')], | |||||
extra_subdir) | |||||
else: | |||||
epoch = self.epoch | |||||
if epoch is None: | |||||
epoch = 'best' | |||||
report_dir = '%s/epoch,%s%s' % ( | |||||
self.report_dir, | |||||
epoch, extra_subdir) | |||||
# adding sample info | |||||
if samples_specification not in ['train', 'test', 'val'] and not os.path.isfile(samples_specification): | |||||
if ':' in samples_specification: | |||||
base_info, _ = tuple(samples_specification.split(':')) | |||||
else: | |||||
base_info = samples_specification | |||||
if not os.path.isdir(base_info): | |||||
report_dir = report_dir + '/' + \ | |||||
samples_specification.replace(':', '_').replace('../', '') | |||||
else: | |||||
report_dir = report_dir + '/' + samples_specification | |||||
return self.get_unique_save_dir(report_dir) | |||||
@staticmethod | |||||
def get_unique_save_dir(sd): | |||||
if os.path.exists(sd): | |||||
report_num = 2 | |||||
base_sd = sd | |||||
while os.path.exists(sd): | |||||
sd = base_sd + '_%d' % report_num | |||||
report_num += 1 | |||||
return sd | |||||
def get_save_dir_for_sample(self, save_dir, sample_name): | |||||
if self.save_by_file_name and '/' in sample_name: | |||||
return save_dir + '/' + sample_name[sample_name.rfind('/') + 1:] | |||||
else: | |||||
return save_dir + '/' + sample_name | |||||
def _get_device(args): | |||||
""" cuda num, -1 means parallel, -2 means cpu, no cuda means cpu""" | |||||
dev_name = args.device | |||||
if 'cuda' in dev_name: | |||||
gpu_num = 0 | |||||
if ':' in dev_name: | |||||
gpu_num = int(dev_name.split(':')[1]) | |||||
if gpu_num >= torch.cuda.device_count(): | |||||
print('No %s, using CPU instead!' % dev_name, file=stderr) | |||||
dev_name = 'cpu' | |||||
print('@ Running on CPU @', flush=True) | |||||
if dev_name == 'cuda': | |||||
dev_name = None | |||||
the_device = torch.device('cuda') | |||||
print('@ Running on all %d GPUs @' % torch.cuda.device_count(), flush=True) | |||||
elif 'cuda:' in dev_name: | |||||
the_device = torch.device(dev_name) | |||||
print('@ Running on GPU:%d @' % int(dev_name.split(':')[1]), flush=True) | |||||
else: | |||||
dev_name = 'cpu' | |||||
the_device = torch.device(dev_name) | |||||
print('@ Running on CPU @', flush=True) | |||||
return dev_name, the_device | |||||
def get_adam_creator(init_lr: float = 1e-4, lr_decay: float = 1e-6)\ | |||||
-> Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer]: | |||||
def create_adam(params: Iterator[torch.nn.Parameter]) -> torch.optim.Optimizer: | |||||
return torch.optim.Adam(params, lr=init_lr, weight_decay=lr_decay) | |||||
return create_adam | |||||
def get_sgd_creator(init_lr: float = 1e-4, lr_decay: float = 1e-6)\ | |||||
-> Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer]: | |||||
def create_sgd(params: Iterator[torch.nn.Parameter]) -> torch.optim.Optimizer: | |||||
return torch.optim.SGD(params, lr=init_lr, weight_decay=lr_decay) | |||||
return create_sgd | |||||
from .base_config import BaseConfig, PhaseType | |||||
from typing import Union | |||||
from torch import nn | |||||
from torchvision import transforms | |||||
from ..data.content_loaders.celeba_loader import CelebALoader, CelebATag | |||||
from ..model_evaluation.binary_evaluator import BinaryEvaluator | |||||
class CelebAConfigs(BaseConfig): | |||||
def __init__(self, | |||||
try_name: str, try_num: int, input_size, phase_type: PhaseType) -> None: | |||||
super().__init__(try_name, try_num, 'default', input_size, phase_type, CelebALoader, BinaryEvaluator) | |||||
# replaced configs! | |||||
self.batch_size = 64 | |||||
self.iters_per_epoch = 200 | |||||
self.tags = [tag.name for tag in CelebATag if tag.name.endswith('Tag')] | |||||
self.main_tag: Union[CelebATag, None] = None | |||||
self.data_root: str = 'data/celeba' | |||||
self.dataset_metadata: str = 'dataset_metadata/celeba.tsv' | |||||
self.title_of_reference_metric_to_choose_best_epoch = 'AvgSS' | |||||
self.operator_to_decide_on_improvement_of_val_reference_metric = '>=' | |||||
self.keep_best_and_last_epochs_only = True | |||||
# augmentation | |||||
self.augmentations_dict = { | |||||
'x': nn.Sequential( | |||||
transforms.RandomRotation(45), | |||||
transforms.RandomAffine(0, shear=0.2, scale=(.8, 1.2)), | |||||
transforms.RandomHorizontalFlip(), | |||||
transforms.ColorJitter(brightness=.5, contrast=.5, saturation=.5), | |||||
transforms.RandomPerspective()) | |||||
} | |||||
from .base_config import BaseConfig, PhaseType | |||||
from ..data.content_loaders.imagenet_loader import ImagenetLoader | |||||
from ..model_evaluation.multiclass_evaluator import MulticlassEvaluator | |||||
class ImagenetConfigs(BaseConfig): | |||||
def __init__(self, try_name: str, try_num: int, input_size, phase_type: PhaseType) -> None: | |||||
super().__init__(try_name, try_num, 'data/imagenet/extracted', input_size, phase_type, ImagenetLoader, MulticlassEvaluator.standard_creator('categorical_probability', include_top5=True)) | |||||
# replaced configs! | |||||
self.batch_size = 64 | |||||
self.big_batch_size = 4 | |||||
self.max_epochs = 10 | |||||
self.title_of_reference_metric_to_choose_best_epoch = 'Accuracy' | |||||
self.operator_to_decide_on_improvement_of_val_reference_metric = '>=' | |||||
self.keep_best_and_last_epochs_only = True |
from .base_config import BaseConfig, PhaseType | |||||
from typing import Dict, List, Optional | |||||
from torch import nn | |||||
from torchvision import transforms | |||||
from ..utils.random_brightness_augmentation import ClippedBrightnessAugment | |||||
from ..data.content_loaders.rsna_loader import RSNALoader | |||||
from ..model_evaluation.binary_evaluator import BinaryEvaluator | |||||
class RSNAConfigs(BaseConfig): | |||||
def __init__(self, try_name: str, try_num: int, input_size, phase_type: PhaseType) -> None: | |||||
super().__init__(try_name, try_num, f'dataset_metadata/RSNA/DataSeparation_R{input_size}', input_size, phase_type, RSNALoader, BinaryEvaluator) | |||||
# replaced configs! | |||||
self.batch_size = 64 | |||||
self.max_epochs = 200 | |||||
self.label_map_dict: Dict[str, int] = { | |||||
'healthy': 0, 'pneumonia': 1 | |||||
} | |||||
self.augmentations_dict = { | |||||
'x': nn.Sequential( | |||||
transforms.RandomRotation(45), | |||||
transforms.RandomAffine(0, shear=0.4), | |||||
transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4)), | |||||
ClippedBrightnessAugment(0.5, 1.5, 0, 1)), | |||||
'infection': nn.Sequential( | |||||
transforms.RandomRotation(45), | |||||
transforms.RandomAffine(0, shear=0.4), | |||||
transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4))), | |||||
'interpretations': nn.Sequential( | |||||
transforms.RandomRotation(45), | |||||
transforms.RandomAffine(0, shear=0.4), | |||||
transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4))), | |||||
} | |||||
self.title_of_reference_metric_to_choose_best_epoch = 'BAcc' | |||||
self.operator_to_decide_on_improvement_of_val_reference_metric = '>=' | |||||
self.keep_best_and_last_epochs_only = True | |||||
self.infection_map_size = self.input_size | |||||
self.receptive_field_radii: List[int] = [] | |||||
self.infections_bb_dir: Optional[str] = 'data/RSNA/infection_bounding_boxs.tsv' |
from abc import ABC, abstractmethod | |||||
from functools import partial | |||||
from typing import Dict, List, Tuple | |||||
import torch | |||||
from torch import nn | |||||
from ..configs.base_config import BaseConfig, PhaseType | |||||
class AuxLoss(ABC): | |||||
def __init__(self, title: str, model: nn.Module, | |||||
layer_by_name: Dict[str, nn.Module], | |||||
loss_weight: float): | |||||
"""This class is used for calculating loss based on the middle layers of the network. | |||||
Args: | |||||
title (str): A unique title, the required inputs would be kept by this title in the output dictionary. Also the loss term would be added as title_loss | |||||
model (nn.Module): The base model | |||||
named_layers_list (List[Tuple[str, torch.nn.Module]]): A list of pairs, layer name and the layer. The names used for one instance of this class should be unique. | |||||
loss_weight (float): The weight of the loss in the final loss. The loss term would be automatically added to the model's loss in model_output dictionary by this factor. | |||||
""" | |||||
self._title = title | |||||
self._model = model | |||||
self._loss_weight = loss_weight | |||||
self._layers_names = layer_by_name.keys() | |||||
self._layers_outputs_by_name: Dict[str, torch.Tensor] = dict() | |||||
self.hook_pairs = [ | |||||
(module, partial(self._save_layer_val_hook, name)) | |||||
for name, module in layer_by_name.items() | |||||
] + [(model, self._add_loss)] | |||||
def _save_layer_val_hook(self, layer_name, _, __, layer_out): | |||||
self._layers_outputs_by_name[layer_name] = layer_out | |||||
def configure(self, config: BaseConfig) -> None: | |||||
"""Must be called to add the necessary configs and hooks so the loss will be calculated! | |||||
Args: | |||||
config (BaseConfig): The base config! | |||||
""" | |||||
for phase in PhaseType: | |||||
config.hooks_by_phase[phase] += self.hook_pairs | |||||
def _add_loss(self, _, __, model_out: Dict[str, torch.Tensor]): | |||||
""" A hook for calculating and adding loss the the model's output dictionary | |||||
Args: | |||||
_ ([type]): Model, useless! (just the hook format) | |||||
__ ([type]): Model input, also useless! (just the hook format) | |||||
model_out (Dict[str, torch.Tensor]): The output dictionary of the model | |||||
""" | |||||
# popping values of tensors | |||||
layers_values = [] | |||||
for layer_name in self._layers_names: | |||||
assert layer_name in self._layers_outputs_by_name, f'The output of {layer_name} was not found. It seems the forward function has not been called.' | |||||
layers_values.append((layer_name, self._layers_outputs_by_name.pop(layer_name))) | |||||
loss = self._calculate_loss(layers_values, model_out) | |||||
# adding loss to output dictionary | |||||
assert f'{self._title}_loss' not in model_out, f'Trying to add {self._title}_loss to model\'s output multiple times!' | |||||
model_out[f'{self._title}_loss'] = loss.clone() | |||||
# adding loss to the main loss | |||||
model_out['loss'] = model_out.get('loss', torch.zeros([], dtype=loss.dtype, device=loss.device)) + self._loss_weight * loss | |||||
return model_out | |||||
@abstractmethod | |||||
def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_output: Dict[str, torch.Tensor]) -> torch.Tensor: | |||||
"""Calculates loss based on the output of the layers given in the constructor. | |||||
Args: | |||||
layers_values (List[Tuple[str, torch.Tensor]]): List of layers names and their output value | |||||
model_output (Dict[str, torch.Tensor]): The output dictionary of the model, if something else is required for the loss calculation! or it can be used to add stuff into the output! | |||||
Returns: | |||||
torch.Tensor: The loss | |||||
""" |
from typing import List, Tuple, Dict, Union | |||||
import math | |||||
import torch | |||||
from torch.nn import functional as F | |||||
from ..data.data_loader import DataLoader | |||||
from ..data.dataloader_context import DataloaderContext | |||||
from .aux_loss import AuxLoss | |||||
class BBLoss(AuxLoss): | |||||
def __init__( | |||||
self, title: str, model: torch.nn.Module, | |||||
layer_by_name: Dict[str, torch.nn.Module], | |||||
loss_weight: float, | |||||
inf_mask_kw: str, | |||||
layers_loss_weights: Union[float, List[float]], | |||||
bb_fill_ratio: float): | |||||
super().__init__(title, model, layer_by_name, loss_weight) | |||||
self._inf_mask_kw = inf_mask_kw | |||||
self._bb_fill_ratio = bb_fill_ratio | |||||
if isinstance(layers_loss_weights, list): | |||||
assert len(layers_loss_weights) == len(layer_by_name), f'There should be as many weights as layers, one per layer. Expected {len(layer_by_name)} found {len(layers_loss_weights)}' | |||||
self._layers_loss_weights = layers_loss_weights | |||||
else: | |||||
self._layers_loss_weights = [layers_loss_weights for _ in range(len(layer_by_name))] | |||||
def _calculate_loss( | |||||
self, | |||||
layers_values: List[Tuple[str, torch.Tensor]], | |||||
model_output: Dict[str, torch.Tensor]) -> torch.Tensor: | |||||
# gathering extra information | |||||
dl: DataLoader = DataloaderContext.instance.dataloader | |||||
xray_y = torch.from_numpy(dl.get_current_batch_samples_labels()).to(dl._device) | |||||
infection_mask = dl.get_current_batch_data(keyword=self._inf_mask_kw).to(dl._device) | |||||
xray_y = xray_y.bool() | |||||
infection_mask = infection_mask[xray_y] | |||||
PB, _, H, W = infection_mask.shape | |||||
loss = torch.zeros([], dtype=torch.float32, | |||||
device=xray_y.device, requires_grad=True) | |||||
for li, (ln, lo) in enumerate(layers_values): | |||||
out = F.interpolate(lo, | |||||
size=(H, W), | |||||
mode='bilinear', | |||||
align_corners=True) | |||||
out = out.flatten(start_dim=1) | |||||
neg_out = out[~xray_y] | |||||
pos_out = out[xray_y] | |||||
NB = neg_out.shape[0] | |||||
infection_mask = infection_mask.flatten(start_dim=1) | |||||
pos_out_infection = torch.where(infection_mask < 1, | |||||
torch.zeros_like(pos_out, requires_grad=True), | |||||
pos_out) | |||||
neg_losses = [] | |||||
pos_losses = [] | |||||
if PB > 0: | |||||
bb_area = infection_mask.sum(dim=-1) # B | |||||
k = (self._bb_fill_ratio * bb_area.quantile(q=0.5))\ | |||||
.ceil()\ | |||||
.long() | |||||
# Top positive in bb pixels must be positive | |||||
pos_infection_topk, pos_infection_indices = pos_out_infection\ | |||||
.topk(k, dim=-1, sorted=False) | |||||
pos_infection_batch_index = torch.arange(PB)\ | |||||
.to(pos_infection_indices.device)\ | |||||
.unsqueeze(1)\ | |||||
.repeat_interleave(k, dim=1) | |||||
pos_weight = infection_mask[pos_infection_batch_index, pos_infection_indices].floor() # make non ones to be zero | |||||
pos_losses.append(F.binary_cross_entropy( | |||||
pos_infection_topk, | |||||
torch.ones_like(pos_infection_topk), | |||||
pos_weight | |||||
)) | |||||
if (infection_mask == 0).any(): | |||||
# All positive out bb pixels must be negative | |||||
pos_out_non_infection = pos_out[infection_mask == 0] | |||||
neg_losses.append(F.binary_cross_entropy( | |||||
pos_out_non_infection, | |||||
torch.zeros_like(pos_out_non_infection) | |||||
)) | |||||
if NB > 0: | |||||
if PB > 0: | |||||
# Top negative pixels must be negative | |||||
neg_k = int(math.ceil(PB * k * 1.0 / NB)) | |||||
neg_out_topk = neg_out.topk(neg_k, dim=-1, sorted=False)[0] | |||||
neg_losses.append(F.binary_cross_entropy( | |||||
neg_out_topk, | |||||
torch.zeros_like(neg_out_topk) | |||||
)) | |||||
else: | |||||
# All negative pixels must be negative | |||||
neg_losses.append(F.binary_cross_entropy( | |||||
neg_out, | |||||
torch.zeros_like(neg_out) | |||||
)) | |||||
losses = [] | |||||
if len(neg_losses) > 0: | |||||
losses.append(torch.stack(neg_losses).mean()) | |||||
if len(pos_losses) > 0: | |||||
losses.append(torch.stack(pos_losses).mean()) | |||||
l_loss = torch.stack(losses).mean() | |||||
model_output[f'{self._title}_{ln}_loss'] = l_loss | |||||
loss = loss + self._layers_loss_weights[li] * l_loss | |||||
return loss |
from typing import Dict, Union, List, Tuple | |||||
import numpy as np | |||||
import torch | |||||
from torch.nn import functional as F | |||||
from .aux_loss import AuxLoss | |||||
from ..data.dataloader_context import DataloaderContext | |||||
from ..data.data_loader import DataLoader | |||||
class PoolConcordanceLossCalculator(AuxLoss): | |||||
def __init__(self, | |||||
title: str, model: torch.nn.Module, | |||||
layer_by_name: Dict[str, torch.nn.Module], | |||||
loss_weight: float, weights: Union[float, List[float]], | |||||
diff_thresholds: Union[float, List[float]], | |||||
labels_by_channel: Dict[int, List[int]]): | |||||
super().__init__(title, model, layer_by_name, loss_weight) | |||||
# at least two pools are needed | |||||
assert len(layer_by_name) > 1, 'At least two pool layers are required to calculate this loss!' | |||||
self._title = title | |||||
self._model = model | |||||
self._loss_prefix = f'{title}_JS_loss_' | |||||
self._labels_by_channel = labels_by_channel | |||||
self._pool_names = list(layer_by_name.keys()) | |||||
self._weights = weights if isinstance(weights, list) else \ | |||||
[weights for _ in range(len(layer_by_name) - 1)] | |||||
if isinstance(weights, list): | |||||
assert len(weights) == len(layer_by_name) - 1, 'Weights must have a length of pool_layers -1' | |||||
self._diff_thresholds = diff_thresholds if isinstance(diff_thresholds, list) else \ | |||||
[diff_thresholds for _ in range(len(layer_by_name) - 1)] | |||||
if isinstance(diff_thresholds, list): | |||||
assert len(diff_thresholds) == len(layer_by_name) - 1, 'Diff thresholds must have a length of pool_layers -1' | |||||
def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_output: Dict[str, torch.Tensor]) -> torch.Tensor: | |||||
# reading pools from the output and deleting them | |||||
pool_vals = [x[1] for x in layers_values] | |||||
# getting samples labels | |||||
dl: DataLoader = DataloaderContext.instance.dataloader | |||||
labels = dl.get_current_batch_samples_labels() | |||||
labels = torch.from_numpy(labels).to(dl._device) | |||||
# sorting by shape from biggest to smallest | |||||
p_shapes = np.asarray([p.shape[-1] for p in pool_vals]) | |||||
sorted_inds = np.argsort(-1 * p_shapes) | |||||
pool_vals = [pool_vals[i] for i in sorted_inds] | |||||
pool_names = [self._pool_names[i] for i in sorted_inds] | |||||
loss = torch.zeros([], dtype=torch.float32, device=labels.device, requires_grad=True) | |||||
# for each pair of pools, calculate loss! | |||||
for i in range(len(pool_names) - 1): | |||||
p_loss = self._cal_pool_pair_loss(pool_vals[i], pool_vals[i + 1], self._diff_thresholds[i], labels) | |||||
assert (f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}') not in model_output, 'Trying to add ' + (f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}') + ' to model output multiple times' | |||||
model_output[f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}'] = p_loss.clone() | |||||
loss = loss + self._weights[i] * p_loss | |||||
return loss | |||||
def _cal_pool_pair_loss(self, p1: torch.Tensor, p2: torch.Tensor, diff_threshold: float, labels: torch.Tensor) -> torch.Tensor: | |||||
# down-sampling by max-pool till reaching the same shape as p2! | |||||
if p1.shape[-1] > p2.shape[-1]: | |||||
p1 = F.adaptive_max_pool2d(p1, p2.shape[-2:]) | |||||
# jensen shannon loss, for each channel -> in the related class! | |||||
loss = torch.tensor(0.0, requires_grad=True, device=p1.device) | |||||
for channel, r_labels in self._labels_by_channel.items(): | |||||
on_mask = self._get_inclusion_mask(labels, r_labels) | |||||
ip1 = p1[on_mask, channel, ...] | |||||
ip2 = p2[on_mask, channel, ...] | |||||
if torch.numel(ip1) > 0: | |||||
if ip1.shape != ip2.shape: | |||||
print(f'Problem in shape of concordance loss in {self._title}') | |||||
loss = loss + jensen_shannon_divergence( | |||||
ip1[:, None, ...], ip2[:, None, ...], diff_threshold) | |||||
return loss | |||||
def _get_inclusion_mask( | |||||
self, | |||||
samples_labels: torch.Tensor, | |||||
desired_labels: List[int]) -> torch.Tensor: | |||||
with torch.no_grad(): | |||||
inclusion_mask = torch.stack([samples_labels == l for l in desired_labels], dim=0) | |||||
aggregation = torch.sum(inclusion_mask.float(), dim=0) | |||||
return torch.greater(aggregation, 0) | |||||
def jensen_shannon_divergence(p1: torch.Tensor, p2: torch.Tensor, diff_threshold: float) -> torch.Tensor: | |||||
""" | |||||
Calculates the jensen shannon loss between two distributions p1 and p2 | |||||
Args: | |||||
p1 (torch.Tensor): A tensor of shape B C ... in range [0, 1] that is the probabilities for each neuron in the 1st distribution. | |||||
p2 (torch.Tensor): A tensor of the same shape as src, in range [0, 1] that is the probabilities for each neuron in the 2nd distribution. | |||||
diff_threshold (float): Threshold between p1 and p2 to decide whether to consider one pixel in loss | |||||
Returns: | |||||
torch.Tensor: The calculated loss. | |||||
""" | |||||
assert p1.shape == p2.shape, 'The tensors must have the same shape' | |||||
assert 0 <= diff_threshold < 1, 'The difference threshold should be in range [0, 1)' | |||||
# Reshaping tensors | |||||
p1 = torch.transpose(p1, 0, 1).flatten(1).transpose(0, 1) | |||||
p2 = torch.transpose(p2, 0, 1).flatten(1).transpose(0, 1) | |||||
# if binary, append class 0! | |||||
if p1.shape[1] == 1: | |||||
p1 = torch.cat([1 - p1, p1], dim=1) | |||||
p2 = torch.cat([1 - p2, p2], dim=1) | |||||
with torch.no_grad(): | |||||
mask = torch.abs(p1 - p2).detach() >= diff_threshold | |||||
mask = mask.max(dim=-1)[0] | |||||
# to make sure error does not result in log(negative)! | |||||
lp1 = torch.log(torch.maximum(p1 + 1e-4, 1e-6 + torch.zeros_like(p1))) | |||||
lp2 = torch.log(torch.maximum(p2 + 1e-4, 1e-6 + torch.zeros_like(p2))) | |||||
loss = 0.5 * ( | |||||
_smean(mask, torch.sum(p1 * (lp1 - lp2), dim=-1)) + | |||||
_smean(mask, torch.sum(p2 * (lp2 - lp1), dim=-1)) | |||||
) | |||||
return loss | |||||
def _smean(mask: torch.Tensor, val: torch.Tensor) -> torch.Tensor: | |||||
if torch.any(mask.bool()): | |||||
return torch.mean(val[mask]) | |||||
else: | |||||
return torch.tensor(0.0, requires_grad=True, device=mask.device) |
from typing import Dict, List, Tuple, Optional | |||||
import numpy as np | |||||
import torch | |||||
from torch import nn | |||||
from torch.nn import functional as F | |||||
from ..data.data_loader import DataLoader | |||||
from ..data.dataloader_context import DataloaderContext | |||||
from .aux_loss import AuxLoss | |||||
class DiscriminativeWeaklySupervisedLoss(AuxLoss): | |||||
def __init__( | |||||
self, title: str, model: nn.Module, | |||||
att_score_layer: nn.Module, | |||||
loss_weight: float, | |||||
neg_ratio: float, pos_ratio_range: Tuple[float, float], | |||||
on_labels_by_channel: Dict[int, List[int]], | |||||
discr_score_layer: nn.Module = None, | |||||
w_attention_in_ordering: float = 1, | |||||
w_discr_in_ordering: float = 1): | |||||
""" | |||||
Calculates binary weakly supervised score for an attention layer | |||||
with extra discrimination head. | |||||
Args: | |||||
title (str): The title of the loss, must be unique! | |||||
model (Model): the base model, so the output can be modified | |||||
att_score_layer (torch.nn.Module): A layer that gives attention score (B C ...) | |||||
loss_weight (float): The weight of the loss | |||||
neg_ratio (float, optional): top ratio to apply loss for negative samples. Defaults to 0.1. | |||||
pos_ratio_range (Tuple[float, float], optional): low and top ratios to apply loss for positive samples. Defaults to (0.033, 0.278). Calculated by distribution of positive bounding boxes. | |||||
on_labels_by_channel (Dict[int, List[int]]): The dictionary that specifies the samples related to which labels should be on in each channel. | |||||
w_attention_in_ordering (float): The weight of the attention score used in ordering the pixels. | |||||
w_discr_in_ordering (float): The weight of the reference score used in ordering the pixels. | |||||
discr_score_layer (torch.nn.Module): A layer that gives discriminative score (B C ...) | |||||
""" | |||||
layers = dict( | |||||
att=att_score_layer, | |||||
) | |||||
if discr_score_layer is not None: | |||||
layers['discr'] = discr_score_layer | |||||
super().__init__(title, model, layers, loss_weight) | |||||
self._has_discr = discr_score_layer is not None | |||||
self._neg_ratio = neg_ratio | |||||
self._pos_ratio_range = pos_ratio_range | |||||
self._on_labels_by_channel = on_labels_by_channel | |||||
self._w_attention_in_ordering = w_attention_in_ordering | |||||
self._w_discr_in_ordering = w_discr_in_ordering | |||||
def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_out: Dict[str, torch.Tensor]) -> torch.Tensor: | |||||
discriminative_scores = None | |||||
probabilities = None | |||||
for ln, lv in layers_values: | |||||
if ln == 'att': | |||||
probabilities = lv | |||||
else: | |||||
discriminative_scores = lv | |||||
dl: DataLoader = DataloaderContext.instance.dataloader | |||||
labels = dl.get_current_batch_samples_labels() | |||||
if discriminative_scores is not None: | |||||
discrimination_loss = self._calculate_discrimination_loss( | |||||
labels, discriminative_scores) | |||||
assert (self._title + '_discr_loss') not in model_out, 'Trying to add ' + (self._title + '_discr_loss') + ' to model output multiple times' | |||||
model_out[self._title + '_discr_loss'] = discrimination_loss.clone() | |||||
else: | |||||
discrimination_loss = torch.zeros([], requires_grad=True, device=labels.device) | |||||
attention_loss = self._calculate_attention_loss( | |||||
labels, probabilities, (discriminative_scores if discriminative_scores is not None else probabilities)) | |||||
assert (self._title + '_ws_loss') not in model_out, 'Trying to add ' + (self._title + '_ws_loss') + ' to model output multiple times' | |||||
model_out[self._title + '_ws_loss'] = attention_loss.clone() | |||||
loss = self._loss_weight * (discrimination_loss + attention_loss) | |||||
return loss | |||||
def _calculate_discrimination_loss( | |||||
self, | |||||
samples_labels: torch.Tensor, | |||||
discrimination_scores: torch.Tensor) -> torch.Tensor: | |||||
losses = [] | |||||
for channel, labels in self._on_labels_by_channel.items(): | |||||
on_mask = self._get_inclusion_mask(samples_labels, labels, discrimination_scores.device) | |||||
on_ps = discrimination_scores[on_mask, channel, ...] | |||||
off_ps = discrimination_scores[torch.logical_not(on_mask), channel, ...] | |||||
if torch.numel(on_ps) > 0: | |||||
losses.append(self._cal_loss(1, True, on_ps)) | |||||
if torch.numel(off_ps) > 0: | |||||
losses.append(self._cal_loss(1, False, off_ps)) | |||||
return torch.mean(torch.stack(losses)) | |||||
def _calculate_attention_loss( | |||||
self, | |||||
samples_labels: torch.Tensor, | |||||
attention_scores: torch.Tensor, | |||||
discrimination_scores: torch.Tensor) -> torch.Tensor: | |||||
losses = [] | |||||
for channel, labels in self._on_labels_by_channel.items(): | |||||
on_mask = self._get_inclusion_mask(samples_labels, labels, discrimination_scores.device) | |||||
on_atts = attention_scores[on_mask, channel, ...] | |||||
on_discr = discrimination_scores[on_mask, channel, ...].detach() | |||||
off_atts = attention_scores[torch.logical_not(on_mask), channel, ...] | |||||
off_discr = discrimination_scores[torch.logical_not(on_mask), channel, ...].detach() | |||||
neg_losses = [] | |||||
pos_losses = [] | |||||
# loss injection to the model | |||||
if torch.numel(off_atts) > 0 and self._neg_ratio > 0: | |||||
neg_losses.append(self._cal_loss( | |||||
self._neg_ratio, False, off_atts, off_discr, largest=True | |||||
)) | |||||
if torch.numel(on_atts) > 0: | |||||
# Calculate positive top k to be positive | |||||
if self._pos_ratio_range[0] > 0: | |||||
pos_losses.append(self._cal_loss( | |||||
self._pos_ratio_range[0], True, on_atts, on_discr, True | |||||
)) | |||||
# Calculate positive bottom k to be negative | |||||
if self._pos_ratio_range[1] < 1: | |||||
neg_losses.append(self._cal_loss( | |||||
1 - self._pos_ratio_range[1], False, on_atts, on_discr, False | |||||
)) | |||||
if len(neg_losses) > 0: | |||||
losses.append(torch.stack(neg_losses).mean()) | |||||
if len(pos_losses) > 0: | |||||
losses.append(torch.stack(pos_losses).mean()) | |||||
return torch.stack(losses).mean() | |||||
def _get_inclusion_mask( | |||||
self, | |||||
samples_labels: np.ndarray, | |||||
desired_labels: List[int], device: torch.device) -> torch.Tensor: | |||||
with torch.no_grad(): | |||||
samples_labels = torch.from_numpy(samples_labels).to(device) | |||||
inclusion_mask = torch.stack([samples_labels == l for l in desired_labels], dim=0) | |||||
aggregation = torch.sum(inclusion_mask.float(), dim=0) | |||||
return torch.greater(aggregation, 0) | |||||
def _cal_loss( | |||||
self, | |||||
ratio: float, positive_label: bool, | |||||
att_scores: torch.Tensor, | |||||
discr_scores: Optional[torch.Tensor] = None, | |||||
largest: bool = True): | |||||
if ratio == 1: | |||||
ps = att_scores | |||||
else: | |||||
k = np.ceil( | |||||
ratio * att_scores.shape[-1] * att_scores.shape[-2]).astype(int) | |||||
ps = self._get_topk(att_scores, discr_scores, k, largest=largest) | |||||
ps = ps.flatten() | |||||
if positive_label: | |||||
gt = torch.ones_like(ps) | |||||
else: | |||||
gt = torch.zeros_like(ps) | |||||
return F.binary_cross_entropy(ps, gt) | |||||
def _get_topk(self, att_scores: torch.Tensor, discr_scores: torch.Tensor, | |||||
k: int, dim=-1, largest=True, return_indices=False) -> torch.Tensor: | |||||
scores = self._pixels_scores(att_scores, discr_scores) | |||||
b = att_scores.shape[0] | |||||
top_inds = (scores.flatten(1)).topk(k, dim=dim, largest=largest, sorted=False).indices | |||||
# B K | |||||
ret_val = att_scores.flatten(1)[ | |||||
torch.repeat_interleave( | |||||
torch.arange(b, device=att_scores.device), k).reshape(b, k), | |||||
top_inds] # B K | |||||
if not return_indices: | |||||
return ret_val | |||||
else: | |||||
return ret_val, top_inds | |||||
def _pixels_scores(self, attention_scores: torch.Tensor, discr_scores: torch.Tensor) -> torch.Tensor: | |||||
return self._w_attention_in_ordering * attention_scores + self._w_discr_in_ordering * discr_scores | |||||
""" | |||||
The BatchChooser | |||||
""" | |||||
from abc import abstractmethod | |||||
from typing import List, TYPE_CHECKING | |||||
import numpy as np | |||||
if TYPE_CHECKING: | |||||
from ..data_loader import DataLoader | |||||
class BatchChooser: | |||||
""" | |||||
The BatchChooser | |||||
""" | |||||
def __init__(self, data_loader: 'DataLoader', batch_size: int): | |||||
""" Receives as input a data_loader which contains the information about the samples and the desired batch size. """ | |||||
self.data_loader = data_loader | |||||
self.batch_size = batch_size | |||||
self.current_batch_sample_indices: np.ndarray = np.asarray([], dtype=int) | |||||
self.completed_iteration = False | |||||
self.class_samples_indices = self.data_loader.get_class_sample_indices() | |||||
self._classes_labels: List[int] = list(self.class_samples_indices.keys()) | |||||
def prepare_next_batch(self): | |||||
""" | |||||
Prepares the next batch based on its strategy | |||||
""" | |||||
if self.finished_iteration(): | |||||
self.reset() | |||||
self.current_batch_sample_indices = \ | |||||
self.get_next_batch_sample_indices().astype(int) | |||||
if len(self.current_batch_sample_indices) == 0: | |||||
self.completed_iteration = True | |||||
return | |||||
@abstractmethod | |||||
def get_next_batch_sample_indices(self) -> np.ndarray: | |||||
pass | |||||
def get_current_batch_sample_indices(self) -> np.ndarray: | |||||
""" Returns a list of indices of the samples chosen for the current batch. """ | |||||
return self.current_batch_sample_indices | |||||
def finished_iteration(self): | |||||
""" Returns True if iteration is finished over all the slices of all the samples, False otherwise""" | |||||
return self.completed_iteration | |||||
def reset(self): | |||||
""" Resets the sample iterator. """ | |||||
self.completed_iteration = False | |||||
self.current_batch_sample_indices = np.asarray([], dtype=int) | |||||
def get_current_batch_size(self) -> int: | |||||
return len(self.current_batch_sample_indices) |
from typing import TYPE_CHECKING | |||||
from .batch_chooser import BatchChooser | |||||
import numpy as np | |||||
if TYPE_CHECKING: | |||||
from ..data_loader import DataLoader | |||||
class ClassBalancedShuffledSequentialBatchChooser(BatchChooser): | |||||
def __init__(self, data_loader: 'DataLoader', batch_size: int): | |||||
super().__init__(data_loader, batch_size) | |||||
self._cursor = 0 | |||||
# just copying to prevent any dependency | |||||
for k, v in self.class_samples_indices.items(): | |||||
self.class_samples_indices[k] = np.copy(v) | |||||
np.random.shuffle(self.class_samples_indices[k]) | |||||
self.classes_cursors = np.zeros(len(self.class_samples_indices), dtype=int) | |||||
def get_next_batch_sample_indices(self): | |||||
""" Returns a list containing the indices of the samples chosen for the next batch.""" | |||||
n_samples_per_class = self._get_n_samples_per_class() | |||||
def sample_for_each_class(class_index): | |||||
class_samples = self.class_samples_indices[ | |||||
self._classes_labels[class_index]] | |||||
if self.classes_cursors[class_index] + \ | |||||
n_samples_per_class[class_index] <= \ | |||||
len(class_samples): | |||||
ret_val = np.copy(class_samples[ | |||||
self.classes_cursors[class_index]: | |||||
self.classes_cursors[class_index] + n_samples_per_class[class_index] | |||||
]) | |||||
self.classes_cursors[class_index] += n_samples_per_class[class_index] | |||||
else: | |||||
ret_val = np.copy(class_samples)[ | |||||
self.classes_cursors[class_index]:] | |||||
np.random.shuffle(class_samples) | |||||
self.classes_cursors[class_index] = \ | |||||
n_samples_per_class[class_index] - len(ret_val) | |||||
ret_val = np.concatenate(( | |||||
ret_val, | |||||
np.copy(class_samples | |||||
[: n_samples_per_class[class_index] - len(ret_val)] | |||||
)), axis=0) | |||||
if self.classes_cursors[class_index] == len(class_samples): | |||||
np.random.shuffle(class_samples) | |||||
self.classes_cursors[class_index] = 0 | |||||
return ret_val | |||||
chosen_sample_indices = np.concatenate(tuple([ | |||||
sample_for_each_class(ci) for ci in range(len(self._classes_labels)) | |||||
]), axis=0) | |||||
return chosen_sample_indices | |||||
def _get_n_samples_per_class(self): | |||||
# determining the number of samples per class | |||||
n_samples_per_class = np.full(len(self.class_samples_indices), | |||||
int(self.batch_size // len(self.class_samples_indices))) | |||||
remaining_samles_cnt = self.batch_size - np.sum(n_samples_per_class) | |||||
if remaining_samles_cnt: | |||||
if self._cursor + remaining_samles_cnt >= len(self._classes_labels): | |||||
n_samples_per_class[self._cursor: len(self._classes_labels)] += 1 | |||||
remaining_samles_cnt -= (len(self._classes_labels) - self._cursor) | |||||
self._cursor = 0 | |||||
if remaining_samles_cnt > 0: | |||||
n_samples_per_class[self._cursor: self._cursor + remaining_samles_cnt] += 1 | |||||
self._cursor += remaining_samles_cnt | |||||
return n_samples_per_class | |||||
from typing import TYPE_CHECKING | |||||
from .batch_chooser import BatchChooser | |||||
import numpy as np | |||||
if TYPE_CHECKING: | |||||
from ..data_loader import DataLoader | |||||
class SequentialBatchChooser(BatchChooser): | |||||
def __init__(self, data_loader: 'DataLoader', batch_size: int): | |||||
super().__init__(data_loader, batch_size) | |||||
self.cursor = 0 | |||||
self.desired_samples_inds = None | |||||
def reset(self): | |||||
super(SequentialBatchChooser, self).reset() | |||||
self.cursor = 0 | |||||
self.desired_samples_inds = None | |||||
def get_next_batch_sample_indices(self): | |||||
""" Returns a list containing the indices of the samples chosen for the next batch.""" | |||||
if self.desired_samples_inds is None: | |||||
self.desired_samples_inds = \ | |||||
np.arange(self.data_loader.get_number_of_samples()) | |||||
next_cursor = min( | |||||
len(self.desired_samples_inds), | |||||
self.cursor + self.batch_size) | |||||
next_sample_inds = self.desired_samples_inds[ | |||||
self.cursor:next_cursor] | |||||
self.cursor = next_cursor | |||||
return next_sample_inds | |||||
from enum import Enum | |||||
import os | |||||
import glob | |||||
from typing import Callable, TYPE_CHECKING, Union | |||||
import imageio | |||||
import cv2 | |||||
import numpy as np | |||||
import pandas as pd | |||||
from .content_loader import ContentLoader | |||||
if TYPE_CHECKING: | |||||
from ...configs.celeba_configs import CelebAConfigs | |||||
class CelebATag(Enum): | |||||
FiveOClockShadowTag = '5_o_Clock_Shadow' | |||||
ArchedEyebrowsTag = 'Arched_Eyebrows' | |||||
AttractiveTag = 'Attractive' | |||||
BagsUnderEyesTag = 'Bags_Under_Eyes' | |||||
BaldTag = 'Bald' | |||||
BangsTag = 'Bangs' | |||||
BigLipsTag = 'Big_Lips' | |||||
BigNoseTag = 'Big_Nose' | |||||
BlackHairTag = 'Black_Hair' | |||||
BlondHairTag = 'Blond_Hair' | |||||
BlurryTag = 'Blurry' | |||||
BrownHairTag = 'Brown_Hair' | |||||
BushyEyebrowsTag = 'Bushy_Eyebrows' | |||||
ChubbyTag = 'Chubby' | |||||
DoubleChinTag = 'Double_Chin' | |||||
EyeglassesTag = 'Eyeglasses' | |||||
GoateeTag = 'Goatee' | |||||
GrayHairTag = 'Gray_Hair' | |||||
HighCheekbonesTag = 'High_Cheekbones' | |||||
MaleTag = 'Male' | |||||
MouthSlightlyOpenTag = 'Mouth_Slightly_Open' | |||||
MustacheTag = 'Mustache' | |||||
NarrowEyesTag = 'Narrow_Eyes' | |||||
NoBeardTag = 'No_Beard' | |||||
OvalFaceTag = 'Oval_Face' | |||||
PaleSkinTag = 'Pale_Skin' | |||||
PointyNoseTag = 'Pointy_Nose' | |||||
RecedingHairlineTag = 'Receding_Hairline' | |||||
RosyCheeksTag = 'Rosy_Cheeks' | |||||
SideburnsTag = 'Sideburns' | |||||
SmilingTag = 'Smiling' | |||||
StraightHairTag = 'Straight_Hair' | |||||
WavyHairTag = 'Wavy_Hair' | |||||
WearingEarringsTag = 'Wearing_Earrings' | |||||
WearingHatTag = 'Wearing_Hat' | |||||
WearingLipstickTag = 'Wearing_Lipstick' | |||||
WearingNecklaceTag = 'Wearing_Necklace' | |||||
WearingNecktieTag = 'Wearing_Necktie' | |||||
YoungTag = 'Young' | |||||
BoundingBoxX = 'x_1' | |||||
BoundingBoxY = 'y_1' | |||||
BoundingBoxW = 'width' | |||||
BoundingBoxH = 'height' | |||||
Partition = 'partition' | |||||
LeftEyeX = 'lefteye_x' | |||||
LeftEyeY = 'lefteye_y' | |||||
RightEyeX = 'righteye_x' | |||||
RightEyeY = 'righteye_y' | |||||
NoseX = 'nose_x' | |||||
NoseY = 'nose_y' | |||||
LeftMouthX = 'leftmouth_x' | |||||
LeftMouthY = 'leftmouth_y' | |||||
RightMouthX = 'rightmouth_x' | |||||
RightMouthY = 'rightmouth_y' | |||||
class SampleField(Enum): | |||||
ImageId = 'image_id' | |||||
Specification = 'specification' | |||||
class CelebALoader(ContentLoader): | |||||
def __init__(self, conf: 'CelebAConfigs', data_specification: str): | |||||
""" read all directories for scans and annotations, which are split by DataSplitter.py | |||||
And then, keeps only those samples which are used for 'usage' """ | |||||
super().__init__(conf, data_specification) | |||||
self.conf = conf | |||||
self._datasep = True | |||||
self._samples_metadata = self._load_metadata(data_specification) | |||||
self._data_root = conf.data_root | |||||
self._warning_count = 0 | |||||
def _load_metadata(self, data_specification: str) -> pd.DataFrame: | |||||
if '/' in data_specification: | |||||
self._datasep = False | |||||
return pd.DataFrame({ | |||||
SampleField.ImageId.value: glob.glob(os.path.join(data_specification, '*.jpg')) | |||||
}) | |||||
metadata: pd.DataFrame = pd.read_csv(self.conf.dataset_metadata, sep='\t') | |||||
metadata = metadata[metadata[SampleField.Specification.value] == data_specification] | |||||
metadata = metadata.drop(SampleField.Specification.value, axis=1) | |||||
metadata = metadata.reset_index().drop('index', axis=1) | |||||
return metadata | |||||
def get_samples_names(self): | |||||
return self._samples_metadata[SampleField.ImageId.value].values | |||||
def get_samples_labels(self): | |||||
r""" Dummy. Because we have multiple labels """ | |||||
return np.ones((len(self._samples_metadata),))\ | |||||
if self.conf.main_tag is None\ | |||||
else self._samples_metadata[self.conf.main_tag.value].values | |||||
def drop_samples(self, drop_mask: np.ndarray) -> None: | |||||
self._samples_metadata = self._samples_metadata[np.logical_not(drop_mask)] | |||||
def get_placeholder_name_to_fill_function_dict(self): | |||||
""" Returns a dictionary of the placeholders' names (the ones this content loader supports) | |||||
to the functions used for filling them. The functions must receive as input data_loader, | |||||
which is an object of class data_loader that contains information about the current batch | |||||
(e.g. the indices of the samples, or if the sample has many elements the indices of the chosen | |||||
elements) and return an array per placeholder name according to the receives batch information. | |||||
IMPORTANT: Better to use a fixed prefix in the names of the placeholders to become clear which content loader | |||||
they belong to! Some sort of having a mark :))!""" | |||||
return { | |||||
'x': self._get_x, | |||||
**{ | |||||
tag.name: self._generate_tag_getter(tag) for tag in CelebATag | |||||
} | |||||
} | |||||
def _get_x(self, samples_inds: np.ndarray)\ | |||||
-> np.ndarray: | |||||
images = tuple(self._read_image(index) for index in samples_inds) | |||||
images = np.stack(images, axis=0) # B I I 3 | |||||
images = images.transpose((0, 3, 1, 2)) # B 3 I I | |||||
images = images.astype(float) / 255. | |||||
return images | |||||
def _read_image(self, sample_index: int) -> np.ndarray: | |||||
filepath = self._samples_metadata.iloc[sample_index][SampleField.ImageId.value] | |||||
if self._datasep: | |||||
filepath = os.path.join(self._data_root, filepath) | |||||
image = imageio.imread(filepath) # H W 3 | |||||
image = cv2.resize(image, (self.conf.input_size, self.conf.input_size)) # I I 3 | |||||
return image | |||||
def _generate_tag_getter(self, tag: CelebATag) -> Callable[[np.ndarray, Union[None, np.ndarray]], np.ndarray]: | |||||
def get_tag(samples_inds: np.ndarray) -> np.ndarray: | |||||
return self._samples_metadata.iloc[samples_inds][tag.value].values if self._datasep else np.zeros(len(samples_inds)) | |||||
return get_tag |
""" | |||||
A class for loading one content type needed for the run. | |||||
""" | |||||
from abc import abstractmethod, ABC | |||||
from typing import TYPE_CHECKING, Dict, Callable, Union, List | |||||
import numpy as np | |||||
if TYPE_CHECKING: | |||||
from .. import BaseConfig | |||||
LoadFunction = Callable[[np.ndarray, Union[None, np.ndarray]], np.ndarray] | |||||
class ContentLoader(ABC): | |||||
""" A class for loading one content type needed for the run. """ | |||||
def __init__(self, conf: 'BaseConfig', data_specification: str): | |||||
""" Receives as input conf which is a dictionary containing configurations. | |||||
This dictionary can also be used to pass fixed addresses for easing the usage! | |||||
prefix_name is the str which all the variables that must be filled with this class have | |||||
this prefix in their names so it would be clear that this class should fill it. | |||||
data_specification is the string that specifies data and where it is! e.g. train, test, val""" | |||||
self.conf = conf | |||||
self.data_specification = data_specification | |||||
self._fill_func_by_var_name: Union[None, Dict[str, LoadFunction]] = None | |||||
@abstractmethod | |||||
def get_samples_names(self): | |||||
""" Returns a list containing names of all the samples of the content loader, | |||||
each sample must owns a unique ID, and this function returns all this IDs. | |||||
The order of the list must always be the same during one run. | |||||
For example, this function can return an ID column of a table for TableLoader | |||||
or the dir of images as ID for ImageLoader""" | |||||
@abstractmethod | |||||
def get_samples_labels(self): | |||||
""" Returns list of labels of the whole samples. | |||||
The order of the list must always be the same during one run.""" | |||||
@abstractmethod | |||||
def get_placeholder_name_to_fill_function_dict(self) -> Dict[str, LoadFunction]: | |||||
""" Returns a dictionary of the placeholders' names (the ones this content loader supports) | |||||
to the functions used for filling them. The functions must receive as input | |||||
batch_samples_inds and batch_samples_elements_inds which defines the current batch, | |||||
and return an array per placeholder name according to the receives batch information. | |||||
IMPORTANT: Better to use a fixed prefix in the names of the placeholders to become clear which content loader | |||||
they belong to! Some sort of having a mark :))!""" | |||||
def fill_placeholders(self, | |||||
keys: List[str], | |||||
samples_inds: np.ndarray)\ | |||||
-> Dict[str, Union[None, np.ndarray]]: | |||||
""" Receives as input placeholders which is a dictionary of the placeholders' | |||||
names to a torch tensor for filling it to feed the model with and | |||||
samples_inds and samples_elements_inds, which contain | |||||
information about the current batch. Fills placeholders based on the | |||||
function dictionary received in get_placeholder_name_to_fill_function_dict.""" | |||||
if self._fill_func_by_var_name is None: | |||||
self._fill_func_by_var_name = self.get_placeholder_name_to_fill_function_dict() | |||||
# Filling all the placeholders in the received dictionary! | |||||
placeholders = dict() | |||||
for placeholder_name in keys: | |||||
if placeholder_name in self._fill_func_by_var_name: | |||||
placeholder_v = self._fill_func_by_var_name[placeholder_name]( | |||||
samples_inds) | |||||
if placeholder_v is None: | |||||
raise Exception('None value for key %s' % placeholder_name) | |||||
placeholders[placeholder_name] = placeholder_v | |||||
else: | |||||
raise Exception(f'Unknown key for content loader: {placeholder_name}') | |||||
return placeholders | |||||
def get_batch_true_interpretations(self, samples_inds: np.ndarray, | |||||
samples_elements_inds: Union[None, np.ndarray, List[np.ndarray]])\ | |||||
-> np.ndarray: | |||||
""" | |||||
Receives the indices of samples and their elements (if elemented) in the current batch, | |||||
returns the true interpretations as expected (Bounding box or a boolean mask) | |||||
""" | |||||
raise NotImplementedError('If you want to evaluate interpretations, you must implement this function in your content loader.') | |||||
def drop_samples(self, drop_mask: np.ndarray) -> None: | |||||
""" | |||||
Receives a boolean drop mask and drops the samples whose mask are true, | |||||
From now on the indices of samples are based on the new samples set (after elimination) | |||||
""" | |||||
raise NotImplementedError('If you want to filter samples by mapped labels, you should implement this function in your content loader.') |
import os | |||||
from typing import TYPE_CHECKING, Dict, List, Tuple | |||||
import pickle | |||||
from PIL import Image | |||||
from enum import Enum | |||||
import pandas as pd | |||||
import numpy as np | |||||
import torch | |||||
from torchvision import transforms | |||||
from torchvision.datasets.folder import default_loader, find_classes, make_dataset, IMG_EXTENSIONS | |||||
from .content_loader import ContentLoader | |||||
from ...utils.bb_generator import generate_bb_map | |||||
if TYPE_CHECKING: | |||||
from ...configs.imagenet_configs import ImagenetConfigs | |||||
def _get_image_transform(data_specification: str, input_size: int) -> transforms.Compose: | |||||
if data_specification == 'train': | |||||
return transforms.Compose([ | |||||
transforms.RandomResizedCrop(input_size), | |||||
transforms.RandomHorizontalFlip(), | |||||
transforms.ToTensor(), | |||||
]) | |||||
if data_specification in ['test', 'val']: | |||||
return transforms.Compose([ | |||||
transforms.Resize(256), | |||||
transforms.CenterCrop(input_size), | |||||
transforms.ToTensor(), | |||||
]) | |||||
raise ValueError('Unknown data specification: {}'.format(data_specification)) | |||||
def load_samples_cache(cache_name: str, imagenet_dir: str) -> List[Tuple[str, int]]: | |||||
cache_path = os.path.join('.cache/', cache_name) | |||||
print(cache_path) | |||||
if os.path.isfile(cache_path): | |||||
print('Loading cached samples from {}'.format(cache_path)) | |||||
with open(cache_path, 'rb') as f: | |||||
return pickle.load(f) | |||||
print('Creating cache for {}'.format(cache_name)) | |||||
os.makedirs(os.path.dirname(cache_path), exist_ok=True) | |||||
_, class_to_idx = find_classes(imagenet_dir) | |||||
samples = make_dataset(imagenet_dir, class_to_idx, IMG_EXTENSIONS) | |||||
with open(cache_path, 'wb') as f: | |||||
pickle.dump(samples, f) | |||||
return samples | |||||
class BBoxField(Enum): | |||||
ImageId = 'ImageId' | |||||
PredictionString = 'PredictionString' | |||||
def load_bbox_df(imagenet_root: str, data_specification: str) -> pd.DataFrame: | |||||
return pd.read_csv(os.path.join(imagenet_root, f'LOC_{data_specification}_solution.csv')) | |||||
class ImagenetLoader(ContentLoader): | |||||
def __init__(self, conf: 'ImagenetConfigs', data_specification: str): | |||||
super().__init__(conf, data_specification) | |||||
imagenet_dir = os.path.join(conf.data_separation, data_specification) | |||||
self.__samples = load_samples_cache(f'imagenet.{data_specification}', imagenet_dir) | |||||
self.__transform = _get_image_transform(data_specification, conf.input_size) | |||||
self.__bboxes = load_bbox_df(conf.data_separation, data_specification) | |||||
self.__rng_states: Dict[int, Tuple[str, torch.Tensor]] = {} | |||||
def get_samples_names(self): | |||||
''' sample names must be unique, they can be either scan_names or scan_dirs. | |||||
Decided to put scan_names. No difference''' | |||||
return np.array([path for path, _ in self.__samples]) | |||||
def get_samples_labels(self): | |||||
return np.array([label for _, label in self.__samples]) | |||||
def drop_samples(self, drop_mask: np.ndarray) -> None: | |||||
self.__samples = [sample for sample, drop in zip(self.__samples, drop_mask) if not drop] | |||||
def get_placeholder_name_to_fill_function_dict(self): | |||||
""" Returns a dictionary of the placeholders' names (the ones this content loader supports) | |||||
to the functions used for filling them. The functions must receive as input data_loader, | |||||
which is an object of class data_loader that contains information about the current batch | |||||
(e.g. the indices of the samples, or if the sample has many elements the indices of the chosen | |||||
elements) and return an array per placeholder name according to the receives batch information. | |||||
IMPORTANT: Better to use a fixed prefix in the names of the placeholders to become clear which content loader | |||||
they belong to! Some sort of having a mark :))!""" | |||||
return { | |||||
'x': self.__get_x, | |||||
'y': self.__get_y, | |||||
'bbox': self.__get_bbox, | |||||
'interpretations': self.__get_bbox, | |||||
} | |||||
def __get_x(self, samples_inds: np.ndarray)\ | |||||
-> np.ndarray: | |||||
def get_image(index) -> torch.Tensor: | |||||
path, _ = self.__samples[index] | |||||
sample = default_loader(path) | |||||
self.__handle_random_state(index, 'x') | |||||
return self.__transform(sample) | |||||
return torch.stack( | |||||
tuple([get_image(index) for index in samples_inds]), | |||||
dim=0) | |||||
def __get_y(self, samples_inds: np.ndarray)\ | |||||
-> np.ndarray: | |||||
return self.get_samples_labels()[samples_inds] | |||||
def __handle_random_state(self, idx: int, label: str) -> None: | |||||
if idx not in self.__rng_states or self.__rng_states[idx][0] == label: | |||||
self.__rng_states[idx] = (label, torch.get_rng_state()) | |||||
else: | |||||
torch.set_rng_state(self.__rng_states[idx][1]) | |||||
def __get_bbox(self, sample_inds: np.ndarray)\ | |||||
-> np.ndarray: | |||||
def extract_bb(prediction_string: str) -> np.ndarray: | |||||
splitted = prediction_string.split() | |||||
n = len(splitted) // 5 | |||||
return np.array([[float(b) for b in splitted[5 * i + 1: 5 * i + 5]] for i in range(n)])\ | |||||
.reshape((-1, 2, 2)) | |||||
def make_bb(index) -> torch.Tensor: | |||||
path, _ = self.__samples[index] | |||||
image_id = os.path.basename(path).split('.')[0] | |||||
image_size = np.array(Image.open(path).size)[::-1] # 2 | |||||
bboxes = self.__bboxes[self.__bboxes[BBoxField.ImageId.value] == image_id][BBoxField.PredictionString.value] | |||||
if len(bboxes) == 0: | |||||
return torch.zeros(1, self.conf.inp_size, self.conf.inp_size) * np.nan | |||||
bboxes = bboxes.values[0] | |||||
bboxes = extract_bb(bboxes)[..., ::-1] / image_size # N 2 2 | |||||
start_points = bboxes[:, 0] # N 2 | |||||
end_points = bboxes[:, 1] # N 2 | |||||
bb_map = generate_bb_map(start_points, end_points, tuple(image_size)) | |||||
bb_map = Image.fromarray(bb_map) | |||||
self.__handle_random_state(index, 'bb') | |||||
return self.__transform(bb_map) | |||||
return torch.stack([make_bb(i) for i in sample_inds], dim=0) |
from typing import Dict, List, TYPE_CHECKING, Optional, Tuple | |||||
from os import path, listdir | |||||
from sys import stderr | |||||
from functools import partial | |||||
import cv2 | |||||
import numpy as np | |||||
import pandas as pd | |||||
from .content_loader import ContentLoader | |||||
from ...utils.bb_generator import generate_bb_map | |||||
if TYPE_CHECKING: | |||||
from ...configs.rsna_configs import RSNAConfigs | |||||
class RSNALoader(ContentLoader): | |||||
def __init__(self, conf: 'RSNAConfigs', data_specification: str): | |||||
""" read all directories for scans and annotations, which are split by DataSplitter.py | |||||
And then, keeps only those samples which are used for 'usage' """ | |||||
super().__init__(conf, data_specification) | |||||
self.conf = conf | |||||
self.img_size = conf.input_size | |||||
self.infection_map_size = conf.infection_map_size | |||||
self.samples_info = None | |||||
self.load_samples(data_specification) | |||||
self.loaded_images = None | |||||
self._infections_bbs: Optional[Dict[str, Tuple[List[Tuple[float, float], List[Tuple[float, float]]]]]] = None | |||||
self._load_infections() | |||||
def load_samples(self, data_specification): | |||||
if ':' in data_specification: | |||||
data_specification, filter_str = data_specification.split(':') | |||||
else: | |||||
filter_str = None | |||||
if data_specification in ['train', 'val', 'test']: | |||||
self.samples_info = \ | |||||
pd.read_csv('%s/%s.txt' % | |||||
( | |||||
self.conf.data_separation, data_specification), sep='\t', header=0) | |||||
self.samples_info['sample'] = self.samples_info['sample'].apply(lambda s: s[len('../'):]) | |||||
self.samples_info['Interpretation_dir'] = self.samples_info['Interpretation_dir'].apply(lambda s: s[len('../'):]) | |||||
elif path.isfile(data_specification) and path.exists(data_specification): | |||||
self.samples_info = pd.read_csv(data_specification, sep='\t', header=0) | |||||
self.samples_info['sample'] = self.samples_info['sample'].apply(lambda s: s[len('../'):]) | |||||
self.samples_info['Interpretation_dir'] = self.samples_info['Interpretation_dir'].apply(lambda s: s[len('../'):]) | |||||
elif path.isdir(data_specification) and path.exists(data_specification): | |||||
samples_names = np.asarray( | |||||
[data_specification + '/' + x for x in listdir(data_specification)]) | |||||
print('%d samples discovered in %s' % (len(samples_names), data_specification)) | |||||
if filter_str is not None: | |||||
pass_filter_str = np.vectorize(lambda x: '/%s/' % filter_str in x) | |||||
samples_names = samples_names[pass_filter_str(samples_names)] | |||||
print('%d samples remained after filtering' % len(samples_names)) | |||||
self.samples_info = pd.DataFrame({ | |||||
'label': np.full(len(samples_names), -1, dtype=int), | |||||
'sample': samples_names, | |||||
'view': np.full(len(samples_names), 'Unknown', dtype=np.object), | |||||
'dataset': np.full(len(samples_names), 'unknown', dtype=np.object), | |||||
'Interpretation_dir': np.full(len(samples_names), '', dtype=np.object), | |||||
}) | |||||
return | |||||
else: | |||||
print('Please implement this part!', flush=True) | |||||
org_cnt = len(self.samples_info) | |||||
# filtering by filter string! | |||||
if filter_str is not None: | |||||
pass_filter_str = np.vectorize(lambda x: '/%s/' % filter_str in x) | |||||
self.samples_info = self.samples_info[pass_filter_str(self.samples_info['sample'].values)] | |||||
print( | |||||
'%d of %d samples remained after filtering for %s!' % (len(self.samples_info), org_cnt, data_specification), | |||||
flush=True) | |||||
org_cnt = len(self.samples_info) | |||||
# label mapping/ordering! | |||||
label_mapping_dict = self.conf.label_map_dict | |||||
v_map_label = np.vectorize(lambda l: label_mapping_dict.get(l.lower(), -1)) | |||||
self.samples_info['label'] = v_map_label(self.samples_info['label'].values) | |||||
# filtering unmapped labels | |||||
self.samples_info = self.samples_info[self.samples_info['label'].values != -1] | |||||
print('%d out of %d samples remained after label mapping!' % | |||||
(len(self.samples_info), org_cnt), flush=True) | |||||
# counting the labels! | |||||
flattened_labels = self.samples_info['label'].values | |||||
u_labels = np.unique(flattened_labels) | |||||
labels_cnt = np.zeros(len(u_labels)) | |||||
np.add.at(labels_cnt, np.searchsorted(u_labels, flattened_labels, side='left'), 1) | |||||
print('[%s]' % ', '.join(['class-%s: %d' % (str(u_labels[i]), labels_cnt[i]) for i in range(len(labels_cnt))]), | |||||
flush=True) | |||||
def _load_infections(self) -> None: | |||||
""" | |||||
If a file address is given as infections_bb_dir in configs | |||||
(containing imgs and their bounding boxes as a str) | |||||
the information related to bounding boxes will be taken, | |||||
otherwise it is expected from the data separation to have a column | |||||
indicating the address of the infection mask | |||||
""" | |||||
if self.conf.infections_bb_dir is not None: | |||||
bbs_info = pd.read_csv(self.conf.infections_bb_dir, sep='\t', header=0) | |||||
imgs_dirs = np.vectorize(self._get_id_by_path)(bbs_info['img_dir'].values) | |||||
imgs_bbs = [ | |||||
[] if '-' not in str(im_bbs) else [ | |||||
tuple([np.clip(float(x), 0, 1) for x in im_bb.split('-')]) | |||||
for im_bb in im_bbs.split(',') | |||||
] | |||||
for im_bbs in bbs_info['img_bb'].values] | |||||
# reformatting to tuple of list of starts and list of ends | |||||
imgs_bbs_starts = [[ | |||||
(r1, c1) for r1, c1, _, _ in img_bbs] | |||||
for img_bbs in imgs_bbs] | |||||
imgs_bbs_ends = [[ | |||||
(r2, c2) for _, _, r2, c2 in img_bbs] | |||||
for img_bbs in imgs_bbs] | |||||
imgs_bbs = [ | |||||
(imgs_bbs_starts[i], imgs_bbs_ends[i]) | |||||
for i in range(len(imgs_bbs_starts))] | |||||
self._infections_bbs = dict(zip(imgs_dirs, imgs_bbs)) | |||||
def get_samples_names(self): | |||||
return self.samples_info['sample'].values | |||||
def get_samples_labels(self): | |||||
return self.samples_info['label'].values | |||||
def drop_samples(self, drop_mask: np.asarray) -> None: | |||||
""" Keep_mask is a bool array for keeping samples """ | |||||
self.samples_info = self.samples_info[np.logical_not(drop_mask)] | |||||
def get_placeholder_name_to_fill_function_dict(self): | |||||
ret_val = { | |||||
'x': self.get_batch_scaled_images, | |||||
'y': self.get_batch_label, | |||||
'infection': partial(self.get_batch_extended_bbs_map_from_bbs_file, 0) if | |||||
self._infections_bbs is not None else | |||||
self.get_batch_bbs_map_from_mask_file, | |||||
'interpretations': partial(self.get_batch_extended_bbs_map_from_bbs_file, 0) if | |||||
self._infections_bbs is not None else | |||||
self.get_batch_bbs_map_from_mask_file, | |||||
'has_bb': self.get_batch_has_bb_mask_from_bb_file if | |||||
self._infections_bbs is not None else | |||||
self.get_batch_has_bb_mask_from_mask_file | |||||
} | |||||
# if infection radii is not empty, it is expected to have a dictionary | |||||
if len(self.conf.receptive_field_radii) > 0: | |||||
assert self._infections_bbs is not None, \ | |||||
"When having receptive field radii, " \ | |||||
"you should provide bounding boxes as a file in config.infections_bb_dir" | |||||
ret_val.update({ | |||||
f'ex_infection_{r}': partial(self.get_batch_extended_bbs_map_from_bbs_file, r) | |||||
for r in self.conf.receptive_field_radii | |||||
}) | |||||
return ret_val | |||||
def get_batch_has_bb_mask_from_mask_file(self, samples_inds: np.ndarray) \ | |||||
-> np.ndarray: | |||||
assert 'Interpretation_dir' in self.samples_info.columns, \ | |||||
'If you have not specified infections_bb_dirin your config, ' \ | |||||
'you need to have Interpretation_dir column in your data separation' | |||||
def has_inf(si): | |||||
if self.samples_info.iloc[si]['label'] == 0: | |||||
return True | |||||
int_dir = self.samples_info.iloc[si]['Interpretation_dir'] | |||||
return str(int_dir) != 'nan' | |||||
return np.vectorize(has_inf)(samples_inds) | |||||
def get_batch_has_bb_mask_from_bb_file(self, samples_inds: np.ndarray) \ | |||||
-> np.ndarray: | |||||
def has_inf(si): | |||||
if self.samples_info.iloc[si]['label'] == 0: | |||||
return True | |||||
sample_key = self.samples_info.iloc[si]['sample'] | |||||
sample_key = self._get_id_by_path(sample_key) | |||||
return sample_key in self._infections_bbs and \ | |||||
len(self._infections_bbs[sample_key][0]) > 0 | |||||
return np.vectorize(has_inf)(samples_inds) | |||||
def get_batch_bbs_map_from_mask_file(self, samples_inds: np.ndarray)\ | |||||
-> np.ndarray: | |||||
def read_interpretation(im_ind): | |||||
# for healthy, return full 0 | |||||
if self.samples_info.iloc[im_ind]['label'] == 0: | |||||
return np.full((1, self.infection_map_size, self.infection_map_size), 0, dtype=np.float) | |||||
int_dir = self.samples_info.iloc[im_ind]['Interpretation_dir'] | |||||
# if it does not exist, return full -1 | |||||
if str(int_dir) == 'nan': | |||||
return np.full((1, self.infection_map_size, self.infection_map_size), -1, dtype=np.float) | |||||
if 'npy' in int_dir: | |||||
interpretation = np.load(int_dir) | |||||
else: | |||||
interpretation = np.load(int_dir)['arr_0'] | |||||
interpretation = interpretation.astype(float) | |||||
if interpretation.shape != (self.infection_map_size, self.infection_map_size): | |||||
interpretation = (cv2.resize(np.round(255 * interpretation, 0), | |||||
dsize=(self.infection_map_size, self.infection_map_size)) >= 128).astype(float) | |||||
return interpretation[np.newaxis, :, :] | |||||
batch_interpretations = np.stack(tuple([read_interpretation(si) | |||||
for si in samples_inds]), axis=0) | |||||
return batch_interpretations | |||||
@staticmethod | |||||
def _get_id_by_path(path: str) -> str: | |||||
return path[path.index('png_images/'):] | |||||
def get_batch_extended_bbs_map_from_bbs_file(self, radius, samples_inds: np.ndarray) -> np.ndarray: | |||||
def make_map(im_ind): | |||||
if self.samples_info.iloc[im_ind]['label'] == 0: | |||||
return np.full((1, self.infection_map_size, self.infection_map_size), 0, dtype=np.float) | |||||
sample_key = self.samples_info.iloc[im_ind]['sample'] | |||||
sample_key = self._get_id_by_path(sample_key) | |||||
if sample_key in self._infections_bbs and \ | |||||
len(self._infections_bbs[sample_key][0]) > 0: | |||||
bbs_info = self._infections_bbs[sample_key] | |||||
start_points, end_points = bbs_info | |||||
mask = generate_bb_map(start_points, end_points, | |||||
(self.infection_map_size, self.infection_map_size), radius) | |||||
return mask[np.newaxis, :, :] | |||||
else: | |||||
return np.full((1, self.infection_map_size, self.infection_map_size), -1, dtype=np.float) | |||||
batch_interpretations = np.stack(tuple([make_map(si) | |||||
for si in samples_inds]), axis=0) | |||||
return batch_interpretations | |||||
def get_batch_scaled_images(self, | |||||
samples_inds: np.ndarray) \ | |||||
-> np.ndarray: | |||||
def read_img(im_ind): | |||||
im = cv2.imread(self.samples_info.iloc[im_ind]['sample']) | |||||
if im is None: | |||||
print(self.samples_info.iloc[im_ind]['sample'] + ' is missing!', file=stderr) | |||||
raise Exception('Missing image') | |||||
if len(list(im.shape)) == 3: | |||||
ret_val = im[:, :, 0] | |||||
else: | |||||
ret_val = im | |||||
if ret_val.shape != (self.img_size, self.img_size): | |||||
ret_val = cv2.resize(ret_val, dsize=(self.img_size, self.img_size)) | |||||
return ret_val[np.newaxis, :, :] | |||||
batch_imgs = np.stack(tuple([read_img(si) | |||||
for si in samples_inds]), axis=0) | |||||
batch_imgs = batch_imgs.astype(np.float32) / 255 | |||||
return batch_imgs | |||||
def get_batch_label(self, | |||||
samples_inds: np.ndarray, | |||||
) \ | |||||
-> np.ndarray: | |||||
return self.samples_info['label'].iloc[samples_inds].values.astype(np.float32) | |||||
def get_batch_true_interpretations(self, samples_inds: np.ndarray) \ | |||||
-> np.ndarray: | |||||
def read_interpretation(im_ind): | |||||
# for healthy, return full 0 | |||||
if self.samples_info.iloc[im_ind]['label'] == 0: | |||||
return np.full((1, self.img_size, self.img_size), 0, dtype=np.float) | |||||
int_dir = self.samples_info.iloc[im_ind]['Interpretation_dir'] | |||||
if str(int_dir) == 'nan': | |||||
return np.full((1, self.img_size, self.img_size), np.nan, dtype=np.float) | |||||
if 'npy' in int_dir: | |||||
interpretation = np.load(int_dir) | |||||
else: | |||||
interpretation = np.load(int_dir)['arr_0'].astype(int) | |||||
if interpretation.shape != (self.img_size, self.img_size): | |||||
interpretation = (cv2.resize(np.round(255 * interpretation, 0).astype(np.uint8), | |||||
dsize=(self.img_size, self.img_size)) >= 128).astype(float) | |||||
return interpretation[np.newaxis, :, :] | |||||
batch_interpretations = np.stack(tuple([read_interpretation(si) | |||||
for si in samples_inds]), axis=0) | |||||
return batch_interpretations |
""" A class for loading the whole data needed for the run! """ | |||||
from typing import Dict, List, Union, TYPE_CHECKING | |||||
import random | |||||
import numpy as np | |||||
import torch | |||||
from .batch_choosing.batch_chooser import BatchChooser | |||||
from .content_loaders.content_loader import ContentLoader | |||||
if TYPE_CHECKING: | |||||
from ..configs.base_config import BaseConfig | |||||
class RunType: | |||||
TRAIN = 'train' | |||||
VAL = 'val' | |||||
TEST = 'test' | |||||
class DataLoader(): | |||||
""" A class for loading the whole data needed for the run! """ | |||||
def __init__(self, conf: 'BaseConfig', data_specification: str, run_type: RunType): | |||||
""" Conf is a dictionary containing configurations, | |||||
sample specification is a string specifying the samples e.g. address of the CTs | |||||
run type is one of the strings train/val/test specifying the mode that the data_loader | |||||
will be used.""" | |||||
self.conf = conf | |||||
self._device = conf.device | |||||
self.data_specification = data_specification | |||||
self.run_type = run_type | |||||
self._content_loader: ContentLoader = self.conf.content_loader_cls(self.conf, data_specification) | |||||
self.samples_names: np.ndarray = self._content_loader.get_samples_names() | |||||
self.samples_labels = self._content_loader.get_samples_labels() | |||||
keep_mask = self.get_samples_keep_mask() | |||||
if keep_mask is not None: | |||||
drop_mask = np.logical_not(keep_mask) | |||||
self._content_loader.drop_samples(drop_mask) | |||||
self.samples_names = self.samples_names[keep_mask] | |||||
self.samples_labels = self.samples_labels[keep_mask] | |||||
self.class_samples_indices = dict() | |||||
print('%d samples' % len(self.samples_names), flush=True) | |||||
for i in range(len(self.samples_names)): | |||||
if self.samples_labels[i] not in self.class_samples_indices: | |||||
self.class_samples_indices[self.samples_labels[i]] = [] | |||||
self.class_samples_indices[self.samples_labels[i]].append(i) | |||||
for c in self.class_samples_indices.keys(): | |||||
self.class_samples_indices[c] = np.asarray(self.class_samples_indices[c]) | |||||
# for augmentationa | |||||
# Only do augmentations in the training phase | |||||
self._different_augmentation_per_batch = conf.different_augmentation_per_batch | |||||
self._augmentations_dict: Union[None, Dict[str, torch.nn.Module]] = None | |||||
if run_type == RunType.TRAIN: | |||||
self._augmentations_dict = conf.augmentations_dict | |||||
# for preventing reprocessing | |||||
self._processed_data_dictionary: Dict[str, torch.Tensor] = dict() | |||||
# for preserving same augmentations in one batch | |||||
self._transformations_seeds_dict: Dict[torch.nn.Module, Union[int, np.ndarray]] = dict() | |||||
if run_type == RunType.TRAIN: | |||||
self._batch_chooser: BatchChooser = self.conf.train_batch_chooser_cls(self, conf.batch_size) | |||||
else: | |||||
self._batch_chooser: BatchChooser = self.conf.eval_batch_chooser_cls(self, conf.batch_size) | |||||
def get_samples_names(self) -> np.ndarray: | |||||
""" Returns the names of the samples based on the first content loader. | |||||
IMPORTANT: These all must be the same for all of the content loaders.""" | |||||
return self.samples_names | |||||
def get_samples_labels(self): | |||||
""" Returns the labels of the samples in a numpy array. """ | |||||
return self.samples_labels | |||||
def get_number_of_samples(self): | |||||
""" Returns the number of samples loaded. """ | |||||
return len(self.samples_names) | |||||
def get_class_sample_indices(self): | |||||
""" Returns a dictionary, containing lists of samples indices belonging to each class label.""" | |||||
return self.class_samples_indices | |||||
def prepare_next_batch(self): | |||||
# Resetting information | |||||
self._processed_data_dictionary = dict() | |||||
self._transformations_seeds_dict = dict() | |||||
self._batch_chooser.prepare_next_batch() | |||||
def finished_iteration(self) -> bool: | |||||
return self._batch_chooser.finished_iteration() | |||||
def fill_placeholders(self, placeholders_dict: Dict[str, torch.Tensor]) -> None: | |||||
""" Receives as input a dictionary of placeholders and fills them using all the content loaders.""" | |||||
missed_placeholders_names: List[str] | |||||
if len(self._processed_data_dictionary) == 0: | |||||
missed_placeholders_names = list(placeholders_dict.keys()) | |||||
else: | |||||
missed_placeholders_names = [] | |||||
for k, v in placeholders_dict.items(): | |||||
if k in self._processed_data_dictionary: | |||||
placeholders_dict[k] = self._processed_data_dictionary[k] | |||||
else: | |||||
missed_placeholders_names.append(k) | |||||
if len(missed_placeholders_names) == 0: | |||||
return | |||||
new_batch_info = self._content_loader.fill_placeholders( | |||||
missed_placeholders_names, self._batch_chooser.get_current_batch_sample_indices()) | |||||
# filling the unfilled ones | |||||
if len(missed_placeholders_names) > 0: | |||||
# filling all missed keys! | |||||
with torch.no_grad(): | |||||
for k in missed_placeholders_names: | |||||
self._fill_placeholder(placeholders_dict[k], new_batch_info[k]) | |||||
if self._augmentations_dict is not None and k in self._augmentations_dict: | |||||
placeholders_dict[k] = self._apply_augmentation(k, placeholders_dict[k]) | |||||
# Keeping a copy of data | |||||
self._processed_data_dictionary[k] = placeholders_dict[k] | |||||
def get_current_batch_data(self, keyword: str) -> torch.Tensor: | |||||
""" Builds placeholder and retrieves the requirement from loaders and returns it! | |||||
Args: | |||||
keyword (str): The keyword to load for the current batch. | |||||
Returns: | |||||
torch.Tensor: The information related to the keyword for the current batch. | |||||
""" | |||||
if keyword in self._processed_data_dictionary: | |||||
return self._processed_data_dictionary[keyword] | |||||
else: | |||||
placeholders_dict = {keyword: create_empty_placeholder(self._device)} | |||||
self.fill_placeholders(placeholders_dict) | |||||
return placeholders_dict[keyword] | |||||
def get_current_batch_sample_indices(self) -> np.ndarray: | |||||
""" Returns a list of indices of the samples chosen for the current batch. """ | |||||
return self._batch_chooser.get_current_batch_sample_indices() | |||||
def get_max_class_samples_num(self): | |||||
return max([len(x) for x in self.class_samples_indices.values()]) | |||||
def get_classes_num(self): | |||||
return len(self.class_samples_indices) | |||||
def get_samples_keep_mask(self) -> Union[None, np.ndarray]: | |||||
if self.conf.mapped_labels_to_use is None: | |||||
return None | |||||
existing_dict = dict(zip(self.conf.mapped_labels_to_use, self.conf.mapped_labels_to_use)) | |||||
keep_mask = np.vectorize(lambda x: x in existing_dict)(self.samples_labels) | |||||
print(f'Kept {np.sum(keep_mask)} out of {len(keep_mask)} samples after considering the labels of interest.') | |||||
return keep_mask | |||||
def get_current_batch_size(self) -> int: | |||||
return self._batch_chooser.get_current_batch_size() | |||||
def get_current_batch_samples_names(self) -> np.ndarray: | |||||
return self.samples_names[self._batch_chooser.get_current_batch_sample_indices()] | |||||
def get_current_batch_samples_labels(self) -> np.ndarray: | |||||
return self.samples_labels[self._batch_chooser.get_current_batch_sample_indices()] | |||||
def get_current_batch_samples_interpretations(self) -> np.ndarray: | |||||
return self.get_current_batch_data('interpretations') | |||||
@staticmethod | |||||
def _fill_placeholder(placeholder: torch.Tensor, val: np.ndarray): | |||||
""" Fills the torch placeholder with the given numpy value, | |||||
If shape mismatches, resizes the placeholder so the data would fit. """ | |||||
# Resize if the shape mismatches | |||||
if list(placeholder.shape) != list(val.shape): | |||||
placeholder.resize_(*tuple(val.shape)) | |||||
# feeding the value | |||||
placeholder.copy_(torch.Tensor(val)) | |||||
def _apply_augmentation(self, var_name: str, var_val: torch.Tensor) -> torch.Tensor: | |||||
""" This function would be called for the variables we are sure they need augmentation | |||||
(they are presented in augmentation dictionary), applies the specified augmentation, | |||||
makes sure all augmentations be the same on the same batch elements, | |||||
and returns the augmented data""" | |||||
def run_single_aug(aug, inp): | |||||
if type(aug) in self._transformations_seeds_dict: | |||||
seed = self._transformations_seeds_dict[type(aug)] | |||||
else: | |||||
# setting a seed | |||||
seed = np.random.randint(2147483647) # make a seed with numpy generator | |||||
self._transformations_seeds_dict[type(aug)] = seed | |||||
random.seed(seed) | |||||
np.random.seed(seed) | |||||
torch.manual_seed(seed) | |||||
if self._different_augmentation_per_batch: | |||||
ret_val = torch.stack([aug(inp[bi]) for bi in range(inp.shape[0])], dim=0) | |||||
else: | |||||
ret_val = aug(inp) | |||||
return ret_val | |||||
field_transformation = self._augmentations_dict[var_name] | |||||
if not isinstance(field_transformation, torch.nn.Sequential): | |||||
return run_single_aug(field_transformation, var_val) | |||||
else: | |||||
aug_val = var_val | |||||
for single_transform in field_transformation.children(): | |||||
aug_val = run_single_aug( | |||||
single_transform, | |||||
aug_val) | |||||
return aug_val | |||||
def create_empty_placeholder(device: torch.device) -> torch.Tensor: | |||||
""" Create empty placeholder with shape (1) in the given device. | |||||
Args: | |||||
device (torch.device): Device to create the placeholder on. | |||||
Returns: | |||||
torch.Tensor: Empty placeholder. | |||||
""" | |||||
return torch.zeros(1, dtype=torch.float32, | |||||
device=device, | |||||
requires_grad=False) |
""" | |||||
Runs flow of data, from loading to forwarding the model. | |||||
""" | |||||
from typing import Iterator, Union, TypeVar, Generic, Dict | |||||
import torch | |||||
from ..data.data_loader import DataLoader, create_empty_placeholder | |||||
from ..data.dataloader_context import DataloaderContext | |||||
from ..models.model import Model | |||||
TReturn = TypeVar('TReturn') | |||||
class DataFlow(Generic[TReturn]): | |||||
""" | |||||
Runs flow of data, from loading to forwarding the model. | |||||
""" | |||||
def __init__(self, model: 'Model', dataloader: 'DataLoader', | |||||
device: torch.device, print_debug_info:bool = False): | |||||
self._model = model | |||||
self._device = device | |||||
self._dataloader = dataloader | |||||
self._placeholders: Union[None, Dict[str, torch.Tensor]] = None | |||||
self._print_debug_info = print_debug_info | |||||
def iterate(self) -> Iterator[TReturn]: | |||||
""" | |||||
Iterates on data and forwards the model | |||||
""" | |||||
if self._placeholders is None: | |||||
raise Exception("Please use `with` statement before calling iterate method:\n" + | |||||
"with dataflow:\n" + | |||||
" for model_output in dataflow.iterate():\n" + | |||||
" pass\n") | |||||
DataloaderContext.instance.dataloader = self._dataloader | |||||
while True: | |||||
self._dataloader.prepare_next_batch() | |||||
if self._print_debug_info: | |||||
print('> Next Batch:') | |||||
print('\t' + '\n\t'.join(self._dataloader.get_current_batch_samples_names())) | |||||
# check if the iteration is done | |||||
if self._dataloader.finished_iteration(): | |||||
break | |||||
# Filling the placeholders for running the model | |||||
self._dataloader.fill_placeholders(self._placeholders) | |||||
# Running the model | |||||
yield self._final_stage(self._placeholders) | |||||
def __enter__(self): | |||||
# self._dataloader.reset() | |||||
self._placeholders = self._build_placeholders() | |||||
def __exit__(self, exc_type, exc_value, traceback): | |||||
for name in self._placeholders: | |||||
self._placeholders[name] = None | |||||
self._placeholders = None | |||||
def _final_stage(self, model_input: Dict[str, torch.Tensor]) -> TReturn: | |||||
return self._model(**model_input) | |||||
def _build_placeholders(self) -> Dict[str, torch.Tensor]: | |||||
""" Creates placeholders for feeding the model. | |||||
Returns a dictionary containing placeholders. | |||||
The placeholders are created based on the names | |||||
of the input variables of the model's forward method. | |||||
The variables initialized as None are assumed to be | |||||
labels used for training phase only, and won't be | |||||
added in phases other than train.""" | |||||
model_forward_args, args_default_values, additional_kwargs = self._model.get_forward_required_kws_kwargs_and_defaults() | |||||
# Checking which args have None as default | |||||
args_default_values = ['Mandatory' for _ in | |||||
range(len(model_forward_args) - len(args_default_values))] + \ | |||||
args_default_values | |||||
optional_args = [x is None for x in args_default_values] | |||||
placeholders = dict() | |||||
for argname in model_forward_args: | |||||
placeholders[argname] = create_empty_placeholder(self._device) | |||||
for argname in additional_kwargs: | |||||
placeholders[argname] = create_empty_placeholder(self._device) | |||||
return placeholders | |||||
from typing import TYPE_CHECKING | |||||
if TYPE_CHECKING: | |||||
from .data_loader import DataLoader | |||||
class classproperty(property): | |||||
def __get__(self, obj, objtype=None): | |||||
return super(classproperty, self).__get__(objtype) | |||||
def __set__(self, obj, value): | |||||
super(classproperty, self).__set__(type(obj), value) | |||||
def __delete__(self, obj): | |||||
super(classproperty, self).__delete__(type(obj)) | |||||
class DataloaderContext: | |||||
_instance: 'DataloaderContext' = None | |||||
def __init__(self) -> None: | |||||
self.dataloader: 'DataLoader' = None | |||||
@classproperty | |||||
def instance(cls) -> 'DataloaderContext': | |||||
if cls._instance is None: | |||||
cls._instance = DataloaderContext() | |||||
return cls._instance |
from typing import Type | |||||
from .entrypoint import BaseEntrypoint | |||||
from ..configs.base_config import PhaseType | |||||
from importlib import import_module | |||||
import sys | |||||
def get_entry(phase_type: PhaseType): | |||||
EntryPoint: Type[BaseEntrypoint] = import_module(f'torchlap.experiments.{sys.argv[1]}').EntryPoint | |||||
return EntryPoint(phase_type) |
from os import path | |||||
from typing import TYPE_CHECKING | |||||
import torch | |||||
from ..models.model import Model | |||||
if TYPE_CHECKING: | |||||
from ..configs.base_config import BaseConfig | |||||
def load_model(the_model: Model, conf: 'BaseConfig') -> Model: | |||||
dev_name = conf.dev_name | |||||
if conf.final_model_dir is not None: | |||||
load_dir = conf.final_model_dir | |||||
if not path.exists(load_dir): | |||||
raise Exception('No path %s' % load_dir) | |||||
else: | |||||
load_dir = conf.save_dir | |||||
print('>>> loading from: ' + load_dir, flush=True) | |||||
if not path.exists(load_dir): | |||||
raise Exception('Problem in loading the model. %s does not exist!' % load_dir) | |||||
map_location = None | |||||
if dev_name == 'cpu': | |||||
map_location = 'cpu' | |||||
elif dev_name is not None and ':' in dev_name: | |||||
map_location = dev_name | |||||
the_model.to(conf.device) | |||||
if path.isfile(load_dir): | |||||
if not path.exists(load_dir): | |||||
raise Exception('Problem in loading the model. %s does not exist!' % load_dir) | |||||
the_model.load_state_dict(torch.load(load_dir, map_location=map_location)) | |||||
print('Loaded the model at %s' % load_dir, flush=True) | |||||
else: | |||||
if conf.epoch is not None: | |||||
epoch = int(conf.epoch) | |||||
elif path.exists(load_dir + '/GeneralInfo'): | |||||
checkpoint = torch.load(load_dir + '/GeneralInfo', map_location=map_location) | |||||
epoch = checkpoint['best_val_epoch'] | |||||
print(f'Epoch has not been specified in the config, ' | |||||
f'using {epoch} which is best_val_epoch instead') | |||||
else: | |||||
raise Exception('Either epoch or pt dir (final_model) must be given as GeneralInfo was not found') | |||||
if not path.exists('%s/%d' % (load_dir, epoch)): | |||||
raise Exception('Problem in loading the model. %s/%d does not exist!' % | |||||
(load_dir, epoch)) | |||||
checkpoint = torch.load('%s/%d' % (load_dir, epoch), | |||||
map_location=map_location) | |||||
# backward compatibility | |||||
if 'model_state_dict' in checkpoint: | |||||
checkpoint = checkpoint['model_state_dict'] | |||||
the_model.load_state_dict(checkpoint) | |||||
print(f'Loaded the model from epoch {epoch} of {load_dir}', flush=True) | |||||
the_model.to(conf.device) | |||||
return the_model |
from typing import Tuple | |||||
from ...configs.celeba_configs import CelebAConfigs, CelebATag | |||||
from ...models.model import Model | |||||
from ..entrypoint import BaseEntrypoint | |||||
from ...models.celeba.single_tag_org_inception import CelebAORGInception | |||||
class EntryPoint(BaseEntrypoint): | |||||
def _get_conf_model(self) -> Tuple[CelebAConfigs, Model]: | |||||
conf = CelebAConfigs('CelebA_Inception_Org', 8, 299, self.phase_type) | |||||
conf.max_epochs = 12 | |||||
conf.main_tag = CelebATag.SmilingTag | |||||
conf.tags = [conf.main_tag.name] | |||||
model = CelebAORGInception(conf.main_tag.name, 0.4) | |||||
return conf, model | |||||
from typing import Tuple | |||||
from ...configs.celeba_configs import CelebAConfigs, CelebATag | |||||
from ...models.model import Model | |||||
from ..entrypoint import BaseEntrypoint | |||||
from ...models.celeba.org_resnet import CelebAORGResNet18 | |||||
class EntryPoint(BaseEntrypoint): | |||||
def _get_conf_model(self) -> Tuple[CelebAConfigs, Model]: | |||||
conf = CelebAConfigs('CelebA_ResNet_Org', 9, 224, self.phase_type) | |||||
conf.max_epochs = 12 | |||||
conf.main_tag = CelebATag.SmilingTag | |||||
conf.tags = [conf.main_tag.name] | |||||
model = CelebAORGResNet18(conf.main_tag.name) | |||||
return conf, model | |||||
from typing import Tuple | |||||
from collections import OrderedDict | |||||
from ...configs.celeba_configs import CelebAConfigs, CelebATag | |||||
from ...models.model import Model | |||||
from ..entrypoint import BaseEntrypoint | |||||
from ...models.celeba.lap_inception import CelebALAPInception | |||||
from ...modules.lap import LAP | |||||
from ...modules.adaptive_lap import AdaptiveLAP | |||||
from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator | |||||
from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss | |||||
from ...utils.aux_output import AuxOutput | |||||
from ...utils.output_modifier import OutputModifier | |||||
from ...model_evaluation.binary_evaluator import BinaryEvaluator | |||||
from ...model_evaluation.binary_fortelling import BinForetellerEvaluator | |||||
from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator | |||||
from ...model_evaluation.loss_evaluator import LossEvaluator | |||||
from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator | |||||
def lap_factory(channel): | |||||
return LAP( | |||||
channel, 2, 2, hidden_channels=[8], n_attention_heads=3, | |||||
sigmoid_scale=0.1, discriminative_attention=True) | |||||
def adaptive_lap_factory(channel): | |||||
return AdaptiveLAP(channel, [], 0.1, | |||||
n_attention_heads=3, discriminative_attention=True) | |||||
class EntryPoint(BaseEntrypoint): | |||||
def __init__(self, phase_type) -> None: | |||||
self.active_min_ratio: float = 0.02 | |||||
self.active_max_ratio: float = 0.4 | |||||
self.inactive_ratio: float = 0.01 | |||||
self.common_max_ratio = 0.02 | |||||
super().__init__(phase_type) | |||||
def _get_conf_model(self) -> Tuple[CelebAConfigs, Model]: | |||||
conf = CelebAConfigs('CelebA_Inception_WS', 10, 256, self.phase_type) | |||||
conf.max_epochs = 12 | |||||
conf.main_tag = CelebATag.SmilingTag | |||||
conf.tags = [conf.main_tag.name] | |||||
model = CelebALAPInception(conf.main_tag.name, 0.4, lap_factory, adaptive_lap_factory) | |||||
# aux loss for free head | |||||
fws_losses = [ | |||||
DiscriminativeWeaklySupervisedLoss( | |||||
title, model, att_score_layer, 0.025 / 3, 0, | |||||
(0, self.common_max_ratio), | |||||
{2: [0, 1]}, discr_score_layer, | |||||
w_attention_in_ordering=1, w_discr_in_ordering=0) | |||||
for title, att_score_layer, discr_score_layer in [ | |||||
('FPool2', model.maxpool2.attention_layer, model.maxpool2.discrimination_layer), | |||||
('FPool6', model.Mixed_6a.pool.attention_layer, model.Mixed_6a.pool.discrimination_layer), | |||||
('FPool7', model.Mixed_7a.pool.attention_layer, model.Mixed_7a.pool.discrimination_layer), | |||||
('FPoolG', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), | |||||
] | |||||
] | |||||
for fws_loss in fws_losses: | |||||
fws_loss.configure(conf) | |||||
# weakly supervised losses based on discriminative head and attention head | |||||
ws_losses = [ | |||||
DiscriminativeWeaklySupervisedLoss( | |||||
title, model, att_score_layer, 0.025 * 2 / 3, | |||||
self.inactive_ratio, | |||||
(self.active_min_ratio, self.active_max_ratio), {0: [0], 1: [1]}, | |||||
discr_score_layer, | |||||
w_attention_in_ordering=0.2, w_discr_in_ordering=1) | |||||
for title, att_score_layer, discr_score_layer in [ | |||||
('Pool2', model.maxpool2.attention_layer, model.maxpool2.discrimination_layer), | |||||
('Pool6', model.Mixed_6a.pool.attention_layer, model.Mixed_6a.pool.discrimination_layer), | |||||
('Pool7', model.Mixed_7a.pool.attention_layer, model.Mixed_7a.pool.discrimination_layer), | |||||
('PoolG', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), | |||||
] | |||||
] | |||||
for ws_loss in ws_losses: | |||||
ws_loss.configure(conf) | |||||
# concordance loss for attention | |||||
concordance_loss = PoolConcordanceLossCalculator( | |||||
'AC', model, OrderedDict([ | |||||
('att2', model.maxpool2.attention_layer), | |||||
('att6', model.Mixed_6a.pool.attention_layer), | |||||
('att7', model.Mixed_7a.pool.attention_layer), | |||||
('attG', model.avgpool[0].attention_layer), | |||||
]), loss_weight=1, weights=0.1 / 4, diff_thresholds=0, | |||||
labels_by_channel={0: [0], 1: [1]}) | |||||
concordance_loss.configure(conf) | |||||
# concordance loss for discrimination head | |||||
concordance_loss2 = PoolConcordanceLossCalculator( | |||||
'DC', model, OrderedDict([ | |||||
('D-att2', model.maxpool2.discrimination_layer), | |||||
('D-att6', model.Mixed_6a.pool.discrimination_layer), | |||||
('D-att7', model.Mixed_7a.pool.discrimination_layer), | |||||
('D-attG', model.avgpool[0].discrimination_layer), | |||||
]), loss_weight=1, weights=0.05 / 4, diff_thresholds=0, | |||||
labels_by_channel={0: [0], 1: [1]}) | |||||
concordance_loss2.configure(conf) | |||||
conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ | |||||
'b': BinaryEvaluator, | |||||
'l': LossEvaluator, | |||||
'f': BinForetellerEvaluator.standard_creator('foretell'), | |||||
'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), | |||||
})) | |||||
conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' | |||||
################################### | |||||
########### Foreteller ############ | |||||
################################### | |||||
aux = AuxOutput(model, dict( | |||||
foretell_pool2=model.maxpool2.attention_layer, | |||||
foretell_pool6=model.Mixed_6a.pool.attention_layer, | |||||
foretell_pool7=model.Mixed_7a.pool.attention_layer, | |||||
foretell_avgpool=model.avgpool[0].attention_layer, | |||||
)) | |||||
aux.configure(conf) | |||||
output_modifier = OutputModifier(model, | |||||
lambda x: (x - 0.5).relu().flatten(2).sum(dim=2)[:, :2].argmax(dim=1), | |||||
'foretell_pool2', | |||||
'foretell_pool6', | |||||
'foretell_pool7', | |||||
'foretell_avgpool', | |||||
) | |||||
output_modifier.configure(conf) | |||||
return conf, model | |||||
from typing import Tuple | |||||
from collections import OrderedDict | |||||
from ...configs.celeba_configs import CelebAConfigs, CelebATag | |||||
from ...models.model import Model | |||||
from ..entrypoint import BaseEntrypoint | |||||
from ...models.celeba.lap_resnet import CelebALAPResNet18 | |||||
from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator | |||||
from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss | |||||
from ...utils.aux_output import AuxOutput | |||||
from ...utils.output_modifier import OutputModifier | |||||
from ...model_evaluation.binary_evaluator import BinaryEvaluator | |||||
from ...model_evaluation.binary_fortelling import BinForetellerEvaluator | |||||
from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator | |||||
from ...model_evaluation.loss_evaluator import LossEvaluator | |||||
from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator | |||||
class EntryPoint(BaseEntrypoint): | |||||
def __init__(self, phase_type) -> None: | |||||
self.active_min_ratio: float = 0.02 | |||||
self.active_max_ratio: float = 0.4 | |||||
self.inactive_ratio: float = 0.01 | |||||
self.common_max_ratio = 0.02 | |||||
super().__init__(phase_type) | |||||
def _get_conf_model(self) -> Tuple[CelebAConfigs, Model]: | |||||
conf = CelebAConfigs('CelebA_ResNet_WS', 11, 224, self.phase_type) | |||||
conf.max_epochs = 12 | |||||
conf.main_tag = CelebATag.SmilingTag | |||||
conf.tags = [conf.main_tag.name] | |||||
model = CelebALAPResNet18(conf.main_tag.name, sigmoid_scale=0.1) | |||||
# aux loss for free head | |||||
fws_losses = [ | |||||
DiscriminativeWeaklySupervisedLoss( | |||||
title, model, att_score_layer, 0.025 / 3, 0, | |||||
(0, self.common_max_ratio), | |||||
{2: [0, 1]}, discr_score_layer, | |||||
w_attention_in_ordering=1, w_discr_in_ordering=0) | |||||
for title, att_score_layer, discr_score_layer in [ | |||||
('Fatt2', model.layer2[0].pool.attention_layer, model.layer2[0].pool.discrimination_layer), | |||||
('Fatt3', model.layer3[0].pool.attention_layer, model.layer3[0].pool.discrimination_layer), | |||||
('Fatt4', model.layer4[0].pool.attention_layer, model.layer4[0].pool.discrimination_layer), | |||||
('Fatt5', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), | |||||
] | |||||
] | |||||
for fws_loss in fws_losses: | |||||
fws_loss.configure(conf) | |||||
# weakly supervised losses based on discriminative head and attention head | |||||
ws_losses = [ | |||||
DiscriminativeWeaklySupervisedLoss( | |||||
title, model, att_score_layer, 0.025 * 2 / 3, | |||||
self.inactive_ratio, | |||||
(self.active_min_ratio, self.active_max_ratio), {0: [0], 1: [1]}, | |||||
discr_score_layer, | |||||
w_attention_in_ordering=0.2, w_discr_in_ordering=1) | |||||
for title, att_score_layer, discr_score_layer in [ | |||||
('att2', model.layer2[0].pool.attention_layer, model.layer2[0].pool.discrimination_layer), | |||||
('att3', model.layer3[0].pool.attention_layer, model.layer3[0].pool.discrimination_layer), | |||||
('att4', model.layer4[0].pool.attention_layer, model.layer4[0].pool.discrimination_layer), | |||||
('att5', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), | |||||
] | |||||
] | |||||
for ws_loss in ws_losses: | |||||
ws_loss.configure(conf) | |||||
# concordance loss for attention | |||||
concordance_loss = PoolConcordanceLossCalculator( | |||||
'AC', model, OrderedDict([ | |||||
('att2', model.layer2[0].pool.attention_layer), | |||||
('att3', model.layer3[0].pool.attention_layer), | |||||
('att4', model.layer4[0].pool.attention_layer), | |||||
('att5', model.avgpool[0].attention_layer), | |||||
]), loss_weight=1, weights=0.1 / 4, diff_thresholds=0, | |||||
labels_by_channel={0: [0], 1: [1]}) | |||||
concordance_loss.configure(conf) | |||||
# concordance loss for discrimination head | |||||
concordance_loss2 = PoolConcordanceLossCalculator( | |||||
'DC', model, OrderedDict([ | |||||
('att2', model.layer2[0].pool.discrimination_layer), | |||||
('att3', model.layer3[0].pool.discrimination_layer), | |||||
('att4', model.layer4[0].pool.discrimination_layer), | |||||
('att5', model.avgpool[0].discrimination_layer), | |||||
]), loss_weight=1, weights=0.05 / 4, diff_thresholds=0, | |||||
labels_by_channel={0: [0], 1: [1]}) | |||||
concordance_loss2.configure(conf) | |||||
conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ | |||||
'b': BinaryEvaluator, | |||||
'l': LossEvaluator, | |||||
'f': BinForetellerEvaluator.standard_creator('foretell'), | |||||
'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), | |||||
})) | |||||
conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' | |||||
################################### | |||||
########### Foreteller ############ | |||||
################################### | |||||
aux = AuxOutput(model, dict( | |||||
foretell_pool2=model.layer2[0].pool.attention_layer, | |||||
foretell_pool3=model.layer3[0].pool.attention_layer, | |||||
foretell_pool4=model.layer4[0].pool.attention_layer, | |||||
foretell_avgpool=model.avgpool[0].attention_layer, | |||||
)) | |||||
aux.configure(conf) | |||||
output_modifier = OutputModifier(model, | |||||
lambda x: (x - 0.5).relu().flatten(2).sum(dim=2)[:, :2].argmax(dim=1), | |||||
'foretell_pool2', | |||||
'foretell_pool3', | |||||
'foretell_pool4', | |||||
'foretell_avgpool', | |||||
) | |||||
output_modifier.configure(conf) | |||||
return conf, model | |||||
import sys | |||||
from typing import Tuple | |||||
import argparse | |||||
import os | |||||
from abc import ABC, abstractmethod | |||||
from ..models.model import Model | |||||
from ..configs.base_config import BaseConfig, PhaseType | |||||
class BaseEntrypoint(ABC): | |||||
"""Base class for all entrypoints. | |||||
""" | |||||
description = '' | |||||
def __init__(self, phase_type: PhaseType) -> None: | |||||
super().__init__() | |||||
self.phase_type = phase_type | |||||
self.conf, self.model = self._get_conf_model() | |||||
self.parser = self._create_parser(sys.argv[0], sys.argv[1]) | |||||
self.conf.update(self.parser.parse_args(sys.argv[2:])) | |||||
def _create_parser(self, command: str, entrypoint_name: str) -> argparse.ArgumentParser: | |||||
parser = argparse.ArgumentParser( | |||||
prog=f'{os.path.basename(command)} {entrypoint_name}', | |||||
description=self.description or None | |||||
) | |||||
parser.add_argument('--device', default='cpu', type=str, | |||||
help='Analysis will run over device, use cpu, cuda:#, cuda (for all gpus)') | |||||
parser.add_argument('--samples-dir', default=None, type=str, | |||||
help='The directory/dataGroup to be evaluated. In the case of dataGroup can be train/test/val. In the case of directory must contain 0 and 1 subdirectories. Use:FilterName to do over samples containing /FilterName/ in their path') | |||||
parser.add_argument('--final-model-dir', default=None, type=str, | |||||
help='The directory to load the model from, when not given will be calculated!') | |||||
parser.add_argument('--save-dir', default=None, type=str, | |||||
help='The directory to save the model from, when not given will be calculated!') | |||||
parser.add_argument('--report-dir', default=None, type=str, | |||||
help='The dir to save reports per slice per sample in.') | |||||
parser.add_argument('--epoch', default=None, type=str, | |||||
help='The epoch to load.') | |||||
parser.add_argument('--try-name', default=None, type=str, | |||||
help='The run name specifying what run is doing') | |||||
parser.add_argument('--try-num', default=None, type=int, | |||||
help='The try number to load') | |||||
parser.add_argument('--data-separation', default=None, type=str, | |||||
help='The data_separation to be used.') | |||||
parser.add_argument('--batch-size', default=None, type=int, | |||||
help='The batch size to be used.') | |||||
parser.add_argument('--pretrained-model-file', default=None, type=str, | |||||
help='Address of .pt pretrained model') | |||||
parser.add_argument('--max-epochs', default=None, type=int, | |||||
help='The maximum epochs of training!') | |||||
parser.add_argument('--big-batch-size', default=None, type=int, | |||||
help='The big batch size (iteration per optimization)!') | |||||
parser.add_argument('--iters-per-epoch', default=None, type=int, | |||||
help='The number of big batches per epoch!') | |||||
parser.add_argument('--interpretation-method', default=None, type=str, | |||||
help='The method used for interpreting the results!') | |||||
parser.add_argument('--cut-threshold', default=None, type=float, | |||||
help='The threshold for cutting interpretations!') | |||||
parser.add_argument('--global-threshold', action='store_true', | |||||
help='Whether the given cut threshold must be applied global or to the relative values!') | |||||
parser.add_argument('--dynamic-threshold', action='store_true', | |||||
help='Whether to use dynamic threshold in interpretation!') | |||||
parser.add_argument('--class-label-for-interpretation', default=None, type=int, | |||||
help='The class label we want to explain why it has been chosen. None means the decision of the model!') | |||||
parser.add_argument('--interpret-predictions-vs-gt', default='1', type=(lambda x: x == '1'), | |||||
help='If the class_label for_interpretation is None this will be considered. If 1, interpretations would be done for the predicted label, otherwise for the ground truth.') | |||||
parser.add_argument('--mapped-labels-to-use', default=None, type=(lambda x: [int(y) for y in x.split(',')]), | |||||
help='The labels to do the analysis on them only (comma separated), default is all the labels.') | |||||
parser.add_argument('--skip-overlay', action='store_true', default=None, | |||||
help='Passing this flag prevents the interpretation phase from storing overlay images.') | |||||
parser.add_argument('--skip-raw', action='store_true', default=None, | |||||
help='Passing this flag prevents the interpretation phase from storing the raw interpretation values.') | |||||
parser.add_argument('--overlay-only', action='store_true', | |||||
help='Passing this flag makes the interpretation phase to store just overlay images ' | |||||
'and not the `.npy` files.') | |||||
parser.add_argument('--save-by-file-name', action='store_true', | |||||
help='Saves sample-specific files by their file names only, not the whole path.') | |||||
parser.add_argument('--n-interpretation-samples', default=None, type=int, | |||||
help='The max number of samples to be interpreted.') | |||||
parser.add_argument('--interpretation-tag-to-evaluate', default=None, type=str, | |||||
help='The tag to be used as interpretation for evaluation phase, otherwise the first one will be used.') | |||||
return parser | |||||
@abstractmethod | |||||
def _get_conf_model(self) -> Tuple[BaseConfig, Model]: | |||||
pass |
from typing import Tuple | |||||
from torchlap.configs.imagenet_configs import ImagenetConfigs | |||||
from ...models.model import Model | |||||
from ..entrypoint import BaseEntrypoint | |||||
from ...models.imagenet.lap_resnet import ImagenetLAPResNet50 | |||||
class EntryPoint(BaseEntrypoint): | |||||
def _get_conf_model(self) -> Tuple[ImagenetConfigs, Model]: | |||||
config = ImagenetConfigs('ImagenetFT', 2, 224, self.phase_type) | |||||
model = ImagenetLAPResNet50(sigmoid_scale=0.1) | |||||
config.freezing_regexes = [ | |||||
r'(?!^layer4\..*$)(?!^fc\..*$)(^.*$)' | |||||
] | |||||
return config, model |
from typing import Tuple | |||||
from torchlap.configs.imagenet_configs import ImagenetConfigs | |||||
from ...models.model import Model | |||||
from ..entrypoint import BaseEntrypoint | |||||
from ...models.imagenet.lap_resnet import ImagenetLAPResNet50 | |||||
class EntryPoint(BaseEntrypoint): | |||||
def _get_conf_model(self) -> Tuple[ImagenetConfigs, Model]: | |||||
config = ImagenetConfigs('ImagenetNFT', 1, 224, self.phase_type) | |||||
model = ImagenetLAPResNet50(sigmoid_scale=0.1) | |||||
config.freezing_regexes = [ | |||||
r'(?!^.*\.pool\..*$)(^.*$)' | |||||
] | |||||
return config, model |
from collections import OrderedDict | |||||
from typing import Tuple | |||||
from ...configs.rsna_configs import RSNAConfigs | |||||
from ...models.model import Model | |||||
from ..entrypoint import BaseEntrypoint | |||||
from ...models.rsna.lap_inception import RSNALAPInception | |||||
from ...modules.lap import LAP | |||||
from ...modules.adaptive_lap import AdaptiveLAP | |||||
from ...criteria.bb_supervised import BBLoss | |||||
from ...utils.aux_output import AuxOutput | |||||
from ...utils.output_modifier import OutputModifier | |||||
from ...model_evaluation.binary_evaluator import BinaryEvaluator | |||||
from ...model_evaluation.binary_fortelling import BinForetellerEvaluator | |||||
from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator | |||||
from ...model_evaluation.loss_evaluator import LossEvaluator | |||||
from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator | |||||
def lap_factory(channel): | |||||
return LAP( | |||||
channel, 2, 2, hidden_channels=[8], | |||||
sigmoid_scale=0.1, discriminative_attention=False) | |||||
def adaptive_lap_factory(channel): | |||||
return AdaptiveLAP(channel, [], 0.1, discriminative_attention=False) | |||||
class EntryPoint(BaseEntrypoint): | |||||
def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: | |||||
conf = RSNAConfigs('RSNA_Inception_BB', 6, 256, self.phase_type) | |||||
model = RSNALAPInception(0.4, lap_factory, adaptive_lap_factory) | |||||
bb_losses = BBLoss( | |||||
'BB', model, { | |||||
'Pool2': model.maxpool2.attention_layer, | |||||
'Pool6': model.Mixed_6a.pool.attention_layer, | |||||
'Pool7': model.Mixed_7a.pool.attention_layer, | |||||
'PoolG': model.avgpool[0].attention_layer, | |||||
} | |||||
, 1, | |||||
'interpretations', 1, 0.5) | |||||
bb_losses.configure(conf) | |||||
conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ | |||||
'b': BinaryEvaluator, | |||||
'l': LossEvaluator, | |||||
'f': BinForetellerEvaluator.standard_creator('foretell'), | |||||
'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), | |||||
})) | |||||
conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' | |||||
################################### | |||||
########### Foreteller ############ | |||||
################################### | |||||
aux = AuxOutput(model, dict( | |||||
foretell_pool2=model.maxpool2.attention_layer, | |||||
foretell_pool6=model.Mixed_6a.pool.attention_layer, | |||||
foretell_pool7=model.Mixed_7a.pool.attention_layer, | |||||
foretell_avgpool=model.avgpool[0].attention_layer, | |||||
)) | |||||
aux.configure(conf) | |||||
output_modifier = OutputModifier(model, | |||||
lambda x: x.flatten(1).max(dim=-1).values, | |||||
'foretell_pool2', | |||||
'foretell_pool6', | |||||
'foretell_pool7', | |||||
'foretell_avgpool', | |||||
) | |||||
output_modifier.configure(conf) | |||||
return conf, model | |||||
from collections import OrderedDict | |||||
from typing import Tuple | |||||
from ...utils.output_modifier import OutputModifier | |||||
from ...utils.aux_output import AuxOutput | |||||
from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator | |||||
from ...model_evaluation.binary_evaluator import BinaryEvaluator | |||||
from ...model_evaluation.binary_fortelling import BinForetellerEvaluator | |||||
from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator | |||||
from ...model_evaluation.loss_evaluator import LossEvaluator | |||||
from ...criteria.bb_supervised import BBLoss | |||||
from ...configs.rsna_configs import RSNAConfigs | |||||
from ...models.model import Model | |||||
from ..entrypoint import BaseEntrypoint | |||||
from ...models.rsna.lap_resnet import RSNALAPResNet18, get_adaptive_pool_factory, get_pool_factory | |||||
class EntryPoint(BaseEntrypoint): | |||||
def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: | |||||
config = RSNAConfigs('RSNA_ResNet_BB', 1, 224, self.phase_type) | |||||
model = RSNALAPResNet18( | |||||
pool_factory=get_pool_factory(discriminative_attention=False), | |||||
adaptive_factory=get_adaptive_pool_factory(discriminative_attention=False), | |||||
sigmoid_scale=0.1, | |||||
) | |||||
bb_losses = BBLoss( | |||||
'BB', model, { | |||||
'att2': model.layer2[0].pool.attention_layer, | |||||
'att3': model.layer3[0].pool.attention_layer, | |||||
'att4': model.layer4[0].pool.attention_layer, | |||||
'att5': model.avgpool[0].attention_layer, | |||||
} | |||||
, 1, | |||||
'interpretations', 1, 0.5) | |||||
bb_losses.configure(config) | |||||
config.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ | |||||
'b': BinaryEvaluator, | |||||
'l': LossEvaluator, | |||||
'f': BinForetellerEvaluator.standard_creator('foretell'), | |||||
'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), | |||||
})) | |||||
config.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' | |||||
################################### | |||||
########### Foreteller ############ | |||||
################################### | |||||
aux = AuxOutput(model, dict( | |||||
foretell_pool2=model.layer2[0].pool.attention_layer, | |||||
foretell_pool3=model.layer3[0].pool.attention_layer, | |||||
foretell_pool4=model.layer4[0].pool.attention_layer, | |||||
foretell_avgpool=model.avgpool[0].attention_layer, | |||||
)) | |||||
aux.configure(config) | |||||
output_modifier = OutputModifier(model, | |||||
lambda x: x.flatten(1).max(dim=-1).values, | |||||
'foretell_pool2', | |||||
'foretell_pool3', | |||||
'foretell_pool4', | |||||
'foretell_avgpool', | |||||
) | |||||
output_modifier.configure(config) | |||||
return config, model |
from typing import Tuple | |||||
from ...configs.rsna_configs import RSNAConfigs | |||||
from ...models.model import Model | |||||
from ..entrypoint import BaseEntrypoint | |||||
from ...models.rsna.org_inception import RSNAORGInception | |||||
class EntryPoint(BaseEntrypoint): | |||||
def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: | |||||
config = RSNAConfigs('RSNA_Inception_Org', 2, 299, self.phase_type) | |||||
model = RSNAORGInception(0.4) | |||||
return config, model | |||||
from typing import Tuple | |||||
from ...configs.rsna_configs import RSNAConfigs | |||||
from ...models.model import Model | |||||
from ..entrypoint import BaseEntrypoint | |||||
from ...models.rsna.org_resnet import RSNAORGResNet18 | |||||
class EntryPoint(BaseEntrypoint): | |||||
def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: | |||||
config = RSNAConfigs('RSNA_ResNet_Org', 1, 224, self.phase_type) | |||||
model = RSNAORGResNet18() | |||||
return config, model |
from typing import Tuple | |||||
from collections import OrderedDict | |||||
from ...configs.rsna_configs import RSNAConfigs | |||||
from ...models.model import Model | |||||
from ..entrypoint import BaseEntrypoint | |||||
from ...models.rsna.lap_inception import RSNALAPInception | |||||
from ...modules.lap import LAP | |||||
from ...modules.adaptive_lap import AdaptiveLAP | |||||
from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator | |||||
from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss | |||||
from ...utils.aux_output import AuxOutput | |||||
from ...utils.output_modifier import OutputModifier | |||||
from ...model_evaluation.binary_evaluator import BinaryEvaluator | |||||
from ...model_evaluation.binary_fortelling import BinForetellerEvaluator | |||||
from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator | |||||
from ...model_evaluation.loss_evaluator import LossEvaluator | |||||
from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator | |||||
def lap_factory(channel): | |||||
return LAP( | |||||
channel, 2, 2, hidden_channels=[8], | |||||
sigmoid_scale=0.1, discriminative_attention=True) | |||||
def adaptive_lap_factory(channel): | |||||
return AdaptiveLAP(channel, sigmoid_scale=0.1, discriminative_attention=True) | |||||
class EntryPoint(BaseEntrypoint): | |||||
def __init__(self, phase_type) -> None: | |||||
self.active_min_ratio: float = 0.1 | |||||
self.active_max_ratio: float = 0.5 | |||||
self.inactive_ratio: float = 0.1 | |||||
super().__init__(phase_type) | |||||
def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: | |||||
conf = RSNAConfigs('RSNA_Inception_WS', 4, 256, self.phase_type) | |||||
model = RSNALAPInception(0.4, lap_factory, adaptive_lap_factory) | |||||
# weakly supervised losses based on discriminative head and attention head | |||||
ws_losses = [ | |||||
DiscriminativeWeaklySupervisedLoss( | |||||
title, model, att_score_layer, 0.25, | |||||
self.inactive_ratio, | |||||
(self.active_min_ratio, self.active_max_ratio), {0: [1]}, | |||||
discr_score_layer, | |||||
w_attention_in_ordering=0.2, w_discr_in_ordering=1) | |||||
for title, att_score_layer, discr_score_layer in [ | |||||
('Pool2', model.maxpool2.attention_layer, model.maxpool2.discrimination_layer), | |||||
('Pool6', model.Mixed_6a.pool.attention_layer, model.Mixed_6a.pool.discrimination_layer), | |||||
('Pool7', model.Mixed_7a.pool.attention_layer, model.Mixed_7a.pool.discrimination_layer), | |||||
('PoolG', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), | |||||
] | |||||
] | |||||
for ws_loss in ws_losses: | |||||
ws_loss.configure(conf) | |||||
# concordance loss for attention | |||||
concordance_loss = PoolConcordanceLossCalculator( | |||||
'AC', model, OrderedDict([ | |||||
('att2', model.maxpool2.attention_layer), | |||||
('att6', model.Mixed_6a.pool.attention_layer), | |||||
('att7', model.Mixed_7a.pool.attention_layer), | |||||
('attG', model.avgpool[0].attention_layer), | |||||
]), loss_weight=1, weights=1 / 4, diff_thresholds=0.1, | |||||
labels_by_channel={0: [1]}) | |||||
concordance_loss.configure(conf) | |||||
# concordance loss for discrimination head | |||||
concordance_loss2 = PoolConcordanceLossCalculator( | |||||
'DC', model, OrderedDict([ | |||||
('D-att2', model.maxpool2.discrimination_layer), | |||||
('D-att6', model.Mixed_6a.pool.discrimination_layer), | |||||
('D-att7', model.Mixed_7a.pool.discrimination_layer), | |||||
('D-attG', model.avgpool[0].discrimination_layer), | |||||
]), loss_weight=1, weights=0.5 / 4, diff_thresholds=0, | |||||
labels_by_channel={0: [1]}) | |||||
concordance_loss2.configure(conf) | |||||
conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ | |||||
'b': BinaryEvaluator, | |||||
'l': LossEvaluator, | |||||
'f': BinForetellerEvaluator.standard_creator('foretell'), | |||||
'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), | |||||
})) | |||||
conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' | |||||
################################### | |||||
########### Foreteller ############ | |||||
################################### | |||||
aux = AuxOutput(model, dict( | |||||
foretell_pool2=model.maxpool2.attention_layer, | |||||
foretell_pool6=model.Mixed_6a.pool.attention_layer, | |||||
foretell_pool7=model.Mixed_7a.pool.attention_layer, | |||||
foretell_avgpool=model.avgpool[0].attention_layer, | |||||
)) | |||||
aux.configure(conf) | |||||
output_modifier = OutputModifier(model, | |||||
lambda x: x.flatten(1).max(dim=-1).values, | |||||
'foretell_pool2', | |||||
'foretell_pool6', | |||||
'foretell_pool7', | |||||
'foretell_avgpool', | |||||
) | |||||
output_modifier.configure(conf) | |||||
return conf, model | |||||
from collections import OrderedDict | |||||
from typing import Tuple | |||||
from ...utils.output_modifier import OutputModifier | |||||
from ...utils.aux_output import AuxOutput | |||||
from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator | |||||
from ...model_evaluation.binary_evaluator import BinaryEvaluator | |||||
from ...model_evaluation.binary_fortelling import BinForetellerEvaluator | |||||
from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator | |||||
from ...model_evaluation.loss_evaluator import LossEvaluator | |||||
from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator | |||||
from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss | |||||
from ...configs.rsna_configs import RSNAConfigs | |||||
from ...models.model import Model | |||||
from ..entrypoint import BaseEntrypoint | |||||
from ...models.rsna.lap_resnet import RSNALAPResNet18 | |||||
class EntryPoint(BaseEntrypoint): | |||||
def __init__(self, phase_type) -> None: | |||||
self.active_min_ratio: float = 0.1 | |||||
self.active_max_ratio: float = 0.5 | |||||
self.inactive_ratio: float = 0.1 | |||||
super().__init__(phase_type) | |||||
def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: | |||||
config = RSNAConfigs('RSNA_ResNet_WS', 1, 224, self.phase_type) | |||||
model = RSNALAPResNet18() | |||||
# weakly supervised losses based on discriminative head and attention head | |||||
ws_losses = [ | |||||
DiscriminativeWeaklySupervisedLoss( | |||||
title, model, att_score_layer, 0.25, | |||||
self.inactive_ratio, | |||||
(self.active_min_ratio, self.active_max_ratio), {0: [1]}, | |||||
discr_score_layer, | |||||
w_attention_in_ordering=0.2, w_discr_in_ordering=1) | |||||
for title, att_score_layer, discr_score_layer in [ | |||||
('att2', model.layer2[0].pool.attention_layer, model.layer2[0].pool.discrimination_layer), | |||||
('att3', model.layer3[0].pool.attention_layer, model.layer3[0].pool.discrimination_layer), | |||||
('att4', model.layer4[0].pool.attention_layer, model.layer4[0].pool.discrimination_layer), | |||||
('att5', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), | |||||
] | |||||
] | |||||
for ws_loss in ws_losses: | |||||
ws_loss.configure(config) | |||||
# concordance loss for attention | |||||
concordance_loss = PoolConcordanceLossCalculator( | |||||
'AC', model, OrderedDict([ | |||||
('att2', model.layer2[0].pool.attention_layer), | |||||
('att3', model.layer3[0].pool.attention_layer), | |||||
('att4', model.layer4[0].pool.attention_layer), | |||||
('att5', model.avgpool[0].attention_layer), | |||||
]), loss_weight=1, weights=1 / 4, diff_thresholds=0.1, | |||||
labels_by_channel={0: [1]}) | |||||
concordance_loss.configure(config) | |||||
# concordance loss for discrimination head | |||||
concordance_loss2 = PoolConcordanceLossCalculator( | |||||
'DC', model, OrderedDict([ | |||||
('att2', model.layer2[0].pool.discrimination_layer), | |||||
('att3', model.layer3[0].pool.discrimination_layer), | |||||
('att4', model.layer4[0].pool.discrimination_layer), | |||||
('att5', model.avgpool[0].discrimination_layer), | |||||
]), loss_weight=1, weights=0.5 / 4, diff_thresholds=0, | |||||
labels_by_channel={0: [1]}) | |||||
concordance_loss2.configure(config) | |||||
config.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ | |||||
'b': BinaryEvaluator, | |||||
'l': LossEvaluator, | |||||
'f': BinForetellerEvaluator.standard_creator('foretell'), | |||||
'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), | |||||
})) | |||||
config.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' | |||||
################################### | |||||
########### Foreteller ############ | |||||
################################### | |||||
aux = AuxOutput(model, dict( | |||||
foretell_pool2=model.layer2[0].pool.attention_layer, | |||||
foretell_pool3=model.layer3[0].pool.attention_layer, | |||||
foretell_pool4=model.layer4[0].pool.attention_layer, | |||||
foretell_avgpool=model.avgpool[0].attention_layer, | |||||
)) | |||||
aux.configure(config) | |||||
output_modifier = OutputModifier(model, | |||||
lambda x: x.flatten(1).max(dim=-1).values, | |||||
'foretell_pool2', | |||||
'foretell_pool3', | |||||
'foretell_pool4', | |||||
'foretell_avgpool', | |||||
) | |||||
output_modifier.configure(config) | |||||
return config, model |
""" | |||||
Interpretation Modules | |||||
""" | |||||
from .interpretable import InterpretableModel, CamInterpretableModel | |||||
from .interpreter import Interpreter | |||||
from .interpretable_wrapper import InterpretableWrapper | |||||
from .interpreter_maker import create_interpreter, InterpretationType |
from abc import ABC, abstractmethod | |||||
from collections import OrderedDict as OrdDict | |||||
from typing import Dict, List, Tuple, Union | |||||
import torch | |||||
from torch import nn | |||||
import torch.nn.functional as F | |||||
from .interpreter import Interpreter | |||||
from .interpretable import InterpretableModel | |||||
from ..modules import LAP | |||||
from ..utils.hooker import Hooker, Hook | |||||
AttentionDict = Dict[torch.Size, torch.Tensor] | |||||
class AttentionInterpretableModel(InterpretableModel, ABC): | |||||
@property | |||||
@abstractmethod | |||||
def attention_layers(self) -> Dict[str, List[LAP]]: | |||||
""" | |||||
List of attention groups | |||||
""" | |||||
class AttentionInterpreter(Interpreter): | |||||
def __init__(self, model: AttentionInterpretableModel): | |||||
assert isinstance(model, AttentionInterpretableModel),\ | |||||
f"Expected AttentionInterpretableModel, got {type(model)}" | |||||
super().__init__(model) | |||||
layer_hooks = sum([ | |||||
[ | |||||
(layer.attention_layer, self._generate_forward_hook(group_name, layer)) | |||||
for layer in group_layers | |||||
] for group_name, group_layers in model.attention_layers.items() | |||||
], []) | |||||
self._hooker = Hooker(*layer_hooks) | |||||
self._attention_groups: Dict[str, AttentionDict] = { | |||||
group_name: {} for group_name in model.attention_layers} | |||||
self._impact = 1. | |||||
self._impact_decay = .8 | |||||
self._input_size: torch.Size = None | |||||
self._device = None | |||||
def _generate_forward_hook(self, group_name: str, layer: LAP) -> Hook: | |||||
def forward_hook(_: nn.Module, __: Tuple[torch.Tensor, ...], out: torch.Tensor) -> None: | |||||
processed = out.clone() | |||||
shape = processed.shape[2:] | |||||
if shape not in self._attention_groups[group_name]: | |||||
self._attention_groups[group_name][shape] = torch.zeros_like( | |||||
processed) | |||||
self._attention_groups[group_name][shape] += processed | |||||
return forward_hook | |||||
def _reset(self) -> None: | |||||
self._attention_groups = { | |||||
group_name: {} for group_name in self._model.attention_layers} | |||||
self._impact = 1. | |||||
def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor: | |||||
""" Process current attention level. | |||||
Args: | |||||
result (torch.Tensor): Attention map till the previous attention level. (B, 1, H1, W1) | |||||
attention (torch.Tensor): Current level's attention map. (B, 1, H2, W2) | |||||
Returns: | |||||
torch.Tensor: New attention map. (B, 1, H2, W2) | |||||
""" | |||||
self._impact *= self._impact_decay | |||||
# smallest attention layer | |||||
if result is None: | |||||
return attention | |||||
# attention is larger than result | |||||
scale = torch.ceil( | |||||
torch.tensor(attention.shape[2:]) / torch.tensor(result.shape[2:]) | |||||
).int().tolist() | |||||
# find maximum of attention in each kernel | |||||
max_map = F.max_pool2d(attention, kernel_size=scale, stride=scale) | |||||
max_map = F.interpolate(max_map, attention.shape[2:], mode='nearest') | |||||
is_any_active = max_map > 0 | |||||
# interpolate result to attention size | |||||
result = F.interpolate(result, attention.shape[2:], mode='nearest') | |||||
# # maximum between result and attention | |||||
# impacted = torch.max(result, max_map * self._impact) | |||||
impacted = torch.max(result, attention * self._impact) | |||||
# when attention is zero, we use zero | |||||
impacted[attention == 0] = 0 | |||||
# impacted = torch.where(attention > 0, impacted, 0) | |||||
# when result is non-zero, we will use impacted | |||||
impacted[result == 0] = 0 | |||||
# impacted = torch.where(result > 0, impacted, 0) | |||||
# if max_map is zero, we use result, else we use impacted | |||||
result[is_any_active] = impacted[is_any_active] | |||||
# result = torch.where(is_any_active, impacted, result) | |||||
# mask = (max_map > 0) & (result > 0) | |||||
# result[mask] = impacted[mask] | |||||
# result[mask & (attention == 0)] = 0 | |||||
return result | |||||
def _process_attentions(self) -> Dict[str, torch.Tensor]: | |||||
sorted_attentions = { | |||||
group_name: OrdDict(sorted(group_attention.items(), | |||||
key=lambda pair: pair[0])) | |||||
for group_name, group_attention in self._attention_groups.items() | |||||
} | |||||
results: Dict[str, torch.Tensor] = {group_name: None for group_name in sorted_attentions} | |||||
for group_name, group_attention in sorted_attentions.items(): | |||||
self._impact = 1.0 | |||||
# from smallest to largest | |||||
for attention in group_attention.values(): | |||||
attention = (attention - 0.5).relu() | |||||
results[group_name] = self._process_attention(results[group_name], attention) | |||||
interpretations = {} | |||||
for group_name, group_result in results.items(): | |||||
group_result = F.interpolate(group_result, self._input_size, mode='bilinear', align_corners=True) | |||||
interpretations[group_name] = group_result | |||||
''' | |||||
if group_result.shape[1] == 1: | |||||
interpretations[group_name] = group_result | |||||
else: | |||||
interpretations.update({ | |||||
f"{group_name}_{i}": group_result[:, i:i+1] | |||||
for i in range(group_result.shape[1]) | |||||
}) | |||||
''' | |||||
return interpretations | |||||
def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
self._reset() | |||||
self._batch_size = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[0] | |||||
self._input_size = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[2:] | |||||
self._device = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].device | |||||
with self._hooker: | |||||
self._model(**inputs) | |||||
return self._process_attentions() |
from typing import Dict | |||||
from collections import OrderedDict as OrdDict | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from .attention_interpreter import AttentionInterpreter | |||||
AttentionDict = Dict[torch.Size, torch.Tensor] | |||||
class AttentionInterpreterSmoothIntegrator(AttentionInterpreter): | |||||
def _process_attentions(self) -> Dict[str, torch.Tensor]: | |||||
sorted_attentions = { | |||||
group_name: OrdDict(sorted(group_attention.items(), | |||||
key=lambda pair: pair[0])) | |||||
for group_name, group_attention in self._attention_groups.items() | |||||
} | |||||
results: Dict[str, torch.Tensor] = {group_name: None for group_name in sorted_attentions} | |||||
for group_name, group_attention in sorted_attentions.items(): | |||||
self._impact = 1.0 | |||||
# from smallest to largest | |||||
sum_ = 0 | |||||
for attention in group_attention.values(): | |||||
if torch.is_tensor(sum_): | |||||
sum_ = F.interpolate(sum_, size=attention.shape[-2:]) | |||||
sum_ += attention | |||||
attention = (attention - 0.5).relu() | |||||
results[group_name] = self._process_attention(results[group_name], attention) | |||||
results[group_name] = torch.where(results[group_name] > 0, sum_, sum_ / (2 * len(group_attention))) | |||||
interpretations = {} | |||||
for group_name, group_result in results.items(): | |||||
group_result = F.interpolate(group_result, self._input_size, mode='bilinear', align_corners=True) | |||||
interpretations[group_name] = group_result | |||||
''' | |||||
if group_result.shape[1] == 1: | |||||
interpretations[group_name] = group_result | |||||
else: | |||||
interpretations.update({ | |||||
f"{group_name}_{i}": group_result[:, i:i+1] | |||||
for i in range(group_result.shape[1]) | |||||
}) | |||||
''' | |||||
return interpretations |
import torch | |||||
import torch.nn.functional as F | |||||
from .attention_interpreter import AttentionInterpreter, AttentionInterpretableModel | |||||
class AttentionSumInterpreter(AttentionInterpreter): | |||||
def __init__(self, model: AttentionInterpretableModel): | |||||
assert isinstance(model, AttentionInterpretableModel),\ | |||||
f"Expected AttentionInterpretableModel, got {type(model)}" | |||||
super().__init__(model) | |||||
def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor: | |||||
""" Process current attention level. | |||||
Args: | |||||
result (torch.Tensor): Attention map till the previous attention level. (B, 1, H1, W1) | |||||
attention (torch.Tensor): Current level's attention map. (B, 1, H2, W2) | |||||
Returns: | |||||
torch.Tensor: New attention map. (B, 1, H2, W2) | |||||
""" | |||||
self._impact *= self._impact_decay | |||||
# smallest attention layer | |||||
if result is None: | |||||
return attention | |||||
# attention is larger than result | |||||
result = F.interpolate(result, attention.shape[2:], mode='bilinear', align_corners=True) | |||||
return result + attention |
from typing import List, TYPE_CHECKING, Tuple | |||||
import numpy as np | |||||
from skimage import measure | |||||
from .utils import process_interpretations | |||||
if TYPE_CHECKING: | |||||
from ..configs.base_config import BaseConfig | |||||
from . import Interpreter | |||||
class BinaryInterpretationEvaluator2D: | |||||
def __init__(self, n_samples: int, config: 'BaseConfig'): | |||||
self._config = config | |||||
self._n_samples = n_samples | |||||
self._min_intersection_threshold = config.acceptable_min_intersection_threshold | |||||
self._normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||||
self._normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||||
self._tk_normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||||
self._tk_normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||||
def reset(self): | |||||
self._normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||||
self._normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||||
self._tk_normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||||
self._tk_normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||||
def update_summaries( | |||||
self, | |||||
m_interpretations: np.ndarray, ground_truth_interpretations: np.ndarray, | |||||
net_preds: np.ndarray, ground_truth_labels: np.ndarray, | |||||
batch_inds: np.ndarray, interpreter: 'Interpreter' | |||||
) -> None: | |||||
assert len(ground_truth_interpretations.shape) == 3, f'GT interpretations must have a shape of BxWxH but it is {ground_truth_interpretations.shape}' | |||||
assert len(m_interpretations.shape) == 3, f'Model interpretations must have a shape of BxWxH but it is {m_interpretations.shape}' | |||||
# skipping samples without interpretations! | |||||
has_interpretation_mask = np.logical_not(np.any(np.isnan(ground_truth_interpretations), axis=(1, 2))) | |||||
m_interpretations = m_interpretations[has_interpretation_mask] | |||||
ground_truths = ground_truth_interpretations[has_interpretation_mask] | |||||
net_preds[has_interpretation_mask] | |||||
# finding class labels | |||||
if len(net_preds.shape) == 1: | |||||
net_preds = (net_preds >= 0.5).astype(int) | |||||
elif net_preds.shape[1] == 1: | |||||
net_preds = (net_preds[:, 0] >= 0.5).astype(int) | |||||
else: | |||||
net_preds = net_preds.argmax(axis=-1) | |||||
ground_truth_labels = ground_truth_labels[has_interpretation_mask] | |||||
batch_inds = batch_inds[has_interpretation_mask] | |||||
# Checking shapes | |||||
if net_preds.shape == ground_truth_labels.shape: | |||||
net_preds = np.round(net_preds, 0).astype(int) | |||||
else: | |||||
net_preds = np.argmax(net_preds, axis=1) | |||||
# calculating soundness | |||||
soundnesses = (net_preds == ground_truth_labels).astype(int) | |||||
c_interpretations = np.clip(m_interpretations, 0, np.amax(m_interpretations)) | |||||
b_interpretations = np.stack(tuple([ | |||||
process_interpretations(m_interpretations[ind][None, ...], self._config, interpreter) | |||||
for ind in range(len(m_interpretations)) | |||||
]), axis=0)[:, 0, ...] | |||||
b_interpretations = (b_interpretations > 0).astype(bool) | |||||
ground_truths = (ground_truths >= 0.5).astype(bool) #making sure values are 0 and 1 even if resize has been applied | |||||
assert ground_truths.shape[-2:] == b_interpretations.shape[-2:], f'Ground truth and model interpretations must have the same shape, found {ground_truths.shape[-2:]} and {b_interpretations.shape[-2:]}' | |||||
norm_factor = 1.0 * b_interpretations.shape[1] * b_interpretations.shape[2] | |||||
np.add.at(self._normed_intersection_per_soundness, soundnesses, | |||||
np.sum(b_interpretations & ground_truths, axis=(1, 2)) * 1.0 / norm_factor) | |||||
np.add.at(self._normed_union_per_soundness, soundnesses, | |||||
np.sum(b_interpretations | ground_truths, axis=(1, 2)) * 1.0 / norm_factor) | |||||
for i in range(len(b_interpretations)): | |||||
has_nonzero_captured_bbs = False | |||||
has_nonzero_captured_bbs_by_topk = False | |||||
s = soundnesses[i] | |||||
org_labels = measure.label(ground_truths[i, :, :]) | |||||
check_labels = measure.label(b_interpretations[i, :, :]) | |||||
# keeping topK interpretations with k = n_GT! = finding a threshold by quantile! calculating quantile by GT | |||||
n_on_gt = np.sum(ground_truths[i]) | |||||
q = (1 + n_on_gt) * 1.0 / (ground_truths.shape[-1] * ground_truths.shape[-2]) | |||||
# 1 is added because we have > in thresholding not >= | |||||
if q < 1: | |||||
tints = c_interpretations[i] | |||||
th = max(0, np.quantile(tints.reshape(-1), 1 - q)) | |||||
tints = (tints > th) | |||||
else: | |||||
tints = (c_interpretations[i] > 0) | |||||
# TOPK METRICS | |||||
tk_intersection = np.sum(tints & ground_truths[i]) | |||||
tk_union = np.sum(tints | ground_truths[i]) | |||||
self._tk_normed_intersection_per_soundness[s] += tk_intersection * 1.0 / norm_factor | |||||
self._tk_normed_union_per_soundness[s] += tk_union * 1.0 / norm_factor | |||||
@staticmethod | |||||
def get_titles_of_evaluation_metrics() -> List[str]: | |||||
return ['S-IOU', 'S-TK-IOU', | |||||
'M-IOU', 'M-TK-IOU', | |||||
'A-IOU', 'A-TK-IOU'] | |||||
@staticmethod | |||||
def _get_eval_metrics(normed_intersection, normed_union, title, | |||||
tk_normed_intersection, tk_normed_union) -> \ | |||||
Tuple[str, str, str, str, str]: | |||||
iou = (1e-6 + normed_intersection) / (1e-6 + normed_union) | |||||
tk_iou = (1e-6 + tk_normed_intersection) / (1e-6 + tk_normed_union) | |||||
return '%.4f' % (iou * 100,), '%.4f' % (tk_iou * 100,) | |||||
def get_values_of_evaluation_metrics(self) -> List[str]: | |||||
return \ | |||||
list(self._get_eval_metrics( | |||||
self._normed_intersection_per_soundness[1], | |||||
self._normed_union_per_soundness[1], | |||||
'Sounds', | |||||
self._tk_normed_intersection_per_soundness[1], | |||||
self._tk_normed_union_per_soundness[1], | |||||
)) + \ | |||||
list(self._get_eval_metrics( | |||||
self._normed_intersection_per_soundness[0], | |||||
self._normed_union_per_soundness[0], | |||||
'Mistakes', | |||||
self._tk_normed_intersection_per_soundness[0], | |||||
self._tk_normed_union_per_soundness[0], | |||||
)) + \ | |||||
list(self._get_eval_metrics( | |||||
sum(self._normed_intersection_per_soundness), | |||||
sum(self._normed_union_per_soundness), | |||||
'All', | |||||
sum(self._tk_normed_intersection_per_soundness), | |||||
sum(self._tk_normed_union_per_soundness), | |||||
)) | |||||
def print_summaries(self): | |||||
titles = self.get_titles_of_evaluation_metrics() | |||||
vals = self.get_values_of_evaluation_metrics() | |||||
nc = 2 | |||||
for r in range(3): | |||||
print(', '.join(['%s: %s' % (titles[nc * r + i], vals[nc * r + i]) for i in range(nc)])) |
from typing import Dict, Union, Tuple | |||||
import torch | |||||
from . import Interpreter | |||||
from ..data.dataflow import DataFlow | |||||
from ..models.model import ModelIO | |||||
from ..data.data_loader import DataLoader | |||||
class InterpretationDataFlow(DataFlow[Tuple[ModelIO, Dict[str, torch.Tensor]]]): | |||||
def __init__(self, interpreter: Interpreter, | |||||
dataloader: DataLoader, | |||||
phase: str, | |||||
device: torch.device, | |||||
dtype: torch.dtype, | |||||
print_debug_info: bool, | |||||
label_for_interpreter: Union[int, None], | |||||
give_gt_as_label: bool): | |||||
super().__init__(interpreter._model, dataloader, phase, device, | |||||
dtype, print_debug_info=print_debug_info) | |||||
self._interpreter = interpreter | |||||
self._label_for_interpreter = label_for_interpreter | |||||
self._give_gt_as_label = give_gt_as_label | |||||
self._sample_labels = dataloader.get_samples_labels() | |||||
@staticmethod | |||||
def _detach_output(output: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |||||
return { | |||||
name: o.detach() for name, o in output.items() | |||||
} | |||||
def _final_stage(self, model_input: ModelIO) -> Tuple[ModelIO, Dict[str, torch.Tensor]]: | |||||
if self._give_gt_as_label: | |||||
return model_input, self._detach_output(self._interpreter.interpret( | |||||
torch.Tensor(self._sample_labels[ | |||||
self._dataloader.get_current_batch_sample_indices()]), | |||||
**model_input)) | |||||
return model_input, self._detach_output(self._interpreter.interpret( | |||||
self._label_for_interpreter, **model_input | |||||
)) |
""" | |||||
Guided Grad Cam Interpretation | |||||
""" | |||||
from typing import Dict, Union | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from captum import attr | |||||
from . import Interpreter, InterpretableModel, InterpretableWrapper | |||||
class DeepLift(Interpreter): # pylint: disable=too-few-public-methods | |||||
""" | |||||
Produces class activation map | |||||
""" | |||||
def __init__(self, model: InterpretableModel): | |||||
super().__init__(model) | |||||
self._model_wrapper = InterpretableWrapper(model) | |||||
self._interpreter = attr.DeepLift(self._model_wrapper) | |||||
self.__B = None | |||||
def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
''' | |||||
interprets given input and target class | |||||
''' | |||||
B = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[0] | |||||
if self.__B is None: | |||||
self.__B = B | |||||
if B < self.__B: | |||||
padding = self.__B - B | |||||
inputs = { | |||||
k: F.pad(v, (*([0] * (v.ndim * 2 - 1)), padding)) | |||||
for k, v in inputs.items() | |||||
if torch.is_tensor(v) and v.shape[0] == B | |||||
} | |||||
if torch.is_tensor(labels) and labels.shape[0] == B: | |||||
labels = F.pad(labels, (0, padding)) | |||||
labels = self._get_target(labels, **inputs) | |||||
separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs) | |||||
result = self._interpreter.attribute( | |||||
separated_inputs.inputs, | |||||
target=labels, | |||||
additional_forward_args=separated_inputs.additional_inputs)[0] | |||||
result = result[:B] | |||||
return { | |||||
'default': result | |||||
} # TODO: must return and support tuple |
from typing import Dict, Union | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from captum import attr | |||||
from . import Interpreter, CamInterpretableModel, InterpretableWrapper | |||||
class GradCam(Interpreter): | |||||
def __init__(self, model: CamInterpretableModel): | |||||
super().__init__(model) | |||||
self.__model_wrapper = InterpretableWrapper(model) | |||||
self.__interpreters = [ | |||||
attr.LayerGradCam(self.__model_wrapper, conv) | |||||
for conv in model.target_conv_layers | |||||
] | |||||
def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
inp_shape = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[2:] | |||||
labels = self._get_target(labels, **inputs) | |||||
separated_inputs = self.__model_wrapper.convert_inputs_to_args(**inputs) | |||||
gradcams = [interpreter.attribute( | |||||
separated_inputs.inputs, | |||||
target=labels, | |||||
additional_forward_args=separated_inputs.additional_inputs) | |||||
for interpreter in self.__interpreters] | |||||
gradcams = [ | |||||
cam if torch.is_tensor(cam) else cam[0] | |||||
for cam in gradcams | |||||
] | |||||
gradcams = torch.stack(gradcams).sum(dim=0) | |||||
gradcams = F.interpolate(gradcams, size=inp_shape, mode='bilinear', align_corners=True) | |||||
return { | |||||
'default': gradcams, | |||||
} # TODO: must return and support tuple |
""" | |||||
Guided Backprop Interpretation | |||||
""" | |||||
from typing import Dict, Union | |||||
import torch | |||||
from captum import attr | |||||
from . import Interpreter, InterpretableModel, InterpretableWrapper | |||||
class GuidedBackprop(Interpreter): # pylint: disable=too-few-public-methods | |||||
""" | |||||
Produces class activation map | |||||
""" | |||||
def __init__(self, model: InterpretableModel): | |||||
super().__init__(model) | |||||
self._model_wrapper = InterpretableWrapper(model) | |||||
self._interpreter = attr.GuidedBackprop(self._model_wrapper) | |||||
def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
''' | |||||
interprets given input and target class | |||||
''' | |||||
labels = self._get_target(labels, **inputs) | |||||
separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs) | |||||
return { | |||||
'default': self._interpreter.attribute( | |||||
separated_inputs.inputs, | |||||
target=labels, | |||||
additional_forward_args=separated_inputs.additional_inputs)[0] | |||||
} # TODO: must return and support tuple |
""" | |||||
Guided Grad Cam Interpretation | |||||
""" | |||||
from typing import Dict, Union | |||||
import torch | |||||
from captum import attr | |||||
from . import Interpreter, CamInterpretableModel, InterpretableWrapper | |||||
class GuidedGradCam(Interpreter): | |||||
""" Produces class activation map """ | |||||
def __init__(self, model: CamInterpretableModel): | |||||
super().__init__(model) | |||||
self._model_wrapper = InterpretableWrapper(model) | |||||
self._interpreters = [ | |||||
attr.GuidedGradCam(self._model_wrapper, conv) | |||||
for conv in model.target_conv_layers | |||||
] | |||||
def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
""" Interprets given input and target class | |||||
Args: | |||||
y (Union[int, torch.Tensor]): target class | |||||
**inputs (torch.Tensor): model inputs | |||||
Returns: | |||||
Dict[str, torch.Tensor]: Interpretation results | |||||
""" | |||||
labels = self._get_target(labels, **inputs) | |||||
separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs) | |||||
gradcams = [interpreter.attribute( | |||||
separated_inputs.inputs, | |||||
target=labels, | |||||
additional_forward_args=separated_inputs.additional_inputs)[0] | |||||
for interpreter in self._interpreters] | |||||
gradcams = torch.stack(gradcams).sum(dim=0) | |||||
return { | |||||
'default': gradcams, | |||||
} # TODO: must return and support tuple |
from typing import Callable, Type | |||||
import torch | |||||
from ..models.model import ModelIO | |||||
from ..utils.hooker import Hook, Hooker | |||||
from ..data.data_loader import DataLoader | |||||
from ..data.dataloader_context import DataloaderContext | |||||
from .interpreter import Interpreter | |||||
from .attention_interpreter import AttentionInterpretableModel, AttentionInterpreter | |||||
class ImagenetPredictionInterpreter(Interpreter): | |||||
def __init__(self, model: AttentionInterpretableModel, k: int, base_interpreter: Type[AttentionInterpreter]): | |||||
super().__init__(model) | |||||
self.__base_interpreter = base_interpreter(model) | |||||
self.__k = k | |||||
self.__topk_classes: torch.Tensor = None | |||||
self.__base_interpreter._hooker = Hooker( | |||||
(model, self._generate_prediction_hook()), | |||||
*[(attention_layer[-2], hook) | |||||
for attention_layer, hook in self.__base_interpreter._hooker._layer_hook_pairs]) | |||||
@staticmethod | |||||
def standard_factory(k: int, base_interpreter: Type[AttentionInterpreter] = AttentionInterpreter) -> Callable[[AttentionInterpretableModel], 'ImagenetPredictionInterpreter']: | |||||
return lambda model: ImagenetPredictionInterpreter(model, k, base_interpreter) | |||||
def _generate_prediction_hook(self) -> Hook: | |||||
def hook(_, __, output: ModelIO): | |||||
topk = output['categorical_probability'].detach()\ | |||||
.topk(self.__k, dim=-1) | |||||
dataloader: DataLoader = DataloaderContext.instance.dataloader | |||||
sample_names = dataloader.get_current_batch_samples_names() | |||||
for name, top_class, top_prob in zip(sample_names, topk.indices, topk.values): | |||||
print(f'Top classes of ' | |||||
f'{name}: ' | |||||
f'{top_class.detach().cpu().numpy().tolist()} - ' | |||||
f'{top_prob.cpu().numpy().tolist()}', flush=True) | |||||
self.__topk_classes = topk\ | |||||
.indices\ | |||||
.flatten()\ | |||||
.cpu() | |||||
return hook | |||||
def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor: | |||||
batch_indices = torch.arange(attention.shape[0]).repeat_interleave(self.__k) | |||||
attention = attention[batch_indices, self.__topk_classes] | |||||
attention = attention.reshape(-1, self.__k, *attention.shape[-2:]) | |||||
return self.__base_interpreter._process_attention(result, attention) | |||||
def interpret(self, labels, **inputs): | |||||
return self.__base_interpreter.interpret(labels, **inputs) |
""" | |||||
An Interpretable Pytorch Model | |||||
""" | |||||
from abc import abstractmethod, ABC | |||||
from typing import Iterable, List | |||||
import torch | |||||
from torch import nn | |||||
from ..models.model import Model | |||||
class InterpretableModel(Model, ABC): | |||||
""" | |||||
An Interpretable Pytorch Model | |||||
""" | |||||
@property | |||||
@abstractmethod | |||||
def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]: | |||||
""" | |||||
Returns: | |||||
Input module for interpretation | |||||
""" | |||||
@abstractmethod | |||||
def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor: | |||||
""" | |||||
A method to get probabilities assigned to all the classes in the model's forward, | |||||
with shape (B, C), in which B is the batch size and C is number of classes | |||||
Args: | |||||
*inputs: Inputs to the model | |||||
**kwargs: Additional arguments | |||||
Returns: | |||||
Tensor of categorical probabilities | |||||
""" | |||||
class CamInterpretableModel(InterpretableModel, ABC): | |||||
""" | |||||
A model interpretable by gradcam | |||||
""" | |||||
@property | |||||
@abstractmethod | |||||
def target_conv_layers(self) -> List[nn.Module]: | |||||
""" | |||||
Returns: | |||||
The convolutional layers to be interpreted. The result of | |||||
the interpretation will be sum of the grad-cam of these layers | |||||
""" |
""" | |||||
An adapter wrapper to adapt our models to Captum interpreters | |||||
""" | |||||
import inspect | |||||
from typing import Iterable, Dict, Tuple | |||||
from dataclasses import dataclass | |||||
from itertools import chain | |||||
import torch | |||||
from torch import nn | |||||
from . import InterpretableModel | |||||
@dataclass | |||||
class InterpreterArgs: | |||||
inputs: Tuple[torch.Tensor, ...] | |||||
additional_inputs: Tuple[torch.Tensor, ...] | |||||
class InterpretableWrapper(nn.Module): | |||||
""" | |||||
An adapter wrapper to adapt our models to Captum interpreters | |||||
""" | |||||
def __init__(self, model: InterpretableModel): | |||||
super().__init__() | |||||
self._model = model | |||||
@property | |||||
def _additional_names(self) -> Iterable[str]: | |||||
to_be_interpreted_names = self._model.ordered_placeholder_names_to_be_interpreted | |||||
signature = inspect.signature(self._model.forward) | |||||
return [name for name in signature.parameters | |||||
if name not in to_be_interpreted_names | |||||
and signature.parameters[name].default is not None] | |||||
def convert_inputs_to_kwargs(self, *args: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
""" | |||||
Converts an ordered *args to **kwargs | |||||
""" | |||||
to_be_interpreted_names = self._model.ordered_placeholder_names_to_be_interpreted | |||||
additional_names = self._additional_names | |||||
inputs = {} | |||||
for i, name in enumerate(chain(to_be_interpreted_names, additional_names)): | |||||
inputs[name] = args[i] | |||||
return inputs | |||||
def convert_inputs_to_args(self, **kwargs: torch.Tensor) -> InterpreterArgs: | |||||
""" | |||||
Converts a **kwargs to ordered *args | |||||
""" | |||||
return InterpreterArgs( | |||||
tuple(kwargs[name] for name in self._model.ordered_placeholder_names_to_be_interpreted | |||||
if name in kwargs), | |||||
tuple(kwargs[name] for name in self._additional_names | |||||
if name in kwargs) | |||||
) | |||||
def forward(self, *args: torch.Tensor) -> torch.Tensor: | |||||
""" | |||||
Forwards the model | |||||
""" | |||||
inputs = self.convert_inputs_to_kwargs(*args) | |||||
return self._model.get_categorical_probabilities(**inputs) |
from typing import Dict, Union, Tuple | |||||
import torch | |||||
from .interpreter import Interpreter | |||||
from ..data.dataflow import DataFlow | |||||
from ..data.data_loader import DataLoader | |||||
class InterpretationDataFlow(DataFlow[Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]]): | |||||
def __init__(self, interpreter: Interpreter, | |||||
dataloader: DataLoader, | |||||
device: torch.device, | |||||
print_debug_info: bool, | |||||
label_for_interpreter: Union[int, None], | |||||
give_gt_as_label: bool): | |||||
super().__init__(interpreter._model, dataloader, device, print_debug_info=print_debug_info) | |||||
self._interpreter = interpreter | |||||
self._label_for_interpreter = label_for_interpreter | |||||
self._give_gt_as_label = give_gt_as_label | |||||
self._sample_labels = dataloader.get_samples_labels() | |||||
@staticmethod | |||||
def _detach_output(output: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |||||
if output is None: | |||||
return None | |||||
return { | |||||
name: o.detach() for name, o in output.items() | |||||
} | |||||
def _final_stage(self, model_input: Dict[str, torch.Tensor]) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: | |||||
if self._give_gt_as_label: | |||||
return model_input, self._detach_output(self._interpreter.interpret( | |||||
torch.Tensor(self._sample_labels[ | |||||
self._dataloader.get_current_batch_sample_indices()]), | |||||
**model_input)) | |||||
return model_input, self._detach_output(self._interpreter.interpret( | |||||
self._label_for_interpreter, **model_input | |||||
)) |
from typing import TYPE_CHECKING | |||||
import traceback | |||||
from time import time | |||||
from ..data.data_loader import DataLoader | |||||
from .interpretable import InterpretableModel | |||||
from .interpreter_maker import create_interpreter | |||||
from .interpreter import Interpreter | |||||
from .interpretation_dataflow import InterpretationDataFlow | |||||
from .binary_interpretation_evaluator_2d import BinaryInterpretationEvaluator2D | |||||
if TYPE_CHECKING: | |||||
from ..configs.base_config import BaseConfig | |||||
class InterpretingEvalRunner: | |||||
def __init__(self, conf: 'BaseConfig', model: InterpretableModel): | |||||
self.conf = conf | |||||
self.model = model | |||||
def evaluate(self): | |||||
interpreter = create_interpreter( | |||||
self.conf.interpretation_method, | |||||
self.model) | |||||
for test_group_info in self.conf.samples_dir.split(','): | |||||
try: | |||||
print('>> Evaluating interpretations for %s' % test_group_info, flush=True) | |||||
t1 = time() | |||||
test_data_loader = \ | |||||
DataLoader(self.conf, test_group_info, 'test') | |||||
evaluator = BinaryInterpretationEvaluator2D( | |||||
test_data_loader.get_number_of_samples(), | |||||
self.conf | |||||
) | |||||
self._eval_interpretations(test_data_loader, interpreter, evaluator) | |||||
evaluator.print_summaries() | |||||
print('Evaluating Interpretations was done in %.2f secs.' % (time() - t1,), flush=True) | |||||
except Exception as e: | |||||
print('Problem in %s' % test_group_info, flush=True) | |||||
track = traceback.format_exc() | |||||
print(track, flush=True) | |||||
def _eval_interpretations(self, | |||||
data_loader: DataLoader, interpreter: Interpreter, | |||||
evaluator: BinaryInterpretationEvaluator2D) -> None: | |||||
""" CAUTION: Resets the data_loader, | |||||
Iterates over the samples (as much as and how data_loader specifies), | |||||
finds the interpretations based on the specified interpretation method | |||||
and saves the results in the received save_dir!""" | |||||
if not isinstance(self.model, InterpretableModel): | |||||
raise Exception('Model has not implemented the requirements of the InterpretableModel') | |||||
# Running the model in evaluation mode | |||||
self.model.eval() | |||||
# initiating variables for running evaluations | |||||
evaluator.reset() | |||||
label_for_interpreter = self.conf.class_label_for_interpretation | |||||
give_gt_as_label = label_for_interpreter is None and not self.conf.interpret_predictions_vs_gt | |||||
dataflow = InterpretationDataFlow(interpreter, | |||||
data_loader, | |||||
self.conf.device, | |||||
False, | |||||
label_for_interpreter, | |||||
give_gt_as_label) | |||||
n_interpreted_samples = 0 | |||||
max_n_samples = self.conf.n_interpretation_samples | |||||
with dataflow: | |||||
for model_input, interpretation_output in dataflow.iterate(): | |||||
n_batch = model_input[ | |||||
self.model.ordered_placeholder_names_to_be_interpreted[0]].shape[0] | |||||
n_samples_to_save = n_batch | |||||
if max_n_samples is not None: | |||||
n_samples_to_save = min( | |||||
max_n_samples - n_interpreted_samples, | |||||
n_batch) | |||||
model_outputs = self.model(**model_input) | |||||
model_preds = model_outputs[ | |||||
self.conf.prediction_key_in_model_output_dict].detach().cpu().numpy() | |||||
if self.conf.interpretation_tag_to_evaluate: | |||||
interpretations = interpretation_output[ | |||||
self.conf.interpretation_tag_to_evaluate | |||||
].detach() | |||||
else: | |||||
interpretations = list(interpretation_output.values())[0].detach() | |||||
interpretations = interpretations.detach().cpu().numpy() | |||||
evaluator.update_summaries( | |||||
interpretations[:n_samples_to_save, 0], | |||||
data_loader.get_current_batch_samples_interpretations()[:n_samples_to_save, 0].cpu().numpy(), | |||||
model_preds[:n_samples_to_save], | |||||
data_loader.get_current_batch_samples_labels()[:n_samples_to_save], | |||||
data_loader.get_current_batch_sample_indices()[:n_samples_to_save], | |||||
interpreter | |||||
) | |||||
del interpretation_output | |||||
del model_outputs | |||||
# if the number of interpreted samples has reached the limit, break | |||||
n_interpreted_samples += n_batch | |||||
if max_n_samples is not None and n_interpreted_samples >= max_n_samples: | |||||
break |
""" | |||||
An abstract class to apply different interpretation algorithms on torch models. | |||||
""" | |||||
from abc import abstractmethod, ABC | |||||
from typing import Dict, List, Union | |||||
import torch | |||||
from torch import nn | |||||
import numpy as np | |||||
from . import InterpretableModel | |||||
class Interpreter(ABC): # pylint: disable=too-few-public-methods | |||||
""" | |||||
An abstract class to apply different interpretation algorithms on torch models. | |||||
""" | |||||
def __init__(self, model: InterpretableModel): | |||||
self._model = model | |||||
def _get_all_leaf_children(self, module: nn.Module = None) -> List[nn.Module]: | |||||
""" | |||||
Gets all leaf modules recursively | |||||
""" | |||||
if module is None: | |||||
module = self._model | |||||
children: List[nn.Module] = [] | |||||
for child in module.children(): | |||||
if len(list(child.children())) == 0: | |||||
children.append(child) | |||||
else: | |||||
children += self._get_all_leaf_children(child) | |||||
return children | |||||
@staticmethod | |||||
def _get_one_hot_output(output: torch.Tensor, | |||||
target: Union[int, torch.Tensor, np.ndarray] = None) -> torch.Tensor: | |||||
batch_size = output.shape[0] | |||||
if target is None: | |||||
target = output.argmax(dim=1) | |||||
one_hot_output = torch.zeros_like(output) | |||||
one_hot_output[torch.arange(batch_size), target] = 1 | |||||
return one_hot_output | |||||
def _get_target(self, target: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Union[int, torch.Tensor]: | |||||
if target is not None: | |||||
return target | |||||
return self._model.get_categorical_probabilities(**inputs).argmax(dim=1) | |||||
@abstractmethod | |||||
def interpret(self, labels: Union[int, torch.Tensor, np.ndarray], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
''' | |||||
An abstract method to interpret given input and target class | |||||
Returns a dictionary of interpretation results. | |||||
''' | |||||
def dynamic_threshold(self, x: np.ndarray) -> np.ndarray: | |||||
""" | |||||
A function to dynamically threshold one sample. | |||||
""" | |||||
return (x - (x.mean() + x.std())).clip(min=0) |
""" | |||||
Creates an interpreter object and returns it | |||||
""" | |||||
from typing import Callable | |||||
from . import InterpretableModel, Interpreter | |||||
class InterpretationType: | |||||
GuidedBackprop = 'GuidedBackprop' | |||||
GuidedGradCam = 'GuidedGradCam' | |||||
GradCam = 'GradCam' | |||||
RelCam = 'RelCam' | |||||
DeepLift = 'DeepLift' | |||||
Attention = 'Attention' | |||||
AttentionSmooth = 'AttentionSmooth' | |||||
AttentionSum = 'AttentionSum' | |||||
ImagenetAttention = 'ImagenetAttention' | |||||
GT = 'GT' | |||||
InterpreterMaker = Callable[[InterpretableModel], Interpreter] | |||||
def create_interpreter(interpretation_method: str, model: InterpretableModel) -> Interpreter: | |||||
""" | |||||
Creates an interpreter object and returns it | |||||
:param interpretation_method: The method name | |||||
:param model: The model to be interpreted | |||||
:return: an interpreter object that can run the model and interpret its output | |||||
""" | |||||
if interpretation_method == InterpretationType.GuidedBackprop: | |||||
from .guided_backprop import GuidedBackprop as interpreter_maker | |||||
elif interpretation_method == InterpretationType.GuidedGradCam: | |||||
from .guided_gradcam import GuidedGradCam as interpreter_maker | |||||
elif interpretation_method == InterpretationType.GradCam: | |||||
from .gradcam import GradCam as interpreter_maker | |||||
elif interpretation_method == InterpretationType.RelCam: | |||||
from .relcam import RelCamInterpreter as interpreter_maker | |||||
elif interpretation_method == InterpretationType.DeepLift: | |||||
from .deep_lift import DeepLift as interpreter_maker | |||||
elif interpretation_method == InterpretationType.Attention: | |||||
from .attention_interpreter import AttentionInterpreter as interpreter_maker | |||||
elif interpretation_method == InterpretationType.AttentionSmooth: | |||||
from .attention_interpreter_smooth_integrator import AttentionInterpreterSmoothIntegrator as interpreter_maker | |||||
elif interpretation_method == InterpretationType.AttentionSum: | |||||
from .attention_sum_interpreter import AttentionSumInterpreter as interpreter_maker | |||||
elif interpretation_method == InterpretationType.ImagenetAttention: | |||||
from .imagenet_attention_interpreter import ImagenetPredictionInterpreter as interpreter_maker | |||||
else: | |||||
raise Exception('Unknown interpretation method ', interpretation_method) | |||||
return interpreter_maker(model) | |||||
from typing import Dict, TYPE_CHECKING | |||||
from os import makedirs, path | |||||
import traceback | |||||
from time import time | |||||
import warnings | |||||
import imageio | |||||
import torch | |||||
import numpy as np | |||||
from ..data.data_loader import DataLoader | |||||
from .interpretable import InterpretableModel | |||||
from .interpreter_maker import create_interpreter | |||||
from .interpreter import Interpreter | |||||
from .interpretation_dataflow import InterpretationDataFlow | |||||
from .utils import overlay_interpretation | |||||
if TYPE_CHECKING: | |||||
from ..configs.base_config import BaseConfig | |||||
class InterpretingRunner: | |||||
def __init__(self, conf: 'BaseConfig', model: InterpretableModel): | |||||
self.conf = conf | |||||
self.model = model | |||||
def interpret(self): | |||||
interpreter = create_interpreter( | |||||
self.conf.interpretation_method, | |||||
self.model) | |||||
for test_group_info in self.conf.samples_dir.split(','): | |||||
try: | |||||
print('>> Finding interpretations for %s' % test_group_info, flush=True) | |||||
t1 = time() | |||||
labels_to_use = self.conf.mapped_labels_to_use | |||||
if labels_to_use is None: | |||||
labels_to_use = 'All' | |||||
else: | |||||
labels_to_use = ','.join([str(x) for x in self.conf.mapped_labels_to_use]) | |||||
report_dir = self.conf.get_sample_group_specific_report_dir( | |||||
test_group_info, | |||||
extra_subdir=f'{self.conf.interpretation_method}-C,{labels_to_use}-cut,{self.conf.cut_threshold}-glob,{self.conf.global_threshold}') | |||||
makedirs(report_dir, exist_ok=True) | |||||
# writing the whole config | |||||
f = open(report_dir + '/conf_info.txt', 'w') | |||||
f.write(str(self.conf) + '\n') | |||||
f.close() | |||||
test_data_loader = \ | |||||
DataLoader(self.conf, test_group_info, 'test') | |||||
self._interpret(report_dir, test_data_loader, interpreter) | |||||
print('Interpretations were saved in %s' % report_dir) | |||||
print('Interpreting was done in %.2f secs.' % (time() - t1,), flush=True) | |||||
except Exception as e: | |||||
print('Problem in %s' % test_group_info, flush=True) | |||||
track = traceback.format_exc() | |||||
print(track, flush=True) | |||||
def _interpret(self, save_dir: str, data_loader: DataLoader, interpreter: Interpreter) -> None: | |||||
""" Iterates over the samples (as much as and how data_loader specifies), | |||||
finds the interpretations based on the specified interpretation method | |||||
and saves the results in the received save_dir!""" | |||||
makedirs(save_dir, exist_ok=True) | |||||
if not isinstance(self.model, InterpretableModel): | |||||
raise Exception('Model has not implemented the requirements of the InterpretableModel') | |||||
# Running the model in evaluation mode | |||||
self.model.eval() | |||||
# initiating variables for running evaluations | |||||
label_for_interpreter = self.conf.class_label_for_interpretation | |||||
give_gt_as_label = label_for_interpreter is None and not self.conf.interpret_predictions_vs_gt | |||||
dataflow = InterpretationDataFlow(interpreter, | |||||
data_loader, | |||||
self.conf.device, | |||||
False, | |||||
label_for_interpreter, | |||||
give_gt_as_label) | |||||
n_interpreted_samples = 0 | |||||
max_n_samples = self.conf.n_interpretation_samples | |||||
with dataflow: | |||||
for model_input, interpretation_output in dataflow.iterate(): | |||||
n_batch = model_input[ | |||||
self.model.ordered_placeholder_names_to_be_interpreted[0]].shape[0] | |||||
n_samples_to_save = n_batch | |||||
if max_n_samples is not None: | |||||
n_samples_to_save = min( | |||||
max_n_samples - n_interpreted_samples, | |||||
n_batch) | |||||
self._save_interpretations_of_batch( | |||||
model_input[self.model.ordered_placeholder_names_to_be_interpreted[0]], | |||||
interpretation_output, save_dir, data_loader, n_samples_to_save, | |||||
interpreter) | |||||
del interpretation_output | |||||
# if the number of interpreted samples has reached the limit, break | |||||
n_interpreted_samples += n_batch | |||||
if max_n_samples is not None and n_interpreted_samples >= max_n_samples: | |||||
break | |||||
def _save_interpretations_of_batch(self, model_input: torch.Tensor, interpreter_output: Dict[str, torch.Tensor], | |||||
save_dir: str, data_loader: DataLoader, | |||||
n_samples_to_save: int, interpreter: Interpreter) -> None: | |||||
""" | |||||
Receives the output of the interpreter and saves the interpretations in the received directory | |||||
in a file named as the sample name. | |||||
The behaviour can be overwritten in children if extra stuff are required. | |||||
:param interpreter_output: | |||||
:param save_dir: | |||||
:return: None | |||||
""" | |||||
batch_samples_names = data_loader.get_current_batch_samples_names() | |||||
save_dirs = [self.conf.get_save_dir_for_sample( | |||||
save_dir, batch_samples_names[bi].replace('../', '')) | |||||
for bi in range(len(batch_samples_names))] | |||||
for sd in save_dirs: | |||||
makedirs(path.dirname(sd), exist_ok=True) | |||||
interpreter_output = {name: output.cpu().numpy() for name, output in interpreter_output.items()} | |||||
""" Make inputs grayscale """ | |||||
model_input = model_input.mean(dim=1).cpu().numpy() | |||||
def save(bis): | |||||
for bi in bis: | |||||
for name, output in interpreter_output.items(): | |||||
output = output[bi] | |||||
filename = f'{save_dirs[bi]}_{name}' | |||||
if self.conf.dynamic_threshold: | |||||
output = interpreter.dynamic_threshold(output) | |||||
if not self.conf.skip_raw: | |||||
np.save(filename, output) | |||||
else: | |||||
warnings.warn('Skipping raw interpretations') | |||||
if not self.conf.skip_overlay: | |||||
if output.shape[0] > 3: | |||||
output = output[:3] | |||||
overlayed = overlay_interpretation(model_input[bi][np.newaxis, ...], output, self.conf) | |||||
imageio.imwrite(f'{filename}_overlay.png', overlayed) | |||||
else: | |||||
warnings.warn('Skipping overlayed interpretations') | |||||
save(np.arange(min(len(model_input), n_samples_to_save))) |
from .interpreter import RelCamInterpreter |
from typing import Dict, Union | |||||
import warnings | |||||
import numpy as np | |||||
import torch.nn.functional as F | |||||
import torch | |||||
from ..interpreter import Interpreter | |||||
from ..interpretable import CamInterpretableModel | |||||
from .relprop import RPProvider | |||||
class RelCamInterpreter(Interpreter): | |||||
def __init__(self, model: CamInterpretableModel): | |||||
super().__init__(model) | |||||
self.__targets = model.target_conv_layers | |||||
for name, module in model.named_modules(): | |||||
if not RPProvider.propable(module): | |||||
warnings.warn(f"Module {name} of type {type(module)} is not propable! Hope you know what you are doing!") | |||||
continue | |||||
RPProvider.create(module) | |||||
@staticmethod | |||||
def __normalize(x: torch.Tensor) -> torch.Tensor: | |||||
fx = x.flatten(1) | |||||
minx = fx.min(dim=1).values[:, None, None, None] | |||||
maxx = fx.max(dim=1).values[:, None, None, None] | |||||
return (x - minx) / (1e-6 + maxx - minx) | |||||
def interpret(self, labels: Union[int, torch.Tensor, np.ndarray], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
x_shape = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape | |||||
with RPProvider.capture(self.__targets) as Rs: | |||||
z = self._model.get_categorical_probabilities(**inputs) | |||||
one_hot_y = self._get_one_hot_output(z, labels) | |||||
RPProvider.get(self._model)(one_hot_y) | |||||
result = list(Rs.values()) | |||||
relcam = torch.stack(result).sum(dim=0)\ | |||||
if len(result) > 1\ | |||||
else result[0] | |||||
relcam = self.__normalize(relcam) | |||||
relcam = F.interpolate(relcam, size=x_shape[2:], mode='bilinear', align_corners=False) | |||||
return { | |||||
'default': relcam, | |||||
} |
from torch import nn | |||||
import torch | |||||
class Add(nn.Module): | |||||
def forward(self, inputs): | |||||
return torch.add(*inputs) | |||||
class Multiply(nn.Module): | |||||
def forward(self, inputs): | |||||
return torch.mul(*inputs) | |||||
class Cat(nn.Module): | |||||
def forward(self, inputs, dim): | |||||
self.__setattr__('dim', dim) | |||||
return torch.cat(inputs, dim) |
from contextlib import ExitStack, contextmanager | |||||
from typing import Callable, Dict, Generic, Iterator, List, Tuple, Type, TypeVar | |||||
from uuid import UUID, uuid4 | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from torch import nn | |||||
import torchvision | |||||
from . import modules as M | |||||
TModule = TypeVar('TModule', bound=nn.Module) | |||||
TType = TypeVar('TType', bound=type) | |||||
def safe_divide(a, b): | |||||
return a / (b + b.eq(0).type(b.type()) * 1e-9) * b.ne(0).type(b.type()) | |||||
class classdict(Dict[type, TType], Generic[TType]): | |||||
def __nearest_available_resolution(self, __k: Type[TModule]) -> Type[TModule]: | |||||
return next((r for r in __k.mro() if dict.__contains__(self, r)), None) | |||||
def __contains__(self, __k: Type[TModule]) -> bool: | |||||
return self.__nearest_available_resolution(__k) is not None | |||||
def __getitem__(self, __k: Type[TModule]) -> TType: | |||||
r = self.__nearest_available_resolution(__k) | |||||
if r is None: | |||||
raise KeyError(f"{__k} not found!") | |||||
return super().__getitem__(r) | |||||
class RPProvider: | |||||
__props: Dict[Type[TModule], Type['RelProp[TModule]']] = classdict() | |||||
__instances: Dict[TModule, 'RelProp[TModule]'] = {} | |||||
__target_results: Dict[nn.Module, torch.Tensor] = {} | |||||
@classmethod | |||||
def register(cls, *module_types: Type[TModule]): | |||||
def decorator(prop_cls): | |||||
cls.__props.update({ | |||||
module_type: prop_cls | |||||
for module_type in module_types | |||||
}) | |||||
return prop_cls | |||||
return decorator | |||||
@classmethod | |||||
def create(cls, module: TModule): | |||||
cls.__instances[module] = prop = cls.__props[type(module)](module) | |||||
return prop | |||||
@classmethod | |||||
def __hook(cls, prop: 'RelProp[TModule]', R: torch.Tensor) -> None: | |||||
r_weight = torch.mean(R, dim=(2, 3), keepdim=True) | |||||
r_cam = prop.X * r_weight | |||||
r_cam = torch.sum(r_cam, dim=1, keepdim=True) | |||||
cls.__target_results[prop.module] = r_cam | |||||
@classmethod | |||||
@contextmanager | |||||
def capture(cls, target_layers: List[nn.Module]) -> Iterator[Dict[nn.Module, torch.Tensor]]: | |||||
with ExitStack() as stack: | |||||
cls.__target_results.clear() | |||||
[stack.enter_context(prop.hook_module()) | |||||
for prop in cls.__instances.values()] | |||||
[stack.enter_context(cls.get(target).register_hook(cls.__hook)) | |||||
for target in target_layers] | |||||
yield cls.__target_results | |||||
@classmethod | |||||
def propable(cls, module: TModule) -> bool: | |||||
return type(module) in cls.__props | |||||
@classmethod | |||||
def get(cls, module: TModule) -> 'RelProp[TModule]': | |||||
return cls.__instances[module] | |||||
@RPProvider.register(nn.Identity, nn.ReLU, nn.LeakyReLU, nn.Dropout, nn.Sigmoid) | |||||
class RelProp(Generic[TModule]): | |||||
Hook = Callable[['RelProp', torch.Tensor], None] | |||||
def __forward_hook(self, _, input: Tuple[torch.Tensor, ...], output: torch.Tensor): | |||||
if type(input[0]) in (list, tuple): | |||||
self.X = [] | |||||
for i in input[0]: | |||||
x = i.detach() | |||||
x.requires_grad = True | |||||
self.X.append(x) | |||||
else: | |||||
self.X = input[0].detach() | |||||
self.X.requires_grad = True | |||||
self.Y = output | |||||
def __init__(self, module: TModule) -> None: | |||||
self.module = module | |||||
self.__hooks: Dict[UUID, RelProp.Hook] = {} | |||||
@contextmanager | |||||
def hook_module(self): | |||||
handle = self.module.register_forward_hook(self.__forward_hook) | |||||
try: | |||||
yield | |||||
finally: | |||||
handle.remove() | |||||
@contextmanager | |||||
def register_hook(self, hook: 'RelProp.Hook'): | |||||
uuid = uuid4() | |||||
self.__hooks[uuid] = hook | |||||
try: | |||||
yield | |||||
finally: | |||||
self.__hooks.pop(uuid) | |||||
def grad(self, Z, X, S): | |||||
C = torch.autograd.grad(Z, X, S, retain_graph=True) | |||||
return C | |||||
def __call__(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
R = self.rel(R, alpha=alpha) | |||||
[hook(self, R) for hook in self.__hooks.values()] | |||||
return R | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
return R | |||||
@RPProvider.register(nn.MaxPool2d, nn.AvgPool2d, M.Add, torchvision.transforms.transforms.Normalize) | |||||
class RelPropSimple(RelProp[TModule]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
Z = self.module.forward(self.X) | |||||
S = safe_divide(R, Z) | |||||
C = self.grad(Z, self.X, S) | |||||
if torch.is_tensor(self.X) == False: | |||||
outputs = [] | |||||
outputs.append(self.X[0] * C[0]) | |||||
outputs.append(self.X[1] * C[1]) | |||||
else: | |||||
outputs = self.X * C[0] | |||||
return outputs | |||||
@RPProvider.register(nn.Flatten) | |||||
class RelPropFlatten(RelProp[TModule]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
return R.reshape_as(self.X) | |||||
@RPProvider.register(nn.AdaptiveAvgPool2d) | |||||
class AdaptiveAvgPool2dRelProp(RelProp[nn.AdaptiveAvgPool2d]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
px = torch.clamp(self.X, min=0) | |||||
def f(x1): | |||||
Z1 = F.adaptive_avg_pool2d(x1, self.module.output_size) | |||||
S1 = safe_divide(R, Z1) | |||||
C1 = x1 * self.grad(Z1, x1, S1)[0] | |||||
return C1 | |||||
activator_relevances = f(px) | |||||
out = activator_relevances | |||||
return out | |||||
@RPProvider.register(nn.ZeroPad2d) | |||||
class ZeroPad2dRelProp(RelProp[nn.ZeroPad2d]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
Z = self.module.forward(self.X) | |||||
S = safe_divide(R, Z) | |||||
C = self.grad(Z, self.X, S) | |||||
outputs = self.X * C[0] | |||||
return outputs | |||||
@RPProvider.register(M.Multiply) | |||||
class MultiplyRelProp(RelProp[M.Multiply]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
x0 = torch.clamp(self.X[0], min=0) | |||||
x1 = torch.clamp(self.X[1], min=0) | |||||
x = [x0, x1] | |||||
Z = self.module.forward(x) | |||||
S = safe_divide(R, Z) | |||||
C = self.grad(Z, x, S) | |||||
outputs = [] | |||||
outputs.append(x[0] * C[0]) | |||||
outputs.append(x[1] * C[1]) | |||||
return outputs | |||||
@RPProvider.register(M.Cat) | |||||
class CatRelProp(RelProp[M.Cat]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
Z = self.module.forward(self.X, self.module.dim) | |||||
S = safe_divide(R, Z) | |||||
C = self.grad(Z, self.X, S) | |||||
outputs = [] | |||||
for x, c in zip(self.X, C): | |||||
outputs.append(x * c) | |||||
return outputs | |||||
@RPProvider.register(nn.Sequential) | |||||
class SequentialRelProp(RelProp[nn.Sequential]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
for m in reversed(self.module): | |||||
R = RPProvider.get(m)(R, alpha=alpha) | |||||
return R | |||||
@RPProvider.register(nn.BatchNorm2d) | |||||
class BatchNorm2dRelProp(RelProp[nn.BatchNorm2d]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
X = self.X | |||||
beta = 1 - alpha | |||||
weight = self.module.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / ( | |||||
(self.module.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.module.eps).pow(0.5)) | |||||
Z = X * weight + 1e-9 | |||||
S = R / Z | |||||
Ca = S * weight | |||||
R = self.X * (Ca) | |||||
return R | |||||
@RPProvider.register(nn.Linear) | |||||
class LinearRelProp(RelProp[nn.Linear]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
beta = alpha - 1 | |||||
pw = torch.clamp(self.module.weight, min=0) | |||||
nw = torch.clamp(self.module.weight, max=0) | |||||
px = torch.clamp(self.X, min=0) | |||||
nx = torch.clamp(self.X, max=0) | |||||
def f(w1, w2, x1, x2): | |||||
Z1 = F.linear(x1, w1) | |||||
Z2 = F.linear(x2, w2) | |||||
Z = Z1 + Z2 | |||||
S = safe_divide(R, Z) | |||||
C1 = x1 * self.grad(Z1, x1, S)[0] | |||||
C2 = x2 * self.grad(Z2, x2, S)[0] | |||||
return C1 + C2 | |||||
activator_relevances = f(pw, nw, px, nx) | |||||
inhibitor_relevances = f(nw, pw, px, nx) | |||||
out = alpha * activator_relevances - beta*inhibitor_relevances | |||||
return out | |||||
@RPProvider.register(nn.Conv2d) | |||||
class Conv2dRelProp(RelProp[nn.Conv2d]): | |||||
def gradprop2(self, DY, weight): | |||||
Z = self.module.forward(self.X) | |||||
output_padding = self.X.size()[2] - ( | |||||
(Z.size()[2] - 1) * self.module.stride[0] - 2 * self.module.padding[0] + self.module.kernel_size[0]) | |||||
return F.conv_transpose2d(DY, weight, stride=self.module.stride, padding=self.module.padding, output_padding=output_padding) | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
if self.X.shape[1] == 3: | |||||
pw = torch.clamp(self.module.weight, min=0) | |||||
nw = torch.clamp(self.module.weight, max=0) | |||||
X = self.X | |||||
L = self.X * 0 + \ | |||||
torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, | |||||
keepdim=True)[0] | |||||
H = self.X * 0 + \ | |||||
torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, | |||||
keepdim=True)[0] | |||||
Za = torch.conv2d(X, self.module.weight, bias=None, stride=self.module.stride, padding=self.module.padding) - \ | |||||
torch.conv2d(L, pw, bias=None, stride=self.module.stride, padding=self.module.padding) - \ | |||||
torch.conv2d(H, nw, bias=None, stride=self.module.stride, | |||||
padding=self.module.padding) + 1e-9 | |||||
S = R / Za | |||||
C = X * self.gradprop2(S, self.module.weight) - L * \ | |||||
self.gradprop2(S, pw) - H * self.gradprop2(S, nw) | |||||
R = C | |||||
else: | |||||
beta = alpha - 1 | |||||
pw = torch.clamp(self.module.weight, min=0) | |||||
nw = torch.clamp(self.module.weight, max=0) | |||||
px = torch.clamp(self.X, min=0) | |||||
nx = torch.clamp(self.X, max=0) | |||||
def f(w1, w2, x1, x2): | |||||
Z1 = F.conv2d(x1, w1, bias=self.module.bias, stride=self.module.stride, | |||||
padding=self.module.padding, groups=self.module.groups) | |||||
Z2 = F.conv2d(x2, w2, bias=self.module.bias, stride=self.module.stride, | |||||
padding=self.module.padding, groups=self.module.groups) | |||||
Z = Z1 + Z2 | |||||
S = safe_divide(R, Z) | |||||
C1 = x1 * self.grad(Z1, x1, S)[0] | |||||
C2 = x2 * self.grad(Z2, x2, S)[0] | |||||
return C1 + C2 | |||||
activator_relevances = f(pw, nw, px, nx) | |||||
inhibitor_relevances = f(nw, pw, px, nx) | |||||
R = alpha * activator_relevances - beta * inhibitor_relevances | |||||
return R | |||||
@RPProvider.register(nn.ConvTranspose2d) | |||||
class ConvTranspose2dRelProp(RelProp[nn.ConvTranspose2d]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
pw = torch.clamp(self.module.weight, min=0) | |||||
px = torch.clamp(self.X, min=0) | |||||
def f(w1, x1): | |||||
Z1 = F.conv_transpose2d(x1, w1, bias=None, stride=self.module.stride, | |||||
padding=self.module.padding, output_padding=self.module.output_padding) | |||||
S1 = safe_divide(R, Z1) | |||||
C1 = x1 * self.grad(Z1, x1, S1)[0] | |||||
return C1 | |||||
activator_relevances = f(pw, px) | |||||
R = activator_relevances | |||||
return R |
from typing import TYPE_CHECKING, Optional | |||||
import numpy as np | |||||
import cv2 | |||||
from skimage.morphology import dilation, erosion | |||||
if TYPE_CHECKING: | |||||
from ..configs.base_config import BaseConfig | |||||
from . import Interpreter | |||||
def overlay_interpretation(image: np.ndarray, int_map: np.ndarray, config: 'BaseConfig') -> np.ndarray: | |||||
""" | |||||
image (np.ndarray): The image to overlay interpretations over. Of shape C W H. The numbers are assumed to be in range [0, 1] | |||||
int_map (np.ndarray): The map containing interpretation scores. | |||||
config ('Config'): An object containing required information about the configurations. | |||||
Returns: | |||||
(np.ndarray): The overlayed image | |||||
""" | |||||
int_map = process_interpretations(int_map, config) # C W H | |||||
int_map = np.moveaxis(int_map, 0, -1) # W H C | |||||
# if there are more than 3 channels, use the first 3! | |||||
if int_map.shape[-1] > 3: | |||||
int_map = int_map[..., :3] | |||||
# norm by max to make sure! | |||||
int_map = int_map * 1.0 / (1e-6 + np.amax(int_map)) | |||||
alpha = 0.5 | |||||
if np.amin(image) < 0: | |||||
image += 0.5 | |||||
image = np.moveaxis(image, 0, -1) # W H C | |||||
image = image * 255 | |||||
o_color_by_c = { | |||||
1: [255, 0, 0], | |||||
2: [255, 255, 0], | |||||
3: [255, 255, 255], | |||||
} | |||||
o_color = np.asarray(o_color_by_c.get(int_map.shape[-1], None)) | |||||
if o_color is None: | |||||
raise Exception("Trying to overlay more than 3 maps on the image.") | |||||
# if int map has two channels, append one zeros to the end! | |||||
if int_map.shape[-1] == 2: | |||||
int_map = np.concatenate((int_map, np.zeros((int_map.shape[-3], int_map.shape[-2], 1))), axis=-1) | |||||
overlayed = (1 - int_map) * image + \ | |||||
int_map * (1 - alpha) * image + \ | |||||
int_map * alpha * o_color[np.newaxis, np.newaxis, :] | |||||
overlayed = np.round(overlayed).astype(np.uint8) | |||||
return overlayed | |||||
def process_interpretations(int_map: np.ndarray, config: 'BaseConfig', interpreter: Optional['Interpreter'] = None) -> np.ndarray: | |||||
""" | |||||
int_map (np.ndarray): The map of interpretation scores. Is assumed to have a shape of C ... | |||||
config ('Config'): An object containing the information about the required configurations | |||||
Returns: | |||||
np.ndarray: Processed interpretation maps, with numbers in the range of [0, 1] | |||||
""" | |||||
if interpreter and config.dynamic_threshold: | |||||
int_map = interpreter.dynamic_threshold(int_map) | |||||
if not config.global_threshold: | |||||
# Treating map as negative=reverse effect; discarding negatives | |||||
int_map[int_map < 0] = 0 | |||||
qval = np.amax(int_map) | |||||
if qval > 0: | |||||
int_map[int_map > qval] = qval | |||||
int_map = int_map / qval | |||||
# applying cut threshold! | |||||
int_map[int_map <= config.cut_threshold] = config.cut_threshold | |||||
int_map -= config.cut_threshold | |||||
kw = int(np.round(0.1 * int_map.shape[-1])) | |||||
mv = np.amax(int_map) | |||||
int_map /= (mv + 1e-6) | |||||
# dilation and erosion to make continuous objects and erase noises | |||||
for c in range(int_map.shape[0]): | |||||
int_map[c] = dilation(int_map[c], cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kw + 3, kw + 3))) | |||||
int_map[c] = erosion(int_map[c], cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kw, kw))) | |||||
int_map[c] = erosion(int_map[c], cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))) | |||||
int_map[c] = cv2.blur(np.round(255 * int_map[c]).astype(np.uint8), (5, 5)) * 1.0 / 255.0 | |||||
return int_map # [0, 1] | |||||
from typing import List, TYPE_CHECKING, Dict | |||||
import numpy as np | |||||
import torch | |||||
from ..data.data_loader import DataLoader | |||||
from ..models.model import Model | |||||
from ..model_evaluation.evaluator import Evaluator | |||||
if TYPE_CHECKING: | |||||
from ..configs.base_config import BaseConfig | |||||
class BinaryEvaluator(Evaluator): | |||||
def __init__(self, model: Model, data_loader: DataLoader, conf: 'BaseConfig'): | |||||
super(BinaryEvaluator, self).__init__(model, data_loader, conf) | |||||
# for summarized information | |||||
self.tp: int = 0 | |||||
self.tn: int = 0 | |||||
self.fp: int = 0 | |||||
self.fn: int = 0 | |||||
self.avg_loss: float = 0 | |||||
self.n_received_samples: int = 0 | |||||
def reset(self): | |||||
self.tp = 0 | |||||
self.tn = 0 | |||||
self.fp = 0 | |||||
self.fn = 0 | |||||
self.avg_loss = 0 | |||||
self.n_received_samples = 0 | |||||
@property | |||||
def _prob_key(self) -> str: | |||||
return 'positive_class_probability' | |||||
@property | |||||
def _loss_key(self) -> str: | |||||
return 'loss' | |||||
def _get_current_batch_gt(self) -> np.ndarray: | |||||
return self.data_loader.get_current_batch_samples_labels() | |||||
def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None: | |||||
assert self._prob_key in model_output, \ | |||||
f"model's output must contain {self._prob_key}" | |||||
gt = self._get_current_batch_gt() | |||||
prediction = (model_output[self._prob_key].cpu().numpy() >= 0.5).astype(int) | |||||
ntp = int(np.sum(np.logical_and(gt == 1, prediction == 1))) | |||||
ntn = int(np.sum(np.logical_and(gt == 0, prediction == 0))) | |||||
nfp = int(np.sum(np.logical_and(gt == 0, prediction == 1))) | |||||
nfn = int(np.sum(np.logical_and(gt == 1, prediction == 0))) | |||||
self.tp += ntp | |||||
self.tn += ntn | |||||
self.fp += nfp | |||||
self.fn += nfn | |||||
new_n = self.n_received_samples + len(gt) | |||||
self.avg_loss = self.avg_loss * (float(self.n_received_samples) / new_n) + \ | |||||
model_output.get(self._loss_key, 0.0) * (float(len(gt)) / new_n) | |||||
self.n_received_samples = new_n | |||||
def get_titles_of_evaluation_metrics(self) -> List[str]: | |||||
return ['Loss', 'Acc', 'Sens', 'Spec', 'BAcc', 'N'] | |||||
def get_values_of_evaluation_metrics(self) -> List[str]: | |||||
p = self.tp + self.fn | |||||
n = self.tn + self.fp | |||||
if p + n > 0: | |||||
accuracy = (self.tp + self.tn) * 100.0 / (n + p) | |||||
else: | |||||
accuracy = -1 | |||||
if p > 0: | |||||
sensitivity = 100.0 * self.tp / p | |||||
else: | |||||
sensitivity = -1 | |||||
if n > 0: | |||||
specificity = 100.0 * self.tn / max(n, 1) | |||||
else: | |||||
specificity = -1 | |||||
if sensitivity > -1 and specificity > -1: | |||||
avg_ss = 0.5 * (sensitivity + specificity) | |||||
elif sensitivity > -1: | |||||
avg_ss = sensitivity | |||||
else: | |||||
avg_ss = specificity | |||||
return ['%.4f' % self.avg_loss, '%.2f' % accuracy, '%.2f' % sensitivity, | |||||
'%.2f' % specificity, '%.2f' % avg_ss, str(self.n_received_samples)] | |||||
from functools import partial | |||||
from typing import List, TYPE_CHECKING, Dict, Callable | |||||
import numpy as np | |||||
from torch import Tensor | |||||
from ..data.data_loader import DataLoader | |||||
from ..models.model import Model | |||||
from ..model_evaluation.evaluator import Evaluator | |||||
if TYPE_CHECKING: | |||||
from ..configs.base_config import BaseConfig | |||||
class BinFaithfulnessEvaluator(Evaluator): | |||||
def __init__(self, kw_prefix: str, gt_kw, model: Model, data_loader: DataLoader, conf: 'BaseConfig'): | |||||
super(BinFaithfulnessEvaluator, self).__init__(model, data_loader, conf) | |||||
self._tp_by_kw: Dict[str, int] = dict() | |||||
self._fp_by_kw: Dict[str, int] = dict() | |||||
self._tn_by_kw: Dict[str, int] = dict() | |||||
self._fn_by_kw: Dict[str, int] = dict() | |||||
self._kw_prefix = kw_prefix | |||||
self._gt_kw = gt_kw | |||||
def reset(self): | |||||
self._tp_by_kw: Dict[str, int] = dict() | |||||
self._fp_by_kw: Dict[str, int] = dict() | |||||
self._tn_by_kw: Dict[str, int] = dict() | |||||
self._fn_by_kw: Dict[str, int] = dict() | |||||
def update_summaries_based_on_model_output(self, model_output: Dict[str, Tensor]) -> None: | |||||
gt = (model_output[self._gt_kw].detach().cpu().numpy() >= 0.5).astype(int) | |||||
# looking for prefixes | |||||
for k in model_output.keys(): | |||||
if k.startswith(self._kw_prefix): | |||||
if k not in self._tp_by_kw: | |||||
self._tp_by_kw[k] = 0 | |||||
self._tn_by_kw[k] = 0 | |||||
self._fp_by_kw[k] = 0 | |||||
self._fn_by_kw[k] = 0 | |||||
pred = (model_output[k].cpu().numpy() >= 0.5).astype(int) | |||||
self._tp_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 1)) | |||||
self._fp_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 1)) | |||||
self._tn_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 0)) | |||||
self._fn_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 0)) | |||||
def get_titles_of_evaluation_metrics(self) -> List[str]: | |||||
return [f'Loyalty_{k}_{metric}' for k in self._tp_by_kw.keys() for metric in ['Acc', 'Sens', 'Spec', 'AvgSS']] | |||||
def _get_values_of_evaluation_metrics(self, kw) -> List[str]: | |||||
tp = self._tp_by_kw[kw] | |||||
tn = self._tn_by_kw[kw] | |||||
fp = self._fp_by_kw[kw] | |||||
fn = self._fn_by_kw[kw] | |||||
p = tp + fn | |||||
n = tn + fp | |||||
if p + n > 0: | |||||
accuracy = (tp + tn) * 100.0 / (n + p) | |||||
else: | |||||
accuracy = -1 | |||||
if p > 0: | |||||
sensitivity = 100.0 * tp / p | |||||
else: | |||||
sensitivity = -1 | |||||
if n > 0: | |||||
specificity = 100.0 * tn / max(n, 1) | |||||
else: | |||||
specificity = -1 | |||||
if sensitivity > -1 and specificity > -1: | |||||
avg_ss = 0.5 * (sensitivity + specificity) | |||||
elif sensitivity > -1: | |||||
avg_ss = sensitivity | |||||
else: | |||||
avg_ss = specificity | |||||
return ['%.2f' % accuracy, '%.2f' % sensitivity, | |||||
'%.2f' % specificity, '%.2f' % avg_ss] | |||||
def get_values_of_evaluation_metrics(self) -> List[str]: | |||||
return [ | |||||
v | |||||
for k in self._tp_by_kw.keys() | |||||
for v in self._get_values_of_evaluation_metrics(k)] | |||||
@classmethod | |||||
def standard_creator(cls, prefix_kw: str, pred_kw: str = 'positive_class_probability') -> Callable[[Model, DataLoader, 'BaseConfig'], 'BinFaithfulnessEvaluator']: | |||||
return partial(BinFaithfulnessEvaluator, prefix_kw, pred_kw) | |||||
def print_evaluation_metrics(self, title: str) -> None: | |||||
""" For more readable printing! """ | |||||
print(f'{title}:') | |||||
for k in self._tp_by_kw.keys(): | |||||
print( | |||||
f'\t{k}:: ' + | |||||
', '.join([f'{m_name}: {m_val}' | |||||
for m_name, m_val in zip(['Acc', 'Sens', 'Spec', 'AvgSS'], self._get_values_of_evaluation_metrics(k))])) |
from typing import List, TYPE_CHECKING, Dict, Callable | |||||
from functools import partial | |||||
import numpy as np | |||||
from torch import Tensor | |||||
from ..data.data_loader import DataLoader | |||||
from ..models.model import Model | |||||
from ..model_evaluation.evaluator import Evaluator | |||||
if TYPE_CHECKING: | |||||
from ..configs.base_config import BaseConfig | |||||
class BinForetellerEvaluator(Evaluator): | |||||
def __init__(self, kw_prefix: str, model: Model, data_loader: DataLoader, conf: 'BaseConfig'): | |||||
super(BinForetellerEvaluator, self).__init__(model, data_loader, conf) | |||||
self._tp_by_kw: Dict[str, int] = dict() | |||||
self._fp_by_kw: Dict[str, int] = dict() | |||||
self._tn_by_kw: Dict[str, int] = dict() | |||||
self._fn_by_kw: Dict[str, int] = dict() | |||||
self._kw_prefix = kw_prefix | |||||
def reset(self): | |||||
self._tp_by_kw: Dict[str, int] = dict() | |||||
self._fp_by_kw: Dict[str, int] = dict() | |||||
self._tn_by_kw: Dict[str, int] = dict() | |||||
self._fn_by_kw: Dict[str, int] = dict() | |||||
def update_summaries_based_on_model_output(self, model_output: Dict[str, Tensor]) -> None: | |||||
gt = self.data_loader.get_current_batch_samples_labels() | |||||
# looking for prefixes | |||||
for k in model_output.keys(): | |||||
if k.startswith(self._kw_prefix): | |||||
if k not in self._tp_by_kw: | |||||
self._tp_by_kw[k] = 0 | |||||
self._tn_by_kw[k] = 0 | |||||
self._fp_by_kw[k] = 0 | |||||
self._fn_by_kw[k] = 0 | |||||
pred = (model_output[k].cpu().numpy() >= 0.5).astype(int) | |||||
self._tp_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 1)) | |||||
self._fp_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 1)) | |||||
self._tn_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 0)) | |||||
self._fn_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 0)) | |||||
def get_titles_of_evaluation_metrics(self) -> List[str]: | |||||
return [f'{k}_{metric}' for k in self._tp_by_kw.keys() for metric in ['Acc', 'Sens', 'Spec', 'AvgSS']] | |||||
def _get_values_of_evaluation_metrics(self, kw) -> List[str]: | |||||
tp = self._tp_by_kw[kw] | |||||
tn = self._tn_by_kw[kw] | |||||
fp = self._fp_by_kw[kw] | |||||
fn = self._fn_by_kw[kw] | |||||
p = tp + fn | |||||
n = tn + fp | |||||
if p + n > 0: | |||||
accuracy = (tp + tn) * 100.0 / (n + p) | |||||
else: | |||||
accuracy = -1 | |||||
if p > 0: | |||||
sensitivity = 100.0 * tp / p | |||||
else: | |||||
sensitivity = -1 | |||||
if n > 0: | |||||
specificity = 100.0 * tn / max(n, 1) | |||||
else: | |||||
specificity = -1 | |||||
if sensitivity > -1 and specificity > -1: | |||||
avg_ss = 0.5 * (sensitivity + specificity) | |||||
elif sensitivity > -1: | |||||
avg_ss = sensitivity | |||||
else: | |||||
avg_ss = specificity | |||||
return ['%.2f' % accuracy, '%.2f' % sensitivity, | |||||
'%.2f' % specificity, '%.2f' % avg_ss] | |||||
def get_values_of_evaluation_metrics(self) -> List[str]: | |||||
return [ | |||||
v | |||||
for k in self._tp_by_kw.keys() | |||||
for v in self._get_values_of_evaluation_metrics(k)] | |||||
@classmethod | |||||
def standard_creator(cls, prefix_kw: str) -> Callable[[Model, DataLoader, 'BaseConfig'], 'BinForetellerEvaluator']: | |||||
return partial(BinForetellerEvaluator, prefix_kw) | |||||
def print_evaluation_metrics(self, title: str) -> None: | |||||
""" For more readable printing! """ | |||||
print(f'{title}:') | |||||
for k in self._tp_by_kw.keys(): | |||||
print( | |||||
f'\t{k}:: ' + | |||||
', '.join([f'{m_name}: {m_val}' | |||||
for m_name, m_val in zip(['Acc', 'Sens', 'Spec', 'AvgSS'], self._get_values_of_evaluation_metrics(k))])) |
from typing import TYPE_CHECKING, Type | |||||
import numpy as np | |||||
from ..data.data_loader import DataLoader | |||||
from .binary_evaluator import BinaryEvaluator | |||||
from ..models.model import Model | |||||
if TYPE_CHECKING: | |||||
from ..configs.base_config import BaseConfig | |||||
from ..data.content_loaders.content_loader import ContentLoader | |||||
class TagBinaryEvaluator(BinaryEvaluator): | |||||
def __init__(self, model: Model, data_loader: DataLoader, conf: 'BaseConfig', tag: str, cl_type: Type[ContentLoader]): | |||||
super().__init__(model, data_loader, conf) | |||||
self._tag = tag | |||||
self._content_loader = data_loader.get_content_loader_of_interest(cl_type) | |||||
@property | |||||
def _prob_key(self) -> str: | |||||
return f'{self._tag}_positive_class_probability' | |||||
@property | |||||
def _gt_key(self) -> str: | |||||
return self._tag | |||||
@property | |||||
def _loss_key(self) -> str: | |||||
return f'{self._tag}_loss' | |||||
def _get_current_batch_gt(self) -> np.ndarray: | |||||
return self._content_loader.get_placeholder_name_to_fill_function_dict()[self._gt_key]( | |||||
self.data_loader.get_current_batch_sample_indices(), None | |||||
) |
""" The evaluator """ | |||||
from abc import abstractmethod, ABC | |||||
from typing import List, TYPE_CHECKING, Dict | |||||
import torch | |||||
from tqdm import tqdm | |||||
from ..data.dataflow import DataFlow | |||||
if TYPE_CHECKING: | |||||
from ..models.model import Model | |||||
from ..data.data_loader import DataLoader | |||||
from ..configs.base_config import BaseConfig | |||||
class Evaluator(ABC): | |||||
""" The evaluator """ | |||||
def __init__(self, model: 'Model', data_loader: 'DataLoader', conf: 'BaseConfig'): | |||||
""" Model is a predictor, data_loader is a subclass of type data_loading.NormalDataLoader which contains information about the samples and how to | |||||
iterate over them. """ | |||||
self.model = model | |||||
self.data_loader = data_loader | |||||
self.conf = conf | |||||
self.dataflow = DataFlow[Dict[str, torch.Tensor]](model, data_loader, conf.device) | |||||
def evaluate(self, max_iters: int = None): | |||||
""" CAUTION: Resets the data_loader, | |||||
Iterates over the samples (as much as and how data_loader specifies), | |||||
calculates the overall evaluation requirements and prints them. | |||||
Title specified the title of the string for printing evaluation metrics. | |||||
classes_to_use specifies the labels of the samples to do the function on them, | |||||
None means all""" | |||||
# setting in dataflow | |||||
max_iters = float('inf') if max_iters is None else max_iters | |||||
with torch.no_grad(): | |||||
# Running the model in evaluation mode | |||||
self.model.eval() | |||||
# initiating variables for running evaluations | |||||
self.reset() | |||||
with self.dataflow, tqdm(enumerate(self.dataflow.iterate())) as pbar: | |||||
for iters, model_output in pbar: | |||||
self.update_summaries_based_on_model_output(model_output) | |||||
del model_output | |||||
pbar.set_description(self.get_evaluation_metrics()) | |||||
if iters + 1 >= max_iters: | |||||
break | |||||
@abstractmethod | |||||
def update_summaries_based_on_model_output( | |||||
self, model_output: Dict[str, torch.Tensor]) -> None: | |||||
""" Updates the inner variables responsible of keeping a summary over data | |||||
based on the new outputs of the model, so when needed evaluation metrics | |||||
can be calculated based on these summaries. Mostly used in train and validation phase | |||||
or eval phase in which we only need evaluation metrics.""" | |||||
@abstractmethod | |||||
def reset(self): | |||||
""" Resets the held information for a new evaluation round!""" | |||||
@abstractmethod | |||||
def get_titles_of_evaluation_metrics(self) -> List[str]: | |||||
""" Returns a list of strings containing the titles of the evaluation metrics""" | |||||
@abstractmethod | |||||
def get_values_of_evaluation_metrics(self) -> List[str]: | |||||
""" Returns a list of values for the calculated evaluation metrics, | |||||
converted to string with the desired format!""" | |||||
def get_evaluation_metrics(self) -> None: | |||||
return ', '.join(["%s: %s" % (eval_title, eval_value) for (eval_title, eval_value) in | |||||
zip( | |||||
self.get_titles_of_evaluation_metrics(), | |||||
self.get_values_of_evaluation_metrics())]) | |||||
def print_evaluation_metrics(self, title: str) -> None: | |||||
""" Prints the values of the evaluation metrics""" | |||||
print("%s: %s" % (title, self.get_evaluation_metrics()), flush=True) |
from typing import List, TYPE_CHECKING, OrderedDict as OrdDict, Dict | |||||
from collections import OrderedDict | |||||
import torch | |||||
from ..data.data_loader import DataLoader | |||||
from ..models.model import Model | |||||
from ..model_evaluation.evaluator import Evaluator | |||||
if TYPE_CHECKING: | |||||
from ..configs.base_config import BaseConfig | |||||
class LossEvaluator(Evaluator): | |||||
def __init__(self, model: Model, data_loader: DataLoader, conf: 'BaseConfig'): | |||||
super(LossEvaluator, self).__init__(model, data_loader, conf) | |||||
self.phase = conf.phase_type | |||||
# for summarized information | |||||
self.avg_loss: float = 0 | |||||
self._avg_other_losses: OrdDict[str, float] = OrderedDict() | |||||
self.n_received_samples: int = 0 | |||||
self._n_received_samples_other_losses: OrdDict[str, int] = OrderedDict() | |||||
def reset(self): | |||||
self.avg_loss = 0 | |||||
self._avg_other_losses = OrderedDict() | |||||
self.n_received_samples = 0 | |||||
self._n_received_samples_other_losses = OrderedDict() | |||||
def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None: | |||||
n_batch = self.data_loader.get_current_batch_size() | |||||
new_n = self.n_received_samples + n_batch | |||||
self.avg_loss = self.avg_loss * (float(self.n_received_samples) / new_n) + \ | |||||
model_output.get('loss', 0.0) * (float(n_batch) / new_n) | |||||
for kw in model_output.keys(): | |||||
if 'loss' in kw and kw != 'loss': | |||||
old_avg = self._avg_other_losses.get(kw, 0.0) | |||||
old_n = self._n_received_samples_other_losses.get(kw, 0) | |||||
new_n = old_n + n_batch | |||||
self._avg_other_losses[kw] = \ | |||||
old_avg * (float(old_n) / new_n) + \ | |||||
model_output[kw].detach().cpu() * (float(n_batch) / new_n) | |||||
self._n_received_samples_other_losses[kw] = new_n | |||||
def get_titles_of_evaluation_metrics(self) -> List[str]: | |||||
return ['Loss'] + list(self._avg_other_losses.keys()) | |||||
def get_values_of_evaluation_metrics(self) -> List[str]: | |||||
return ['%.4f' % self.avg_loss] + ['%.4f' % loss for loss in self._avg_other_losses.values()] |
from typing import Any, Callable, List, Dict, TYPE_CHECKING, Optional | |||||
from functools import partial | |||||
import numpy as np | |||||
import torch | |||||
from sklearn.metrics import confusion_matrix | |||||
from ..data.data_loader import DataLoader | |||||
from ..models.model import Model | |||||
from ..model_evaluation.evaluator import Evaluator | |||||
if TYPE_CHECKING: | |||||
from ..configs.base_config import BaseConfig | |||||
def _precision_score(cm: np.ndarray): | |||||
numinator = np.diag(cm) | |||||
denominator = cm.sum(axis=0) | |||||
return (numinator[denominator != 0] / denominator[denominator != 0]).mean() | |||||
def _recall_score(cm: np.ndarray): | |||||
numinator = np.diag(cm) | |||||
denominator = cm.sum(axis=1) | |||||
return (numinator[denominator != 0] / denominator[denominator != 0]).mean() | |||||
def _accuracy_score(cm: np.ndarray): | |||||
return np.diag(cm).sum() / cm.sum() | |||||
class MulticlassEvaluator(Evaluator): | |||||
@classmethod | |||||
def standard_creator(cls, class_probability_key: str = 'categorical_probability', | |||||
include_top5: bool = False) -> Callable[[Model, DataLoader, 'BaseConfig'], 'MulticlassEvaluator']: | |||||
return partial(MulticlassEvaluator, class_probability_key=class_probability_key, include_top5=include_top5) | |||||
def __init__(self, model: Model, | |||||
data_loader: DataLoader, | |||||
conf: 'BaseConfig', | |||||
class_probability_key: str = 'categorical_probability', | |||||
include_top5: bool = False): | |||||
super().__init__(model, data_loader, conf) | |||||
# for summarized information | |||||
self._avg_loss: float = 0 | |||||
self._n_received_samples: int = 0 | |||||
self._trues = {} | |||||
self._falses = {} | |||||
self._top5_trues: int = 0 | |||||
self._top5_falses: int = 0 | |||||
self._num_iters = 0 | |||||
self._predictions: Dict[int, np.ndarray] = {} | |||||
self._prediction_weights: Dict[int, np.ndarray] = {} | |||||
self._sample_details: Dict[str, Dict[str, Any]] = {} | |||||
self._cm: Optional[np.ndarray] = None | |||||
self._c = None | |||||
self._class_probability_key = class_probability_key | |||||
self._include_top5 = include_top5 | |||||
def reset(self): | |||||
# for summarized information | |||||
self._avg_loss: float = 0 | |||||
self._n_received_samples: int = 0 | |||||
self._trues = {} | |||||
self._falses = {} | |||||
self._top5_trues: int = 0 | |||||
self._top5_falses: int = 0 | |||||
self._num_iters = 0 | |||||
self._cm: Optional[np.ndarray] = None | |||||
def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None: | |||||
assert self._class_probability_key in model_output, \ | |||||
f"model's output dictionary must contain {self._class_probability_key}" | |||||
prediction: np.ndarray = model_output[self._class_probability_key]\ | |||||
.cpu().numpy() # B C | |||||
c = prediction.shape[1] | |||||
ground_truth = self.data_loader.get_current_batch_samples_labels() # B | |||||
if self._include_top5: | |||||
top5_p = prediction.argsort(axis=1)[:, -5:] | |||||
trues = (top5_p == ground_truth[:, None]).any(axis=1).sum() | |||||
self._top5_trues += trues | |||||
self._top5_falses += len(ground_truth) - trues | |||||
prediction = prediction.argmax(axis=1) # B | |||||
for i in range(c): | |||||
nt = int(np.sum(np.logical_and(ground_truth == i, prediction == i))) | |||||
nf = int(np.sum(np.logical_and(ground_truth == i, prediction != i))) | |||||
if i not in self._trues: | |||||
self._trues[i] = 0 | |||||
self._falses[i] = 0 | |||||
self._trues[i] += nt | |||||
self._falses[i] += nf | |||||
self._cm = (0 if self._cm is None else self._cm)\ | |||||
+ confusion_matrix(ground_truth, prediction, | |||||
labels=np.arange(c)).astype(float) | |||||
self._num_iters += 1 | |||||
new_n = self._n_received_samples + len(ground_truth) | |||||
self._avg_loss = self._avg_loss * (float(self._n_received_samples) / new_n) + \ | |||||
model_output.get('loss', 0.0) * (float(len(ground_truth)) / new_n) | |||||
self._n_received_samples = new_n | |||||
def get_titles_of_evaluation_metrics(self) -> List[str]: | |||||
return ['Loss', 'Accuracy', 'Precision', 'Recall'] + \ | |||||
(['Top5Acc'] if self._include_top5 else []) | |||||
def get_values_of_evaluation_metrics(self) -> List[str]: | |||||
accuracy = _accuracy_score(self._cm) * 100 | |||||
precision = _precision_score(self._cm) * 100 | |||||
recall = _recall_score(self._cm) * 100 | |||||
top5_acc = self._top5_trues * 100.0 / (self._top5_trues + self._top5_falses)\ | |||||
if self._include_top5 else None | |||||
return [f'{self._avg_loss:.4e}', f'{accuracy:8.4f}', f'{precision:8.4f}', f'{recall:8.4f}'] + \ | |||||
([f'{top5_acc:8.4f}'] if self._include_top5 else []) |
from typing import List, TYPE_CHECKING, Dict, OrderedDict, Type | |||||
from collections import OrderedDict as ODict | |||||
import torch | |||||
from ..data.data_loader import DataLoader | |||||
from ..models.model import Model | |||||
from ..model_evaluation.evaluator import Evaluator | |||||
if TYPE_CHECKING: | |||||
from ..configs.base_config import BaseConfig | |||||
# If you want your model to be analyzed from different points of view by different evaluators, you can use this holder! | |||||
class MultiEvaluatorEvaluator(Evaluator): | |||||
def __init__( | |||||
self, model: Model, data_loader: DataLoader, conf: 'BaseConfig', | |||||
evaluators_cls_by_name: OrderedDict[str, Type[Evaluator]]): | |||||
""" | |||||
evaluators_cls_by_name: The key is an arbitrary name to call the instance, | |||||
cls is the constructor | |||||
""" | |||||
super(MultiEvaluatorEvaluator, self).__init__(model, data_loader, conf) | |||||
# Making all evaluators! | |||||
self._evaluators_by_name: OrderedDict[str, Evaluator] = ODict() | |||||
for eval_name, eval_cls in evaluators_cls_by_name.items(): | |||||
self._evaluators_by_name[eval_name] = eval_cls(model, data_loader, conf) | |||||
def reset(self): | |||||
for evaluator in self._evaluators_by_name.values(): | |||||
evaluator.reset() | |||||
def get_titles_of_evaluation_metrics(self) -> List[str]: | |||||
titles = [] | |||||
for eval_name, evaluator in self._evaluators_by_name.items(): | |||||
titles += [eval_name + '_' + t for t in evaluator.get_titles_of_evaluation_metrics()] | |||||
return titles | |||||
def get_values_of_evaluation_metrics(self) -> List[str]: | |||||
metrics = [] | |||||
for evaluator in self._evaluators_by_name.values(): | |||||
metrics += evaluator.get_values_of_evaluation_metrics() | |||||
return metrics | |||||
def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None: | |||||
for evaluator in self._evaluators_by_name.values(): | |||||
evaluator.update_summaries_based_on_model_output(model_output) | |||||
@staticmethod | |||||
def create_standard_multi_evaluator_evaluator_maker(evaluators_cls_by_name: OrderedDict[str, Type[Evaluator]]): | |||||
""" For making a constructor, consistent with the known Evaluator""" | |||||
def typical_maker(model: Model, data_loader: DataLoader, conf: 'BaseConfig') -> MultiEvaluatorEvaluator: | |||||
return MultiEvaluatorEvaluator(model, data_loader, conf, evaluators_cls_by_name) | |||||
return typical_maker | |||||
def print_evaluation_metrics(self, title: str) -> None: | |||||
""" For more readable printing! """ | |||||
print(f'{title}:') | |||||
for e_name, e_obj in self._evaluators_by_name.items(): | |||||
e_obj.print_evaluation_metrics('\t' + e_name) |
import typing | |||||
from typing import Dict, Iterable | |||||
from collections import OrderedDict | |||||
import torch | |||||
from torch.nn import functional as F | |||||
import torchvision | |||||
from ..lap_inception import LAPInception | |||||
class CelebALAPInception(LAPInception): | |||||
def __init__(self, tag:str, aux_weight: float, pool_factory, adaptive_pool_factory): | |||||
super().__init__(aux_weight, n_classes=1, pool_factory=pool_factory, adaptive_pool_factory=adaptive_pool_factory) | |||||
self._tag = tag | |||||
@property | |||||
def additional_kwargs(self) -> typing.OrderedDict[str, bool]: | |||||
r""" Returns a dictionary from additional `kwargs` names to their optionality """ | |||||
return OrderedDict({ | |||||
f'{self._tag}': True, | |||||
}) | |||||
def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
# x: B 3 224 224 | |||||
if self.training: | |||||
out, aux = torchvision.models.Inception3.forward(self, x) # B 1 | |||||
out, aux = out.flatten(), aux.flatten() # B | |||||
else: | |||||
out = torchvision.models.Inception3.forward(self, x).flatten() # B | |||||
aux = None | |||||
output = dict() | |||||
output['positive_class_probability'] = out | |||||
if f'{self._tag}' not in gts: | |||||
return output | |||||
gt = gts[f'{self._tag}'] | |||||
r""" Class weighted loss """ | |||||
loss = torch.mean(torch.stack(tuple( | |||||
F.binary_cross_entropy(out[gt == i], gt[gt == i]) for i in gt.unique() | |||||
))) | |||||
output['loss'] = loss | |||||
return output | |||||
""" INTERPRETATION """ | |||||
@property | |||||
def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]: | |||||
""" | |||||
:return: input module for interpretation | |||||
""" | |||||
return ['x'] | |||||
from collections import OrderedDict | |||||
from typing import Dict, List | |||||
import typing | |||||
import torch | |||||
from ...modules import LAP, AdaptiveLAP | |||||
from ..lap_resnet import LapBasicBlock, LapResNet, PoolFactory | |||||
def pool_factory(channels, sigmoid_scale=1.0): | |||||
return LAP(channels, | |||||
sigmoid_scale=sigmoid_scale, | |||||
n_attention_heads=3, | |||||
discriminative_attention=True) | |||||
def adaptive_pool_factory(channels, sigmoid_scale=1.0): | |||||
return AdaptiveLAP(channels, | |||||
sigmoid_scale=sigmoid_scale, | |||||
n_attention_heads=3, | |||||
discriminative_attention=True) | |||||
class CelebALAPResNet18(LapResNet): | |||||
def __init__(self, tag: str, | |||||
pool_factory: PoolFactory = pool_factory, | |||||
adaptive_factory: PoolFactory = adaptive_pool_factory, | |||||
sigmoid_scale: float = 1.0): | |||||
super().__init__(LapBasicBlock, [2, 2, 2, 2], | |||||
pool_factory=pool_factory, | |||||
adaptive_factory=adaptive_factory, | |||||
sigmoid_scale=sigmoid_scale, | |||||
binary=True) | |||||
self._tag = tag | |||||
@property | |||||
def additional_kwargs(self) -> typing.OrderedDict[str, bool]: | |||||
r""" Returns a dictionary from additional `kwargs` names to their optionality """ | |||||
return OrderedDict({ | |||||
f'{self._tag}': True, | |||||
}) | |||||
def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
y = gts[f'{self._tag}'] | |||||
return super().forward(x, y) | |||||
@property | |||||
def attention_layers(self) -> Dict[str, List[LAP]]: | |||||
res = super().attention_layers | |||||
res.pop('4_overall') | |||||
return res |
from collections import OrderedDict | |||||
from typing import Dict | |||||
import typing | |||||
import torch | |||||
from ..tv_resnet import BasicBlock, ResNet | |||||
class CelebAORGResNet18(ResNet): | |||||
def __init__(self, tag: str): | |||||
super().__init__(BasicBlock, [2, 2, 2, 2], binary=True) | |||||
self._tag = tag | |||||
@property | |||||
def additional_kwargs(self) -> typing.OrderedDict[str, bool]: | |||||
r""" Returns a dictionary from additional `kwargs` names to their optionality """ | |||||
return OrderedDict({ | |||||
f'{self._tag}': True, | |||||
}) | |||||
def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
y = gts[f'{self._tag}'] | |||||
return super().forward(x, y) |
import typing | |||||
from typing import Dict, List, Iterable | |||||
from collections import OrderedDict | |||||
import torch | |||||
from torch import nn | |||||
from torch.nn import functional as F | |||||
import torchvision | |||||
from ..tv_inception import Inception3 | |||||
class CelebAORGInception(Inception3): | |||||
def __init__(self, tag: str, aux_weight: float): | |||||
super().__init__(aux_weight, n_classes=1) | |||||
self._tag = tag | |||||
@property | |||||
def additional_kwargs(self) -> typing.OrderedDict[str, bool]: | |||||
r""" Returns a dictionary from additional `kwargs` names to their optionality """ | |||||
return OrderedDict({ | |||||
f'{self._tag}': True, | |||||
}) | |||||
def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
if self.training: | |||||
out, aux = torchvision.models.Inception3.forward(self, x) # B 1 | |||||
out, aux = out.flatten(), aux.flatten() # B | |||||
else: | |||||
out = torchvision.models.Inception3.forward(self, x).flatten() # B | |||||
aux = None | |||||
output = dict() | |||||
output['positive_class_probability'] = out | |||||
if f'{self._tag}' not in gts: | |||||
return output | |||||
gt = gts[f'{self._tag}'] | |||||
r""" Class weighted loss """ | |||||
loss = torch.mean(torch.stack(tuple( | |||||
F.binary_cross_entropy(out[gt == i], gt[gt == i]) for i in gt.unique() | |||||
))) | |||||
output['loss'] = loss | |||||
return output | |||||
@property | |||||
def target_conv_layers(self) -> List[nn.Module]: | |||||
return [ | |||||
self.Mixed_7c.branch1x1, | |||||
self.Mixed_7c.branch3x3_2a, self.Mixed_7c.branch3x3_2b, | |||||
self.Mixed_7c.branch3x3dbl_3a, self.Mixed_7c.branch3x3dbl_3b, | |||||
self.Mixed_7c.branch_pool | |||||
] | |||||
@property | |||||
def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]: | |||||
return ['x'] | |||||
def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor: | |||||
p = self.forward(*inputs, **kwargs)['positive_class_probability'] | |||||
return torch.stack([1 - p, p], dim=1) |
from typing import Dict, List | |||||
import torch | |||||
from torchvision import transforms | |||||
from ...modules import LAP, AdaptiveLAP, ScaledSigmoid | |||||
from ..lap_resnet import LapBottleneck, LapResNet, PoolFactory | |||||
def pool_factory(channels, sigmoid_scale=1.0): | |||||
return LAP(channels, | |||||
hidden_channels=[1000], | |||||
sigmoid_scale=sigmoid_scale, | |||||
hidden_activation=ScaledSigmoid.get_factory(sigmoid_scale)) | |||||
class ImagenetLAPResNet50(LapResNet): | |||||
def __init__(self, pool_factory: PoolFactory = pool_factory, sigmoid_scale: float = 1.0): | |||||
super().__init__(LapBottleneck, [3, 4, 6, 3], | |||||
pool_factory=pool_factory, | |||||
sigmoid_scale=sigmoid_scale, | |||||
lap_positions=[4]) | |||||
self.normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
x = self.normalize(x) | |||||
return super().forward(x, y) | |||||
@property | |||||
def attention_layers(self) -> Dict[str, List[LAP]]: | |||||
return { | |||||
'2_layer4': super().attention_layers['2_layer4'], | |||||
} |
from typing import List, Optional, Callable, Any, Iterable, Dict | |||||
import torch | |||||
from torch import Tensor, nn | |||||
import torchvision | |||||
from ..modules.lap import LAP | |||||
from ..interpreting.attention_interpreter import AttentionInterpretableModel | |||||
from ..interpreting.interpretable import CamInterpretableModel | |||||
from ..interpreting.relcam.relprop import RPProvider, RelProp | |||||
from ..interpreting.relcam import modules as M | |||||
class InceptionB(nn.Module): | |||||
def __init__( | |||||
self, | |||||
in_channels: int, | |||||
pool_factory, | |||||
conv_block: Optional[Callable[..., nn.Module]] = None | |||||
) -> None: | |||||
super(InceptionB, self).__init__() | |||||
if conv_block is None: | |||||
conv_block = BasicConv2d | |||||
#self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) | |||||
self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=1, padding=1) | |||||
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) | |||||
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) | |||||
#self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2) | |||||
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=1, padding=1) | |||||
self.pool = pool_factory(384 + 96 + in_channels) | |||||
self.cat = M.Cat() | |||||
def _forward(self, x: Tensor) -> List[Tensor]: | |||||
branch3x3 = self.branch3x3(x) | |||||
branch3x3dbl = self.branch3x3dbl_1(x) | |||||
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |||||
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) | |||||
outputs = [branch3x3, branch3x3dbl, x] | |||||
return outputs | |||||
def forward(self, x: Tensor) -> Tensor: | |||||
outputs = self._forward(x) | |||||
return self.pool(self.cat(outputs, 1)) | |||||
@RPProvider.register(InceptionB) | |||||
class InceptionBRelProp(RelProp[InceptionB]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
R = RPProvider.get(self.module.pool)(R, alpha=alpha) | |||||
branch3x3, branch3x3dbl, x1 = RPProvider.get(self.module.cat)(R, alpha=alpha) | |||||
branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_3)(branch3x3dbl, alpha=alpha) | |||||
branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha) | |||||
x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha) | |||||
x3 = RPProvider.get(self.module.branch3x3)(branch3x3, alpha=alpha) | |||||
return x1 + x2 + x3 | |||||
class InceptionD(nn.Module): | |||||
def __init__( | |||||
self, | |||||
in_channels: int, | |||||
pool_factory, | |||||
conv_block: Optional[Callable[..., nn.Module]] = None | |||||
) -> None: | |||||
super(InceptionD, self).__init__() | |||||
if conv_block is None: | |||||
conv_block = BasicConv2d | |||||
self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) | |||||
#self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) | |||||
self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=1, padding=1) | |||||
#self.branch3x3_2_stride = get_pooler(pooler_cls, 320, 2, pooler_hidden_layers) | |||||
self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1) | |||||
self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3)) | |||||
self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0)) | |||||
#self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2) | |||||
self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=1, padding=1) | |||||
#self.branch7x7x3_4_stride = get_pooler(pooler_cls, 192, 2, pooler_hidden_layers) | |||||
self.pool = pool_factory(320 + 192 + in_channels) | |||||
self.cat = M.Cat() | |||||
def _forward(self, x: Tensor) -> List[Tensor]: | |||||
branch3x3 = self.branch3x3_1(x) | |||||
branch3x3 = self.branch3x3_2(branch3x3) | |||||
branch7x7x3 = self.branch7x7x3_1(x) | |||||
branch7x7x3 = self.branch7x7x3_2(branch7x7x3) | |||||
branch7x7x3 = self.branch7x7x3_3(branch7x7x3) | |||||
branch7x7x3 = self.branch7x7x3_4(branch7x7x3) | |||||
outputs = [branch3x3, branch7x7x3, x] | |||||
return outputs | |||||
def forward(self, x: Tensor) -> Tensor: | |||||
outputs = self._forward(x) | |||||
return self.pool(self.cat(outputs, 1)) | |||||
@RPProvider.register(InceptionD) | |||||
class InceptionDRelProp(RelProp[InceptionD]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
R = RPProvider.get(self.module.pool)(R, alpha=alpha) | |||||
branch3x3, branch7x7x3, x1 = RPProvider.get(self.module.cat)(R, alpha=alpha) | |||||
branch7x7x3 = RPProvider.get(self.module.branch7x7x3_4)(branch7x7x3, alpha=alpha) | |||||
branch7x7x3 = RPProvider.get(self.module.branch7x7x3_3)(branch7x7x3, alpha=alpha) | |||||
branch7x7x3 = RPProvider.get(self.module.branch7x7x3_2)(branch7x7x3, alpha=alpha) | |||||
x2 = RPProvider.get(self.module.branch7x7x3_1)(branch7x7x3, alpha=alpha) | |||||
branch3x3 = RPProvider.get(self.module.branch3x3_2)(branch3x3, alpha=alpha) | |||||
x3 = RPProvider.get(self.module.branch3x3_1)(branch3x3, alpha=alpha) | |||||
return x1 + x2 + x3 | |||||
def inception_b_maker(pool_factory): | |||||
return ( | |||||
lambda in_channels, conv_block=None: | |||||
InceptionB(in_channels, pool_factory, conv_block)) | |||||
def inception_d_maker(pool_factory): | |||||
return ( | |||||
lambda in_channels, conv_block=None: | |||||
InceptionD(in_channels, pool_factory, conv_block)) | |||||
class InceptionA(nn.Module): | |||||
def __init__( | |||||
self, | |||||
in_channels: int, | |||||
pool_features: int, | |||||
conv_block: Optional[Callable[..., nn.Module]] = None, | |||||
) -> None: | |||||
super(InceptionA, self).__init__() | |||||
if conv_block is None: | |||||
conv_block = BasicConv2d | |||||
self.branch1x1 = conv_block(in_channels, 64, kernel_size=1) | |||||
self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1) | |||||
self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2) | |||||
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) | |||||
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) | |||||
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1) | |||||
self.avg_pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1) | |||||
self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1) | |||||
self.cat = M.Cat() | |||||
def _forward(self, x: Tensor) -> List[Tensor]: | |||||
branch1x1 = self.branch1x1(x) | |||||
branch5x5 = self.branch5x5_1(x) | |||||
branch5x5 = self.branch5x5_2(branch5x5) | |||||
branch3x3dbl = self.branch3x3dbl_1(x) | |||||
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |||||
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) | |||||
branch_pool = self.avg_pool(x) | |||||
branch_pool = self.branch_pool(branch_pool) | |||||
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] | |||||
return outputs | |||||
def forward(self, x: Tensor) -> Tensor: | |||||
outputs = self._forward(x) | |||||
return self.cat(outputs, 1) | |||||
@RPProvider.register(InceptionA) | |||||
class InceptionARelProp(RelProp[InceptionA]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
branch1x1, branch5x5, branch3x3dbl, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha) | |||||
branch_pool = RPProvider.get(self.module.branch_pool)(branch_pool, alpha=alpha) | |||||
x1 = RPProvider.get(self.module.avg_pool)(branch_pool, alpha=alpha) | |||||
branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_3)(branch3x3dbl, alpha=alpha) | |||||
branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha) | |||||
x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha) | |||||
branch5x5 = RPProvider.get(self.module.branch5x5_2)(branch5x5, alpha=alpha) | |||||
x3 = RPProvider.get(self.module.branch5x5_1)(branch5x5, alpha=alpha) | |||||
x4 = RPProvider.get(self.module.branch1x1)(branch1x1, alpha=alpha) | |||||
return x1 + x2 + x3 + x4 | |||||
class InceptionC(nn.Module): | |||||
def __init__( | |||||
self, | |||||
in_channels: int, | |||||
channels_7x7: int, | |||||
conv_block: Optional[Callable[..., nn.Module]] = None | |||||
) -> None: | |||||
super(InceptionC, self).__init__() | |||||
if conv_block is None: | |||||
conv_block = BasicConv2d | |||||
self.branch1x1 = conv_block(in_channels, 192, kernel_size=1) | |||||
c7 = channels_7x7 | |||||
self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1) | |||||
self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) | |||||
self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0)) | |||||
self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1) | |||||
self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) | |||||
self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) | |||||
self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) | |||||
self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3)) | |||||
self.avg_pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1) | |||||
self.branch_pool = conv_block(in_channels, 192, kernel_size=1) | |||||
self.cat = M.Cat() | |||||
def _forward(self, x: Tensor) -> List[Tensor]: | |||||
branch1x1 = self.branch1x1(x) | |||||
branch7x7 = self.branch7x7_1(x) | |||||
branch7x7 = self.branch7x7_2(branch7x7) | |||||
branch7x7 = self.branch7x7_3(branch7x7) | |||||
branch7x7dbl = self.branch7x7dbl_1(x) | |||||
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) | |||||
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) | |||||
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) | |||||
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) | |||||
branch_pool = self.avg_pool(x) | |||||
branch_pool = self.branch_pool(branch_pool) | |||||
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] | |||||
return outputs | |||||
def forward(self, x: Tensor) -> Tensor: | |||||
outputs = self._forward(x) | |||||
return self.cat(outputs, 1) | |||||
@RPProvider.register(InceptionC) | |||||
class InceptionCRelProp(RelProp[InceptionC]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
branch1x1, branch7x7, branch7x7dbl, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha) | |||||
branch_pool = RPProvider.get(self.module.branch_pool)(branch_pool, alpha=alpha) | |||||
x1 = RPProvider.get(self.module.avg_pool)(branch_pool, alpha=alpha) | |||||
branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_5)(branch7x7dbl, alpha=alpha) | |||||
branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_4)(branch7x7dbl, alpha=alpha) | |||||
branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_3)(branch7x7dbl, alpha=alpha) | |||||
branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_2)(branch7x7dbl, alpha=alpha) | |||||
x2 = RPProvider.get(self.module.branch7x7dbl_1)(branch7x7dbl, alpha=alpha) | |||||
branch7x7 = RPProvider.get(self.module.branch7x7_3)(branch7x7, alpha=alpha) | |||||
branch7x7 = RPProvider.get(self.module.branch7x7_2)(branch7x7, alpha=alpha) | |||||
x3 = RPProvider.get(self.module.branch7x7_1)(branch7x7, alpha=alpha) | |||||
x4 = RPProvider.get(self.module.branch1x1)(branch1x1, alpha=alpha) | |||||
return x1 + x2 + x3 + x4 | |||||
class InceptionE(nn.Module): | |||||
def __init__( | |||||
self, | |||||
in_channels: int, | |||||
conv_block: Optional[Callable[..., nn.Module]] = None | |||||
) -> None: | |||||
super(InceptionE, self).__init__() | |||||
if conv_block is None: | |||||
conv_block = BasicConv2d | |||||
self.branch1x1 = conv_block(in_channels, 320, kernel_size=1) | |||||
self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1) | |||||
self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) | |||||
self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) | |||||
self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1) | |||||
self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1) | |||||
self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) | |||||
self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) | |||||
self.avg_pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1) | |||||
self.branch_pool = conv_block(in_channels, 192, kernel_size=1) | |||||
self.cat1 = M.Cat() | |||||
self.cat2 = M.Cat() | |||||
self.cat3 = M.Cat() | |||||
def _forward(self, x: Tensor) -> List[Tensor]: | |||||
branch1x1 = self.branch1x1(x) | |||||
branch3x3 = self.branch3x3_1(x) | |||||
branch3x3 = [ | |||||
self.branch3x3_2a(branch3x3), | |||||
self.branch3x3_2b(branch3x3), | |||||
] | |||||
branch3x3 = self.cat1(branch3x3, 1) | |||||
branch3x3dbl = self.branch3x3dbl_1(x) | |||||
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |||||
branch3x3dbl = [ | |||||
self.branch3x3dbl_3a(branch3x3dbl), | |||||
self.branch3x3dbl_3b(branch3x3dbl), | |||||
] | |||||
branch3x3dbl = self.cat2(branch3x3dbl, 1) | |||||
branch_pool = self.avg_pool(x) | |||||
branch_pool = self.branch_pool(branch_pool) | |||||
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] | |||||
return outputs | |||||
def forward(self, x: Tensor) -> Tensor: | |||||
outputs = self._forward(x) | |||||
return self.cat3(outputs, 1) | |||||
@RPProvider.register(InceptionE) | |||||
class InceptionERelProp(RelProp[InceptionE]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
branch1x1, branch3x3, branch3x3dbl, branch_pool = RPProvider.get(self.module.cat3)(R, alpha=alpha) | |||||
branch_pool = RPProvider.get(self.module.branch_pool)(branch_pool, alpha=alpha) | |||||
x1 = RPProvider.get(self.module.avg_pool)(branch_pool, alpha=alpha) | |||||
branch3x3dbl_3a, branch3x3dbl_3b = RPProvider.get(self.module.cat2)(branch3x3dbl, alpha=alpha) | |||||
branch3x3dbl_1 = RPProvider.get(self.module.branch3x3dbl_3a)(branch3x3dbl_3a, alpha=alpha) | |||||
branch3x3dbl_2 = RPProvider.get(self.module.branch3x3dbl_3b)(branch3x3dbl_3b, alpha=alpha) | |||||
branch3x3dbl = branch3x3dbl_1 + branch3x3dbl_2 | |||||
branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha) | |||||
x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha) | |||||
branch3x3_2a, branch3x3_2b = RPProvider.get(self.module.cat1)(branch3x3, alpha=alpha) | |||||
branch3x3_1 = RPProvider.get(self.module.branch3x3_2a)(branch3x3_2a, alpha=alpha) | |||||
branch3x3_2 = RPProvider.get(self.module.branch3x3_2b)(branch3x3_2b, alpha=alpha) | |||||
branch3x3 = branch3x3_1 + branch3x3_2 | |||||
x3 = RPProvider.get(self.module.branch3x3_1)(branch3x3, alpha=alpha) | |||||
x4 = RPProvider.get(self.module.branch1x1)(branch1x1, alpha=alpha) | |||||
return x1 + x2 + x3 + x4 | |||||
class InceptionAux(nn.Module): | |||||
def __init__( | |||||
self, | |||||
in_channels: int, | |||||
num_classes: int, | |||||
conv_block: Optional[Callable[..., nn.Module]] = None | |||||
) -> None: | |||||
super(InceptionAux, self).__init__() | |||||
if conv_block is None: | |||||
conv_block = BasicConv2d | |||||
self.avgpool1 = nn.AvgPool2d(kernel_size=5, stride=3) | |||||
self.conv0 = conv_block(in_channels, 128, kernel_size=1) | |||||
self.conv1 = conv_block(128, 768, kernel_size=5, padding=2) | |||||
self.conv1.stddev = 0.01 # type: ignore[assignment] | |||||
self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1)) | |||||
self.flatten = nn.Flatten() | |||||
self.fc = nn.Linear(768, num_classes) | |||||
self.fc.stddev = 0.001 # type: ignore[assignment] | |||||
def forward(self, x: Tensor) -> Tensor: | |||||
# N x 768 x 17 x 17 | |||||
x = self.avgpool1(x) | |||||
# N x 768 x 5 x 5 | |||||
x = self.conv0(x) | |||||
# N x 128 x 5 x 5 | |||||
x = self.conv1(x) | |||||
# N x 768 x 1 x 1 | |||||
# Adaptive average pooling | |||||
x = self.avgpool2(x) | |||||
# N x 768 x 1 x 1 | |||||
x = self.flatten(x) | |||||
# N x 768 | |||||
x = self.fc(x) | |||||
# N x 1000 | |||||
return x | |||||
@RPProvider.register(InceptionAux) | |||||
class InceptionAuxRelProp(RelProp[InceptionAux]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
R = RPProvider.get(self.module.fc)(R, alpha=alpha) | |||||
R = RPProvider.get(self.module.flatten)(R, alpha=alpha) | |||||
R = RPProvider.get(self.module.avgpool2)(R, alpha=alpha) | |||||
R = RPProvider.get(self.module.conv1)(R, alpha=alpha) | |||||
R = RPProvider.get(self.module.conv0)(R, alpha=alpha) | |||||
R = RPProvider.get(self.module.avgpool1)(R, alpha=alpha) | |||||
return R | |||||
class BasicConv2d(nn.Module): | |||||
def __init__( | |||||
self, | |||||
in_channels: int, | |||||
out_channels: int, | |||||
**kwargs: Any | |||||
) -> None: | |||||
super(BasicConv2d, self).__init__() | |||||
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) | |||||
self.bn = nn.BatchNorm2d(out_channels, eps=0.001) | |||||
self.act = torch.nn.ReLU() | |||||
def forward(self, x: Tensor) -> Tensor: | |||||
x = self.conv(x) | |||||
x = self.bn(x) | |||||
return self.act(x) | |||||
@RPProvider.register(BasicConv2d) | |||||
class BasicConv2dRelProp(RelProp[BasicConv2d]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
R = RPProvider.get(self.module.act)(R, alpha=alpha) | |||||
R = RPProvider.get(self.module.bn)(R, alpha=alpha) | |||||
R = RPProvider.get(self.module.conv)(R, alpha=alpha) | |||||
return R | |||||
class LAPInception(AttentionInterpretableModel, CamInterpretableModel, torchvision.models.Inception3): | |||||
def __init__(self, aux_weight: float, n_classes, pool_factory, adaptive_pool_factory): | |||||
torchvision.models.Inception3.__init__( | |||||
self, | |||||
transform_input=False, init_weights=False, | |||||
inception_blocks=[ | |||||
BasicConv2d, InceptionA, inception_b_maker(pool_factory), | |||||
InceptionC, inception_d_maker(pool_factory), InceptionE, InceptionAux | |||||
]) | |||||
self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2, padding=1) | |||||
self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3, padding=1) | |||||
self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) | |||||
self.maxpool1 = pool_factory(64) | |||||
self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1,) | |||||
self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3, padding=1) | |||||
self.maxpool2 = pool_factory(192) | |||||
if adaptive_pool_factory is not None: | |||||
self.avgpool = nn.Sequential( | |||||
adaptive_pool_factory(2048)) | |||||
self.fc = nn.Sequential( | |||||
nn.Linear(2048, 1), | |||||
nn.Sigmoid() | |||||
) | |||||
self.AuxLogits.fc = nn.Sequential( | |||||
nn.Linear(768, 1), | |||||
nn.Sigmoid() | |||||
) | |||||
self.aux_weight = aux_weight | |||||
@property | |||||
def target_conv_layers(self) -> List[nn.Module]: | |||||
return [ | |||||
self.Mixed_7c.branch1x1, | |||||
self.Mixed_7c.branch3x3_2a, self.Mixed_7c.branch3x3_2b, | |||||
self.Mixed_7c.branch3x3dbl_3a, self.Mixed_7c.branch3x3dbl_3b, | |||||
self.Mixed_7c.branch_pool | |||||
] | |||||
@property | |||||
def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]: | |||||
return ['x'] | |||||
def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor: | |||||
p = self.forward(*inputs, **kwargs)['positive_class_probability'] | |||||
return torch.stack([1 - p, p], dim=1) | |||||
@property | |||||
def attention_layers(self) -> Dict[str, List[LAP]]: | |||||
attention_groups = { | |||||
'0_layer2': [self.maxpool2], | |||||
'1_layer6': [self.Mixed_6a.pool], | |||||
'2_layer7': [self.Mixed_7a.pool], | |||||
'3_avgpool': [self.avgpool[0]], | |||||
'4_all': [ | |||||
self.maxpool2, | |||||
self.Mixed_6a.pool, | |||||
self.Mixed_7a.pool, | |||||
self.avgpool[0]], | |||||
} | |||||
return attention_groups | |||||
@RPProvider.register(LAPInception) | |||||
class Inception3Mo4RelProp(RelProp[LAPInception]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
if RPProvider.get(self.module.fc).Y.shape[1] == 1: | |||||
R = R[:, -1:] | |||||
R = RPProvider.get(self.module.fc)(R, alpha=alpha) # B 2048 | |||||
R = R.reshape_as(RPProvider.get(self.module.dropout).Y) # B 2048 1 1 | |||||
R = RPProvider.get(self.module.dropout)(R, alpha=alpha) # B 2048 1 1 | |||||
R = RPProvider.get(self.module.avgpool)(R, alpha=alpha) # B 2048 8 8 | |||||
R = RPProvider.get(self.module.Mixed_7c)(R, alpha=alpha) # B 2048 8 8 | |||||
R = RPProvider.get(self.module.Mixed_7b)(R, alpha=alpha) # B 1280 8 8 | |||||
R = RPProvider.get(self.module.Mixed_7a)(R, alpha=alpha) # B 768 17 17 | |||||
R = RPProvider.get(self.module.Mixed_6e)(R, alpha=alpha) # B 768 17 17 | |||||
R = RPProvider.get(self.module.Mixed_6d)(R, alpha=alpha) # B 768 17 17 | |||||
R = RPProvider.get(self.module.Mixed_6c)(R, alpha=alpha) # B 768 17 17 | |||||
R = RPProvider.get(self.module.Mixed_6b)(R, alpha=alpha) # B 768 17 17 | |||||
R = RPProvider.get(self.module.Mixed_6a)(R, alpha=alpha) # B 288 35 35 | |||||
R = RPProvider.get(self.module.Mixed_5d)(R, alpha=alpha) # B 288 35 35 | |||||
R = RPProvider.get(self.module.Mixed_5c)(R, alpha=alpha) # B 256 35 35 | |||||
R = RPProvider.get(self.module.Mixed_5b)(R, alpha=alpha) # B 192 35 35 | |||||
R = RPProvider.get(self.module.maxpool2)(R, alpha=alpha) # B 192 71 71 | |||||
R = RPProvider.get(self.module.Conv2d_4a_3x3)(R, alpha=alpha) # B 80 73 73 | |||||
R = RPProvider.get(self.module.Conv2d_3b_1x1)(R, alpha=alpha) # B 64 73 73 | |||||
R = RPProvider.get(self.module.maxpool1)(R, alpha=alpha) # B 64 147 147 | |||||
R = RPProvider.get(self.module.Conv2d_2b_3x3)(R, alpha=alpha) # B 32 147 147 | |||||
R = RPProvider.get(self.module.Conv2d_2a_3x3)(R, alpha=alpha) # B 32 149 149 | |||||
R = RPProvider.get(self.module.Conv2d_1a_3x3)(R, alpha=alpha) # B 3 299 299 | |||||
return R |
from typing import Callable, Dict, Generic, List, Optional, Type, TypeVar | |||||
from typing_extensions import Protocol | |||||
import warnings | |||||
import torch | |||||
from torch import nn | |||||
from ..modules import LAP, AdaptiveLAP | |||||
from .tv_resnet import BasicBlock, Bottleneck, ResNet | |||||
from ..interpreting.relcam.relprop import RPProvider, RelProp, RelPropSimple | |||||
from ..interpreting.relcam import modules as M | |||||
from ..interpreting.attention_interpreter import AttentionInterpretableModel | |||||
class PoolFactory(Protocol): | |||||
def __call__(self, channels: int, sigmoid_scale: float = 1.0) -> LAP: ... | |||||
lap_factory: PoolFactory = \ | |||||
lambda channels, sigmoid_scale=1.0: \ | |||||
LAP(channels, sigmoid_scale=sigmoid_scale) | |||||
adaptive_lap_factory: PoolFactory = \ | |||||
lambda channels, sigmoid_scale=1.0: \ | |||||
AdaptiveLAP(channels, sigmoid_scale=sigmoid_scale) | |||||
class LapBasicBlock(BasicBlock): | |||||
def __init__(self, pool_factory: PoolFactory, inplanes, planes, stride=1, downsample=None, groups=1, | |||||
base_width=64, dilation=1, norm_layer=None, sigmoid_scale: float = 1.0): | |||||
super().__init__(inplanes, planes, stride=stride, downsample=downsample, | |||||
groups=groups, base_width=base_width, dilation=dilation, | |||||
norm_layer=norm_layer) | |||||
self.pool = None | |||||
if stride != 1: | |||||
assert downsample is not None | |||||
self.conv1 = nn.Conv2d(inplanes, planes, 3, padding=1, bias=False) | |||||
self.downsample[0] = nn.Conv2d(inplanes, planes, 1, bias=False) | |||||
self.pool = pool_factory(planes * 2, sigmoid_scale=sigmoid_scale) | |||||
self.relu3 = nn.ReLU() | |||||
self.cat = M.Cat() | |||||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
out = self.conv1(x) # B P H W | |||||
out = self.bn1(out) # B P H W | |||||
out = self.relu(out) # B P H W | |||||
if self.downsample is not None: | |||||
x = self.downsample(x) # B P H W | |||||
x = self.relu2(x) | |||||
if self.pool is not None: | |||||
poolin = self.cat([out, x], 1) # B 2P H W | |||||
poolout: torch.Tensor = self.pool(poolin) # B 2P H/S W/S | |||||
out, x = poolout.chunk(2, dim=1) # B P H/S W/S | |||||
out = self.conv2(out) # B P H/S W/S | |||||
out = self.bn2(out) # B P H/S W/S | |||||
out = self.add([out, x]) # B P H/S W/S | |||||
out = self.relu3(out) # B P H/S W/S | |||||
return out | |||||
@RPProvider.register(LapBasicBlock) | |||||
class BlockRelProp(RelProp[LapBasicBlock]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
out = RPProvider.get(self.module.relu3)(R, alpha=alpha) | |||||
out, x = RPProvider.get(self.module.add)(out, alpha=alpha) | |||||
out = RPProvider.get(self.module.bn2)(out, alpha=alpha) | |||||
out = RPProvider.get(self.module.conv2)(out, alpha=alpha) | |||||
if self.module.pool is not None: | |||||
poolout = torch.cat([out, x], dim=1) | |||||
poolin = RPProvider.get(self.module.pool)(poolout, alpha=alpha) | |||||
out, x = RPProvider.get(self.module.cat)(poolin, alpha=alpha) | |||||
if self.module.downsample is not None: | |||||
x = RPProvider.get(self.module.relu2)(x, alpha=alpha) | |||||
x = RPProvider.get(self.module.downsample)(x, alpha=alpha) | |||||
out = RPProvider.get(self.module.relu)(out, alpha=alpha) | |||||
out = RPProvider.get(self.module.bn1)(out, alpha=alpha) | |||||
x1 = RPProvider.get(self.module.conv1)(out, alpha=alpha) | |||||
return x + x1 | |||||
class LapBottleneck(Bottleneck): | |||||
def __init__(self, pool_factory: PoolFactory, | |||||
inplanes: int, | |||||
planes: int, | |||||
stride: int = 1, | |||||
downsample: Optional[nn.Module] = None, | |||||
groups: int = 1, | |||||
base_width: int = 64, | |||||
dilation: int = 1, | |||||
norm_layer: Optional[Callable[..., nn.Module]] = None, | |||||
sigmoid_scale: float = 1.0) -> None: | |||||
super().__init__(inplanes, planes, stride=stride, downsample=downsample, | |||||
groups=groups, base_width=base_width, dilation=dilation, | |||||
norm_layer=norm_layer) | |||||
self.pool = None | |||||
if stride != 1: | |||||
assert downsample is not None | |||||
width = int(planes * (base_width / 64.)) * groups | |||||
self.conv2 = nn.Conv2d(width, width, 3, padding=1, bias=False) | |||||
self.downsample[0] = nn.Conv2d(inplanes, planes * self.expansion, 1, bias=False) | |||||
self.pool = pool_factory(planes * (self.expansion + 1), sigmoid_scale=sigmoid_scale) | |||||
self.cat = M.Cat() | |||||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
# x # B I H W | |||||
out = self.conv1(x) # B P H W | |||||
out = self.bn1(out) # B P H W | |||||
out = self.relu(out) # B P H W | |||||
out = self.conv2(out) # B P H W | |||||
out = self.bn2(out) # B P H W | |||||
out = self.relu2(out) # B P H W | |||||
if self.downsample is not None: | |||||
x = self.downsample(x) # B 4P H W | |||||
if self.pool is not None: | |||||
poolin = self.cat([out, x], 1) # B 5P H W | |||||
poolout: torch.Tensor = self.pool(poolin) # B 5P H/S W/S | |||||
out, x = poolout.split( # B P H/S W/S | |||||
[out.shape[1], x.shape[1]], # B 4P H/S W/S | |||||
dim=1) | |||||
out = self.conv3(out) # B 4P H/S W/S | |||||
out = self.bn3(out) # B 4P H/S W/S | |||||
out = self.add([out, x]) # B P H/S W/S | |||||
out = self.relu3(out) # B P H/S W/S | |||||
return out | |||||
@RPProvider.register(LapBottleneck) | |||||
class BlockRP(RelProp[LapBottleneck]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
out = RPProvider.get(self.module.relu3)(R, alpha=alpha) | |||||
out, x = RPProvider.get(self.module.add)(out, alpha=alpha) | |||||
out = RPProvider.get(self.module.bn3)(out, alpha=alpha) | |||||
out = RPProvider.get(self.module.conv3)(out, alpha=alpha) | |||||
if self.module.pool is not None: | |||||
poolout = torch.cat([out, x], dim=1) | |||||
poolin = RPProvider.get(self.module.pool)(poolout, alpha=alpha) | |||||
out, x = RPProvider.get(self.module.cat)(poolin, alpha=alpha) | |||||
if self.module.downsample is not None: | |||||
x = RPProvider.get(self.module.downsample)(x, alpha=alpha) | |||||
out = RPProvider.get(self.module.relu2)(out, alpha=alpha) | |||||
out = RPProvider.get(self.module.bn2)(out, alpha=alpha) | |||||
out = RPProvider.get(self.module.conv2)(out, alpha=alpha) | |||||
out = RPProvider.get(self.module.relu)(out, alpha=alpha) | |||||
out = RPProvider.get(self.module.bn1)(out, alpha=alpha) | |||||
x1 = RPProvider.get(self.module.conv1)(out, alpha=alpha) | |||||
return x + x1 | |||||
TLapBlock = TypeVar('TLapBlock', LapBasicBlock, LapBottleneck) | |||||
class BlockFactory(Generic[TLapBlock]): | |||||
def __init__(self, block: Type[TLapBlock], pool_factory: PoolFactory, sigmoid_scale: float = 1.0) -> None: | |||||
self.expansion = block.expansion | |||||
self._block = block | |||||
self._pool_factory = pool_factory | |||||
self._sigmoid_scale = sigmoid_scale | |||||
def __call__(self, *args, **kwargs) -> TLapBlock: | |||||
return self._block(self._pool_factory, *args, **kwargs, sigmoid_scale=self._sigmoid_scale) | |||||
class Stride(nn.Module): | |||||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
return x[:, :, ::2, ::2] | |||||
@RPProvider.register(Stride) | |||||
class StrideRelProp(RelPropSimple[Stride]): | |||||
pass | |||||
class LapResNet(ResNet, AttentionInterpretableModel): | |||||
def __init__(self, | |||||
block: Type[TLapBlock], | |||||
layers: List[int], | |||||
pool_factory: PoolFactory = lap_factory, | |||||
lap_positions: List[int] = [2, 3, 4], | |||||
adaptive_factory: PoolFactory = None, | |||||
sigmoid_scale: float = 1.0, | |||||
binary: bool = False): | |||||
super().__init__(BlockFactory(block, pool_factory, sigmoid_scale=sigmoid_scale), layers, binary=binary) | |||||
if adaptive_factory is not None: | |||||
self.avgpool = nn.Sequential( | |||||
adaptive_factory(512, sigmoid_scale=sigmoid_scale), | |||||
) | |||||
for i in range(2, 5): | |||||
if i not in lap_positions: | |||||
warnings.warn(f'Putting stride on layer {i}') | |||||
getattr(self, 'layer{}'.format(i))[0].pool = Stride() | |||||
@property | |||||
def attention_layers(self) -> Dict[str, List[LAP]]: | |||||
""" | |||||
List of attention groups | |||||
""" | |||||
attention_groups = { | |||||
'0_layer2': [self.layer2[0].pool], | |||||
'1_layer3': [self.layer3[0].pool], | |||||
'2_layer4': [self.layer4[0].pool], | |||||
'3_avgpool': [self.avgpool[0]], | |||||
'4_overall': [ | |||||
self.layer2[0].pool, | |||||
self.layer3[0].pool, | |||||
self.layer4[0].pool, | |||||
self.avgpool[0], | |||||
] | |||||
} | |||||
assert all( | |||||
all( | |||||
isinstance(layer, LAP) | |||||
for layer in attention_group) | |||||
for attention_group in attention_groups.values()), \ | |||||
"Only LAP is supported for this interpretation method" | |||||
return attention_groups | |||||
def lap_resnet18(pool_factory: PoolFactory = lap_factory, | |||||
adaptive_factory: PoolFactory = None, | |||||
sigmoid_scale: float = 1.0, | |||||
lap_positions: List[int] = [2, 3, 4], | |||||
binary: bool = False) -> LapResNet: | |||||
"""Constructs a LAP-ResNet-18 model. | |||||
""" | |||||
return LapResNet(LapBasicBlock, [2, 2, 2, 2], pool_factory=pool_factory, | |||||
adaptive_factory=adaptive_factory, sigmoid_scale=sigmoid_scale, | |||||
lap_positions=lap_positions, binary=binary) | |||||
def lap_resnet50(pool_factory: PoolFactory = lap_factory, | |||||
adaptive_factory: PoolFactory = None, | |||||
sigmoid_scale: float = 1.0, | |||||
lap_positions: List[int] = [2, 3, 4], | |||||
binary: bool = False) -> LapResNet: | |||||
"""Constructs a LAP-ResNet-50 model. | |||||
""" | |||||
return LapResNet(LapBottleneck, [3, 4, 6, 3], pool_factory=pool_factory, | |||||
adaptive_factory=adaptive_factory, sigmoid_scale=sigmoid_scale, | |||||
lap_positions=lap_positions, binary=binary) |
""" | |||||
The Model base class | |||||
""" | |||||
from collections import OrderedDict | |||||
import inspect | |||||
import typing | |||||
from typing import Any, Dict, List, Tuple | |||||
import re | |||||
from abc import ABC | |||||
import torch | |||||
from torch.nn import BatchNorm2d, BatchNorm1d, BatchNorm3d | |||||
ModelIO = Dict[str, torch.Tensor] | |||||
class Model(torch.nn.Module, ABC): | |||||
""" | |||||
The Model base class | |||||
""" | |||||
_frozen_bns_list: List[torch.nn.Module] = [] | |||||
def init_weights_from_other_model(self, pretrained_model_dir=None): | |||||
""" Initializes the weights from another model with the given address, | |||||
Default is to set all the parameters with the same name and shape, | |||||
but can be rewritten in subclasses. | |||||
The model to load is received via the function get_other_model_to_load_from. | |||||
The default value is the same model!""" | |||||
if pretrained_model_dir is None: | |||||
print('The model was not preinitialized.', flush=True) | |||||
return | |||||
other_model_dict = torch.load(pretrained_model_dir) | |||||
own_state = self.state_dict() | |||||
cnt = 0 | |||||
skip_cnt = 0 | |||||
for name, param in other_model_dict.items(): | |||||
if isinstance(param, torch.nn.Parameter): | |||||
# backwards compatibility for serialized parameters | |||||
param = param.data | |||||
if name in own_state and own_state[name].data.shape == param.shape: | |||||
own_state[name].copy_(param) | |||||
cnt += 1 | |||||
else: | |||||
skip_cnt += 1 | |||||
print(f'{cnt} out of {len(own_state)} parameters were loaded successfully, {skip_cnt} of the given set were skipped.' , flush=True) | |||||
def freeze_parameters(self, regex_list: List[str] = None) -> None: | |||||
""" | |||||
An auxiliary function for freezing the parameters of the model | |||||
using an inclusion keywords list and exclusion keywords list. | |||||
:param regex_list: All the parameters matching at least one of the given regexes would be freezed. | |||||
None means no parameter would be frozen. | |||||
:return: None | |||||
""" | |||||
# One solution for batchnorms is registering a preforward hook | |||||
# to call eval everytime before running them | |||||
# But that causes a problem, we can't dynamically change the frozen batchnorms during training | |||||
# e.g. in layerwise unfreezing. | |||||
# So it's better to set a dynamic attribute that keeps their list, rather than having an init | |||||
# Init causes other problems in multi-inheritance | |||||
if regex_list is None: | |||||
print('No parameter was freezeed!', flush=True) | |||||
return | |||||
regex_list = [re.compile(the_pattern) for the_pattern in regex_list] | |||||
def check_freeze_conditions(param_name): | |||||
for the_regex in regex_list: | |||||
if the_regex.fullmatch(param_name) is not None: | |||||
return True | |||||
return False | |||||
frozen_cnt = 0 | |||||
total_cnt = 0 | |||||
frozen_modules_names = dict() | |||||
for n, p in self.named_parameters(): | |||||
total_cnt += 1 | |||||
if check_freeze_conditions(n): | |||||
p.requires_grad = False | |||||
frozen_cnt += 1 | |||||
frozen_modules_names[n[:n.rfind('.')]] = True | |||||
self._frozen_bns_list = [] | |||||
for module_name, module in self.named_modules(): | |||||
if (module_name in frozen_modules_names) and \ | |||||
(isinstance(module, BatchNorm2d) or | |||||
isinstance(module, BatchNorm1d) or | |||||
isinstance(module, BatchNorm3d)): | |||||
self._frozen_bns_list.append(module) | |||||
print('********************* NOTICE *********************') | |||||
print('%d out of %d parameters were frozen!' % (frozen_cnt, total_cnt)) | |||||
print( | |||||
'%d BatchNorm layers will be frozen by being kept in the val mode, IF YOU HAVE NOT OVERRIDEN IT IN YOUR MAIN MODEL!' % len( | |||||
self._frozen_bns_list)) | |||||
print('***************************************************', flush=True) | |||||
def train(self, mode: bool = True): | |||||
super(Model, self).train(mode) | |||||
for bn in self._frozen_bns_list: | |||||
bn.train(False) | |||||
return self | |||||
@property | |||||
def additional_kwargs(self) -> typing.OrderedDict[str, bool]: | |||||
r""" Returns a dictionary from additional `kwargs` names to their optionality """ | |||||
return OrderedDict({}) | |||||
def get_forward_required_kws_kwargs_and_defaults(self) -> Tuple[List[str], List[Any], Dict[str, Any]]: | |||||
"""Returns the names of the keyword arguments required in the forward function and the list of the default values for the ones having them. | |||||
The second list might be shorter than the first and it represents the default values from the end of the arguments. | |||||
Returns: | |||||
Tuple[List[str], List[Any], Dict[str, Any]]: The list of args names and default values for the ones having them (from some point to the end) + the dictionary of kwargs | |||||
""" | |||||
model_argspec = inspect.getfullargspec(self.forward) | |||||
model_forward_args = list(model_argspec.args) | |||||
# skipping self arg | |||||
model_forward_args = model_forward_args[1:] | |||||
args_default_values = model_argspec.defaults | |||||
if args_default_values is None: | |||||
args_default_values = [] | |||||
else: | |||||
args_default_values = list(args_default_values) | |||||
additional_kwargs = self.additional_kwargs if model_argspec.varkw is not None else {} | |||||
return model_forward_args, args_default_values, additional_kwargs | |||||
from typing import Dict | |||||
import torch | |||||
from torch.nn import functional as F | |||||
import torchvision | |||||
from ..lap_inception import LAPInception | |||||
class RSNALAPInception(LAPInception): | |||||
def __init__(self, aux_weight: float, pool_factory, adaptive_pool_factory): | |||||
super().__init__(aux_weight, n_classes=1, pool_factory=pool_factory, adaptive_pool_factory=adaptive_pool_factory) | |||||
def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]: | |||||
x = x.repeat_interleave(3, dim=1) # B 3 224 224 | |||||
if self.training: | |||||
out, aux = torchvision.models.Inception3.forward(self, x) # B 1 | |||||
out, aux = out.flatten(), aux.flatten() # B | |||||
else: | |||||
out = torchvision.models.Inception3.forward(self, x).flatten() # B | |||||
aux = None | |||||
if y is not None: | |||||
main_loss = F.binary_cross_entropy(out, y) | |||||
if aux is not None: | |||||
aux_loss = F.binary_cross_entropy(aux, y) | |||||
loss = (main_loss + self.aux_weight * aux_loss) / (1 + self.aux_weight) | |||||
else: | |||||
loss = main_loss | |||||
return { | |||||
'positive_class_probability': out, | |||||
'loss': loss | |||||
} | |||||
return { | |||||
'positive_class_probability': out | |||||
} |
from typing import Dict | |||||
import torch | |||||
from ...modules import LAP, AdaptiveLAP | |||||
from ..lap_resnet import LapBasicBlock, LapResNet, PoolFactory | |||||
def get_pool_factory(discriminative_attention=True): | |||||
def pool_factory(channels, sigmoid_scale=1.0): | |||||
return LAP(channels, | |||||
hidden_channels=[8], | |||||
sigmoid_scale=sigmoid_scale, | |||||
discriminative_attention=discriminative_attention) | |||||
return pool_factory | |||||
def get_adaptive_pool_factory(discriminative_attention=True): | |||||
def adaptive_pool_factory(channels, sigmoid_scale=1.0): | |||||
return AdaptiveLAP(channels, | |||||
sigmoid_scale=sigmoid_scale, | |||||
discriminative_attention=discriminative_attention) | |||||
return adaptive_pool_factory | |||||
class RSNALAPResNet18(LapResNet): | |||||
def __init__(self, pool_factory: PoolFactory = get_pool_factory(), | |||||
adaptive_factory: PoolFactory = get_adaptive_pool_factory(), | |||||
sigmoid_scale: float = 1.0): | |||||
super().__init__(LapBasicBlock, [2, 2, 2, 2], | |||||
pool_factory=pool_factory, | |||||
adaptive_factory=adaptive_factory, | |||||
sigmoid_scale=sigmoid_scale, | |||||
binary=True) | |||||
def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]: | |||||
x = x.repeat_interleave(3, dim=1) # B 3 224 224 | |||||
return super().forward(x, y) |
from typing import Dict | |||||
import torch | |||||
from torch.nn import functional as F | |||||
import torchvision | |||||
from ..tv_inception import Inception3 | |||||
class RSNAORGInception(Inception3): | |||||
def __init__(self, aux_weight: float): | |||||
super().__init__(aux_weight, n_classes=1) | |||||
def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]: | |||||
x = x.repeat_interleave(3, dim=1) # B 3 224 224 | |||||
if self.training: | |||||
out, aux = torchvision.models.Inception3.forward(self, x) # B 1 | |||||
out, aux = out.flatten(), aux.flatten() # B | |||||
else: | |||||
out = torchvision.models.Inception3.forward(self, x).flatten() # B | |||||
aux = None | |||||
if y is not None: | |||||
main_loss = F.binary_cross_entropy(out, y) | |||||
if aux is not None: | |||||
aux_loss = F.binary_cross_entropy(aux, y) | |||||
loss = (main_loss + self.aux_weight * aux_loss) / (1 + self.aux_weight) | |||||
else: | |||||
loss = main_loss | |||||
return { | |||||
'positive_class_probability': out, | |||||
'loss': loss | |||||
} | |||||
return { | |||||
'positive_class_probability': out | |||||
} |
from typing import Dict | |||||
import torch | |||||
from ..tv_resnet import BasicBlock, ResNet | |||||
class RSNAORGResNet18(ResNet): | |||||
def __init__(self): | |||||
super().__init__(BasicBlock, [2, 2, 2, 2], binary=True) | |||||
def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]: | |||||
x = x.repeat_interleave(3, dim=1) # B 3 224 224 | |||||
return super().forward(x, y) |
from typing import Iterable, Optional, Callable, List | |||||
import torch | |||||
from torch import nn, Tensor | |||||
import torchvision | |||||
from ..interpreting.interpretable import CamInterpretableModel | |||||
from ..interpreting.relcam.relprop import RPProvider, RelProp | |||||
from ..interpreting.relcam import modules as M | |||||
from .lap_inception import ( | |||||
InceptionA, | |||||
InceptionC, | |||||
InceptionE, | |||||
BasicConv2d, | |||||
InceptionAux as BaseInceptionAux, | |||||
) | |||||
class InceptionB(nn.Module): | |||||
def __init__( | |||||
self, | |||||
in_channels: int, | |||||
conv_block: Optional[Callable[..., nn.Module]] = None | |||||
) -> None: | |||||
super(InceptionB, self).__init__() | |||||
if conv_block is None: | |||||
conv_block = BasicConv2d | |||||
self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) | |||||
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) | |||||
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) | |||||
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2) | |||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2) | |||||
self.cat = M.Cat() | |||||
def _forward(self, x: Tensor) -> List[Tensor]: | |||||
branch3x3 = self.branch3x3(x) | |||||
branch3x3dbl = self.branch3x3dbl_1(x) | |||||
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |||||
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) | |||||
branch_pool = self.maxpool(x) | |||||
outputs = [branch3x3, branch3x3dbl, branch_pool] | |||||
return outputs | |||||
def forward(self, x: Tensor) -> Tensor: | |||||
outputs = self._forward(x) | |||||
return self.cat(outputs, 1) | |||||
@RPProvider.register(InceptionB) | |||||
class InceptionBRelProp(RelProp[InceptionB]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
branch3x3, branch3x3dbl, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha) | |||||
x1 = RPProvider.get(self.module.maxpool)(branch_pool, alpha=alpha) | |||||
branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_3)(branch3x3dbl, alpha=alpha) | |||||
branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha) | |||||
x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha) | |||||
x3 = RPProvider.get(self.module.branch3x3)(branch3x3, alpha=alpha) | |||||
return x1 + x2 + x3 | |||||
class InceptionD(nn.Module): | |||||
def __init__( | |||||
self, | |||||
in_channels: int, | |||||
conv_block: Optional[Callable[..., nn.Module]] = None | |||||
) -> None: | |||||
super(InceptionD, self).__init__() | |||||
if conv_block is None: | |||||
conv_block = BasicConv2d | |||||
self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) | |||||
self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) | |||||
self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1) | |||||
self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3)) | |||||
self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0)) | |||||
self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2) | |||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2) | |||||
self.cat = M.Cat() | |||||
def _forward(self, x: Tensor) -> List[Tensor]: | |||||
branch3x3 = self.branch3x3_1(x) | |||||
branch3x3 = self.branch3x3_2(branch3x3) | |||||
branch7x7x3 = self.branch7x7x3_1(x) | |||||
branch7x7x3 = self.branch7x7x3_2(branch7x7x3) | |||||
branch7x7x3 = self.branch7x7x3_3(branch7x7x3) | |||||
branch7x7x3 = self.branch7x7x3_4(branch7x7x3) | |||||
branch_pool = self.maxpool(x) | |||||
outputs = [branch3x3, branch7x7x3, branch_pool] | |||||
return outputs | |||||
def forward(self, x: Tensor) -> Tensor: | |||||
outputs = self._forward(x) | |||||
return self.cat(outputs, 1) | |||||
@RPProvider.register(InceptionD) | |||||
class InceptionDRelProp(RelProp[InceptionD]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
branch3x3, branch7x7x3, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha) | |||||
x1 = RPProvider.get(self.module.maxpool)(branch_pool, alpha=alpha) | |||||
branch7x7x3 = RPProvider.get(self.module.branch7x7x3_4)(branch7x7x3, alpha=alpha) | |||||
branch7x7x3 = RPProvider.get(self.module.branch7x7x3_3)(branch7x7x3, alpha=alpha) | |||||
branch7x7x3 = RPProvider.get(self.module.branch7x7x3_2)(branch7x7x3, alpha=alpha) | |||||
x2 = RPProvider.get(self.module.branch7x7x3_1)(branch7x7x3, alpha=alpha) | |||||
branch3x3 = RPProvider.get(self.module.branch3x3_2)(branch3x3, alpha=alpha) | |||||
x3 = RPProvider.get(self.module.branch3x3_1)(branch3x3, alpha=alpha) | |||||
return x1 + x2 + x3 | |||||
class InceptionAux(BaseInceptionAux): | |||||
def __init__( | |||||
self, | |||||
in_channels: int, | |||||
num_classes: int, | |||||
conv_block: Optional[Callable[..., nn.Module]] = None | |||||
) -> None: | |||||
super().__init__(in_channels, num_classes, conv_block=conv_block) | |||||
if conv_block is None: | |||||
conv_block = BasicConv2d | |||||
self.conv1 = conv_block(128, 768, kernel_size=5) | |||||
self.conv1.stddev = 0.01 # type: ignore[assignment] | |||||
class Inception3(CamInterpretableModel, torchvision.models.Inception3): | |||||
def __init__(self, aux_weight: float, n_classes=1): | |||||
torchvision.models.Inception3.__init__(self, | |||||
transform_input=False, init_weights=False, | |||||
inception_blocks = [ | |||||
BasicConv2d, InceptionA, InceptionB, InceptionC, | |||||
InceptionD, InceptionE, InceptionAux | |||||
]) | |||||
self.fc = nn.Sequential( | |||||
nn.Linear(2048, n_classes), | |||||
nn.Sigmoid() | |||||
) | |||||
self.AuxLogits.fc = nn.Sequential( | |||||
nn.Linear(768, n_classes), | |||||
nn.Sigmoid() | |||||
) | |||||
self.aux_weight = aux_weight | |||||
@property | |||||
def target_conv_layers(self) -> List[nn.Module]: | |||||
return [ | |||||
self.Mixed_7c.branch1x1, | |||||
self.Mixed_7c.branch3x3_2a, self.Mixed_7c.branch3x3_2b, | |||||
self.Mixed_7c.branch3x3dbl_3a, self.Mixed_7c.branch3x3dbl_3b, | |||||
self.Mixed_7c.branch_pool | |||||
] | |||||
@property | |||||
def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]: | |||||
return ['x'] | |||||
def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor: | |||||
p = self.forward(*inputs, **kwargs)['positive_class_probability'] | |||||
return torch.stack([1 - p, p], dim=1) | |||||
@RPProvider.register(Inception3) | |||||
class Inception3Mo4RelProp(RelProp[Inception3]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
if RPProvider.get(self.module.fc).Y.shape[1] == 1: | |||||
R = R[:, -1:] | |||||
R = RPProvider.get(self.module.fc)(R, alpha=alpha) # B 2048 | |||||
R = R.reshape_as(RPProvider.get(self.module.dropout).Y) # B 2048 1 1 | |||||
R = RPProvider.get(self.module.dropout)(R, alpha=alpha) # B 2048 1 1 | |||||
R = RPProvider.get(self.module.avgpool)(R, alpha=alpha) # B 2048 8 8 | |||||
R = RPProvider.get(self.module.Mixed_7c)(R, alpha=alpha) # B 2048 8 8 | |||||
R = RPProvider.get(self.module.Mixed_7b)(R, alpha=alpha) # B 1280 8 8 | |||||
R = RPProvider.get(self.module.Mixed_7a)(R, alpha=alpha) # B 768 17 17 | |||||
R = RPProvider.get(self.module.Mixed_6e)(R, alpha=alpha) # B 768 17 17 | |||||
R = RPProvider.get(self.module.Mixed_6d)(R, alpha=alpha) # B 768 17 17 | |||||
R = RPProvider.get(self.module.Mixed_6c)(R, alpha=alpha) # B 768 17 17 | |||||
R = RPProvider.get(self.module.Mixed_6b)(R, alpha=alpha) # B 768 17 17 | |||||
R = RPProvider.get(self.module.Mixed_6a)(R, alpha=alpha) # B 288 35 35 | |||||
R = RPProvider.get(self.module.Mixed_5d)(R, alpha=alpha) # B 288 35 35 | |||||
R = RPProvider.get(self.module.Mixed_5c)(R, alpha=alpha) # B 256 35 35 | |||||
R = RPProvider.get(self.module.Mixed_5b)(R, alpha=alpha) # B 192 35 35 | |||||
R = RPProvider.get(self.module.maxpool2)(R, alpha=alpha) # B 192 71 71 | |||||
R = RPProvider.get(self.module.Conv2d_4a_3x3)(R, alpha=alpha) # B 80 73 73 | |||||
R = RPProvider.get(self.module.Conv2d_3b_1x1)(R, alpha=alpha) # B 64 73 73 | |||||
R = RPProvider.get(self.module.maxpool1)(R, alpha=alpha) # B 64 147 147 | |||||
R = RPProvider.get(self.module.Conv2d_2b_3x3)(R, alpha=alpha) # B 32 147 147 | |||||
R = RPProvider.get(self.module.Conv2d_2a_3x3)(R, alpha=alpha) # B 32 149 149 | |||||
R = RPProvider.get(self.module.Conv2d_1a_3x3)(R, alpha=alpha) # B 3 299 299 | |||||
return R |
from typing import Dict, Iterable, Type, Any, Callable, Union, List, Optional | |||||
import torch | |||||
import torch.nn as nn | |||||
from torch import Tensor | |||||
from torch.nn import functional as F | |||||
from ..interpreting.interpretable import CamInterpretableModel | |||||
from ..interpreting.relcam.relprop import RPProvider, RelProp | |||||
from ..interpreting.relcam import modules as M | |||||
__all__ = [ | |||||
"ResNet", | |||||
"resnet18", | |||||
"resnet34", | |||||
"resnet50", | |||||
"resnet101", | |||||
"resnet152", | |||||
"resnext50_32x4d", | |||||
"resnext101_32x8d", | |||||
"wide_resnet50_2", | |||||
"wide_resnet101_2", | |||||
] | |||||
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: | |||||
"""3x3 convolution with padding""" | |||||
return nn.Conv2d( | |||||
in_planes, | |||||
out_planes, | |||||
kernel_size=3, | |||||
stride=stride, | |||||
padding=dilation, | |||||
groups=groups, | |||||
bias=False, | |||||
dilation=dilation, | |||||
) | |||||
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: | |||||
"""1x1 convolution""" | |||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |||||
class BasicBlock(nn.Module): | |||||
expansion: int = 1 | |||||
def __init__( | |||||
self, | |||||
inplanes: int, | |||||
planes: int, | |||||
stride: int = 1, | |||||
downsample: Optional[nn.Module] = None, | |||||
groups: int = 1, | |||||
base_width: int = 64, | |||||
dilation: int = 1, | |||||
norm_layer: Optional[Callable[..., nn.Module]] = None, | |||||
) -> None: | |||||
super().__init__() | |||||
if norm_layer is None: | |||||
norm_layer = nn.BatchNorm2d | |||||
if groups != 1 or base_width != 64: | |||||
raise ValueError("BasicBlock only supports groups=1 and base_width=64") | |||||
if dilation > 1: | |||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | |||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1 | |||||
self.conv1 = conv3x3(inplanes, planes, stride) | |||||
self.bn1 = norm_layer(planes) | |||||
self.relu = nn.ReLU() | |||||
self.relu2 = nn.ReLU() | |||||
self.conv2 = conv3x3(planes, planes) | |||||
self.bn2 = norm_layer(planes) | |||||
self.downsample = downsample | |||||
self.stride = stride | |||||
self.add = M.Add() | |||||
def forward(self, x: Tensor) -> Tensor: | |||||
out = self.conv1(x) | |||||
out = self.bn1(out) | |||||
out = self.relu(out) | |||||
out = self.conv2(out) | |||||
out = self.bn2(out) | |||||
if self.downsample is not None: | |||||
x = self.downsample(x) | |||||
out = self.add([out, x]) | |||||
out = self.relu2(out) | |||||
return out | |||||
@RPProvider.register(BasicBlock) | |||||
class BasicBlockRelProp(RelProp[BasicBlock]): | |||||
def rel(self, R, alpha): | |||||
out = RPProvider.get(self.module.relu2)(R, alpha) | |||||
out, x = RPProvider.get(self.module.add)(out, alpha) | |||||
if self.module.downsample is not None: | |||||
x = RPProvider.get(self.module.downsample)(x, alpha) | |||||
out = RPProvider.get(self.module.bn2)(out, alpha) | |||||
out = RPProvider.get(self.module.conv2)(out, alpha) | |||||
out = RPProvider.get(self.module.relu)(out, alpha) | |||||
out = RPProvider.get(self.module.bn1)(out, alpha) | |||||
x1 = RPProvider.get(self.module.conv1)(out, alpha) | |||||
return x + x1 | |||||
class Bottleneck(nn.Module): | |||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) | |||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1) | |||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. | |||||
# This variant is also known as ResNet V1.5 and improves accuracy according to | |||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. | |||||
expansion: int = 4 | |||||
def __init__( | |||||
self, | |||||
inplanes: int, | |||||
planes: int, | |||||
stride: int = 1, | |||||
downsample: Optional[nn.Module] = None, | |||||
groups: int = 1, | |||||
base_width: int = 64, | |||||
dilation: int = 1, | |||||
norm_layer: Optional[Callable[..., nn.Module]] = None, | |||||
) -> None: | |||||
super().__init__() | |||||
if norm_layer is None: | |||||
norm_layer = nn.BatchNorm2d | |||||
width = int(planes * (base_width / 64.0)) * groups | |||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1 | |||||
self.conv1 = conv1x1(inplanes, width) | |||||
self.bn1 = norm_layer(width) | |||||
self.conv2 = conv3x3(width, width, stride, groups, dilation) | |||||
self.bn2 = norm_layer(width) | |||||
self.conv3 = conv1x1(width, planes * self.expansion) | |||||
self.bn3 = norm_layer(planes * self.expansion) | |||||
self.relu = nn.ReLU() | |||||
self.relu2 = nn.ReLU() | |||||
self.relu3 = nn.ReLU() | |||||
self.downsample = downsample | |||||
self.stride = stride | |||||
self.add = M.Add() | |||||
def forward(self, x: Tensor) -> Tensor: | |||||
out = self.conv1(x) | |||||
out = self.bn1(out) | |||||
out = self.relu(out) | |||||
out = self.conv2(out) | |||||
out = self.bn2(out) | |||||
out = self.relu2(out) | |||||
out = self.conv3(out) | |||||
out = self.bn3(out) | |||||
if self.downsample is not None: | |||||
x = self.downsample(x) | |||||
out = self.add([out, x]) | |||||
out = self.relu3(out) | |||||
return out | |||||
@RPProvider.register(Bottleneck) | |||||
class BottleneckRelProp(RelProp[Bottleneck]): | |||||
def rel(self, R, alpha): | |||||
out = RPProvider.get(self.module.relu3)(R, alpha) | |||||
out, x = RPProvider.get(self.module.add)(out, alpha) | |||||
if self.downsample is not None: | |||||
x = RPProvider.get(self.module.downsample)(x, alpha) | |||||
out = RPProvider.get(self.module.bn3)(out, alpha) | |||||
out = RPProvider.get(self.module.conv3)(out, alpha) | |||||
out = RPProvider.get(self.module.relu2)(out, alpha) | |||||
out = RPProvider.get(self.module.bn2)(out, alpha) | |||||
out = RPProvider.get(self.module.conv2)(out, alpha) | |||||
out = RPProvider.get(self.module.relu)(out, alpha) | |||||
out = RPProvider.get(self.module.bn1)(out, alpha) | |||||
x1 = RPProvider.get(self.module.conv1)(out, alpha) | |||||
return x + x1 | |||||
class ResNet(CamInterpretableModel): | |||||
def __init__( | |||||
self, | |||||
block: Type[Union[BasicBlock, Bottleneck]], | |||||
layers: List[int], | |||||
num_classes: int = 1000, | |||||
zero_init_residual: bool = False, | |||||
groups: int = 1, | |||||
width_per_group: int = 64, | |||||
replace_stride_with_dilation: Optional[List[bool]] = None, | |||||
norm_layer: Optional[Callable[..., nn.Module]] = None, | |||||
binary: bool = False, | |||||
) -> None: | |||||
super().__init__() | |||||
if norm_layer is None: | |||||
norm_layer = nn.BatchNorm2d | |||||
self._norm_layer = norm_layer | |||||
self.inplanes = 64 | |||||
self.dilation = 1 | |||||
if replace_stride_with_dilation is None: | |||||
# each element in the tuple indicates if we should replace | |||||
# the 2x2 stride with a dilated convolution instead | |||||
replace_stride_with_dilation = [False, False, False] | |||||
if len(replace_stride_with_dilation) != 3: | |||||
raise ValueError( | |||||
"replace_stride_with_dilation should be None " | |||||
f"or a 3-element tuple, got {replace_stride_with_dilation}" | |||||
) | |||||
self.groups = groups | |||||
self.base_width = width_per_group | |||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) | |||||
self.bn1 = norm_layer(self.inplanes) | |||||
self.relu = nn.ReLU() | |||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||||
self.layer1 = self._make_layer(block, 64, layers[0]) | |||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) | |||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) | |||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) | |||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |||||
self._binary = binary | |||||
if binary: | |||||
self.fc = nn.Sequential( | |||||
nn.Linear(512 * block.expansion, 1), | |||||
nn.Sigmoid(), | |||||
nn.Flatten(0) | |||||
) | |||||
else: | |||||
self.fc = nn.Linear(512 * block.expansion, num_classes) | |||||
for m in self.modules(): | |||||
if isinstance(m, nn.Conv2d): | |||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | |||||
nn.init.constant_(m.weight, 1) | |||||
nn.init.constant_(m.bias, 0) | |||||
# Zero-initialize the last BN in each residual branch, | |||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity. | |||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | |||||
if zero_init_residual: | |||||
for m in self.modules(): | |||||
if isinstance(m, Bottleneck): | |||||
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] | |||||
elif isinstance(m, BasicBlock): | |||||
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] | |||||
def _make_layer( | |||||
self, | |||||
block: Type[Union[BasicBlock, Bottleneck]], | |||||
planes: int, | |||||
blocks: int, | |||||
stride: int = 1, | |||||
dilate: bool = False, | |||||
) -> nn.Sequential: | |||||
norm_layer = self._norm_layer | |||||
downsample = None | |||||
previous_dilation = self.dilation | |||||
if dilate: | |||||
self.dilation *= stride | |||||
stride = 1 | |||||
if stride != 1 or self.inplanes != planes * block.expansion: | |||||
downsample = nn.Sequential( | |||||
conv1x1(self.inplanes, planes * block.expansion, stride), | |||||
norm_layer(planes * block.expansion), | |||||
) | |||||
layers = [] | |||||
layers.append( | |||||
block( | |||||
self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer | |||||
) | |||||
) | |||||
self.inplanes = planes * block.expansion | |||||
for _ in range(1, blocks): | |||||
layers.append( | |||||
block( | |||||
self.inplanes, | |||||
planes, | |||||
groups=self.groups, | |||||
base_width=self.base_width, | |||||
dilation=self.dilation, | |||||
norm_layer=norm_layer, | |||||
) | |||||
) | |||||
return nn.Sequential(*layers) | |||||
def _forward_impl(self, x: Tensor) -> Tensor: | |||||
# See note [TorchScript super()] | |||||
x = self.conv1(x) | |||||
x = self.bn1(x) | |||||
x = self.relu(x) | |||||
x = self.maxpool(x) | |||||
x = self.layer1(x) | |||||
x = self.layer2(x) | |||||
x = self.layer3(x) | |||||
x = self.layer4(x) | |||||
x = self.avgpool(x) | |||||
x = torch.flatten(x, 1) | |||||
x = self.fc(x) | |||||
return x | |||||
def forward(self, x: Tensor, y: Tensor) -> Dict[str, Tensor]: | |||||
p = self._forward_impl(x) | |||||
if self._binary: | |||||
return dict( | |||||
positive_class_probability=p, | |||||
loss=F.binary_cross_entropy(p, y), | |||||
) | |||||
return dict( | |||||
categorical_probability=p.softmax(dim=1), | |||||
loss=F.cross_entropy(p, y.long()), | |||||
) | |||||
@property | |||||
def target_conv_layers(self) -> List[nn.Module]: | |||||
""" | |||||
Returns: | |||||
The convolutional layers to be interpreted. The result of | |||||
the interpretation will be sum of the grad-cam of these layers | |||||
""" | |||||
return [self.layer4] | |||||
def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor: | |||||
""" | |||||
A method to get probabilities assigned to all the classes in the model's forward, | |||||
with shape (B, C), in which B is the batch size and C is number of classes | |||||
Args: | |||||
*inputs: Inputs to the model | |||||
**kwargs: Additional arguments | |||||
Returns: | |||||
Tensor of categorical probabilities | |||||
""" | |||||
if self._binary: | |||||
p = self.forward(*inputs, **kwargs)['positive_class_probability'] | |||||
return torch.stack([1 - p, p], dim=1) | |||||
return self.forward(*inputs, **kwargs)['categorical_probability'] | |||||
@property | |||||
def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]: | |||||
""" | |||||
Returns: | |||||
Input module for interpretation | |||||
""" | |||||
return ['x'] | |||||
@RPProvider.register(ResNet) | |||||
class ResNetRelProp(RelProp[ResNet]): | |||||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
if RPProvider.get(self.module.fc).Y.ndim == 1: | |||||
R = R[:, -1] | |||||
R = RPProvider.get(self.module.fc)(R, alpha=alpha) | |||||
R = R.reshape_as(RPProvider.get(self.module.avgpool).Y) | |||||
R = RPProvider.get(self.module.avgpool)(R, alpha=alpha) | |||||
R = RPProvider.get(self.module.layer4)(R, alpha=alpha) | |||||
R = RPProvider.get(self.module.layer3)(R, alpha=alpha) | |||||
R = RPProvider.get(self.module.layer2)(R, alpha=alpha) | |||||
R = RPProvider.get(self.module.layer1)(R, alpha=alpha) | |||||
R = RPProvider.get(self.module.maxpool)(R, alpha=alpha) | |||||
R = RPProvider.get(self.module.relu)(R, alpha=alpha) | |||||
R = RPProvider.get(self.module.bn1)(R, alpha=alpha) | |||||
R = RPProvider.get(self.module.conv1)(R, alpha=alpha) | |||||
return R | |||||
def _resnet( | |||||
arch: str, | |||||
block: Type[Union[BasicBlock, Bottleneck]], | |||||
layers: List[int], | |||||
pretrained: bool, | |||||
progress: bool, | |||||
**kwargs: Any, | |||||
) -> ResNet: | |||||
model = ResNet(block, layers, **kwargs) | |||||
#if pretrained: | |||||
#state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) | |||||
#model.load_state_dict(state_dict) | |||||
return model | |||||
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||||
r"""ResNet-18 model from | |||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. | |||||
Args: | |||||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||||
progress (bool): If True, displays a progress bar of the download to stderr | |||||
""" | |||||
return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) | |||||
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||||
r"""ResNet-34 model from | |||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. | |||||
Args: | |||||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||||
progress (bool): If True, displays a progress bar of the download to stderr | |||||
""" | |||||
return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) | |||||
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||||
r"""ResNet-50 model from | |||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. | |||||
Args: | |||||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||||
progress (bool): If True, displays a progress bar of the download to stderr | |||||
""" | |||||
return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) | |||||
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||||
r"""ResNet-101 model from | |||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. | |||||
Args: | |||||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||||
progress (bool): If True, displays a progress bar of the download to stderr | |||||
""" | |||||
return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) | |||||
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||||
r"""ResNet-152 model from | |||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. | |||||
Args: | |||||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||||
progress (bool): If True, displays a progress bar of the download to stderr | |||||
""" | |||||
return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) | |||||
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||||
r"""ResNeXt-50 32x4d model from | |||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_. | |||||
Args: | |||||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||||
progress (bool): If True, displays a progress bar of the download to stderr | |||||
""" | |||||
kwargs["groups"] = 32 | |||||
kwargs["width_per_group"] = 4 | |||||
return _resnet("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) | |||||
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||||
r"""ResNeXt-101 32x8d model from | |||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_. | |||||
Args: | |||||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||||
progress (bool): If True, displays a progress bar of the download to stderr | |||||
""" | |||||
kwargs["groups"] = 32 | |||||
kwargs["width_per_group"] = 8 | |||||
return _resnet("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) | |||||
def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||||
r"""Wide ResNet-50-2 model from | |||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_. | |||||
The model is the same as ResNet except for the bottleneck number of channels | |||||
which is twice larger in every block. The number of channels in outer 1x1 | |||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 | |||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048. | |||||
Args: | |||||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||||
progress (bool): If True, displays a progress bar of the download to stderr | |||||
""" | |||||
kwargs["width_per_group"] = 64 * 2 | |||||
return _resnet("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) | |||||
def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||||
r"""Wide ResNet-101-2 model from | |||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_. | |||||
The model is the same as ResNet except for the bottleneck number of channels | |||||
which is twice larger in every block. The number of channels in outer 1x1 | |||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 | |||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048. | |||||
Args: | |||||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||||
progress (bool): If True, displays a progress bar of the download to stderr | |||||
""" | |||||
kwargs["width_per_group"] = 64 * 2 | |||||
return _resnet("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) |
from .gaussianmax import GaussianMax2d | |||||
from .activated_conv import ActivatedConv2d | |||||
from .scaled_sigmoid import ScaledSigmoid | |||||
from .lap import LAP | |||||
from .adaptive_lap import AdaptiveLAP |