| # Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,pycharm+all,macos | |||||
| # Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode,pycharm+all,macos | |||||
| ### macOS ### | |||||
| # General | |||||
| .DS_Store | |||||
| .AppleDouble | |||||
| .LSOverride | |||||
| # Icon must end with two \r | |||||
| Icon | |||||
| # Thumbnails | |||||
| ._* | |||||
| # Files that might appear in the root of a volume | |||||
| .DocumentRevisions-V100 | |||||
| .fseventsd | |||||
| .Spotlight-V100 | |||||
| .TemporaryItems | |||||
| .Trashes | |||||
| .VolumeIcon.icns | |||||
| .com.apple.timemachine.donotpresent | |||||
| # Directories potentially created on remote AFP share | |||||
| .AppleDB | |||||
| .AppleDesktop | |||||
| Network Trash Folder | |||||
| Temporary Items | |||||
| .apdisk | |||||
| ### macOS Patch ### | |||||
| # iCloud generated files | |||||
| *.icloud | |||||
| ### PyCharm+all ### | |||||
| # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider | |||||
| # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 | |||||
| # User-specific stuff | |||||
| .idea/**/workspace.xml | |||||
| .idea/**/tasks.xml | |||||
| .idea/**/usage.statistics.xml | |||||
| .idea/**/dictionaries | |||||
| .idea/**/shelf | |||||
| # AWS User-specific | |||||
| .idea/**/aws.xml | |||||
| # Generated files | |||||
| .idea/**/contentModel.xml | |||||
| # Sensitive or high-churn files | |||||
| .idea/**/dataSources/ | |||||
| .idea/**/dataSources.ids | |||||
| .idea/**/dataSources.local.xml | |||||
| .idea/**/sqlDataSources.xml | |||||
| .idea/**/dynamic.xml | |||||
| .idea/**/uiDesigner.xml | |||||
| .idea/**/dbnavigator.xml | |||||
| # Gradle | |||||
| .idea/**/gradle.xml | |||||
| .idea/**/libraries | |||||
| # Gradle and Maven with auto-import | |||||
| # When using Gradle or Maven with auto-import, you should exclude module files, | |||||
| # since they will be recreated, and may cause churn. Uncomment if using | |||||
| # auto-import. | |||||
| # .idea/artifacts | |||||
| # .idea/compiler.xml | |||||
| # .idea/jarRepositories.xml | |||||
| # .idea/modules.xml | |||||
| # .idea/*.iml | |||||
| # .idea/modules | |||||
| # *.iml | |||||
| # *.ipr | |||||
| # CMake | |||||
| cmake-build-*/ | |||||
| # Mongo Explorer plugin | |||||
| .idea/**/mongoSettings.xml | |||||
| # File-based project format | |||||
| *.iws | |||||
| # IntelliJ | |||||
| out/ | |||||
| # mpeltonen/sbt-idea plugin | |||||
| .idea_modules/ | |||||
| # JIRA plugin | |||||
| atlassian-ide-plugin.xml | |||||
| # Cursive Clojure plugin | |||||
| .idea/replstate.xml | |||||
| # SonarLint plugin | |||||
| .idea/sonarlint/ | |||||
| # Crashlytics plugin (for Android Studio and IntelliJ) | |||||
| com_crashlytics_export_strings.xml | |||||
| crashlytics.properties | |||||
| crashlytics-build.properties | |||||
| fabric.properties | |||||
| # Editor-based Rest Client | |||||
| .idea/httpRequests | |||||
| # Android studio 3.1+ serialized cache file | |||||
| .idea/caches/build_file_checksums.ser | |||||
| ### PyCharm+all Patch ### | |||||
| # Ignore everything but code style settings and run configurations | |||||
| # that are supposed to be shared within teams. | |||||
| .idea/* | |||||
| !.idea/codeStyles | |||||
| !.idea/runConfigurations | |||||
| ### Python ### | |||||
| # Byte-compiled / optimized / DLL files | |||||
| __pycache__/ | |||||
| *.py[cod] | |||||
| *$py.class | |||||
| # C extensions | |||||
| *.so | |||||
| # Distribution / packaging | |||||
| .Python | |||||
| build/ | |||||
| develop-eggs/ | |||||
| dist/ | |||||
| downloads/ | |||||
| eggs/ | |||||
| .eggs/ | |||||
| lib/ | |||||
| lib64/ | |||||
| parts/ | |||||
| sdist/ | |||||
| var/ | |||||
| wheels/ | |||||
| share/python-wheels/ | |||||
| *.egg-info/ | |||||
| .installed.cfg | |||||
| *.egg | |||||
| MANIFEST | |||||
| # PyInstaller | |||||
| # Usually these files are written by a python script from a template | |||||
| # before PyInstaller builds the exe, so as to inject date/other infos into it. | |||||
| *.manifest | |||||
| *.spec | |||||
| # Installer logs | |||||
| pip-log.txt | |||||
| pip-delete-this-directory.txt | |||||
| # Unit test / coverage reports | |||||
| htmlcov/ | |||||
| .tox/ | |||||
| .nox/ | |||||
| .coverage | |||||
| .coverage.* | |||||
| .cache | |||||
| nosetests.xml | |||||
| coverage.xml | |||||
| *.cover | |||||
| *.py,cover | |||||
| .hypothesis/ | |||||
| .pytest_cache/ | |||||
| cover/ | |||||
| # Translations | |||||
| *.mo | |||||
| *.pot | |||||
| # Django stuff: | |||||
| *.log | |||||
| local_settings.py | |||||
| db.sqlite3 | |||||
| db.sqlite3-journal | |||||
| # Flask stuff: | |||||
| instance/ | |||||
| .webassets-cache | |||||
| # Scrapy stuff: | |||||
| .scrapy | |||||
| # Sphinx documentation | |||||
| docs/_build/ | |||||
| # PyBuilder | |||||
| .pybuilder/ | |||||
| target/ | |||||
| # Jupyter Notebook | |||||
| .ipynb_checkpoints | |||||
| # IPython | |||||
| profile_default/ | |||||
| ipython_config.py | |||||
| # pyenv | |||||
| # For a library or package, you might want to ignore these files since the code is | |||||
| # intended to run in multiple environments; otherwise, check them in: | |||||
| # .python-version | |||||
| # pipenv | |||||
| # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | |||||
| # However, in case of collaboration, if having platform-specific dependencies or dependencies | |||||
| # having no cross-platform support, pipenv may install dependencies that don't work, or not | |||||
| # install all needed dependencies. | |||||
| #Pipfile.lock | |||||
| # poetry | |||||
| # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. | |||||
| # This is especially recommended for binary packages to ensure reproducibility, and is more | |||||
| # commonly ignored for libraries. | |||||
| # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control | |||||
| #poetry.lock | |||||
| # pdm | |||||
| # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. | |||||
| #pdm.lock | |||||
| # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it | |||||
| # in version control. | |||||
| # https://pdm.fming.dev/#use-with-ide | |||||
| .pdm.toml | |||||
| # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm | |||||
| __pypackages__/ | |||||
| # Celery stuff | |||||
| celerybeat-schedule | |||||
| celerybeat.pid | |||||
| # SageMath parsed files | |||||
| *.sage.py | |||||
| # Environments | |||||
| .env | |||||
| .venv | |||||
| env/ | |||||
| venv/ | |||||
| ENV/ | |||||
| env.bak/ | |||||
| venv.bak/ | |||||
| # Spyder project settings | |||||
| .spyderproject | |||||
| .spyproject | |||||
| # Rope project settings | |||||
| .ropeproject | |||||
| # mkdocs documentation | |||||
| /site | |||||
| # mypy | |||||
| .mypy_cache/ | |||||
| .dmypy.json | |||||
| dmypy.json | |||||
| # Pyre type checker | |||||
| .pyre/ | |||||
| # pytype static type analyzer | |||||
| .pytype/ | |||||
| # Cython debug symbols | |||||
| cython_debug/ | |||||
| # PyCharm | |||||
| # JetBrains specific template is maintained in a separate JetBrains.gitignore that can | |||||
| # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore | |||||
| # and can be added to the global gitignore or merged into this file. For a more nuclear | |||||
| # option (not recommended) you can uncomment the following to ignore the entire idea folder. | |||||
| #.idea/ | |||||
| ### Python Patch ### | |||||
| # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration | |||||
| poetry.toml | |||||
| # ruff | |||||
| .ruff_cache/ | |||||
| ### VisualStudioCode ### | |||||
| .vscode/* | |||||
| !.vscode/settings.json | |||||
| !.vscode/tasks.json | |||||
| !.vscode/launch.json | |||||
| !.vscode/extensions.json | |||||
| !.vscode/*.code-snippets | |||||
| # Local History for Visual Studio Code | |||||
| .history/ | |||||
| # Built Visual Studio Code Extensions | |||||
| *.vsix | |||||
| ### VisualStudioCode Patch ### | |||||
| # Ignore all local history of files | |||||
| .history | |||||
| .ionide | |||||
| # End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,pycharm+all,macos |
| # LAP | |||||
| An Attention-Based Module for Faithful Interpretation and Knowledge Injection in Convolutional Neural Networks | |||||
| In this supplementary material, we have provided a pdf containing more details for some of the sections mentioned in the paper, and the code of our work. | |||||
| # Data Preparation | |||||
| ## Metadata | |||||
| First, download the metadata from [here](https://mega.nz/file/MqhXiLgC#NK1vd9ZLGU-3a182x-BXfvGCSuqHgq9hflI6tddh4Wc) | |||||
| ## RSNA | |||||
| ## CelebA | |||||
| 1. Download the dataset from [here](https://www.kaggle.com/jessicali9530/celeba-dataset) | |||||
| 2. Extract the data | |||||
| ```bash | |||||
| unzip -q -j celeba-dataset.zip '**/*.jpg' -d data/celeba | |||||
| ``` | |||||
| 3. Extract `celeba.tsv` from [this](#metadata) metadata file to `dataset_metadata/celeba.tsv` | |||||
| ## Imagenet | |||||
| 1. Download the ImageNet ILSVRC2012 dataset from [here](https://www.image-net.org/challenges/LSVRC/index.php). | |||||
| 1. Extract `ILSVRC2012_img_train_t3.tar` to `data/imagenet/tars`. | |||||
| 1. Run the following command to extract the train images per class: | |||||
| ```bash | |||||
| python data_preparation/imagenet/extract_train.py -s data/imagenet/tars/ -n <num-threads> | |||||
| ``` | |||||
| 1. Run the following command to extract the validation images per class (replace directory with the directory containing `ILSVRC2012_img_val.tar`, and extract `imagenet_val_maps.pklz` from [this](#metadata) metadata file): | |||||
| ```bash | |||||
| python data_preparation/imagenet/extract_val.py -s <directory> -m <imagenet_val_maps.pklz> | |||||
| ``` | |||||
| 1. download the localizations bounding-boxes from [here](https://www.kaggle.com/c/imagenet-object-localization-challenge/data). Put `LOC_train_solution.csv` and `LOC_val_solution.csv` in `data/imagenet/extracted`. | |||||
| # Model checkpoints | |||||
| Download all the checkpoints from [here](https://mega.nz/file/16pWWA5L#b1SwLEv-YO5lL9wLVTcgykn2LmUBHm1exPFNB0JBoZo). | |||||
| # Run | |||||
| ## Evaluate the checkpoints | |||||
| ### RSNA | |||||
| #### Vanilla ResNet 18 | |||||
| ##### Evaluate the model performance on the test set | |||||
| ```bash | |||||
| python evaluate.py rsna.org_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.org_resnet.pth | |||||
| ``` | |||||
| #### Vanilla Inception V3 | |||||
| ##### Evaluate the model performance on the test set | |||||
| ```bash | |||||
| python evaluate.py rsna.org_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.org_inception.pth | |||||
| ``` | |||||
| #### WS ResNet 18 | |||||
| ##### Evaluate the model performance on the test set | |||||
| ```bash | |||||
| python evaluate.py rsna.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_resnet.pth | |||||
| ``` | |||||
| ##### Generate the LAP interpretations for positive samples | |||||
| ```bash | |||||
| python interpret.py rsna.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir rsna/ws_resnet_interpretations | |||||
| ``` | |||||
| ##### Generate the LAP interpretations for negative samples | |||||
| ```bash | |||||
| python interpret.py rsna.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir rsna/ws_resnet_interpretations | |||||
| ``` | |||||
| #### WS Inception V3 | |||||
| ##### Evaluate the model performance on the test set | |||||
| ```bash | |||||
| python evaluate.py rsna.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_inception.pth | |||||
| ``` | |||||
| ##### Generate the LAP interpretations for positive samples | |||||
| ```bash | |||||
| python interpret.py rsna.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir rsna/ws_inception_interpretations | |||||
| ``` | |||||
| ##### Generate the LAP interpretations for negative samples | |||||
| ```bash | |||||
| python interpret.py rsna.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir rsna/ws_inception_interpretations | |||||
| ``` | |||||
| #### BB ResNet 18 | |||||
| ##### Evaluate the model performance on the test set | |||||
| ```bash | |||||
| python evaluate.py rsna.bb_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.bb_resnet.pth | |||||
| ``` | |||||
| ##### Generate the LAP interpretations for positive samples | |||||
| ```bash | |||||
| python interpret.py rsna.bb_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.bb_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir rsna/bb_resnet_interpretations | |||||
| ``` | |||||
| ##### Generate the LAP interpretations for negative samples | |||||
| ```bash | |||||
| python interpret.py rsna.bb_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.bb_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir rsna/bb_resnet_interpretations | |||||
| ``` | |||||
| #### BB Inception V3 | |||||
| ##### Evaluate the model performance on the test set | |||||
| ```bash | |||||
| python evaluate.py rsna.bb_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.bb_inception.pth | |||||
| ``` | |||||
| ##### Generate the LAP interpretations for positive samples | |||||
| ```bash | |||||
| python interpret.py rsna.bb_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.bb_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir rsna/bb_inception_interpretations | |||||
| ``` | |||||
| ##### Generate the LAP interpretations for negative samples | |||||
| ```bash | |||||
| python interpret.py rsna.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir rsna/ws_inception_interpretations | |||||
| ``` | |||||
| ### CelebA | |||||
| #### Vanilla ResNet 18 | |||||
| ##### Evaluate the model performance on the test set | |||||
| ```bash | |||||
| python evaluate.py celeba.org_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.org_resnet.pth | |||||
| ``` | |||||
| #### Vanilla Inception V3 | |||||
| ##### Evaluate the model performance on the test set | |||||
| ```bash | |||||
| python evaluate.py celeba.org_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.org_inception.pth | |||||
| ``` | |||||
| #### WS ResNet 18 | |||||
| ##### Evaluate the model performance on the test set | |||||
| ```bash | |||||
| python evaluate.py celeba.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_resnet.pth | |||||
| ``` | |||||
| ##### Generate the LAP interpretations for positive samples | |||||
| ```bash | |||||
| python interpret.py celeba.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir celeba/ws_resnet_interpretations | |||||
| ``` | |||||
| ##### Generate the LAP interpretations for negative samples | |||||
| ```bash | |||||
| python interpret.py celeba.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir celeba/ws_resnet_interpretations | |||||
| ``` | |||||
| #### WS Inception V3 | |||||
| ##### Evaluate the model performance on the test set | |||||
| ```bash | |||||
| python evaluate.py celeba.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_inception.pth | |||||
| ``` | |||||
| ##### Generate the LAP interpretations for positive samples | |||||
| ```bash | |||||
| python interpret.py celeba.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir celeba/ws_inception_interpretations | |||||
| ``` | |||||
| ##### Generate the LAP interpretations for negative samples | |||||
| ```bash | |||||
| python interpret.py celeba.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir celeba/ws_inception_interpretations | |||||
| ``` | |||||
| ### ImageNet | |||||
| #### WS ResNet 50 (Fine-tuned) | |||||
| ##### Evaluate the model performance on the validation set | |||||
| ```bash | |||||
| python evaluate.py imagenet.lap_resnet50_ft --device cuda:0 --samples-dir val --final-model-dir checkpoints/imagenet.ws_resnet_ft.pth | |||||
| ``` |
| import argparse | |||||
| import glob | |||||
| import os | |||||
| import tarfile | |||||
| from multiprocessing import Pool | |||||
| parser = argparse.ArgumentParser() | |||||
| parser.add_argument('-s', dest='source', help='Class tars directory', required=True) | |||||
| parser.add_argument('-t', dest='target', help='train set directory', default='data/imagenet/extracted') | |||||
| parser.add_argument('-n', dest='num_threads', help='number of threads', default=1, type=int) | |||||
| args = parser.parse_args() | |||||
| class_tars = glob.glob(os.path.join(args.source, '*.tar')) | |||||
| assert len(class_tars) == 1000, f"class_tars length: {len(class_tars)}" | |||||
| def extract(class_tar): | |||||
| filename = os.path.basename(class_tar) | |||||
| class_name = filename.replace('.tar', '') | |||||
| print('Extract ' + os.path.basename(class_tar)) | |||||
| class_fname = os.path.join(args.target, class_name) | |||||
| os.makedirs(class_fname, exist_ok=True) | |||||
| with tarfile.open(class_tar) as f: | |||||
| f.extractall(class_fname) | |||||
| pool = Pool(args.num_threads) | |||||
| pool.map(extract, class_tars) | |||||
| pool.close() |
| """Prepare the ImageNet dataset""" | |||||
| # import torch | |||||
| import os | |||||
| import argparse | |||||
| import tarfile | |||||
| import pickle | |||||
| import gzip | |||||
| # import subprocess | |||||
| # from tqdm import tqdm | |||||
| # from mxnet.gluon.utils import check_sha1 | |||||
| # from gluoncv.utils import download, makedirs | |||||
| _VAL_TAR = 'ILSVRC2012_img_val.tar' | |||||
| def parse_args(): | |||||
| parser = argparse.ArgumentParser( | |||||
| description='Setup the ImageNet dataset.', | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||||
| parser.add_argument('-s', dest='download_dir', required=True, | |||||
| help="The directory that contains downloaded tar files") | |||||
| parser.add_argument('-m', dest='mapping', required=True, | |||||
| help="The mapping file for validation set") | |||||
| parser.add_argument('--target-dir', default='data/imagenet/extracted', | |||||
| help="The directory to store extracted images") | |||||
| args = parser.parse_args() | |||||
| return args | |||||
| def check_file(filename): | |||||
| if not os.path.exists(filename): | |||||
| raise ValueError('File not found: '+filename) | |||||
| def extract_val(tar_fname, target_dir, val_maps_file): | |||||
| os.makedirs(target_dir) | |||||
| print('Extracting ' + tar_fname) | |||||
| with tarfile.open(tar_fname) as tar: | |||||
| tar.extractall(target_dir) | |||||
| # build rec file before images are moved into subfolders | |||||
| # move images to proper subfolders | |||||
| with gzip.open(val_maps_file, 'rb') as f: | |||||
| dirs, mappings = pickle.load(f) | |||||
| for d in dirs: | |||||
| os.makedirs(os.path.join(target_dir, d)) | |||||
| for m in mappings: | |||||
| os.rename(os.path.join(target_dir, m[0]), os.path.join(target_dir, m[1], m[0])) | |||||
| def main(): | |||||
| args = parse_args() | |||||
| target_dir = os.path.expanduser(args.target_dir) | |||||
| if os.path.exists(target_dir): | |||||
| raise ValueError('Target dir ['+target_dir+'] exists. Remove it first') | |||||
| download_dir = os.path.expanduser(args.download_dir) | |||||
| val_tar_fname = os.path.join(download_dir, _VAL_TAR) | |||||
| check_file(val_tar_fname) | |||||
| extract_val(val_tar_fname, os.path.join(target_dir, 'val'), args.mapping) | |||||
| if __name__ == '__main__': | |||||
| main() |
| from typing import List, Tuple | |||||
| import argparse | |||||
| from os import makedirs, path, listdir | |||||
| import numpy as np | |||||
| import cv2 | |||||
| from multiprocessing import Pool | |||||
| def resize_image(img_dir: str, img_save_dir: str, res: int) -> None: | |||||
| img = cv2.imread(img_dir) | |||||
| img = cv2.resize(img, dsize=(res, res)) | |||||
| cv2.imwrite(img_save_dir, img) | |||||
| if __name__ == '__main__': | |||||
| parser = argparse.ArgumentParser() | |||||
| parser.add_argument('kaggle_dataset_dir', type=str, help='download the dataset from https://www.kaggle.com/c/rsna-pneumonia-detection-challenge/data, extract it, and pass its path is this argument') | |||||
| parser.add_argument('resolution', type=int, help='The resolution required for your model, 224 for resnet, 299 for vanilla inception, 256 for modified inception') | |||||
| parser.add_argument('cores', type=int, help='The number of cores for multiprocessing.') | |||||
| args = parser.parse_args() | |||||
| save_dir = f'data/RSNA-Kaggle_R{args.resolution}' | |||||
| makedirs(save_dir, exist_ok=True) | |||||
| kaggle_dataset_dir = args.kaggle_dataset_dir | |||||
| assert path.exists(kaggle_dataset_dir), f'{kaggle_dataset_dir} does not exist!' | |||||
| # reading rsna images names | |||||
| rsna_imgs_path = path.join(kaggle_dataset_dir, 'stage_2_train_images') | |||||
| assert path.exists(rsna_imgs_path), 'Make sure there is a folder named stage_2_train_images in the passed kaggle_directory!' | |||||
| imgs_names = np.asarray(listdir(rsna_imgs_path)) | |||||
| imgs_src = np.vectorize(lambda x: path.join(rsna_imgs_path, x))(imgs_names) | |||||
| imgs_dst = np.vectorize(lambda x: path.join(save_dir, x))(imgs_names) | |||||
| pool = Pool(args.cores) | |||||
| pool.starmap(resize_image, zip(imgs_src, imgs_dst, np.full((len(imgs_src),), args.resolution))) | |||||
| pool.close() | |||||
| import traceback | |||||
| from torchlap.experiments._entry_loader import get_entry, PhaseType | |||||
| from torchlap.data.data_loader import DataLoader, RunType | |||||
| from torchlap.experiments._model_loading import load_model | |||||
| from torchlap.utils.hooker import Hooker | |||||
| def main() -> None: | |||||
| ep = get_entry(PhaseType.EVAL) | |||||
| conf = ep.conf | |||||
| model = ep.model | |||||
| model = load_model(model, conf) | |||||
| model.eval() | |||||
| for test_group_info in conf.samples_dir.split(','): | |||||
| try: | |||||
| print('') | |||||
| print('>> Running evaluations for %s' % test_group_info, flush=True) | |||||
| test_data_loader = DataLoader(conf, test_group_info, RunType.TEST) | |||||
| evaluator = conf.evaluator_cls(model, test_data_loader, conf) | |||||
| with Hooker(*conf.hooks_by_phase[conf.phase_type]): | |||||
| evaluator.evaluate() | |||||
| except Exception: | |||||
| print('Problem in %s' % test_group_info, flush=True) | |||||
| track = traceback.format_exc() | |||||
| print(track, flush=True) | |||||
| if __name__ == '__main__': | |||||
| main() |
| from torchlap.experiments._entry_loader import get_entry, PhaseType | |||||
| from torchlap.experiments._model_loading import load_model | |||||
| from torchlap.interpreting.interpretation_evaluation_runner import InterpretingEvalRunner | |||||
| from torchlap.utils.hooker import Hooker | |||||
| def main() -> None: | |||||
| ep = get_entry(PhaseType.EVALINTERPRET) | |||||
| conf = ep.conf | |||||
| model = ep.model | |||||
| model = load_model(model, conf) | |||||
| model.eval() | |||||
| runner = InterpretingEvalRunner(conf, model) | |||||
| with Hooker(*conf.hooks_by_phase[conf.phase_type]): | |||||
| runner.evaluate() | |||||
| if __name__ == '__main__': | |||||
| main() |
| from torchlap.experiments._entry_loader import get_entry, PhaseType | |||||
| from torchlap.experiments._model_loading import load_model | |||||
| from torchlap.interpreting.interpreting_runner import InterpretingRunner | |||||
| from torchlap.utils.hooker import Hooker | |||||
| def main() -> None: | |||||
| ep = get_entry(PhaseType.INTERPRET) | |||||
| conf = ep.conf | |||||
| model = ep.model | |||||
| model = load_model(model, conf) | |||||
| model.eval() | |||||
| runner = InterpretingRunner(conf, model) | |||||
| with Hooker(*conf.hooks_by_phase[conf.phase_type]): | |||||
| runner.interpret() | |||||
| if __name__ == '__main__': | |||||
| main() |
| numpy | |||||
| pandas | |||||
| torch>=1.7.0 | |||||
| torchvision>=0.8.1 | |||||
| scikit-image>=0.14.3 | |||||
| scikit-learn>=0.19.1 | |||||
| captum>=0.3.0 | |||||
| opencv-python>=4.4.0.46 |
| __version__ = 'master' |
| from typing import TYPE_CHECKING, Union, Dict, Tuple, Type, List, Callable, Iterator, Optional | |||||
| import os | |||||
| import random | |||||
| import numpy as np | |||||
| from sys import stderr | |||||
| from enum import Enum | |||||
| import torch | |||||
| from ..interpreting.interpreter_maker import InterpretationType | |||||
| from ..data.batch_choosing.class_balanced_shuffled_sequential import ClassBalancedShuffledSequentialBatchChooser | |||||
| from ..data.batch_choosing.sequential import SequentialBatchChooser | |||||
| from ..utils.hooker import HookPair | |||||
| if TYPE_CHECKING: | |||||
| from ..model_evaluation.evaluator import Evaluator | |||||
| from ..data.batch_choosing.batch_chooser import BatchChooser | |||||
| from ..data.content_loaders.content_loader import ContentLoader | |||||
| class PhaseType(Enum): | |||||
| TRAIN = 'train' | |||||
| EVAL = 'eval' | |||||
| INTERPRET = 'interpret' | |||||
| EVALINTERPRET = 'evalinterpret' | |||||
| class BaseConfig: | |||||
| def __init__(self, | |||||
| try_name: str, try_num: int, data_separation: str, | |||||
| input_size: int, | |||||
| phase_type: PhaseType, | |||||
| content_loader_cls: Type['ContentLoader'], | |||||
| evaluator_cls: Type['Evaluator']) -> None: | |||||
| # -------------------- ESSENTIAL CONFIG -------------------- | |||||
| self.data_separation = data_separation | |||||
| self.try_name = try_name | |||||
| self.try_num = try_num | |||||
| self.phase_type = phase_type | |||||
| # classes | |||||
| self.evaluator_cls = evaluator_cls | |||||
| self.content_loader_cls = content_loader_cls | |||||
| self.train_batch_chooser_cls: Type['BatchChooser'] = ClassBalancedShuffledSequentialBatchChooser | |||||
| self.eval_batch_chooser_cls: Type['BatchChooser'] = SequentialBatchChooser | |||||
| # Setups | |||||
| self.input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = input_size | |||||
| self.batch_size = 64 | |||||
| self.random_seed: int = 17 | |||||
| self.samples_dir: Union[str, None] = None | |||||
| # Saving and loading directories! | |||||
| self.save_dir: Union[str, None] = None | |||||
| self.report_dir: Union[str, None] = None | |||||
| self.final_model_dir: Union[str, None] = None | |||||
| self.epoch: Union[str, None] = None | |||||
| # Device setting | |||||
| self.dev_name = None | |||||
| self.device = None | |||||
| # -------------------- ESSENTIAL CONFIG -------------------- | |||||
| # -------------------- TRAINING CONFIG -------------------- | |||||
| self.pretrained_model_file = None | |||||
| self.big_batch_size: int = 1 | |||||
| self.max_epochs: int = 10 | |||||
| self.iters_per_epoch: Union[int, None] = None | |||||
| self.val_iters_per_epoch: Union[int, None] = None | |||||
| # cleaning log while training | |||||
| self.keep_best_and_last_epochs_only = False | |||||
| # auto-stopping the train | |||||
| self.n_unsuccessful_epochs_to_stop: Union[None, int] = 100 | |||||
| self.different_augmentation_per_batch = True | |||||
| self.augmentations_dict: Union[Dict[str, torch.nn.Module], None] = None | |||||
| self.title_of_reference_metric_to_choose_best_epoch = 'Loss' | |||||
| self.operator_to_decide_on_improvement_of_val_reference_metric = '<=' | |||||
| # for dataloader | |||||
| self.mapped_labels_to_use: Union[List[int], None] = None # None means all, a list of ints= the ones to consider | |||||
| # hooking | |||||
| self.hooks_by_phase: Dict[PhaseType, List[HookPair]] = { | |||||
| phasetype: [] for phasetype in PhaseType | |||||
| } | |||||
| # -------------------- TRAINING CONFIG -------------------- | |||||
| # ------------------ OPTIMIZATION CONFIG ------------------ | |||||
| self.init_lr: float = 1e-4 | |||||
| self.lr_decay: float = 1e-6 | |||||
| self.freezing_regexes: Union[None, List[str]] = None | |||||
| self.optimizer_creator: Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer] = get_adam_creator(self.init_lr, self.lr_decay) | |||||
| # ------------------ OPTIMIZATION CONFIG ------------------ | |||||
| # ----------------- INTERPRETATION CONFIG ----------------- | |||||
| self.save_by_file_name = False | |||||
| self.interpretation_method = InterpretationType.GuidedBackprop | |||||
| self.class_label_for_interpretation: Union[int, None] = None | |||||
| self.interpret_predictions_vs_gt: bool = True | |||||
| # saving interpretations: | |||||
| self.skip_overlay = False | |||||
| self.skip_raw = False | |||||
| # FOR INTERPRETATION EVALUATION | |||||
| # A threshold to cut all interpretations lower than | |||||
| self.cut_threshold: float = 0 | |||||
| # whether the threshold is applied to un-normalized interpretation map | |||||
| self.global_threshold: bool = False | |||||
| # whether to use dynamic threshold | |||||
| self.dynamic_threshold: bool = False | |||||
| self.prediction_key_in_model_output_dict = 'positive_class_probability' | |||||
| # for when interpretation masks as continuous objects are compared with GT | |||||
| self.acceptable_min_intersection_threshold: float = 0.5 | |||||
| # the key for interpretation to be evaluated, None = the first one available | |||||
| self.interpretation_tag_to_evaluate = None | |||||
| self.n_interpretation_samples: Optional[int] = None | |||||
| # ----------------- INTERPRETATION CONFIG ----------------- | |||||
| def __repr__(self): | |||||
| """ | |||||
| Represent config as string | |||||
| """ | |||||
| return str(self.__dict__) | |||||
| def _update_based_on_args(self, args) -> None: | |||||
| """ Updates the values based on the received arguments from the input! """ | |||||
| args_dict = args.__dict__ | |||||
| for arg in args_dict: | |||||
| if arg in self.__dict__ and args_dict[arg] is not None: | |||||
| setattr(self, arg, args_dict[arg]) | |||||
| def update(self, args): | |||||
| """ | |||||
| updates the config from args | |||||
| """ | |||||
| # dirty piece of code for repetition | |||||
| self._update_based_on_args(args) | |||||
| self._update_based_on_args(args) | |||||
| self.dev_name, self.device = _get_device(args) | |||||
| self._set_random_seeds() | |||||
| self._set_save_dir() | |||||
| self._update_report_dir() | |||||
| def _set_random_seeds(self): | |||||
| # Set the seed for hash based operations in python | |||||
| os.environ['PYTHONHASHSEED'] = '0' | |||||
| torch.manual_seed(self.random_seed) | |||||
| torch.cuda.manual_seed_all(self.random_seed) | |||||
| np.random.seed(self.random_seed) | |||||
| random.seed(self.random_seed) | |||||
| if self.phase_type != PhaseType.TRAIN: | |||||
| torch.backends.cudnn.deterministic = True | |||||
| torch.backends.cudnn.benchmark = False | |||||
| else: | |||||
| torch.backends.cudnn.deterministic = False | |||||
| torch.backends.cudnn.benchmark = True | |||||
| # Set the numpy seed | |||||
| np.random.seed(self.random_seed) | |||||
| def _set_save_dir(self): | |||||
| if self.save_dir is None: | |||||
| self.save_dir = f"../Results/{self.try_num}_{self.try_name}" | |||||
| os.makedirs(self.save_dir, exist_ok=True) | |||||
| def _update_report_dir(self): | |||||
| if self.report_dir is None: | |||||
| self.report_dir = self.save_dir + '/' + self.phase_type.value | |||||
| def get_sample_group_specific_report_dir(self, samples_specification, extra_subdir=None) -> str: | |||||
| if extra_subdir is None: | |||||
| extra_subdir = '' | |||||
| else: | |||||
| extra_subdir = '/' + extra_subdir | |||||
| # Making sure of keeping version information! | |||||
| if self.final_model_dir is not None: | |||||
| report_dir = '%s/%s%s' % ( | |||||
| self.report_dir, | |||||
| self.final_model_dir[self.final_model_dir.rfind('/') + 1:self.final_model_dir.rfind('.')], | |||||
| extra_subdir) | |||||
| else: | |||||
| epoch = self.epoch | |||||
| if epoch is None: | |||||
| epoch = 'best' | |||||
| report_dir = '%s/epoch,%s%s' % ( | |||||
| self.report_dir, | |||||
| epoch, extra_subdir) | |||||
| # adding sample info | |||||
| if samples_specification not in ['train', 'test', 'val'] and not os.path.isfile(samples_specification): | |||||
| if ':' in samples_specification: | |||||
| base_info, _ = tuple(samples_specification.split(':')) | |||||
| else: | |||||
| base_info = samples_specification | |||||
| if not os.path.isdir(base_info): | |||||
| report_dir = report_dir + '/' + \ | |||||
| samples_specification.replace(':', '_').replace('../', '') | |||||
| else: | |||||
| report_dir = report_dir + '/' + samples_specification | |||||
| return self.get_unique_save_dir(report_dir) | |||||
| @staticmethod | |||||
| def get_unique_save_dir(sd): | |||||
| if os.path.exists(sd): | |||||
| report_num = 2 | |||||
| base_sd = sd | |||||
| while os.path.exists(sd): | |||||
| sd = base_sd + '_%d' % report_num | |||||
| report_num += 1 | |||||
| return sd | |||||
| def get_save_dir_for_sample(self, save_dir, sample_name): | |||||
| if self.save_by_file_name and '/' in sample_name: | |||||
| return save_dir + '/' + sample_name[sample_name.rfind('/') + 1:] | |||||
| else: | |||||
| return save_dir + '/' + sample_name | |||||
| def _get_device(args): | |||||
| """ cuda num, -1 means parallel, -2 means cpu, no cuda means cpu""" | |||||
| dev_name = args.device | |||||
| if 'cuda' in dev_name: | |||||
| gpu_num = 0 | |||||
| if ':' in dev_name: | |||||
| gpu_num = int(dev_name.split(':')[1]) | |||||
| if gpu_num >= torch.cuda.device_count(): | |||||
| print('No %s, using CPU instead!' % dev_name, file=stderr) | |||||
| dev_name = 'cpu' | |||||
| print('@ Running on CPU @', flush=True) | |||||
| if dev_name == 'cuda': | |||||
| dev_name = None | |||||
| the_device = torch.device('cuda') | |||||
| print('@ Running on all %d GPUs @' % torch.cuda.device_count(), flush=True) | |||||
| elif 'cuda:' in dev_name: | |||||
| the_device = torch.device(dev_name) | |||||
| print('@ Running on GPU:%d @' % int(dev_name.split(':')[1]), flush=True) | |||||
| else: | |||||
| dev_name = 'cpu' | |||||
| the_device = torch.device(dev_name) | |||||
| print('@ Running on CPU @', flush=True) | |||||
| return dev_name, the_device | |||||
| def get_adam_creator(init_lr: float = 1e-4, lr_decay: float = 1e-6)\ | |||||
| -> Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer]: | |||||
| def create_adam(params: Iterator[torch.nn.Parameter]) -> torch.optim.Optimizer: | |||||
| return torch.optim.Adam(params, lr=init_lr, weight_decay=lr_decay) | |||||
| return create_adam | |||||
| def get_sgd_creator(init_lr: float = 1e-4, lr_decay: float = 1e-6)\ | |||||
| -> Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer]: | |||||
| def create_sgd(params: Iterator[torch.nn.Parameter]) -> torch.optim.Optimizer: | |||||
| return torch.optim.SGD(params, lr=init_lr, weight_decay=lr_decay) | |||||
| return create_sgd | |||||
| from .base_config import BaseConfig, PhaseType | |||||
| from typing import Union | |||||
| from torch import nn | |||||
| from torchvision import transforms | |||||
| from ..data.content_loaders.celeba_loader import CelebALoader, CelebATag | |||||
| from ..model_evaluation.binary_evaluator import BinaryEvaluator | |||||
| class CelebAConfigs(BaseConfig): | |||||
| def __init__(self, | |||||
| try_name: str, try_num: int, input_size, phase_type: PhaseType) -> None: | |||||
| super().__init__(try_name, try_num, 'default', input_size, phase_type, CelebALoader, BinaryEvaluator) | |||||
| # replaced configs! | |||||
| self.batch_size = 64 | |||||
| self.iters_per_epoch = 200 | |||||
| self.tags = [tag.name for tag in CelebATag if tag.name.endswith('Tag')] | |||||
| self.main_tag: Union[CelebATag, None] = None | |||||
| self.data_root: str = 'data/celeba' | |||||
| self.dataset_metadata: str = 'dataset_metadata/celeba.tsv' | |||||
| self.title_of_reference_metric_to_choose_best_epoch = 'AvgSS' | |||||
| self.operator_to_decide_on_improvement_of_val_reference_metric = '>=' | |||||
| self.keep_best_and_last_epochs_only = True | |||||
| # augmentation | |||||
| self.augmentations_dict = { | |||||
| 'x': nn.Sequential( | |||||
| transforms.RandomRotation(45), | |||||
| transforms.RandomAffine(0, shear=0.2, scale=(.8, 1.2)), | |||||
| transforms.RandomHorizontalFlip(), | |||||
| transforms.ColorJitter(brightness=.5, contrast=.5, saturation=.5), | |||||
| transforms.RandomPerspective()) | |||||
| } | |||||
| from .base_config import BaseConfig, PhaseType | |||||
| from ..data.content_loaders.imagenet_loader import ImagenetLoader | |||||
| from ..model_evaluation.multiclass_evaluator import MulticlassEvaluator | |||||
| class ImagenetConfigs(BaseConfig): | |||||
| def __init__(self, try_name: str, try_num: int, input_size, phase_type: PhaseType) -> None: | |||||
| super().__init__(try_name, try_num, 'data/imagenet/extracted', input_size, phase_type, ImagenetLoader, MulticlassEvaluator.standard_creator('categorical_probability', include_top5=True)) | |||||
| # replaced configs! | |||||
| self.batch_size = 64 | |||||
| self.big_batch_size = 4 | |||||
| self.max_epochs = 10 | |||||
| self.title_of_reference_metric_to_choose_best_epoch = 'Accuracy' | |||||
| self.operator_to_decide_on_improvement_of_val_reference_metric = '>=' | |||||
| self.keep_best_and_last_epochs_only = True |
| from .base_config import BaseConfig, PhaseType | |||||
| from typing import Dict, List, Optional | |||||
| from torch import nn | |||||
| from torchvision import transforms | |||||
| from ..utils.random_brightness_augmentation import ClippedBrightnessAugment | |||||
| from ..data.content_loaders.rsna_loader import RSNALoader | |||||
| from ..model_evaluation.binary_evaluator import BinaryEvaluator | |||||
| class RSNAConfigs(BaseConfig): | |||||
| def __init__(self, try_name: str, try_num: int, input_size, phase_type: PhaseType) -> None: | |||||
| super().__init__(try_name, try_num, f'dataset_metadata/RSNA/DataSeparation_R{input_size}', input_size, phase_type, RSNALoader, BinaryEvaluator) | |||||
| # replaced configs! | |||||
| self.batch_size = 64 | |||||
| self.max_epochs = 200 | |||||
| self.label_map_dict: Dict[str, int] = { | |||||
| 'healthy': 0, 'pneumonia': 1 | |||||
| } | |||||
| self.augmentations_dict = { | |||||
| 'x': nn.Sequential( | |||||
| transforms.RandomRotation(45), | |||||
| transforms.RandomAffine(0, shear=0.4), | |||||
| transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4)), | |||||
| ClippedBrightnessAugment(0.5, 1.5, 0, 1)), | |||||
| 'infection': nn.Sequential( | |||||
| transforms.RandomRotation(45), | |||||
| transforms.RandomAffine(0, shear=0.4), | |||||
| transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4))), | |||||
| 'interpretations': nn.Sequential( | |||||
| transforms.RandomRotation(45), | |||||
| transforms.RandomAffine(0, shear=0.4), | |||||
| transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4))), | |||||
| } | |||||
| self.title_of_reference_metric_to_choose_best_epoch = 'BAcc' | |||||
| self.operator_to_decide_on_improvement_of_val_reference_metric = '>=' | |||||
| self.keep_best_and_last_epochs_only = True | |||||
| self.infection_map_size = self.input_size | |||||
| self.receptive_field_radii: List[int] = [] | |||||
| self.infections_bb_dir: Optional[str] = 'data/RSNA/infection_bounding_boxs.tsv' |
| from abc import ABC, abstractmethod | |||||
| from functools import partial | |||||
| from typing import Dict, List, Tuple | |||||
| import torch | |||||
| from torch import nn | |||||
| from ..configs.base_config import BaseConfig, PhaseType | |||||
| class AuxLoss(ABC): | |||||
| def __init__(self, title: str, model: nn.Module, | |||||
| layer_by_name: Dict[str, nn.Module], | |||||
| loss_weight: float): | |||||
| """This class is used for calculating loss based on the middle layers of the network. | |||||
| Args: | |||||
| title (str): A unique title, the required inputs would be kept by this title in the output dictionary. Also the loss term would be added as title_loss | |||||
| model (nn.Module): The base model | |||||
| named_layers_list (List[Tuple[str, torch.nn.Module]]): A list of pairs, layer name and the layer. The names used for one instance of this class should be unique. | |||||
| loss_weight (float): The weight of the loss in the final loss. The loss term would be automatically added to the model's loss in model_output dictionary by this factor. | |||||
| """ | |||||
| self._title = title | |||||
| self._model = model | |||||
| self._loss_weight = loss_weight | |||||
| self._layers_names = layer_by_name.keys() | |||||
| self._layers_outputs_by_name: Dict[str, torch.Tensor] = dict() | |||||
| self.hook_pairs = [ | |||||
| (module, partial(self._save_layer_val_hook, name)) | |||||
| for name, module in layer_by_name.items() | |||||
| ] + [(model, self._add_loss)] | |||||
| def _save_layer_val_hook(self, layer_name, _, __, layer_out): | |||||
| self._layers_outputs_by_name[layer_name] = layer_out | |||||
| def configure(self, config: BaseConfig) -> None: | |||||
| """Must be called to add the necessary configs and hooks so the loss will be calculated! | |||||
| Args: | |||||
| config (BaseConfig): The base config! | |||||
| """ | |||||
| for phase in PhaseType: | |||||
| config.hooks_by_phase[phase] += self.hook_pairs | |||||
| def _add_loss(self, _, __, model_out: Dict[str, torch.Tensor]): | |||||
| """ A hook for calculating and adding loss the the model's output dictionary | |||||
| Args: | |||||
| _ ([type]): Model, useless! (just the hook format) | |||||
| __ ([type]): Model input, also useless! (just the hook format) | |||||
| model_out (Dict[str, torch.Tensor]): The output dictionary of the model | |||||
| """ | |||||
| # popping values of tensors | |||||
| layers_values = [] | |||||
| for layer_name in self._layers_names: | |||||
| assert layer_name in self._layers_outputs_by_name, f'The output of {layer_name} was not found. It seems the forward function has not been called.' | |||||
| layers_values.append((layer_name, self._layers_outputs_by_name.pop(layer_name))) | |||||
| loss = self._calculate_loss(layers_values, model_out) | |||||
| # adding loss to output dictionary | |||||
| assert f'{self._title}_loss' not in model_out, f'Trying to add {self._title}_loss to model\'s output multiple times!' | |||||
| model_out[f'{self._title}_loss'] = loss.clone() | |||||
| # adding loss to the main loss | |||||
| model_out['loss'] = model_out.get('loss', torch.zeros([], dtype=loss.dtype, device=loss.device)) + self._loss_weight * loss | |||||
| return model_out | |||||
| @abstractmethod | |||||
| def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_output: Dict[str, torch.Tensor]) -> torch.Tensor: | |||||
| """Calculates loss based on the output of the layers given in the constructor. | |||||
| Args: | |||||
| layers_values (List[Tuple[str, torch.Tensor]]): List of layers names and their output value | |||||
| model_output (Dict[str, torch.Tensor]): The output dictionary of the model, if something else is required for the loss calculation! or it can be used to add stuff into the output! | |||||
| Returns: | |||||
| torch.Tensor: The loss | |||||
| """ |
| from typing import List, Tuple, Dict, Union | |||||
| import math | |||||
| import torch | |||||
| from torch.nn import functional as F | |||||
| from ..data.data_loader import DataLoader | |||||
| from ..data.dataloader_context import DataloaderContext | |||||
| from .aux_loss import AuxLoss | |||||
| class BBLoss(AuxLoss): | |||||
| def __init__( | |||||
| self, title: str, model: torch.nn.Module, | |||||
| layer_by_name: Dict[str, torch.nn.Module], | |||||
| loss_weight: float, | |||||
| inf_mask_kw: str, | |||||
| layers_loss_weights: Union[float, List[float]], | |||||
| bb_fill_ratio: float): | |||||
| super().__init__(title, model, layer_by_name, loss_weight) | |||||
| self._inf_mask_kw = inf_mask_kw | |||||
| self._bb_fill_ratio = bb_fill_ratio | |||||
| if isinstance(layers_loss_weights, list): | |||||
| assert len(layers_loss_weights) == len(layer_by_name), f'There should be as many weights as layers, one per layer. Expected {len(layer_by_name)} found {len(layers_loss_weights)}' | |||||
| self._layers_loss_weights = layers_loss_weights | |||||
| else: | |||||
| self._layers_loss_weights = [layers_loss_weights for _ in range(len(layer_by_name))] | |||||
| def _calculate_loss( | |||||
| self, | |||||
| layers_values: List[Tuple[str, torch.Tensor]], | |||||
| model_output: Dict[str, torch.Tensor]) -> torch.Tensor: | |||||
| # gathering extra information | |||||
| dl: DataLoader = DataloaderContext.instance.dataloader | |||||
| xray_y = torch.from_numpy(dl.get_current_batch_samples_labels()).to(dl._device) | |||||
| infection_mask = dl.get_current_batch_data(keyword=self._inf_mask_kw).to(dl._device) | |||||
| xray_y = xray_y.bool() | |||||
| infection_mask = infection_mask[xray_y] | |||||
| PB, _, H, W = infection_mask.shape | |||||
| loss = torch.zeros([], dtype=torch.float32, | |||||
| device=xray_y.device, requires_grad=True) | |||||
| for li, (ln, lo) in enumerate(layers_values): | |||||
| out = F.interpolate(lo, | |||||
| size=(H, W), | |||||
| mode='bilinear', | |||||
| align_corners=True) | |||||
| out = out.flatten(start_dim=1) | |||||
| neg_out = out[~xray_y] | |||||
| pos_out = out[xray_y] | |||||
| NB = neg_out.shape[0] | |||||
| infection_mask = infection_mask.flatten(start_dim=1) | |||||
| pos_out_infection = torch.where(infection_mask < 1, | |||||
| torch.zeros_like(pos_out, requires_grad=True), | |||||
| pos_out) | |||||
| neg_losses = [] | |||||
| pos_losses = [] | |||||
| if PB > 0: | |||||
| bb_area = infection_mask.sum(dim=-1) # B | |||||
| k = (self._bb_fill_ratio * bb_area.quantile(q=0.5))\ | |||||
| .ceil()\ | |||||
| .long() | |||||
| # Top positive in bb pixels must be positive | |||||
| pos_infection_topk, pos_infection_indices = pos_out_infection\ | |||||
| .topk(k, dim=-1, sorted=False) | |||||
| pos_infection_batch_index = torch.arange(PB)\ | |||||
| .to(pos_infection_indices.device)\ | |||||
| .unsqueeze(1)\ | |||||
| .repeat_interleave(k, dim=1) | |||||
| pos_weight = infection_mask[pos_infection_batch_index, pos_infection_indices].floor() # make non ones to be zero | |||||
| pos_losses.append(F.binary_cross_entropy( | |||||
| pos_infection_topk, | |||||
| torch.ones_like(pos_infection_topk), | |||||
| pos_weight | |||||
| )) | |||||
| if (infection_mask == 0).any(): | |||||
| # All positive out bb pixels must be negative | |||||
| pos_out_non_infection = pos_out[infection_mask == 0] | |||||
| neg_losses.append(F.binary_cross_entropy( | |||||
| pos_out_non_infection, | |||||
| torch.zeros_like(pos_out_non_infection) | |||||
| )) | |||||
| if NB > 0: | |||||
| if PB > 0: | |||||
| # Top negative pixels must be negative | |||||
| neg_k = int(math.ceil(PB * k * 1.0 / NB)) | |||||
| neg_out_topk = neg_out.topk(neg_k, dim=-1, sorted=False)[0] | |||||
| neg_losses.append(F.binary_cross_entropy( | |||||
| neg_out_topk, | |||||
| torch.zeros_like(neg_out_topk) | |||||
| )) | |||||
| else: | |||||
| # All negative pixels must be negative | |||||
| neg_losses.append(F.binary_cross_entropy( | |||||
| neg_out, | |||||
| torch.zeros_like(neg_out) | |||||
| )) | |||||
| losses = [] | |||||
| if len(neg_losses) > 0: | |||||
| losses.append(torch.stack(neg_losses).mean()) | |||||
| if len(pos_losses) > 0: | |||||
| losses.append(torch.stack(pos_losses).mean()) | |||||
| l_loss = torch.stack(losses).mean() | |||||
| model_output[f'{self._title}_{ln}_loss'] = l_loss | |||||
| loss = loss + self._layers_loss_weights[li] * l_loss | |||||
| return loss |
| from typing import Dict, Union, List, Tuple | |||||
| import numpy as np | |||||
| import torch | |||||
| from torch.nn import functional as F | |||||
| from .aux_loss import AuxLoss | |||||
| from ..data.dataloader_context import DataloaderContext | |||||
| from ..data.data_loader import DataLoader | |||||
| class PoolConcordanceLossCalculator(AuxLoss): | |||||
| def __init__(self, | |||||
| title: str, model: torch.nn.Module, | |||||
| layer_by_name: Dict[str, torch.nn.Module], | |||||
| loss_weight: float, weights: Union[float, List[float]], | |||||
| diff_thresholds: Union[float, List[float]], | |||||
| labels_by_channel: Dict[int, List[int]]): | |||||
| super().__init__(title, model, layer_by_name, loss_weight) | |||||
| # at least two pools are needed | |||||
| assert len(layer_by_name) > 1, 'At least two pool layers are required to calculate this loss!' | |||||
| self._title = title | |||||
| self._model = model | |||||
| self._loss_prefix = f'{title}_JS_loss_' | |||||
| self._labels_by_channel = labels_by_channel | |||||
| self._pool_names = list(layer_by_name.keys()) | |||||
| self._weights = weights if isinstance(weights, list) else \ | |||||
| [weights for _ in range(len(layer_by_name) - 1)] | |||||
| if isinstance(weights, list): | |||||
| assert len(weights) == len(layer_by_name) - 1, 'Weights must have a length of pool_layers -1' | |||||
| self._diff_thresholds = diff_thresholds if isinstance(diff_thresholds, list) else \ | |||||
| [diff_thresholds for _ in range(len(layer_by_name) - 1)] | |||||
| if isinstance(diff_thresholds, list): | |||||
| assert len(diff_thresholds) == len(layer_by_name) - 1, 'Diff thresholds must have a length of pool_layers -1' | |||||
| def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_output: Dict[str, torch.Tensor]) -> torch.Tensor: | |||||
| # reading pools from the output and deleting them | |||||
| pool_vals = [x[1] for x in layers_values] | |||||
| # getting samples labels | |||||
| dl: DataLoader = DataloaderContext.instance.dataloader | |||||
| labels = dl.get_current_batch_samples_labels() | |||||
| labels = torch.from_numpy(labels).to(dl._device) | |||||
| # sorting by shape from biggest to smallest | |||||
| p_shapes = np.asarray([p.shape[-1] for p in pool_vals]) | |||||
| sorted_inds = np.argsort(-1 * p_shapes) | |||||
| pool_vals = [pool_vals[i] for i in sorted_inds] | |||||
| pool_names = [self._pool_names[i] for i in sorted_inds] | |||||
| loss = torch.zeros([], dtype=torch.float32, device=labels.device, requires_grad=True) | |||||
| # for each pair of pools, calculate loss! | |||||
| for i in range(len(pool_names) - 1): | |||||
| p_loss = self._cal_pool_pair_loss(pool_vals[i], pool_vals[i + 1], self._diff_thresholds[i], labels) | |||||
| assert (f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}') not in model_output, 'Trying to add ' + (f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}') + ' to model output multiple times' | |||||
| model_output[f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}'] = p_loss.clone() | |||||
| loss = loss + self._weights[i] * p_loss | |||||
| return loss | |||||
| def _cal_pool_pair_loss(self, p1: torch.Tensor, p2: torch.Tensor, diff_threshold: float, labels: torch.Tensor) -> torch.Tensor: | |||||
| # down-sampling by max-pool till reaching the same shape as p2! | |||||
| if p1.shape[-1] > p2.shape[-1]: | |||||
| p1 = F.adaptive_max_pool2d(p1, p2.shape[-2:]) | |||||
| # jensen shannon loss, for each channel -> in the related class! | |||||
| loss = torch.tensor(0.0, requires_grad=True, device=p1.device) | |||||
| for channel, r_labels in self._labels_by_channel.items(): | |||||
| on_mask = self._get_inclusion_mask(labels, r_labels) | |||||
| ip1 = p1[on_mask, channel, ...] | |||||
| ip2 = p2[on_mask, channel, ...] | |||||
| if torch.numel(ip1) > 0: | |||||
| if ip1.shape != ip2.shape: | |||||
| print(f'Problem in shape of concordance loss in {self._title}') | |||||
| loss = loss + jensen_shannon_divergence( | |||||
| ip1[:, None, ...], ip2[:, None, ...], diff_threshold) | |||||
| return loss | |||||
| def _get_inclusion_mask( | |||||
| self, | |||||
| samples_labels: torch.Tensor, | |||||
| desired_labels: List[int]) -> torch.Tensor: | |||||
| with torch.no_grad(): | |||||
| inclusion_mask = torch.stack([samples_labels == l for l in desired_labels], dim=0) | |||||
| aggregation = torch.sum(inclusion_mask.float(), dim=0) | |||||
| return torch.greater(aggregation, 0) | |||||
| def jensen_shannon_divergence(p1: torch.Tensor, p2: torch.Tensor, diff_threshold: float) -> torch.Tensor: | |||||
| """ | |||||
| Calculates the jensen shannon loss between two distributions p1 and p2 | |||||
| Args: | |||||
| p1 (torch.Tensor): A tensor of shape B C ... in range [0, 1] that is the probabilities for each neuron in the 1st distribution. | |||||
| p2 (torch.Tensor): A tensor of the same shape as src, in range [0, 1] that is the probabilities for each neuron in the 2nd distribution. | |||||
| diff_threshold (float): Threshold between p1 and p2 to decide whether to consider one pixel in loss | |||||
| Returns: | |||||
| torch.Tensor: The calculated loss. | |||||
| """ | |||||
| assert p1.shape == p2.shape, 'The tensors must have the same shape' | |||||
| assert 0 <= diff_threshold < 1, 'The difference threshold should be in range [0, 1)' | |||||
| # Reshaping tensors | |||||
| p1 = torch.transpose(p1, 0, 1).flatten(1).transpose(0, 1) | |||||
| p2 = torch.transpose(p2, 0, 1).flatten(1).transpose(0, 1) | |||||
| # if binary, append class 0! | |||||
| if p1.shape[1] == 1: | |||||
| p1 = torch.cat([1 - p1, p1], dim=1) | |||||
| p2 = torch.cat([1 - p2, p2], dim=1) | |||||
| with torch.no_grad(): | |||||
| mask = torch.abs(p1 - p2).detach() >= diff_threshold | |||||
| mask = mask.max(dim=-1)[0] | |||||
| # to make sure error does not result in log(negative)! | |||||
| lp1 = torch.log(torch.maximum(p1 + 1e-4, 1e-6 + torch.zeros_like(p1))) | |||||
| lp2 = torch.log(torch.maximum(p2 + 1e-4, 1e-6 + torch.zeros_like(p2))) | |||||
| loss = 0.5 * ( | |||||
| _smean(mask, torch.sum(p1 * (lp1 - lp2), dim=-1)) + | |||||
| _smean(mask, torch.sum(p2 * (lp2 - lp1), dim=-1)) | |||||
| ) | |||||
| return loss | |||||
| def _smean(mask: torch.Tensor, val: torch.Tensor) -> torch.Tensor: | |||||
| if torch.any(mask.bool()): | |||||
| return torch.mean(val[mask]) | |||||
| else: | |||||
| return torch.tensor(0.0, requires_grad=True, device=mask.device) |
| from typing import Dict, List, Tuple, Optional | |||||
| import numpy as np | |||||
| import torch | |||||
| from torch import nn | |||||
| from torch.nn import functional as F | |||||
| from ..data.data_loader import DataLoader | |||||
| from ..data.dataloader_context import DataloaderContext | |||||
| from .aux_loss import AuxLoss | |||||
| class DiscriminativeWeaklySupervisedLoss(AuxLoss): | |||||
| def __init__( | |||||
| self, title: str, model: nn.Module, | |||||
| att_score_layer: nn.Module, | |||||
| loss_weight: float, | |||||
| neg_ratio: float, pos_ratio_range: Tuple[float, float], | |||||
| on_labels_by_channel: Dict[int, List[int]], | |||||
| discr_score_layer: nn.Module = None, | |||||
| w_attention_in_ordering: float = 1, | |||||
| w_discr_in_ordering: float = 1): | |||||
| """ | |||||
| Calculates binary weakly supervised score for an attention layer | |||||
| with extra discrimination head. | |||||
| Args: | |||||
| title (str): The title of the loss, must be unique! | |||||
| model (Model): the base model, so the output can be modified | |||||
| att_score_layer (torch.nn.Module): A layer that gives attention score (B C ...) | |||||
| loss_weight (float): The weight of the loss | |||||
| neg_ratio (float, optional): top ratio to apply loss for negative samples. Defaults to 0.1. | |||||
| pos_ratio_range (Tuple[float, float], optional): low and top ratios to apply loss for positive samples. Defaults to (0.033, 0.278). Calculated by distribution of positive bounding boxes. | |||||
| on_labels_by_channel (Dict[int, List[int]]): The dictionary that specifies the samples related to which labels should be on in each channel. | |||||
| w_attention_in_ordering (float): The weight of the attention score used in ordering the pixels. | |||||
| w_discr_in_ordering (float): The weight of the reference score used in ordering the pixels. | |||||
| discr_score_layer (torch.nn.Module): A layer that gives discriminative score (B C ...) | |||||
| """ | |||||
| layers = dict( | |||||
| att=att_score_layer, | |||||
| ) | |||||
| if discr_score_layer is not None: | |||||
| layers['discr'] = discr_score_layer | |||||
| super().__init__(title, model, layers, loss_weight) | |||||
| self._has_discr = discr_score_layer is not None | |||||
| self._neg_ratio = neg_ratio | |||||
| self._pos_ratio_range = pos_ratio_range | |||||
| self._on_labels_by_channel = on_labels_by_channel | |||||
| self._w_attention_in_ordering = w_attention_in_ordering | |||||
| self._w_discr_in_ordering = w_discr_in_ordering | |||||
| def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_out: Dict[str, torch.Tensor]) -> torch.Tensor: | |||||
| discriminative_scores = None | |||||
| probabilities = None | |||||
| for ln, lv in layers_values: | |||||
| if ln == 'att': | |||||
| probabilities = lv | |||||
| else: | |||||
| discriminative_scores = lv | |||||
| dl: DataLoader = DataloaderContext.instance.dataloader | |||||
| labels = dl.get_current_batch_samples_labels() | |||||
| if discriminative_scores is not None: | |||||
| discrimination_loss = self._calculate_discrimination_loss( | |||||
| labels, discriminative_scores) | |||||
| assert (self._title + '_discr_loss') not in model_out, 'Trying to add ' + (self._title + '_discr_loss') + ' to model output multiple times' | |||||
| model_out[self._title + '_discr_loss'] = discrimination_loss.clone() | |||||
| else: | |||||
| discrimination_loss = torch.zeros([], requires_grad=True, device=labels.device) | |||||
| attention_loss = self._calculate_attention_loss( | |||||
| labels, probabilities, (discriminative_scores if discriminative_scores is not None else probabilities)) | |||||
| assert (self._title + '_ws_loss') not in model_out, 'Trying to add ' + (self._title + '_ws_loss') + ' to model output multiple times' | |||||
| model_out[self._title + '_ws_loss'] = attention_loss.clone() | |||||
| loss = self._loss_weight * (discrimination_loss + attention_loss) | |||||
| return loss | |||||
| def _calculate_discrimination_loss( | |||||
| self, | |||||
| samples_labels: torch.Tensor, | |||||
| discrimination_scores: torch.Tensor) -> torch.Tensor: | |||||
| losses = [] | |||||
| for channel, labels in self._on_labels_by_channel.items(): | |||||
| on_mask = self._get_inclusion_mask(samples_labels, labels, discrimination_scores.device) | |||||
| on_ps = discrimination_scores[on_mask, channel, ...] | |||||
| off_ps = discrimination_scores[torch.logical_not(on_mask), channel, ...] | |||||
| if torch.numel(on_ps) > 0: | |||||
| losses.append(self._cal_loss(1, True, on_ps)) | |||||
| if torch.numel(off_ps) > 0: | |||||
| losses.append(self._cal_loss(1, False, off_ps)) | |||||
| return torch.mean(torch.stack(losses)) | |||||
| def _calculate_attention_loss( | |||||
| self, | |||||
| samples_labels: torch.Tensor, | |||||
| attention_scores: torch.Tensor, | |||||
| discrimination_scores: torch.Tensor) -> torch.Tensor: | |||||
| losses = [] | |||||
| for channel, labels in self._on_labels_by_channel.items(): | |||||
| on_mask = self._get_inclusion_mask(samples_labels, labels, discrimination_scores.device) | |||||
| on_atts = attention_scores[on_mask, channel, ...] | |||||
| on_discr = discrimination_scores[on_mask, channel, ...].detach() | |||||
| off_atts = attention_scores[torch.logical_not(on_mask), channel, ...] | |||||
| off_discr = discrimination_scores[torch.logical_not(on_mask), channel, ...].detach() | |||||
| neg_losses = [] | |||||
| pos_losses = [] | |||||
| # loss injection to the model | |||||
| if torch.numel(off_atts) > 0 and self._neg_ratio > 0: | |||||
| neg_losses.append(self._cal_loss( | |||||
| self._neg_ratio, False, off_atts, off_discr, largest=True | |||||
| )) | |||||
| if torch.numel(on_atts) > 0: | |||||
| # Calculate positive top k to be positive | |||||
| if self._pos_ratio_range[0] > 0: | |||||
| pos_losses.append(self._cal_loss( | |||||
| self._pos_ratio_range[0], True, on_atts, on_discr, True | |||||
| )) | |||||
| # Calculate positive bottom k to be negative | |||||
| if self._pos_ratio_range[1] < 1: | |||||
| neg_losses.append(self._cal_loss( | |||||
| 1 - self._pos_ratio_range[1], False, on_atts, on_discr, False | |||||
| )) | |||||
| if len(neg_losses) > 0: | |||||
| losses.append(torch.stack(neg_losses).mean()) | |||||
| if len(pos_losses) > 0: | |||||
| losses.append(torch.stack(pos_losses).mean()) | |||||
| return torch.stack(losses).mean() | |||||
| def _get_inclusion_mask( | |||||
| self, | |||||
| samples_labels: np.ndarray, | |||||
| desired_labels: List[int], device: torch.device) -> torch.Tensor: | |||||
| with torch.no_grad(): | |||||
| samples_labels = torch.from_numpy(samples_labels).to(device) | |||||
| inclusion_mask = torch.stack([samples_labels == l for l in desired_labels], dim=0) | |||||
| aggregation = torch.sum(inclusion_mask.float(), dim=0) | |||||
| return torch.greater(aggregation, 0) | |||||
| def _cal_loss( | |||||
| self, | |||||
| ratio: float, positive_label: bool, | |||||
| att_scores: torch.Tensor, | |||||
| discr_scores: Optional[torch.Tensor] = None, | |||||
| largest: bool = True): | |||||
| if ratio == 1: | |||||
| ps = att_scores | |||||
| else: | |||||
| k = np.ceil( | |||||
| ratio * att_scores.shape[-1] * att_scores.shape[-2]).astype(int) | |||||
| ps = self._get_topk(att_scores, discr_scores, k, largest=largest) | |||||
| ps = ps.flatten() | |||||
| if positive_label: | |||||
| gt = torch.ones_like(ps) | |||||
| else: | |||||
| gt = torch.zeros_like(ps) | |||||
| return F.binary_cross_entropy(ps, gt) | |||||
| def _get_topk(self, att_scores: torch.Tensor, discr_scores: torch.Tensor, | |||||
| k: int, dim=-1, largest=True, return_indices=False) -> torch.Tensor: | |||||
| scores = self._pixels_scores(att_scores, discr_scores) | |||||
| b = att_scores.shape[0] | |||||
| top_inds = (scores.flatten(1)).topk(k, dim=dim, largest=largest, sorted=False).indices | |||||
| # B K | |||||
| ret_val = att_scores.flatten(1)[ | |||||
| torch.repeat_interleave( | |||||
| torch.arange(b, device=att_scores.device), k).reshape(b, k), | |||||
| top_inds] # B K | |||||
| if not return_indices: | |||||
| return ret_val | |||||
| else: | |||||
| return ret_val, top_inds | |||||
| def _pixels_scores(self, attention_scores: torch.Tensor, discr_scores: torch.Tensor) -> torch.Tensor: | |||||
| return self._w_attention_in_ordering * attention_scores + self._w_discr_in_ordering * discr_scores | |||||
| """ | |||||
| The BatchChooser | |||||
| """ | |||||
| from abc import abstractmethod | |||||
| from typing import List, TYPE_CHECKING | |||||
| import numpy as np | |||||
| if TYPE_CHECKING: | |||||
| from ..data_loader import DataLoader | |||||
| class BatchChooser: | |||||
| """ | |||||
| The BatchChooser | |||||
| """ | |||||
| def __init__(self, data_loader: 'DataLoader', batch_size: int): | |||||
| """ Receives as input a data_loader which contains the information about the samples and the desired batch size. """ | |||||
| self.data_loader = data_loader | |||||
| self.batch_size = batch_size | |||||
| self.current_batch_sample_indices: np.ndarray = np.asarray([], dtype=int) | |||||
| self.completed_iteration = False | |||||
| self.class_samples_indices = self.data_loader.get_class_sample_indices() | |||||
| self._classes_labels: List[int] = list(self.class_samples_indices.keys()) | |||||
| def prepare_next_batch(self): | |||||
| """ | |||||
| Prepares the next batch based on its strategy | |||||
| """ | |||||
| if self.finished_iteration(): | |||||
| self.reset() | |||||
| self.current_batch_sample_indices = \ | |||||
| self.get_next_batch_sample_indices().astype(int) | |||||
| if len(self.current_batch_sample_indices) == 0: | |||||
| self.completed_iteration = True | |||||
| return | |||||
| @abstractmethod | |||||
| def get_next_batch_sample_indices(self) -> np.ndarray: | |||||
| pass | |||||
| def get_current_batch_sample_indices(self) -> np.ndarray: | |||||
| """ Returns a list of indices of the samples chosen for the current batch. """ | |||||
| return self.current_batch_sample_indices | |||||
| def finished_iteration(self): | |||||
| """ Returns True if iteration is finished over all the slices of all the samples, False otherwise""" | |||||
| return self.completed_iteration | |||||
| def reset(self): | |||||
| """ Resets the sample iterator. """ | |||||
| self.completed_iteration = False | |||||
| self.current_batch_sample_indices = np.asarray([], dtype=int) | |||||
| def get_current_batch_size(self) -> int: | |||||
| return len(self.current_batch_sample_indices) |
| from typing import TYPE_CHECKING | |||||
| from .batch_chooser import BatchChooser | |||||
| import numpy as np | |||||
| if TYPE_CHECKING: | |||||
| from ..data_loader import DataLoader | |||||
| class ClassBalancedShuffledSequentialBatchChooser(BatchChooser): | |||||
| def __init__(self, data_loader: 'DataLoader', batch_size: int): | |||||
| super().__init__(data_loader, batch_size) | |||||
| self._cursor = 0 | |||||
| # just copying to prevent any dependency | |||||
| for k, v in self.class_samples_indices.items(): | |||||
| self.class_samples_indices[k] = np.copy(v) | |||||
| np.random.shuffle(self.class_samples_indices[k]) | |||||
| self.classes_cursors = np.zeros(len(self.class_samples_indices), dtype=int) | |||||
| def get_next_batch_sample_indices(self): | |||||
| """ Returns a list containing the indices of the samples chosen for the next batch.""" | |||||
| n_samples_per_class = self._get_n_samples_per_class() | |||||
| def sample_for_each_class(class_index): | |||||
| class_samples = self.class_samples_indices[ | |||||
| self._classes_labels[class_index]] | |||||
| if self.classes_cursors[class_index] + \ | |||||
| n_samples_per_class[class_index] <= \ | |||||
| len(class_samples): | |||||
| ret_val = np.copy(class_samples[ | |||||
| self.classes_cursors[class_index]: | |||||
| self.classes_cursors[class_index] + n_samples_per_class[class_index] | |||||
| ]) | |||||
| self.classes_cursors[class_index] += n_samples_per_class[class_index] | |||||
| else: | |||||
| ret_val = np.copy(class_samples)[ | |||||
| self.classes_cursors[class_index]:] | |||||
| np.random.shuffle(class_samples) | |||||
| self.classes_cursors[class_index] = \ | |||||
| n_samples_per_class[class_index] - len(ret_val) | |||||
| ret_val = np.concatenate(( | |||||
| ret_val, | |||||
| np.copy(class_samples | |||||
| [: n_samples_per_class[class_index] - len(ret_val)] | |||||
| )), axis=0) | |||||
| if self.classes_cursors[class_index] == len(class_samples): | |||||
| np.random.shuffle(class_samples) | |||||
| self.classes_cursors[class_index] = 0 | |||||
| return ret_val | |||||
| chosen_sample_indices = np.concatenate(tuple([ | |||||
| sample_for_each_class(ci) for ci in range(len(self._classes_labels)) | |||||
| ]), axis=0) | |||||
| return chosen_sample_indices | |||||
| def _get_n_samples_per_class(self): | |||||
| # determining the number of samples per class | |||||
| n_samples_per_class = np.full(len(self.class_samples_indices), | |||||
| int(self.batch_size // len(self.class_samples_indices))) | |||||
| remaining_samles_cnt = self.batch_size - np.sum(n_samples_per_class) | |||||
| if remaining_samles_cnt: | |||||
| if self._cursor + remaining_samles_cnt >= len(self._classes_labels): | |||||
| n_samples_per_class[self._cursor: len(self._classes_labels)] += 1 | |||||
| remaining_samles_cnt -= (len(self._classes_labels) - self._cursor) | |||||
| self._cursor = 0 | |||||
| if remaining_samles_cnt > 0: | |||||
| n_samples_per_class[self._cursor: self._cursor + remaining_samles_cnt] += 1 | |||||
| self._cursor += remaining_samles_cnt | |||||
| return n_samples_per_class | |||||
| from typing import TYPE_CHECKING | |||||
| from .batch_chooser import BatchChooser | |||||
| import numpy as np | |||||
| if TYPE_CHECKING: | |||||
| from ..data_loader import DataLoader | |||||
| class SequentialBatchChooser(BatchChooser): | |||||
| def __init__(self, data_loader: 'DataLoader', batch_size: int): | |||||
| super().__init__(data_loader, batch_size) | |||||
| self.cursor = 0 | |||||
| self.desired_samples_inds = None | |||||
| def reset(self): | |||||
| super(SequentialBatchChooser, self).reset() | |||||
| self.cursor = 0 | |||||
| self.desired_samples_inds = None | |||||
| def get_next_batch_sample_indices(self): | |||||
| """ Returns a list containing the indices of the samples chosen for the next batch.""" | |||||
| if self.desired_samples_inds is None: | |||||
| self.desired_samples_inds = \ | |||||
| np.arange(self.data_loader.get_number_of_samples()) | |||||
| next_cursor = min( | |||||
| len(self.desired_samples_inds), | |||||
| self.cursor + self.batch_size) | |||||
| next_sample_inds = self.desired_samples_inds[ | |||||
| self.cursor:next_cursor] | |||||
| self.cursor = next_cursor | |||||
| return next_sample_inds | |||||
| from enum import Enum | |||||
| import os | |||||
| import glob | |||||
| from typing import Callable, TYPE_CHECKING, Union | |||||
| import imageio | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import pandas as pd | |||||
| from .content_loader import ContentLoader | |||||
| if TYPE_CHECKING: | |||||
| from ...configs.celeba_configs import CelebAConfigs | |||||
| class CelebATag(Enum): | |||||
| FiveOClockShadowTag = '5_o_Clock_Shadow' | |||||
| ArchedEyebrowsTag = 'Arched_Eyebrows' | |||||
| AttractiveTag = 'Attractive' | |||||
| BagsUnderEyesTag = 'Bags_Under_Eyes' | |||||
| BaldTag = 'Bald' | |||||
| BangsTag = 'Bangs' | |||||
| BigLipsTag = 'Big_Lips' | |||||
| BigNoseTag = 'Big_Nose' | |||||
| BlackHairTag = 'Black_Hair' | |||||
| BlondHairTag = 'Blond_Hair' | |||||
| BlurryTag = 'Blurry' | |||||
| BrownHairTag = 'Brown_Hair' | |||||
| BushyEyebrowsTag = 'Bushy_Eyebrows' | |||||
| ChubbyTag = 'Chubby' | |||||
| DoubleChinTag = 'Double_Chin' | |||||
| EyeglassesTag = 'Eyeglasses' | |||||
| GoateeTag = 'Goatee' | |||||
| GrayHairTag = 'Gray_Hair' | |||||
| HighCheekbonesTag = 'High_Cheekbones' | |||||
| MaleTag = 'Male' | |||||
| MouthSlightlyOpenTag = 'Mouth_Slightly_Open' | |||||
| MustacheTag = 'Mustache' | |||||
| NarrowEyesTag = 'Narrow_Eyes' | |||||
| NoBeardTag = 'No_Beard' | |||||
| OvalFaceTag = 'Oval_Face' | |||||
| PaleSkinTag = 'Pale_Skin' | |||||
| PointyNoseTag = 'Pointy_Nose' | |||||
| RecedingHairlineTag = 'Receding_Hairline' | |||||
| RosyCheeksTag = 'Rosy_Cheeks' | |||||
| SideburnsTag = 'Sideburns' | |||||
| SmilingTag = 'Smiling' | |||||
| StraightHairTag = 'Straight_Hair' | |||||
| WavyHairTag = 'Wavy_Hair' | |||||
| WearingEarringsTag = 'Wearing_Earrings' | |||||
| WearingHatTag = 'Wearing_Hat' | |||||
| WearingLipstickTag = 'Wearing_Lipstick' | |||||
| WearingNecklaceTag = 'Wearing_Necklace' | |||||
| WearingNecktieTag = 'Wearing_Necktie' | |||||
| YoungTag = 'Young' | |||||
| BoundingBoxX = 'x_1' | |||||
| BoundingBoxY = 'y_1' | |||||
| BoundingBoxW = 'width' | |||||
| BoundingBoxH = 'height' | |||||
| Partition = 'partition' | |||||
| LeftEyeX = 'lefteye_x' | |||||
| LeftEyeY = 'lefteye_y' | |||||
| RightEyeX = 'righteye_x' | |||||
| RightEyeY = 'righteye_y' | |||||
| NoseX = 'nose_x' | |||||
| NoseY = 'nose_y' | |||||
| LeftMouthX = 'leftmouth_x' | |||||
| LeftMouthY = 'leftmouth_y' | |||||
| RightMouthX = 'rightmouth_x' | |||||
| RightMouthY = 'rightmouth_y' | |||||
| class SampleField(Enum): | |||||
| ImageId = 'image_id' | |||||
| Specification = 'specification' | |||||
| class CelebALoader(ContentLoader): | |||||
| def __init__(self, conf: 'CelebAConfigs', data_specification: str): | |||||
| """ read all directories for scans and annotations, which are split by DataSplitter.py | |||||
| And then, keeps only those samples which are used for 'usage' """ | |||||
| super().__init__(conf, data_specification) | |||||
| self.conf = conf | |||||
| self._datasep = True | |||||
| self._samples_metadata = self._load_metadata(data_specification) | |||||
| self._data_root = conf.data_root | |||||
| self._warning_count = 0 | |||||
| def _load_metadata(self, data_specification: str) -> pd.DataFrame: | |||||
| if '/' in data_specification: | |||||
| self._datasep = False | |||||
| return pd.DataFrame({ | |||||
| SampleField.ImageId.value: glob.glob(os.path.join(data_specification, '*.jpg')) | |||||
| }) | |||||
| metadata: pd.DataFrame = pd.read_csv(self.conf.dataset_metadata, sep='\t') | |||||
| metadata = metadata[metadata[SampleField.Specification.value] == data_specification] | |||||
| metadata = metadata.drop(SampleField.Specification.value, axis=1) | |||||
| metadata = metadata.reset_index().drop('index', axis=1) | |||||
| return metadata | |||||
| def get_samples_names(self): | |||||
| return self._samples_metadata[SampleField.ImageId.value].values | |||||
| def get_samples_labels(self): | |||||
| r""" Dummy. Because we have multiple labels """ | |||||
| return np.ones((len(self._samples_metadata),))\ | |||||
| if self.conf.main_tag is None\ | |||||
| else self._samples_metadata[self.conf.main_tag.value].values | |||||
| def drop_samples(self, drop_mask: np.ndarray) -> None: | |||||
| self._samples_metadata = self._samples_metadata[np.logical_not(drop_mask)] | |||||
| def get_placeholder_name_to_fill_function_dict(self): | |||||
| """ Returns a dictionary of the placeholders' names (the ones this content loader supports) | |||||
| to the functions used for filling them. The functions must receive as input data_loader, | |||||
| which is an object of class data_loader that contains information about the current batch | |||||
| (e.g. the indices of the samples, or if the sample has many elements the indices of the chosen | |||||
| elements) and return an array per placeholder name according to the receives batch information. | |||||
| IMPORTANT: Better to use a fixed prefix in the names of the placeholders to become clear which content loader | |||||
| they belong to! Some sort of having a mark :))!""" | |||||
| return { | |||||
| 'x': self._get_x, | |||||
| **{ | |||||
| tag.name: self._generate_tag_getter(tag) for tag in CelebATag | |||||
| } | |||||
| } | |||||
| def _get_x(self, samples_inds: np.ndarray)\ | |||||
| -> np.ndarray: | |||||
| images = tuple(self._read_image(index) for index in samples_inds) | |||||
| images = np.stack(images, axis=0) # B I I 3 | |||||
| images = images.transpose((0, 3, 1, 2)) # B 3 I I | |||||
| images = images.astype(float) / 255. | |||||
| return images | |||||
| def _read_image(self, sample_index: int) -> np.ndarray: | |||||
| filepath = self._samples_metadata.iloc[sample_index][SampleField.ImageId.value] | |||||
| if self._datasep: | |||||
| filepath = os.path.join(self._data_root, filepath) | |||||
| image = imageio.imread(filepath) # H W 3 | |||||
| image = cv2.resize(image, (self.conf.input_size, self.conf.input_size)) # I I 3 | |||||
| return image | |||||
| def _generate_tag_getter(self, tag: CelebATag) -> Callable[[np.ndarray, Union[None, np.ndarray]], np.ndarray]: | |||||
| def get_tag(samples_inds: np.ndarray) -> np.ndarray: | |||||
| return self._samples_metadata.iloc[samples_inds][tag.value].values if self._datasep else np.zeros(len(samples_inds)) | |||||
| return get_tag |
| """ | |||||
| A class for loading one content type needed for the run. | |||||
| """ | |||||
| from abc import abstractmethod, ABC | |||||
| from typing import TYPE_CHECKING, Dict, Callable, Union, List | |||||
| import numpy as np | |||||
| if TYPE_CHECKING: | |||||
| from .. import BaseConfig | |||||
| LoadFunction = Callable[[np.ndarray, Union[None, np.ndarray]], np.ndarray] | |||||
| class ContentLoader(ABC): | |||||
| """ A class for loading one content type needed for the run. """ | |||||
| def __init__(self, conf: 'BaseConfig', data_specification: str): | |||||
| """ Receives as input conf which is a dictionary containing configurations. | |||||
| This dictionary can also be used to pass fixed addresses for easing the usage! | |||||
| prefix_name is the str which all the variables that must be filled with this class have | |||||
| this prefix in their names so it would be clear that this class should fill it. | |||||
| data_specification is the string that specifies data and where it is! e.g. train, test, val""" | |||||
| self.conf = conf | |||||
| self.data_specification = data_specification | |||||
| self._fill_func_by_var_name: Union[None, Dict[str, LoadFunction]] = None | |||||
| @abstractmethod | |||||
| def get_samples_names(self): | |||||
| """ Returns a list containing names of all the samples of the content loader, | |||||
| each sample must owns a unique ID, and this function returns all this IDs. | |||||
| The order of the list must always be the same during one run. | |||||
| For example, this function can return an ID column of a table for TableLoader | |||||
| or the dir of images as ID for ImageLoader""" | |||||
| @abstractmethod | |||||
| def get_samples_labels(self): | |||||
| """ Returns list of labels of the whole samples. | |||||
| The order of the list must always be the same during one run.""" | |||||
| @abstractmethod | |||||
| def get_placeholder_name_to_fill_function_dict(self) -> Dict[str, LoadFunction]: | |||||
| """ Returns a dictionary of the placeholders' names (the ones this content loader supports) | |||||
| to the functions used for filling them. The functions must receive as input | |||||
| batch_samples_inds and batch_samples_elements_inds which defines the current batch, | |||||
| and return an array per placeholder name according to the receives batch information. | |||||
| IMPORTANT: Better to use a fixed prefix in the names of the placeholders to become clear which content loader | |||||
| they belong to! Some sort of having a mark :))!""" | |||||
| def fill_placeholders(self, | |||||
| keys: List[str], | |||||
| samples_inds: np.ndarray)\ | |||||
| -> Dict[str, Union[None, np.ndarray]]: | |||||
| """ Receives as input placeholders which is a dictionary of the placeholders' | |||||
| names to a torch tensor for filling it to feed the model with and | |||||
| samples_inds and samples_elements_inds, which contain | |||||
| information about the current batch. Fills placeholders based on the | |||||
| function dictionary received in get_placeholder_name_to_fill_function_dict.""" | |||||
| if self._fill_func_by_var_name is None: | |||||
| self._fill_func_by_var_name = self.get_placeholder_name_to_fill_function_dict() | |||||
| # Filling all the placeholders in the received dictionary! | |||||
| placeholders = dict() | |||||
| for placeholder_name in keys: | |||||
| if placeholder_name in self._fill_func_by_var_name: | |||||
| placeholder_v = self._fill_func_by_var_name[placeholder_name]( | |||||
| samples_inds) | |||||
| if placeholder_v is None: | |||||
| raise Exception('None value for key %s' % placeholder_name) | |||||
| placeholders[placeholder_name] = placeholder_v | |||||
| else: | |||||
| raise Exception(f'Unknown key for content loader: {placeholder_name}') | |||||
| return placeholders | |||||
| def get_batch_true_interpretations(self, samples_inds: np.ndarray, | |||||
| samples_elements_inds: Union[None, np.ndarray, List[np.ndarray]])\ | |||||
| -> np.ndarray: | |||||
| """ | |||||
| Receives the indices of samples and their elements (if elemented) in the current batch, | |||||
| returns the true interpretations as expected (Bounding box or a boolean mask) | |||||
| """ | |||||
| raise NotImplementedError('If you want to evaluate interpretations, you must implement this function in your content loader.') | |||||
| def drop_samples(self, drop_mask: np.ndarray) -> None: | |||||
| """ | |||||
| Receives a boolean drop mask and drops the samples whose mask are true, | |||||
| From now on the indices of samples are based on the new samples set (after elimination) | |||||
| """ | |||||
| raise NotImplementedError('If you want to filter samples by mapped labels, you should implement this function in your content loader.') |
| import os | |||||
| from typing import TYPE_CHECKING, Dict, List, Tuple | |||||
| import pickle | |||||
| from PIL import Image | |||||
| from enum import Enum | |||||
| import pandas as pd | |||||
| import numpy as np | |||||
| import torch | |||||
| from torchvision import transforms | |||||
| from torchvision.datasets.folder import default_loader, find_classes, make_dataset, IMG_EXTENSIONS | |||||
| from .content_loader import ContentLoader | |||||
| from ...utils.bb_generator import generate_bb_map | |||||
| if TYPE_CHECKING: | |||||
| from ...configs.imagenet_configs import ImagenetConfigs | |||||
| def _get_image_transform(data_specification: str, input_size: int) -> transforms.Compose: | |||||
| if data_specification == 'train': | |||||
| return transforms.Compose([ | |||||
| transforms.RandomResizedCrop(input_size), | |||||
| transforms.RandomHorizontalFlip(), | |||||
| transforms.ToTensor(), | |||||
| ]) | |||||
| if data_specification in ['test', 'val']: | |||||
| return transforms.Compose([ | |||||
| transforms.Resize(256), | |||||
| transforms.CenterCrop(input_size), | |||||
| transforms.ToTensor(), | |||||
| ]) | |||||
| raise ValueError('Unknown data specification: {}'.format(data_specification)) | |||||
| def load_samples_cache(cache_name: str, imagenet_dir: str) -> List[Tuple[str, int]]: | |||||
| cache_path = os.path.join('.cache/', cache_name) | |||||
| print(cache_path) | |||||
| if os.path.isfile(cache_path): | |||||
| print('Loading cached samples from {}'.format(cache_path)) | |||||
| with open(cache_path, 'rb') as f: | |||||
| return pickle.load(f) | |||||
| print('Creating cache for {}'.format(cache_name)) | |||||
| os.makedirs(os.path.dirname(cache_path), exist_ok=True) | |||||
| _, class_to_idx = find_classes(imagenet_dir) | |||||
| samples = make_dataset(imagenet_dir, class_to_idx, IMG_EXTENSIONS) | |||||
| with open(cache_path, 'wb') as f: | |||||
| pickle.dump(samples, f) | |||||
| return samples | |||||
| class BBoxField(Enum): | |||||
| ImageId = 'ImageId' | |||||
| PredictionString = 'PredictionString' | |||||
| def load_bbox_df(imagenet_root: str, data_specification: str) -> pd.DataFrame: | |||||
| return pd.read_csv(os.path.join(imagenet_root, f'LOC_{data_specification}_solution.csv')) | |||||
| class ImagenetLoader(ContentLoader): | |||||
| def __init__(self, conf: 'ImagenetConfigs', data_specification: str): | |||||
| super().__init__(conf, data_specification) | |||||
| imagenet_dir = os.path.join(conf.data_separation, data_specification) | |||||
| self.__samples = load_samples_cache(f'imagenet.{data_specification}', imagenet_dir) | |||||
| self.__transform = _get_image_transform(data_specification, conf.input_size) | |||||
| self.__bboxes = load_bbox_df(conf.data_separation, data_specification) | |||||
| self.__rng_states: Dict[int, Tuple[str, torch.Tensor]] = {} | |||||
| def get_samples_names(self): | |||||
| ''' sample names must be unique, they can be either scan_names or scan_dirs. | |||||
| Decided to put scan_names. No difference''' | |||||
| return np.array([path for path, _ in self.__samples]) | |||||
| def get_samples_labels(self): | |||||
| return np.array([label for _, label in self.__samples]) | |||||
| def drop_samples(self, drop_mask: np.ndarray) -> None: | |||||
| self.__samples = [sample for sample, drop in zip(self.__samples, drop_mask) if not drop] | |||||
| def get_placeholder_name_to_fill_function_dict(self): | |||||
| """ Returns a dictionary of the placeholders' names (the ones this content loader supports) | |||||
| to the functions used for filling them. The functions must receive as input data_loader, | |||||
| which is an object of class data_loader that contains information about the current batch | |||||
| (e.g. the indices of the samples, or if the sample has many elements the indices of the chosen | |||||
| elements) and return an array per placeholder name according to the receives batch information. | |||||
| IMPORTANT: Better to use a fixed prefix in the names of the placeholders to become clear which content loader | |||||
| they belong to! Some sort of having a mark :))!""" | |||||
| return { | |||||
| 'x': self.__get_x, | |||||
| 'y': self.__get_y, | |||||
| 'bbox': self.__get_bbox, | |||||
| 'interpretations': self.__get_bbox, | |||||
| } | |||||
| def __get_x(self, samples_inds: np.ndarray)\ | |||||
| -> np.ndarray: | |||||
| def get_image(index) -> torch.Tensor: | |||||
| path, _ = self.__samples[index] | |||||
| sample = default_loader(path) | |||||
| self.__handle_random_state(index, 'x') | |||||
| return self.__transform(sample) | |||||
| return torch.stack( | |||||
| tuple([get_image(index) for index in samples_inds]), | |||||
| dim=0) | |||||
| def __get_y(self, samples_inds: np.ndarray)\ | |||||
| -> np.ndarray: | |||||
| return self.get_samples_labels()[samples_inds] | |||||
| def __handle_random_state(self, idx: int, label: str) -> None: | |||||
| if idx not in self.__rng_states or self.__rng_states[idx][0] == label: | |||||
| self.__rng_states[idx] = (label, torch.get_rng_state()) | |||||
| else: | |||||
| torch.set_rng_state(self.__rng_states[idx][1]) | |||||
| def __get_bbox(self, sample_inds: np.ndarray)\ | |||||
| -> np.ndarray: | |||||
| def extract_bb(prediction_string: str) -> np.ndarray: | |||||
| splitted = prediction_string.split() | |||||
| n = len(splitted) // 5 | |||||
| return np.array([[float(b) for b in splitted[5 * i + 1: 5 * i + 5]] for i in range(n)])\ | |||||
| .reshape((-1, 2, 2)) | |||||
| def make_bb(index) -> torch.Tensor: | |||||
| path, _ = self.__samples[index] | |||||
| image_id = os.path.basename(path).split('.')[0] | |||||
| image_size = np.array(Image.open(path).size)[::-1] # 2 | |||||
| bboxes = self.__bboxes[self.__bboxes[BBoxField.ImageId.value] == image_id][BBoxField.PredictionString.value] | |||||
| if len(bboxes) == 0: | |||||
| return torch.zeros(1, self.conf.inp_size, self.conf.inp_size) * np.nan | |||||
| bboxes = bboxes.values[0] | |||||
| bboxes = extract_bb(bboxes)[..., ::-1] / image_size # N 2 2 | |||||
| start_points = bboxes[:, 0] # N 2 | |||||
| end_points = bboxes[:, 1] # N 2 | |||||
| bb_map = generate_bb_map(start_points, end_points, tuple(image_size)) | |||||
| bb_map = Image.fromarray(bb_map) | |||||
| self.__handle_random_state(index, 'bb') | |||||
| return self.__transform(bb_map) | |||||
| return torch.stack([make_bb(i) for i in sample_inds], dim=0) |
| from typing import Dict, List, TYPE_CHECKING, Optional, Tuple | |||||
| from os import path, listdir | |||||
| from sys import stderr | |||||
| from functools import partial | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import pandas as pd | |||||
| from .content_loader import ContentLoader | |||||
| from ...utils.bb_generator import generate_bb_map | |||||
| if TYPE_CHECKING: | |||||
| from ...configs.rsna_configs import RSNAConfigs | |||||
| class RSNALoader(ContentLoader): | |||||
| def __init__(self, conf: 'RSNAConfigs', data_specification: str): | |||||
| """ read all directories for scans and annotations, which are split by DataSplitter.py | |||||
| And then, keeps only those samples which are used for 'usage' """ | |||||
| super().__init__(conf, data_specification) | |||||
| self.conf = conf | |||||
| self.img_size = conf.input_size | |||||
| self.infection_map_size = conf.infection_map_size | |||||
| self.samples_info = None | |||||
| self.load_samples(data_specification) | |||||
| self.loaded_images = None | |||||
| self._infections_bbs: Optional[Dict[str, Tuple[List[Tuple[float, float], List[Tuple[float, float]]]]]] = None | |||||
| self._load_infections() | |||||
| def load_samples(self, data_specification): | |||||
| if ':' in data_specification: | |||||
| data_specification, filter_str = data_specification.split(':') | |||||
| else: | |||||
| filter_str = None | |||||
| if data_specification in ['train', 'val', 'test']: | |||||
| self.samples_info = \ | |||||
| pd.read_csv('%s/%s.txt' % | |||||
| ( | |||||
| self.conf.data_separation, data_specification), sep='\t', header=0) | |||||
| self.samples_info['sample'] = self.samples_info['sample'].apply(lambda s: s[len('../'):]) | |||||
| self.samples_info['Interpretation_dir'] = self.samples_info['Interpretation_dir'].apply(lambda s: s[len('../'):]) | |||||
| elif path.isfile(data_specification) and path.exists(data_specification): | |||||
| self.samples_info = pd.read_csv(data_specification, sep='\t', header=0) | |||||
| self.samples_info['sample'] = self.samples_info['sample'].apply(lambda s: s[len('../'):]) | |||||
| self.samples_info['Interpretation_dir'] = self.samples_info['Interpretation_dir'].apply(lambda s: s[len('../'):]) | |||||
| elif path.isdir(data_specification) and path.exists(data_specification): | |||||
| samples_names = np.asarray( | |||||
| [data_specification + '/' + x for x in listdir(data_specification)]) | |||||
| print('%d samples discovered in %s' % (len(samples_names), data_specification)) | |||||
| if filter_str is not None: | |||||
| pass_filter_str = np.vectorize(lambda x: '/%s/' % filter_str in x) | |||||
| samples_names = samples_names[pass_filter_str(samples_names)] | |||||
| print('%d samples remained after filtering' % len(samples_names)) | |||||
| self.samples_info = pd.DataFrame({ | |||||
| 'label': np.full(len(samples_names), -1, dtype=int), | |||||
| 'sample': samples_names, | |||||
| 'view': np.full(len(samples_names), 'Unknown', dtype=np.object), | |||||
| 'dataset': np.full(len(samples_names), 'unknown', dtype=np.object), | |||||
| 'Interpretation_dir': np.full(len(samples_names), '', dtype=np.object), | |||||
| }) | |||||
| return | |||||
| else: | |||||
| print('Please implement this part!', flush=True) | |||||
| org_cnt = len(self.samples_info) | |||||
| # filtering by filter string! | |||||
| if filter_str is not None: | |||||
| pass_filter_str = np.vectorize(lambda x: '/%s/' % filter_str in x) | |||||
| self.samples_info = self.samples_info[pass_filter_str(self.samples_info['sample'].values)] | |||||
| print( | |||||
| '%d of %d samples remained after filtering for %s!' % (len(self.samples_info), org_cnt, data_specification), | |||||
| flush=True) | |||||
| org_cnt = len(self.samples_info) | |||||
| # label mapping/ordering! | |||||
| label_mapping_dict = self.conf.label_map_dict | |||||
| v_map_label = np.vectorize(lambda l: label_mapping_dict.get(l.lower(), -1)) | |||||
| self.samples_info['label'] = v_map_label(self.samples_info['label'].values) | |||||
| # filtering unmapped labels | |||||
| self.samples_info = self.samples_info[self.samples_info['label'].values != -1] | |||||
| print('%d out of %d samples remained after label mapping!' % | |||||
| (len(self.samples_info), org_cnt), flush=True) | |||||
| # counting the labels! | |||||
| flattened_labels = self.samples_info['label'].values | |||||
| u_labels = np.unique(flattened_labels) | |||||
| labels_cnt = np.zeros(len(u_labels)) | |||||
| np.add.at(labels_cnt, np.searchsorted(u_labels, flattened_labels, side='left'), 1) | |||||
| print('[%s]' % ', '.join(['class-%s: %d' % (str(u_labels[i]), labels_cnt[i]) for i in range(len(labels_cnt))]), | |||||
| flush=True) | |||||
| def _load_infections(self) -> None: | |||||
| """ | |||||
| If a file address is given as infections_bb_dir in configs | |||||
| (containing imgs and their bounding boxes as a str) | |||||
| the information related to bounding boxes will be taken, | |||||
| otherwise it is expected from the data separation to have a column | |||||
| indicating the address of the infection mask | |||||
| """ | |||||
| if self.conf.infections_bb_dir is not None: | |||||
| bbs_info = pd.read_csv(self.conf.infections_bb_dir, sep='\t', header=0) | |||||
| imgs_dirs = np.vectorize(self._get_id_by_path)(bbs_info['img_dir'].values) | |||||
| imgs_bbs = [ | |||||
| [] if '-' not in str(im_bbs) else [ | |||||
| tuple([np.clip(float(x), 0, 1) for x in im_bb.split('-')]) | |||||
| for im_bb in im_bbs.split(',') | |||||
| ] | |||||
| for im_bbs in bbs_info['img_bb'].values] | |||||
| # reformatting to tuple of list of starts and list of ends | |||||
| imgs_bbs_starts = [[ | |||||
| (r1, c1) for r1, c1, _, _ in img_bbs] | |||||
| for img_bbs in imgs_bbs] | |||||
| imgs_bbs_ends = [[ | |||||
| (r2, c2) for _, _, r2, c2 in img_bbs] | |||||
| for img_bbs in imgs_bbs] | |||||
| imgs_bbs = [ | |||||
| (imgs_bbs_starts[i], imgs_bbs_ends[i]) | |||||
| for i in range(len(imgs_bbs_starts))] | |||||
| self._infections_bbs = dict(zip(imgs_dirs, imgs_bbs)) | |||||
| def get_samples_names(self): | |||||
| return self.samples_info['sample'].values | |||||
| def get_samples_labels(self): | |||||
| return self.samples_info['label'].values | |||||
| def drop_samples(self, drop_mask: np.asarray) -> None: | |||||
| """ Keep_mask is a bool array for keeping samples """ | |||||
| self.samples_info = self.samples_info[np.logical_not(drop_mask)] | |||||
| def get_placeholder_name_to_fill_function_dict(self): | |||||
| ret_val = { | |||||
| 'x': self.get_batch_scaled_images, | |||||
| 'y': self.get_batch_label, | |||||
| 'infection': partial(self.get_batch_extended_bbs_map_from_bbs_file, 0) if | |||||
| self._infections_bbs is not None else | |||||
| self.get_batch_bbs_map_from_mask_file, | |||||
| 'interpretations': partial(self.get_batch_extended_bbs_map_from_bbs_file, 0) if | |||||
| self._infections_bbs is not None else | |||||
| self.get_batch_bbs_map_from_mask_file, | |||||
| 'has_bb': self.get_batch_has_bb_mask_from_bb_file if | |||||
| self._infections_bbs is not None else | |||||
| self.get_batch_has_bb_mask_from_mask_file | |||||
| } | |||||
| # if infection radii is not empty, it is expected to have a dictionary | |||||
| if len(self.conf.receptive_field_radii) > 0: | |||||
| assert self._infections_bbs is not None, \ | |||||
| "When having receptive field radii, " \ | |||||
| "you should provide bounding boxes as a file in config.infections_bb_dir" | |||||
| ret_val.update({ | |||||
| f'ex_infection_{r}': partial(self.get_batch_extended_bbs_map_from_bbs_file, r) | |||||
| for r in self.conf.receptive_field_radii | |||||
| }) | |||||
| return ret_val | |||||
| def get_batch_has_bb_mask_from_mask_file(self, samples_inds: np.ndarray) \ | |||||
| -> np.ndarray: | |||||
| assert 'Interpretation_dir' in self.samples_info.columns, \ | |||||
| 'If you have not specified infections_bb_dirin your config, ' \ | |||||
| 'you need to have Interpretation_dir column in your data separation' | |||||
| def has_inf(si): | |||||
| if self.samples_info.iloc[si]['label'] == 0: | |||||
| return True | |||||
| int_dir = self.samples_info.iloc[si]['Interpretation_dir'] | |||||
| return str(int_dir) != 'nan' | |||||
| return np.vectorize(has_inf)(samples_inds) | |||||
| def get_batch_has_bb_mask_from_bb_file(self, samples_inds: np.ndarray) \ | |||||
| -> np.ndarray: | |||||
| def has_inf(si): | |||||
| if self.samples_info.iloc[si]['label'] == 0: | |||||
| return True | |||||
| sample_key = self.samples_info.iloc[si]['sample'] | |||||
| sample_key = self._get_id_by_path(sample_key) | |||||
| return sample_key in self._infections_bbs and \ | |||||
| len(self._infections_bbs[sample_key][0]) > 0 | |||||
| return np.vectorize(has_inf)(samples_inds) | |||||
| def get_batch_bbs_map_from_mask_file(self, samples_inds: np.ndarray)\ | |||||
| -> np.ndarray: | |||||
| def read_interpretation(im_ind): | |||||
| # for healthy, return full 0 | |||||
| if self.samples_info.iloc[im_ind]['label'] == 0: | |||||
| return np.full((1, self.infection_map_size, self.infection_map_size), 0, dtype=np.float) | |||||
| int_dir = self.samples_info.iloc[im_ind]['Interpretation_dir'] | |||||
| # if it does not exist, return full -1 | |||||
| if str(int_dir) == 'nan': | |||||
| return np.full((1, self.infection_map_size, self.infection_map_size), -1, dtype=np.float) | |||||
| if 'npy' in int_dir: | |||||
| interpretation = np.load(int_dir) | |||||
| else: | |||||
| interpretation = np.load(int_dir)['arr_0'] | |||||
| interpretation = interpretation.astype(float) | |||||
| if interpretation.shape != (self.infection_map_size, self.infection_map_size): | |||||
| interpretation = (cv2.resize(np.round(255 * interpretation, 0), | |||||
| dsize=(self.infection_map_size, self.infection_map_size)) >= 128).astype(float) | |||||
| return interpretation[np.newaxis, :, :] | |||||
| batch_interpretations = np.stack(tuple([read_interpretation(si) | |||||
| for si in samples_inds]), axis=0) | |||||
| return batch_interpretations | |||||
| @staticmethod | |||||
| def _get_id_by_path(path: str) -> str: | |||||
| return path[path.index('png_images/'):] | |||||
| def get_batch_extended_bbs_map_from_bbs_file(self, radius, samples_inds: np.ndarray) -> np.ndarray: | |||||
| def make_map(im_ind): | |||||
| if self.samples_info.iloc[im_ind]['label'] == 0: | |||||
| return np.full((1, self.infection_map_size, self.infection_map_size), 0, dtype=np.float) | |||||
| sample_key = self.samples_info.iloc[im_ind]['sample'] | |||||
| sample_key = self._get_id_by_path(sample_key) | |||||
| if sample_key in self._infections_bbs and \ | |||||
| len(self._infections_bbs[sample_key][0]) > 0: | |||||
| bbs_info = self._infections_bbs[sample_key] | |||||
| start_points, end_points = bbs_info | |||||
| mask = generate_bb_map(start_points, end_points, | |||||
| (self.infection_map_size, self.infection_map_size), radius) | |||||
| return mask[np.newaxis, :, :] | |||||
| else: | |||||
| return np.full((1, self.infection_map_size, self.infection_map_size), -1, dtype=np.float) | |||||
| batch_interpretations = np.stack(tuple([make_map(si) | |||||
| for si in samples_inds]), axis=0) | |||||
| return batch_interpretations | |||||
| def get_batch_scaled_images(self, | |||||
| samples_inds: np.ndarray) \ | |||||
| -> np.ndarray: | |||||
| def read_img(im_ind): | |||||
| im = cv2.imread(self.samples_info.iloc[im_ind]['sample']) | |||||
| if im is None: | |||||
| print(self.samples_info.iloc[im_ind]['sample'] + ' is missing!', file=stderr) | |||||
| raise Exception('Missing image') | |||||
| if len(list(im.shape)) == 3: | |||||
| ret_val = im[:, :, 0] | |||||
| else: | |||||
| ret_val = im | |||||
| if ret_val.shape != (self.img_size, self.img_size): | |||||
| ret_val = cv2.resize(ret_val, dsize=(self.img_size, self.img_size)) | |||||
| return ret_val[np.newaxis, :, :] | |||||
| batch_imgs = np.stack(tuple([read_img(si) | |||||
| for si in samples_inds]), axis=0) | |||||
| batch_imgs = batch_imgs.astype(np.float32) / 255 | |||||
| return batch_imgs | |||||
| def get_batch_label(self, | |||||
| samples_inds: np.ndarray, | |||||
| ) \ | |||||
| -> np.ndarray: | |||||
| return self.samples_info['label'].iloc[samples_inds].values.astype(np.float32) | |||||
| def get_batch_true_interpretations(self, samples_inds: np.ndarray) \ | |||||
| -> np.ndarray: | |||||
| def read_interpretation(im_ind): | |||||
| # for healthy, return full 0 | |||||
| if self.samples_info.iloc[im_ind]['label'] == 0: | |||||
| return np.full((1, self.img_size, self.img_size), 0, dtype=np.float) | |||||
| int_dir = self.samples_info.iloc[im_ind]['Interpretation_dir'] | |||||
| if str(int_dir) == 'nan': | |||||
| return np.full((1, self.img_size, self.img_size), np.nan, dtype=np.float) | |||||
| if 'npy' in int_dir: | |||||
| interpretation = np.load(int_dir) | |||||
| else: | |||||
| interpretation = np.load(int_dir)['arr_0'].astype(int) | |||||
| if interpretation.shape != (self.img_size, self.img_size): | |||||
| interpretation = (cv2.resize(np.round(255 * interpretation, 0).astype(np.uint8), | |||||
| dsize=(self.img_size, self.img_size)) >= 128).astype(float) | |||||
| return interpretation[np.newaxis, :, :] | |||||
| batch_interpretations = np.stack(tuple([read_interpretation(si) | |||||
| for si in samples_inds]), axis=0) | |||||
| return batch_interpretations |
| """ A class for loading the whole data needed for the run! """ | |||||
| from typing import Dict, List, Union, TYPE_CHECKING | |||||
| import random | |||||
| import numpy as np | |||||
| import torch | |||||
| from .batch_choosing.batch_chooser import BatchChooser | |||||
| from .content_loaders.content_loader import ContentLoader | |||||
| if TYPE_CHECKING: | |||||
| from ..configs.base_config import BaseConfig | |||||
| class RunType: | |||||
| TRAIN = 'train' | |||||
| VAL = 'val' | |||||
| TEST = 'test' | |||||
| class DataLoader(): | |||||
| """ A class for loading the whole data needed for the run! """ | |||||
| def __init__(self, conf: 'BaseConfig', data_specification: str, run_type: RunType): | |||||
| """ Conf is a dictionary containing configurations, | |||||
| sample specification is a string specifying the samples e.g. address of the CTs | |||||
| run type is one of the strings train/val/test specifying the mode that the data_loader | |||||
| will be used.""" | |||||
| self.conf = conf | |||||
| self._device = conf.device | |||||
| self.data_specification = data_specification | |||||
| self.run_type = run_type | |||||
| self._content_loader: ContentLoader = self.conf.content_loader_cls(self.conf, data_specification) | |||||
| self.samples_names: np.ndarray = self._content_loader.get_samples_names() | |||||
| self.samples_labels = self._content_loader.get_samples_labels() | |||||
| keep_mask = self.get_samples_keep_mask() | |||||
| if keep_mask is not None: | |||||
| drop_mask = np.logical_not(keep_mask) | |||||
| self._content_loader.drop_samples(drop_mask) | |||||
| self.samples_names = self.samples_names[keep_mask] | |||||
| self.samples_labels = self.samples_labels[keep_mask] | |||||
| self.class_samples_indices = dict() | |||||
| print('%d samples' % len(self.samples_names), flush=True) | |||||
| for i in range(len(self.samples_names)): | |||||
| if self.samples_labels[i] not in self.class_samples_indices: | |||||
| self.class_samples_indices[self.samples_labels[i]] = [] | |||||
| self.class_samples_indices[self.samples_labels[i]].append(i) | |||||
| for c in self.class_samples_indices.keys(): | |||||
| self.class_samples_indices[c] = np.asarray(self.class_samples_indices[c]) | |||||
| # for augmentationa | |||||
| # Only do augmentations in the training phase | |||||
| self._different_augmentation_per_batch = conf.different_augmentation_per_batch | |||||
| self._augmentations_dict: Union[None, Dict[str, torch.nn.Module]] = None | |||||
| if run_type == RunType.TRAIN: | |||||
| self._augmentations_dict = conf.augmentations_dict | |||||
| # for preventing reprocessing | |||||
| self._processed_data_dictionary: Dict[str, torch.Tensor] = dict() | |||||
| # for preserving same augmentations in one batch | |||||
| self._transformations_seeds_dict: Dict[torch.nn.Module, Union[int, np.ndarray]] = dict() | |||||
| if run_type == RunType.TRAIN: | |||||
| self._batch_chooser: BatchChooser = self.conf.train_batch_chooser_cls(self, conf.batch_size) | |||||
| else: | |||||
| self._batch_chooser: BatchChooser = self.conf.eval_batch_chooser_cls(self, conf.batch_size) | |||||
| def get_samples_names(self) -> np.ndarray: | |||||
| """ Returns the names of the samples based on the first content loader. | |||||
| IMPORTANT: These all must be the same for all of the content loaders.""" | |||||
| return self.samples_names | |||||
| def get_samples_labels(self): | |||||
| """ Returns the labels of the samples in a numpy array. """ | |||||
| return self.samples_labels | |||||
| def get_number_of_samples(self): | |||||
| """ Returns the number of samples loaded. """ | |||||
| return len(self.samples_names) | |||||
| def get_class_sample_indices(self): | |||||
| """ Returns a dictionary, containing lists of samples indices belonging to each class label.""" | |||||
| return self.class_samples_indices | |||||
| def prepare_next_batch(self): | |||||
| # Resetting information | |||||
| self._processed_data_dictionary = dict() | |||||
| self._transformations_seeds_dict = dict() | |||||
| self._batch_chooser.prepare_next_batch() | |||||
| def finished_iteration(self) -> bool: | |||||
| return self._batch_chooser.finished_iteration() | |||||
| def fill_placeholders(self, placeholders_dict: Dict[str, torch.Tensor]) -> None: | |||||
| """ Receives as input a dictionary of placeholders and fills them using all the content loaders.""" | |||||
| missed_placeholders_names: List[str] | |||||
| if len(self._processed_data_dictionary) == 0: | |||||
| missed_placeholders_names = list(placeholders_dict.keys()) | |||||
| else: | |||||
| missed_placeholders_names = [] | |||||
| for k, v in placeholders_dict.items(): | |||||
| if k in self._processed_data_dictionary: | |||||
| placeholders_dict[k] = self._processed_data_dictionary[k] | |||||
| else: | |||||
| missed_placeholders_names.append(k) | |||||
| if len(missed_placeholders_names) == 0: | |||||
| return | |||||
| new_batch_info = self._content_loader.fill_placeholders( | |||||
| missed_placeholders_names, self._batch_chooser.get_current_batch_sample_indices()) | |||||
| # filling the unfilled ones | |||||
| if len(missed_placeholders_names) > 0: | |||||
| # filling all missed keys! | |||||
| with torch.no_grad(): | |||||
| for k in missed_placeholders_names: | |||||
| self._fill_placeholder(placeholders_dict[k], new_batch_info[k]) | |||||
| if self._augmentations_dict is not None and k in self._augmentations_dict: | |||||
| placeholders_dict[k] = self._apply_augmentation(k, placeholders_dict[k]) | |||||
| # Keeping a copy of data | |||||
| self._processed_data_dictionary[k] = placeholders_dict[k] | |||||
| def get_current_batch_data(self, keyword: str) -> torch.Tensor: | |||||
| """ Builds placeholder and retrieves the requirement from loaders and returns it! | |||||
| Args: | |||||
| keyword (str): The keyword to load for the current batch. | |||||
| Returns: | |||||
| torch.Tensor: The information related to the keyword for the current batch. | |||||
| """ | |||||
| if keyword in self._processed_data_dictionary: | |||||
| return self._processed_data_dictionary[keyword] | |||||
| else: | |||||
| placeholders_dict = {keyword: create_empty_placeholder(self._device)} | |||||
| self.fill_placeholders(placeholders_dict) | |||||
| return placeholders_dict[keyword] | |||||
| def get_current_batch_sample_indices(self) -> np.ndarray: | |||||
| """ Returns a list of indices of the samples chosen for the current batch. """ | |||||
| return self._batch_chooser.get_current_batch_sample_indices() | |||||
| def get_max_class_samples_num(self): | |||||
| return max([len(x) for x in self.class_samples_indices.values()]) | |||||
| def get_classes_num(self): | |||||
| return len(self.class_samples_indices) | |||||
| def get_samples_keep_mask(self) -> Union[None, np.ndarray]: | |||||
| if self.conf.mapped_labels_to_use is None: | |||||
| return None | |||||
| existing_dict = dict(zip(self.conf.mapped_labels_to_use, self.conf.mapped_labels_to_use)) | |||||
| keep_mask = np.vectorize(lambda x: x in existing_dict)(self.samples_labels) | |||||
| print(f'Kept {np.sum(keep_mask)} out of {len(keep_mask)} samples after considering the labels of interest.') | |||||
| return keep_mask | |||||
| def get_current_batch_size(self) -> int: | |||||
| return self._batch_chooser.get_current_batch_size() | |||||
| def get_current_batch_samples_names(self) -> np.ndarray: | |||||
| return self.samples_names[self._batch_chooser.get_current_batch_sample_indices()] | |||||
| def get_current_batch_samples_labels(self) -> np.ndarray: | |||||
| return self.samples_labels[self._batch_chooser.get_current_batch_sample_indices()] | |||||
| def get_current_batch_samples_interpretations(self) -> np.ndarray: | |||||
| return self.get_current_batch_data('interpretations') | |||||
| @staticmethod | |||||
| def _fill_placeholder(placeholder: torch.Tensor, val: np.ndarray): | |||||
| """ Fills the torch placeholder with the given numpy value, | |||||
| If shape mismatches, resizes the placeholder so the data would fit. """ | |||||
| # Resize if the shape mismatches | |||||
| if list(placeholder.shape) != list(val.shape): | |||||
| placeholder.resize_(*tuple(val.shape)) | |||||
| # feeding the value | |||||
| placeholder.copy_(torch.Tensor(val)) | |||||
| def _apply_augmentation(self, var_name: str, var_val: torch.Tensor) -> torch.Tensor: | |||||
| """ This function would be called for the variables we are sure they need augmentation | |||||
| (they are presented in augmentation dictionary), applies the specified augmentation, | |||||
| makes sure all augmentations be the same on the same batch elements, | |||||
| and returns the augmented data""" | |||||
| def run_single_aug(aug, inp): | |||||
| if type(aug) in self._transformations_seeds_dict: | |||||
| seed = self._transformations_seeds_dict[type(aug)] | |||||
| else: | |||||
| # setting a seed | |||||
| seed = np.random.randint(2147483647) # make a seed with numpy generator | |||||
| self._transformations_seeds_dict[type(aug)] = seed | |||||
| random.seed(seed) | |||||
| np.random.seed(seed) | |||||
| torch.manual_seed(seed) | |||||
| if self._different_augmentation_per_batch: | |||||
| ret_val = torch.stack([aug(inp[bi]) for bi in range(inp.shape[0])], dim=0) | |||||
| else: | |||||
| ret_val = aug(inp) | |||||
| return ret_val | |||||
| field_transformation = self._augmentations_dict[var_name] | |||||
| if not isinstance(field_transformation, torch.nn.Sequential): | |||||
| return run_single_aug(field_transformation, var_val) | |||||
| else: | |||||
| aug_val = var_val | |||||
| for single_transform in field_transformation.children(): | |||||
| aug_val = run_single_aug( | |||||
| single_transform, | |||||
| aug_val) | |||||
| return aug_val | |||||
| def create_empty_placeholder(device: torch.device) -> torch.Tensor: | |||||
| """ Create empty placeholder with shape (1) in the given device. | |||||
| Args: | |||||
| device (torch.device): Device to create the placeholder on. | |||||
| Returns: | |||||
| torch.Tensor: Empty placeholder. | |||||
| """ | |||||
| return torch.zeros(1, dtype=torch.float32, | |||||
| device=device, | |||||
| requires_grad=False) |
| """ | |||||
| Runs flow of data, from loading to forwarding the model. | |||||
| """ | |||||
| from typing import Iterator, Union, TypeVar, Generic, Dict | |||||
| import torch | |||||
| from ..data.data_loader import DataLoader, create_empty_placeholder | |||||
| from ..data.dataloader_context import DataloaderContext | |||||
| from ..models.model import Model | |||||
| TReturn = TypeVar('TReturn') | |||||
| class DataFlow(Generic[TReturn]): | |||||
| """ | |||||
| Runs flow of data, from loading to forwarding the model. | |||||
| """ | |||||
| def __init__(self, model: 'Model', dataloader: 'DataLoader', | |||||
| device: torch.device, print_debug_info:bool = False): | |||||
| self._model = model | |||||
| self._device = device | |||||
| self._dataloader = dataloader | |||||
| self._placeholders: Union[None, Dict[str, torch.Tensor]] = None | |||||
| self._print_debug_info = print_debug_info | |||||
| def iterate(self) -> Iterator[TReturn]: | |||||
| """ | |||||
| Iterates on data and forwards the model | |||||
| """ | |||||
| if self._placeholders is None: | |||||
| raise Exception("Please use `with` statement before calling iterate method:\n" + | |||||
| "with dataflow:\n" + | |||||
| " for model_output in dataflow.iterate():\n" + | |||||
| " pass\n") | |||||
| DataloaderContext.instance.dataloader = self._dataloader | |||||
| while True: | |||||
| self._dataloader.prepare_next_batch() | |||||
| if self._print_debug_info: | |||||
| print('> Next Batch:') | |||||
| print('\t' + '\n\t'.join(self._dataloader.get_current_batch_samples_names())) | |||||
| # check if the iteration is done | |||||
| if self._dataloader.finished_iteration(): | |||||
| break | |||||
| # Filling the placeholders for running the model | |||||
| self._dataloader.fill_placeholders(self._placeholders) | |||||
| # Running the model | |||||
| yield self._final_stage(self._placeholders) | |||||
| def __enter__(self): | |||||
| # self._dataloader.reset() | |||||
| self._placeholders = self._build_placeholders() | |||||
| def __exit__(self, exc_type, exc_value, traceback): | |||||
| for name in self._placeholders: | |||||
| self._placeholders[name] = None | |||||
| self._placeholders = None | |||||
| def _final_stage(self, model_input: Dict[str, torch.Tensor]) -> TReturn: | |||||
| return self._model(**model_input) | |||||
| def _build_placeholders(self) -> Dict[str, torch.Tensor]: | |||||
| """ Creates placeholders for feeding the model. | |||||
| Returns a dictionary containing placeholders. | |||||
| The placeholders are created based on the names | |||||
| of the input variables of the model's forward method. | |||||
| The variables initialized as None are assumed to be | |||||
| labels used for training phase only, and won't be | |||||
| added in phases other than train.""" | |||||
| model_forward_args, args_default_values, additional_kwargs = self._model.get_forward_required_kws_kwargs_and_defaults() | |||||
| # Checking which args have None as default | |||||
| args_default_values = ['Mandatory' for _ in | |||||
| range(len(model_forward_args) - len(args_default_values))] + \ | |||||
| args_default_values | |||||
| optional_args = [x is None for x in args_default_values] | |||||
| placeholders = dict() | |||||
| for argname in model_forward_args: | |||||
| placeholders[argname] = create_empty_placeholder(self._device) | |||||
| for argname in additional_kwargs: | |||||
| placeholders[argname] = create_empty_placeholder(self._device) | |||||
| return placeholders | |||||
| from typing import TYPE_CHECKING | |||||
| if TYPE_CHECKING: | |||||
| from .data_loader import DataLoader | |||||
| class classproperty(property): | |||||
| def __get__(self, obj, objtype=None): | |||||
| return super(classproperty, self).__get__(objtype) | |||||
| def __set__(self, obj, value): | |||||
| super(classproperty, self).__set__(type(obj), value) | |||||
| def __delete__(self, obj): | |||||
| super(classproperty, self).__delete__(type(obj)) | |||||
| class DataloaderContext: | |||||
| _instance: 'DataloaderContext' = None | |||||
| def __init__(self) -> None: | |||||
| self.dataloader: 'DataLoader' = None | |||||
| @classproperty | |||||
| def instance(cls) -> 'DataloaderContext': | |||||
| if cls._instance is None: | |||||
| cls._instance = DataloaderContext() | |||||
| return cls._instance |
| from typing import Type | |||||
| from .entrypoint import BaseEntrypoint | |||||
| from ..configs.base_config import PhaseType | |||||
| from importlib import import_module | |||||
| import sys | |||||
| def get_entry(phase_type: PhaseType): | |||||
| EntryPoint: Type[BaseEntrypoint] = import_module(f'torchlap.experiments.{sys.argv[1]}').EntryPoint | |||||
| return EntryPoint(phase_type) |
| from os import path | |||||
| from typing import TYPE_CHECKING | |||||
| import torch | |||||
| from ..models.model import Model | |||||
| if TYPE_CHECKING: | |||||
| from ..configs.base_config import BaseConfig | |||||
| def load_model(the_model: Model, conf: 'BaseConfig') -> Model: | |||||
| dev_name = conf.dev_name | |||||
| if conf.final_model_dir is not None: | |||||
| load_dir = conf.final_model_dir | |||||
| if not path.exists(load_dir): | |||||
| raise Exception('No path %s' % load_dir) | |||||
| else: | |||||
| load_dir = conf.save_dir | |||||
| print('>>> loading from: ' + load_dir, flush=True) | |||||
| if not path.exists(load_dir): | |||||
| raise Exception('Problem in loading the model. %s does not exist!' % load_dir) | |||||
| map_location = None | |||||
| if dev_name == 'cpu': | |||||
| map_location = 'cpu' | |||||
| elif dev_name is not None and ':' in dev_name: | |||||
| map_location = dev_name | |||||
| the_model.to(conf.device) | |||||
| if path.isfile(load_dir): | |||||
| if not path.exists(load_dir): | |||||
| raise Exception('Problem in loading the model. %s does not exist!' % load_dir) | |||||
| the_model.load_state_dict(torch.load(load_dir, map_location=map_location)) | |||||
| print('Loaded the model at %s' % load_dir, flush=True) | |||||
| else: | |||||
| if conf.epoch is not None: | |||||
| epoch = int(conf.epoch) | |||||
| elif path.exists(load_dir + '/GeneralInfo'): | |||||
| checkpoint = torch.load(load_dir + '/GeneralInfo', map_location=map_location) | |||||
| epoch = checkpoint['best_val_epoch'] | |||||
| print(f'Epoch has not been specified in the config, ' | |||||
| f'using {epoch} which is best_val_epoch instead') | |||||
| else: | |||||
| raise Exception('Either epoch or pt dir (final_model) must be given as GeneralInfo was not found') | |||||
| if not path.exists('%s/%d' % (load_dir, epoch)): | |||||
| raise Exception('Problem in loading the model. %s/%d does not exist!' % | |||||
| (load_dir, epoch)) | |||||
| checkpoint = torch.load('%s/%d' % (load_dir, epoch), | |||||
| map_location=map_location) | |||||
| # backward compatibility | |||||
| if 'model_state_dict' in checkpoint: | |||||
| checkpoint = checkpoint['model_state_dict'] | |||||
| the_model.load_state_dict(checkpoint) | |||||
| print(f'Loaded the model from epoch {epoch} of {load_dir}', flush=True) | |||||
| the_model.to(conf.device) | |||||
| return the_model |
| from typing import Tuple | |||||
| from ...configs.celeba_configs import CelebAConfigs, CelebATag | |||||
| from ...models.model import Model | |||||
| from ..entrypoint import BaseEntrypoint | |||||
| from ...models.celeba.single_tag_org_inception import CelebAORGInception | |||||
| class EntryPoint(BaseEntrypoint): | |||||
| def _get_conf_model(self) -> Tuple[CelebAConfigs, Model]: | |||||
| conf = CelebAConfigs('CelebA_Inception_Org', 8, 299, self.phase_type) | |||||
| conf.max_epochs = 12 | |||||
| conf.main_tag = CelebATag.SmilingTag | |||||
| conf.tags = [conf.main_tag.name] | |||||
| model = CelebAORGInception(conf.main_tag.name, 0.4) | |||||
| return conf, model | |||||
| from typing import Tuple | |||||
| from ...configs.celeba_configs import CelebAConfigs, CelebATag | |||||
| from ...models.model import Model | |||||
| from ..entrypoint import BaseEntrypoint | |||||
| from ...models.celeba.org_resnet import CelebAORGResNet18 | |||||
| class EntryPoint(BaseEntrypoint): | |||||
| def _get_conf_model(self) -> Tuple[CelebAConfigs, Model]: | |||||
| conf = CelebAConfigs('CelebA_ResNet_Org', 9, 224, self.phase_type) | |||||
| conf.max_epochs = 12 | |||||
| conf.main_tag = CelebATag.SmilingTag | |||||
| conf.tags = [conf.main_tag.name] | |||||
| model = CelebAORGResNet18(conf.main_tag.name) | |||||
| return conf, model | |||||
| from typing import Tuple | |||||
| from collections import OrderedDict | |||||
| from ...configs.celeba_configs import CelebAConfigs, CelebATag | |||||
| from ...models.model import Model | |||||
| from ..entrypoint import BaseEntrypoint | |||||
| from ...models.celeba.lap_inception import CelebALAPInception | |||||
| from ...modules.lap import LAP | |||||
| from ...modules.adaptive_lap import AdaptiveLAP | |||||
| from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator | |||||
| from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss | |||||
| from ...utils.aux_output import AuxOutput | |||||
| from ...utils.output_modifier import OutputModifier | |||||
| from ...model_evaluation.binary_evaluator import BinaryEvaluator | |||||
| from ...model_evaluation.binary_fortelling import BinForetellerEvaluator | |||||
| from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator | |||||
| from ...model_evaluation.loss_evaluator import LossEvaluator | |||||
| from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator | |||||
| def lap_factory(channel): | |||||
| return LAP( | |||||
| channel, 2, 2, hidden_channels=[8], n_attention_heads=3, | |||||
| sigmoid_scale=0.1, discriminative_attention=True) | |||||
| def adaptive_lap_factory(channel): | |||||
| return AdaptiveLAP(channel, [], 0.1, | |||||
| n_attention_heads=3, discriminative_attention=True) | |||||
| class EntryPoint(BaseEntrypoint): | |||||
| def __init__(self, phase_type) -> None: | |||||
| self.active_min_ratio: float = 0.02 | |||||
| self.active_max_ratio: float = 0.4 | |||||
| self.inactive_ratio: float = 0.01 | |||||
| self.common_max_ratio = 0.02 | |||||
| super().__init__(phase_type) | |||||
| def _get_conf_model(self) -> Tuple[CelebAConfigs, Model]: | |||||
| conf = CelebAConfigs('CelebA_Inception_WS', 10, 256, self.phase_type) | |||||
| conf.max_epochs = 12 | |||||
| conf.main_tag = CelebATag.SmilingTag | |||||
| conf.tags = [conf.main_tag.name] | |||||
| model = CelebALAPInception(conf.main_tag.name, 0.4, lap_factory, adaptive_lap_factory) | |||||
| # aux loss for free head | |||||
| fws_losses = [ | |||||
| DiscriminativeWeaklySupervisedLoss( | |||||
| title, model, att_score_layer, 0.025 / 3, 0, | |||||
| (0, self.common_max_ratio), | |||||
| {2: [0, 1]}, discr_score_layer, | |||||
| w_attention_in_ordering=1, w_discr_in_ordering=0) | |||||
| for title, att_score_layer, discr_score_layer in [ | |||||
| ('FPool2', model.maxpool2.attention_layer, model.maxpool2.discrimination_layer), | |||||
| ('FPool6', model.Mixed_6a.pool.attention_layer, model.Mixed_6a.pool.discrimination_layer), | |||||
| ('FPool7', model.Mixed_7a.pool.attention_layer, model.Mixed_7a.pool.discrimination_layer), | |||||
| ('FPoolG', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), | |||||
| ] | |||||
| ] | |||||
| for fws_loss in fws_losses: | |||||
| fws_loss.configure(conf) | |||||
| # weakly supervised losses based on discriminative head and attention head | |||||
| ws_losses = [ | |||||
| DiscriminativeWeaklySupervisedLoss( | |||||
| title, model, att_score_layer, 0.025 * 2 / 3, | |||||
| self.inactive_ratio, | |||||
| (self.active_min_ratio, self.active_max_ratio), {0: [0], 1: [1]}, | |||||
| discr_score_layer, | |||||
| w_attention_in_ordering=0.2, w_discr_in_ordering=1) | |||||
| for title, att_score_layer, discr_score_layer in [ | |||||
| ('Pool2', model.maxpool2.attention_layer, model.maxpool2.discrimination_layer), | |||||
| ('Pool6', model.Mixed_6a.pool.attention_layer, model.Mixed_6a.pool.discrimination_layer), | |||||
| ('Pool7', model.Mixed_7a.pool.attention_layer, model.Mixed_7a.pool.discrimination_layer), | |||||
| ('PoolG', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), | |||||
| ] | |||||
| ] | |||||
| for ws_loss in ws_losses: | |||||
| ws_loss.configure(conf) | |||||
| # concordance loss for attention | |||||
| concordance_loss = PoolConcordanceLossCalculator( | |||||
| 'AC', model, OrderedDict([ | |||||
| ('att2', model.maxpool2.attention_layer), | |||||
| ('att6', model.Mixed_6a.pool.attention_layer), | |||||
| ('att7', model.Mixed_7a.pool.attention_layer), | |||||
| ('attG', model.avgpool[0].attention_layer), | |||||
| ]), loss_weight=1, weights=0.1 / 4, diff_thresholds=0, | |||||
| labels_by_channel={0: [0], 1: [1]}) | |||||
| concordance_loss.configure(conf) | |||||
| # concordance loss for discrimination head | |||||
| concordance_loss2 = PoolConcordanceLossCalculator( | |||||
| 'DC', model, OrderedDict([ | |||||
| ('D-att2', model.maxpool2.discrimination_layer), | |||||
| ('D-att6', model.Mixed_6a.pool.discrimination_layer), | |||||
| ('D-att7', model.Mixed_7a.pool.discrimination_layer), | |||||
| ('D-attG', model.avgpool[0].discrimination_layer), | |||||
| ]), loss_weight=1, weights=0.05 / 4, diff_thresholds=0, | |||||
| labels_by_channel={0: [0], 1: [1]}) | |||||
| concordance_loss2.configure(conf) | |||||
| conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ | |||||
| 'b': BinaryEvaluator, | |||||
| 'l': LossEvaluator, | |||||
| 'f': BinForetellerEvaluator.standard_creator('foretell'), | |||||
| 'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), | |||||
| })) | |||||
| conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' | |||||
| ################################### | |||||
| ########### Foreteller ############ | |||||
| ################################### | |||||
| aux = AuxOutput(model, dict( | |||||
| foretell_pool2=model.maxpool2.attention_layer, | |||||
| foretell_pool6=model.Mixed_6a.pool.attention_layer, | |||||
| foretell_pool7=model.Mixed_7a.pool.attention_layer, | |||||
| foretell_avgpool=model.avgpool[0].attention_layer, | |||||
| )) | |||||
| aux.configure(conf) | |||||
| output_modifier = OutputModifier(model, | |||||
| lambda x: (x - 0.5).relu().flatten(2).sum(dim=2)[:, :2].argmax(dim=1), | |||||
| 'foretell_pool2', | |||||
| 'foretell_pool6', | |||||
| 'foretell_pool7', | |||||
| 'foretell_avgpool', | |||||
| ) | |||||
| output_modifier.configure(conf) | |||||
| return conf, model | |||||
| from typing import Tuple | |||||
| from collections import OrderedDict | |||||
| from ...configs.celeba_configs import CelebAConfigs, CelebATag | |||||
| from ...models.model import Model | |||||
| from ..entrypoint import BaseEntrypoint | |||||
| from ...models.celeba.lap_resnet import CelebALAPResNet18 | |||||
| from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator | |||||
| from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss | |||||
| from ...utils.aux_output import AuxOutput | |||||
| from ...utils.output_modifier import OutputModifier | |||||
| from ...model_evaluation.binary_evaluator import BinaryEvaluator | |||||
| from ...model_evaluation.binary_fortelling import BinForetellerEvaluator | |||||
| from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator | |||||
| from ...model_evaluation.loss_evaluator import LossEvaluator | |||||
| from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator | |||||
| class EntryPoint(BaseEntrypoint): | |||||
| def __init__(self, phase_type) -> None: | |||||
| self.active_min_ratio: float = 0.02 | |||||
| self.active_max_ratio: float = 0.4 | |||||
| self.inactive_ratio: float = 0.01 | |||||
| self.common_max_ratio = 0.02 | |||||
| super().__init__(phase_type) | |||||
| def _get_conf_model(self) -> Tuple[CelebAConfigs, Model]: | |||||
| conf = CelebAConfigs('CelebA_ResNet_WS', 11, 224, self.phase_type) | |||||
| conf.max_epochs = 12 | |||||
| conf.main_tag = CelebATag.SmilingTag | |||||
| conf.tags = [conf.main_tag.name] | |||||
| model = CelebALAPResNet18(conf.main_tag.name, sigmoid_scale=0.1) | |||||
| # aux loss for free head | |||||
| fws_losses = [ | |||||
| DiscriminativeWeaklySupervisedLoss( | |||||
| title, model, att_score_layer, 0.025 / 3, 0, | |||||
| (0, self.common_max_ratio), | |||||
| {2: [0, 1]}, discr_score_layer, | |||||
| w_attention_in_ordering=1, w_discr_in_ordering=0) | |||||
| for title, att_score_layer, discr_score_layer in [ | |||||
| ('Fatt2', model.layer2[0].pool.attention_layer, model.layer2[0].pool.discrimination_layer), | |||||
| ('Fatt3', model.layer3[0].pool.attention_layer, model.layer3[0].pool.discrimination_layer), | |||||
| ('Fatt4', model.layer4[0].pool.attention_layer, model.layer4[0].pool.discrimination_layer), | |||||
| ('Fatt5', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), | |||||
| ] | |||||
| ] | |||||
| for fws_loss in fws_losses: | |||||
| fws_loss.configure(conf) | |||||
| # weakly supervised losses based on discriminative head and attention head | |||||
| ws_losses = [ | |||||
| DiscriminativeWeaklySupervisedLoss( | |||||
| title, model, att_score_layer, 0.025 * 2 / 3, | |||||
| self.inactive_ratio, | |||||
| (self.active_min_ratio, self.active_max_ratio), {0: [0], 1: [1]}, | |||||
| discr_score_layer, | |||||
| w_attention_in_ordering=0.2, w_discr_in_ordering=1) | |||||
| for title, att_score_layer, discr_score_layer in [ | |||||
| ('att2', model.layer2[0].pool.attention_layer, model.layer2[0].pool.discrimination_layer), | |||||
| ('att3', model.layer3[0].pool.attention_layer, model.layer3[0].pool.discrimination_layer), | |||||
| ('att4', model.layer4[0].pool.attention_layer, model.layer4[0].pool.discrimination_layer), | |||||
| ('att5', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), | |||||
| ] | |||||
| ] | |||||
| for ws_loss in ws_losses: | |||||
| ws_loss.configure(conf) | |||||
| # concordance loss for attention | |||||
| concordance_loss = PoolConcordanceLossCalculator( | |||||
| 'AC', model, OrderedDict([ | |||||
| ('att2', model.layer2[0].pool.attention_layer), | |||||
| ('att3', model.layer3[0].pool.attention_layer), | |||||
| ('att4', model.layer4[0].pool.attention_layer), | |||||
| ('att5', model.avgpool[0].attention_layer), | |||||
| ]), loss_weight=1, weights=0.1 / 4, diff_thresholds=0, | |||||
| labels_by_channel={0: [0], 1: [1]}) | |||||
| concordance_loss.configure(conf) | |||||
| # concordance loss for discrimination head | |||||
| concordance_loss2 = PoolConcordanceLossCalculator( | |||||
| 'DC', model, OrderedDict([ | |||||
| ('att2', model.layer2[0].pool.discrimination_layer), | |||||
| ('att3', model.layer3[0].pool.discrimination_layer), | |||||
| ('att4', model.layer4[0].pool.discrimination_layer), | |||||
| ('att5', model.avgpool[0].discrimination_layer), | |||||
| ]), loss_weight=1, weights=0.05 / 4, diff_thresholds=0, | |||||
| labels_by_channel={0: [0], 1: [1]}) | |||||
| concordance_loss2.configure(conf) | |||||
| conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ | |||||
| 'b': BinaryEvaluator, | |||||
| 'l': LossEvaluator, | |||||
| 'f': BinForetellerEvaluator.standard_creator('foretell'), | |||||
| 'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), | |||||
| })) | |||||
| conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' | |||||
| ################################### | |||||
| ########### Foreteller ############ | |||||
| ################################### | |||||
| aux = AuxOutput(model, dict( | |||||
| foretell_pool2=model.layer2[0].pool.attention_layer, | |||||
| foretell_pool3=model.layer3[0].pool.attention_layer, | |||||
| foretell_pool4=model.layer4[0].pool.attention_layer, | |||||
| foretell_avgpool=model.avgpool[0].attention_layer, | |||||
| )) | |||||
| aux.configure(conf) | |||||
| output_modifier = OutputModifier(model, | |||||
| lambda x: (x - 0.5).relu().flatten(2).sum(dim=2)[:, :2].argmax(dim=1), | |||||
| 'foretell_pool2', | |||||
| 'foretell_pool3', | |||||
| 'foretell_pool4', | |||||
| 'foretell_avgpool', | |||||
| ) | |||||
| output_modifier.configure(conf) | |||||
| return conf, model | |||||
| import sys | |||||
| from typing import Tuple | |||||
| import argparse | |||||
| import os | |||||
| from abc import ABC, abstractmethod | |||||
| from ..models.model import Model | |||||
| from ..configs.base_config import BaseConfig, PhaseType | |||||
| class BaseEntrypoint(ABC): | |||||
| """Base class for all entrypoints. | |||||
| """ | |||||
| description = '' | |||||
| def __init__(self, phase_type: PhaseType) -> None: | |||||
| super().__init__() | |||||
| self.phase_type = phase_type | |||||
| self.conf, self.model = self._get_conf_model() | |||||
| self.parser = self._create_parser(sys.argv[0], sys.argv[1]) | |||||
| self.conf.update(self.parser.parse_args(sys.argv[2:])) | |||||
| def _create_parser(self, command: str, entrypoint_name: str) -> argparse.ArgumentParser: | |||||
| parser = argparse.ArgumentParser( | |||||
| prog=f'{os.path.basename(command)} {entrypoint_name}', | |||||
| description=self.description or None | |||||
| ) | |||||
| parser.add_argument('--device', default='cpu', type=str, | |||||
| help='Analysis will run over device, use cpu, cuda:#, cuda (for all gpus)') | |||||
| parser.add_argument('--samples-dir', default=None, type=str, | |||||
| help='The directory/dataGroup to be evaluated. In the case of dataGroup can be train/test/val. In the case of directory must contain 0 and 1 subdirectories. Use:FilterName to do over samples containing /FilterName/ in their path') | |||||
| parser.add_argument('--final-model-dir', default=None, type=str, | |||||
| help='The directory to load the model from, when not given will be calculated!') | |||||
| parser.add_argument('--save-dir', default=None, type=str, | |||||
| help='The directory to save the model from, when not given will be calculated!') | |||||
| parser.add_argument('--report-dir', default=None, type=str, | |||||
| help='The dir to save reports per slice per sample in.') | |||||
| parser.add_argument('--epoch', default=None, type=str, | |||||
| help='The epoch to load.') | |||||
| parser.add_argument('--try-name', default=None, type=str, | |||||
| help='The run name specifying what run is doing') | |||||
| parser.add_argument('--try-num', default=None, type=int, | |||||
| help='The try number to load') | |||||
| parser.add_argument('--data-separation', default=None, type=str, | |||||
| help='The data_separation to be used.') | |||||
| parser.add_argument('--batch-size', default=None, type=int, | |||||
| help='The batch size to be used.') | |||||
| parser.add_argument('--pretrained-model-file', default=None, type=str, | |||||
| help='Address of .pt pretrained model') | |||||
| parser.add_argument('--max-epochs', default=None, type=int, | |||||
| help='The maximum epochs of training!') | |||||
| parser.add_argument('--big-batch-size', default=None, type=int, | |||||
| help='The big batch size (iteration per optimization)!') | |||||
| parser.add_argument('--iters-per-epoch', default=None, type=int, | |||||
| help='The number of big batches per epoch!') | |||||
| parser.add_argument('--interpretation-method', default=None, type=str, | |||||
| help='The method used for interpreting the results!') | |||||
| parser.add_argument('--cut-threshold', default=None, type=float, | |||||
| help='The threshold for cutting interpretations!') | |||||
| parser.add_argument('--global-threshold', action='store_true', | |||||
| help='Whether the given cut threshold must be applied global or to the relative values!') | |||||
| parser.add_argument('--dynamic-threshold', action='store_true', | |||||
| help='Whether to use dynamic threshold in interpretation!') | |||||
| parser.add_argument('--class-label-for-interpretation', default=None, type=int, | |||||
| help='The class label we want to explain why it has been chosen. None means the decision of the model!') | |||||
| parser.add_argument('--interpret-predictions-vs-gt', default='1', type=(lambda x: x == '1'), | |||||
| help='If the class_label for_interpretation is None this will be considered. If 1, interpretations would be done for the predicted label, otherwise for the ground truth.') | |||||
| parser.add_argument('--mapped-labels-to-use', default=None, type=(lambda x: [int(y) for y in x.split(',')]), | |||||
| help='The labels to do the analysis on them only (comma separated), default is all the labels.') | |||||
| parser.add_argument('--skip-overlay', action='store_true', default=None, | |||||
| help='Passing this flag prevents the interpretation phase from storing overlay images.') | |||||
| parser.add_argument('--skip-raw', action='store_true', default=None, | |||||
| help='Passing this flag prevents the interpretation phase from storing the raw interpretation values.') | |||||
| parser.add_argument('--overlay-only', action='store_true', | |||||
| help='Passing this flag makes the interpretation phase to store just overlay images ' | |||||
| 'and not the `.npy` files.') | |||||
| parser.add_argument('--save-by-file-name', action='store_true', | |||||
| help='Saves sample-specific files by their file names only, not the whole path.') | |||||
| parser.add_argument('--n-interpretation-samples', default=None, type=int, | |||||
| help='The max number of samples to be interpreted.') | |||||
| parser.add_argument('--interpretation-tag-to-evaluate', default=None, type=str, | |||||
| help='The tag to be used as interpretation for evaluation phase, otherwise the first one will be used.') | |||||
| return parser | |||||
| @abstractmethod | |||||
| def _get_conf_model(self) -> Tuple[BaseConfig, Model]: | |||||
| pass |
| from typing import Tuple | |||||
| from torchlap.configs.imagenet_configs import ImagenetConfigs | |||||
| from ...models.model import Model | |||||
| from ..entrypoint import BaseEntrypoint | |||||
| from ...models.imagenet.lap_resnet import ImagenetLAPResNet50 | |||||
| class EntryPoint(BaseEntrypoint): | |||||
| def _get_conf_model(self) -> Tuple[ImagenetConfigs, Model]: | |||||
| config = ImagenetConfigs('ImagenetFT', 2, 224, self.phase_type) | |||||
| model = ImagenetLAPResNet50(sigmoid_scale=0.1) | |||||
| config.freezing_regexes = [ | |||||
| r'(?!^layer4\..*$)(?!^fc\..*$)(^.*$)' | |||||
| ] | |||||
| return config, model |
| from typing import Tuple | |||||
| from torchlap.configs.imagenet_configs import ImagenetConfigs | |||||
| from ...models.model import Model | |||||
| from ..entrypoint import BaseEntrypoint | |||||
| from ...models.imagenet.lap_resnet import ImagenetLAPResNet50 | |||||
| class EntryPoint(BaseEntrypoint): | |||||
| def _get_conf_model(self) -> Tuple[ImagenetConfigs, Model]: | |||||
| config = ImagenetConfigs('ImagenetNFT', 1, 224, self.phase_type) | |||||
| model = ImagenetLAPResNet50(sigmoid_scale=0.1) | |||||
| config.freezing_regexes = [ | |||||
| r'(?!^.*\.pool\..*$)(^.*$)' | |||||
| ] | |||||
| return config, model |
| from collections import OrderedDict | |||||
| from typing import Tuple | |||||
| from ...configs.rsna_configs import RSNAConfigs | |||||
| from ...models.model import Model | |||||
| from ..entrypoint import BaseEntrypoint | |||||
| from ...models.rsna.lap_inception import RSNALAPInception | |||||
| from ...modules.lap import LAP | |||||
| from ...modules.adaptive_lap import AdaptiveLAP | |||||
| from ...criteria.bb_supervised import BBLoss | |||||
| from ...utils.aux_output import AuxOutput | |||||
| from ...utils.output_modifier import OutputModifier | |||||
| from ...model_evaluation.binary_evaluator import BinaryEvaluator | |||||
| from ...model_evaluation.binary_fortelling import BinForetellerEvaluator | |||||
| from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator | |||||
| from ...model_evaluation.loss_evaluator import LossEvaluator | |||||
| from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator | |||||
| def lap_factory(channel): | |||||
| return LAP( | |||||
| channel, 2, 2, hidden_channels=[8], | |||||
| sigmoid_scale=0.1, discriminative_attention=False) | |||||
| def adaptive_lap_factory(channel): | |||||
| return AdaptiveLAP(channel, [], 0.1, discriminative_attention=False) | |||||
| class EntryPoint(BaseEntrypoint): | |||||
| def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: | |||||
| conf = RSNAConfigs('RSNA_Inception_BB', 6, 256, self.phase_type) | |||||
| model = RSNALAPInception(0.4, lap_factory, adaptive_lap_factory) | |||||
| bb_losses = BBLoss( | |||||
| 'BB', model, { | |||||
| 'Pool2': model.maxpool2.attention_layer, | |||||
| 'Pool6': model.Mixed_6a.pool.attention_layer, | |||||
| 'Pool7': model.Mixed_7a.pool.attention_layer, | |||||
| 'PoolG': model.avgpool[0].attention_layer, | |||||
| } | |||||
| , 1, | |||||
| 'interpretations', 1, 0.5) | |||||
| bb_losses.configure(conf) | |||||
| conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ | |||||
| 'b': BinaryEvaluator, | |||||
| 'l': LossEvaluator, | |||||
| 'f': BinForetellerEvaluator.standard_creator('foretell'), | |||||
| 'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), | |||||
| })) | |||||
| conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' | |||||
| ################################### | |||||
| ########### Foreteller ############ | |||||
| ################################### | |||||
| aux = AuxOutput(model, dict( | |||||
| foretell_pool2=model.maxpool2.attention_layer, | |||||
| foretell_pool6=model.Mixed_6a.pool.attention_layer, | |||||
| foretell_pool7=model.Mixed_7a.pool.attention_layer, | |||||
| foretell_avgpool=model.avgpool[0].attention_layer, | |||||
| )) | |||||
| aux.configure(conf) | |||||
| output_modifier = OutputModifier(model, | |||||
| lambda x: x.flatten(1).max(dim=-1).values, | |||||
| 'foretell_pool2', | |||||
| 'foretell_pool6', | |||||
| 'foretell_pool7', | |||||
| 'foretell_avgpool', | |||||
| ) | |||||
| output_modifier.configure(conf) | |||||
| return conf, model | |||||
| from collections import OrderedDict | |||||
| from typing import Tuple | |||||
| from ...utils.output_modifier import OutputModifier | |||||
| from ...utils.aux_output import AuxOutput | |||||
| from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator | |||||
| from ...model_evaluation.binary_evaluator import BinaryEvaluator | |||||
| from ...model_evaluation.binary_fortelling import BinForetellerEvaluator | |||||
| from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator | |||||
| from ...model_evaluation.loss_evaluator import LossEvaluator | |||||
| from ...criteria.bb_supervised import BBLoss | |||||
| from ...configs.rsna_configs import RSNAConfigs | |||||
| from ...models.model import Model | |||||
| from ..entrypoint import BaseEntrypoint | |||||
| from ...models.rsna.lap_resnet import RSNALAPResNet18, get_adaptive_pool_factory, get_pool_factory | |||||
| class EntryPoint(BaseEntrypoint): | |||||
| def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: | |||||
| config = RSNAConfigs('RSNA_ResNet_BB', 1, 224, self.phase_type) | |||||
| model = RSNALAPResNet18( | |||||
| pool_factory=get_pool_factory(discriminative_attention=False), | |||||
| adaptive_factory=get_adaptive_pool_factory(discriminative_attention=False), | |||||
| sigmoid_scale=0.1, | |||||
| ) | |||||
| bb_losses = BBLoss( | |||||
| 'BB', model, { | |||||
| 'att2': model.layer2[0].pool.attention_layer, | |||||
| 'att3': model.layer3[0].pool.attention_layer, | |||||
| 'att4': model.layer4[0].pool.attention_layer, | |||||
| 'att5': model.avgpool[0].attention_layer, | |||||
| } | |||||
| , 1, | |||||
| 'interpretations', 1, 0.5) | |||||
| bb_losses.configure(config) | |||||
| config.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ | |||||
| 'b': BinaryEvaluator, | |||||
| 'l': LossEvaluator, | |||||
| 'f': BinForetellerEvaluator.standard_creator('foretell'), | |||||
| 'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), | |||||
| })) | |||||
| config.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' | |||||
| ################################### | |||||
| ########### Foreteller ############ | |||||
| ################################### | |||||
| aux = AuxOutput(model, dict( | |||||
| foretell_pool2=model.layer2[0].pool.attention_layer, | |||||
| foretell_pool3=model.layer3[0].pool.attention_layer, | |||||
| foretell_pool4=model.layer4[0].pool.attention_layer, | |||||
| foretell_avgpool=model.avgpool[0].attention_layer, | |||||
| )) | |||||
| aux.configure(config) | |||||
| output_modifier = OutputModifier(model, | |||||
| lambda x: x.flatten(1).max(dim=-1).values, | |||||
| 'foretell_pool2', | |||||
| 'foretell_pool3', | |||||
| 'foretell_pool4', | |||||
| 'foretell_avgpool', | |||||
| ) | |||||
| output_modifier.configure(config) | |||||
| return config, model |
| from typing import Tuple | |||||
| from ...configs.rsna_configs import RSNAConfigs | |||||
| from ...models.model import Model | |||||
| from ..entrypoint import BaseEntrypoint | |||||
| from ...models.rsna.org_inception import RSNAORGInception | |||||
| class EntryPoint(BaseEntrypoint): | |||||
| def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: | |||||
| config = RSNAConfigs('RSNA_Inception_Org', 2, 299, self.phase_type) | |||||
| model = RSNAORGInception(0.4) | |||||
| return config, model | |||||
| from typing import Tuple | |||||
| from ...configs.rsna_configs import RSNAConfigs | |||||
| from ...models.model import Model | |||||
| from ..entrypoint import BaseEntrypoint | |||||
| from ...models.rsna.org_resnet import RSNAORGResNet18 | |||||
| class EntryPoint(BaseEntrypoint): | |||||
| def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: | |||||
| config = RSNAConfigs('RSNA_ResNet_Org', 1, 224, self.phase_type) | |||||
| model = RSNAORGResNet18() | |||||
| return config, model |
| from typing import Tuple | |||||
| from collections import OrderedDict | |||||
| from ...configs.rsna_configs import RSNAConfigs | |||||
| from ...models.model import Model | |||||
| from ..entrypoint import BaseEntrypoint | |||||
| from ...models.rsna.lap_inception import RSNALAPInception | |||||
| from ...modules.lap import LAP | |||||
| from ...modules.adaptive_lap import AdaptiveLAP | |||||
| from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator | |||||
| from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss | |||||
| from ...utils.aux_output import AuxOutput | |||||
| from ...utils.output_modifier import OutputModifier | |||||
| from ...model_evaluation.binary_evaluator import BinaryEvaluator | |||||
| from ...model_evaluation.binary_fortelling import BinForetellerEvaluator | |||||
| from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator | |||||
| from ...model_evaluation.loss_evaluator import LossEvaluator | |||||
| from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator | |||||
| def lap_factory(channel): | |||||
| return LAP( | |||||
| channel, 2, 2, hidden_channels=[8], | |||||
| sigmoid_scale=0.1, discriminative_attention=True) | |||||
| def adaptive_lap_factory(channel): | |||||
| return AdaptiveLAP(channel, sigmoid_scale=0.1, discriminative_attention=True) | |||||
| class EntryPoint(BaseEntrypoint): | |||||
| def __init__(self, phase_type) -> None: | |||||
| self.active_min_ratio: float = 0.1 | |||||
| self.active_max_ratio: float = 0.5 | |||||
| self.inactive_ratio: float = 0.1 | |||||
| super().__init__(phase_type) | |||||
| def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: | |||||
| conf = RSNAConfigs('RSNA_Inception_WS', 4, 256, self.phase_type) | |||||
| model = RSNALAPInception(0.4, lap_factory, adaptive_lap_factory) | |||||
| # weakly supervised losses based on discriminative head and attention head | |||||
| ws_losses = [ | |||||
| DiscriminativeWeaklySupervisedLoss( | |||||
| title, model, att_score_layer, 0.25, | |||||
| self.inactive_ratio, | |||||
| (self.active_min_ratio, self.active_max_ratio), {0: [1]}, | |||||
| discr_score_layer, | |||||
| w_attention_in_ordering=0.2, w_discr_in_ordering=1) | |||||
| for title, att_score_layer, discr_score_layer in [ | |||||
| ('Pool2', model.maxpool2.attention_layer, model.maxpool2.discrimination_layer), | |||||
| ('Pool6', model.Mixed_6a.pool.attention_layer, model.Mixed_6a.pool.discrimination_layer), | |||||
| ('Pool7', model.Mixed_7a.pool.attention_layer, model.Mixed_7a.pool.discrimination_layer), | |||||
| ('PoolG', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), | |||||
| ] | |||||
| ] | |||||
| for ws_loss in ws_losses: | |||||
| ws_loss.configure(conf) | |||||
| # concordance loss for attention | |||||
| concordance_loss = PoolConcordanceLossCalculator( | |||||
| 'AC', model, OrderedDict([ | |||||
| ('att2', model.maxpool2.attention_layer), | |||||
| ('att6', model.Mixed_6a.pool.attention_layer), | |||||
| ('att7', model.Mixed_7a.pool.attention_layer), | |||||
| ('attG', model.avgpool[0].attention_layer), | |||||
| ]), loss_weight=1, weights=1 / 4, diff_thresholds=0.1, | |||||
| labels_by_channel={0: [1]}) | |||||
| concordance_loss.configure(conf) | |||||
| # concordance loss for discrimination head | |||||
| concordance_loss2 = PoolConcordanceLossCalculator( | |||||
| 'DC', model, OrderedDict([ | |||||
| ('D-att2', model.maxpool2.discrimination_layer), | |||||
| ('D-att6', model.Mixed_6a.pool.discrimination_layer), | |||||
| ('D-att7', model.Mixed_7a.pool.discrimination_layer), | |||||
| ('D-attG', model.avgpool[0].discrimination_layer), | |||||
| ]), loss_weight=1, weights=0.5 / 4, diff_thresholds=0, | |||||
| labels_by_channel={0: [1]}) | |||||
| concordance_loss2.configure(conf) | |||||
| conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ | |||||
| 'b': BinaryEvaluator, | |||||
| 'l': LossEvaluator, | |||||
| 'f': BinForetellerEvaluator.standard_creator('foretell'), | |||||
| 'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), | |||||
| })) | |||||
| conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' | |||||
| ################################### | |||||
| ########### Foreteller ############ | |||||
| ################################### | |||||
| aux = AuxOutput(model, dict( | |||||
| foretell_pool2=model.maxpool2.attention_layer, | |||||
| foretell_pool6=model.Mixed_6a.pool.attention_layer, | |||||
| foretell_pool7=model.Mixed_7a.pool.attention_layer, | |||||
| foretell_avgpool=model.avgpool[0].attention_layer, | |||||
| )) | |||||
| aux.configure(conf) | |||||
| output_modifier = OutputModifier(model, | |||||
| lambda x: x.flatten(1).max(dim=-1).values, | |||||
| 'foretell_pool2', | |||||
| 'foretell_pool6', | |||||
| 'foretell_pool7', | |||||
| 'foretell_avgpool', | |||||
| ) | |||||
| output_modifier.configure(conf) | |||||
| return conf, model | |||||
| from collections import OrderedDict | |||||
| from typing import Tuple | |||||
| from ...utils.output_modifier import OutputModifier | |||||
| from ...utils.aux_output import AuxOutput | |||||
| from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator | |||||
| from ...model_evaluation.binary_evaluator import BinaryEvaluator | |||||
| from ...model_evaluation.binary_fortelling import BinForetellerEvaluator | |||||
| from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator | |||||
| from ...model_evaluation.loss_evaluator import LossEvaluator | |||||
| from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator | |||||
| from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss | |||||
| from ...configs.rsna_configs import RSNAConfigs | |||||
| from ...models.model import Model | |||||
| from ..entrypoint import BaseEntrypoint | |||||
| from ...models.rsna.lap_resnet import RSNALAPResNet18 | |||||
| class EntryPoint(BaseEntrypoint): | |||||
| def __init__(self, phase_type) -> None: | |||||
| self.active_min_ratio: float = 0.1 | |||||
| self.active_max_ratio: float = 0.5 | |||||
| self.inactive_ratio: float = 0.1 | |||||
| super().__init__(phase_type) | |||||
| def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: | |||||
| config = RSNAConfigs('RSNA_ResNet_WS', 1, 224, self.phase_type) | |||||
| model = RSNALAPResNet18() | |||||
| # weakly supervised losses based on discriminative head and attention head | |||||
| ws_losses = [ | |||||
| DiscriminativeWeaklySupervisedLoss( | |||||
| title, model, att_score_layer, 0.25, | |||||
| self.inactive_ratio, | |||||
| (self.active_min_ratio, self.active_max_ratio), {0: [1]}, | |||||
| discr_score_layer, | |||||
| w_attention_in_ordering=0.2, w_discr_in_ordering=1) | |||||
| for title, att_score_layer, discr_score_layer in [ | |||||
| ('att2', model.layer2[0].pool.attention_layer, model.layer2[0].pool.discrimination_layer), | |||||
| ('att3', model.layer3[0].pool.attention_layer, model.layer3[0].pool.discrimination_layer), | |||||
| ('att4', model.layer4[0].pool.attention_layer, model.layer4[0].pool.discrimination_layer), | |||||
| ('att5', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), | |||||
| ] | |||||
| ] | |||||
| for ws_loss in ws_losses: | |||||
| ws_loss.configure(config) | |||||
| # concordance loss for attention | |||||
| concordance_loss = PoolConcordanceLossCalculator( | |||||
| 'AC', model, OrderedDict([ | |||||
| ('att2', model.layer2[0].pool.attention_layer), | |||||
| ('att3', model.layer3[0].pool.attention_layer), | |||||
| ('att4', model.layer4[0].pool.attention_layer), | |||||
| ('att5', model.avgpool[0].attention_layer), | |||||
| ]), loss_weight=1, weights=1 / 4, diff_thresholds=0.1, | |||||
| labels_by_channel={0: [1]}) | |||||
| concordance_loss.configure(config) | |||||
| # concordance loss for discrimination head | |||||
| concordance_loss2 = PoolConcordanceLossCalculator( | |||||
| 'DC', model, OrderedDict([ | |||||
| ('att2', model.layer2[0].pool.discrimination_layer), | |||||
| ('att3', model.layer3[0].pool.discrimination_layer), | |||||
| ('att4', model.layer4[0].pool.discrimination_layer), | |||||
| ('att5', model.avgpool[0].discrimination_layer), | |||||
| ]), loss_weight=1, weights=0.5 / 4, diff_thresholds=0, | |||||
| labels_by_channel={0: [1]}) | |||||
| concordance_loss2.configure(config) | |||||
| config.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ | |||||
| 'b': BinaryEvaluator, | |||||
| 'l': LossEvaluator, | |||||
| 'f': BinForetellerEvaluator.standard_creator('foretell'), | |||||
| 'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), | |||||
| })) | |||||
| config.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' | |||||
| ################################### | |||||
| ########### Foreteller ############ | |||||
| ################################### | |||||
| aux = AuxOutput(model, dict( | |||||
| foretell_pool2=model.layer2[0].pool.attention_layer, | |||||
| foretell_pool3=model.layer3[0].pool.attention_layer, | |||||
| foretell_pool4=model.layer4[0].pool.attention_layer, | |||||
| foretell_avgpool=model.avgpool[0].attention_layer, | |||||
| )) | |||||
| aux.configure(config) | |||||
| output_modifier = OutputModifier(model, | |||||
| lambda x: x.flatten(1).max(dim=-1).values, | |||||
| 'foretell_pool2', | |||||
| 'foretell_pool3', | |||||
| 'foretell_pool4', | |||||
| 'foretell_avgpool', | |||||
| ) | |||||
| output_modifier.configure(config) | |||||
| return config, model |
| """ | |||||
| Interpretation Modules | |||||
| """ | |||||
| from .interpretable import InterpretableModel, CamInterpretableModel | |||||
| from .interpreter import Interpreter | |||||
| from .interpretable_wrapper import InterpretableWrapper | |||||
| from .interpreter_maker import create_interpreter, InterpretationType |
| from abc import ABC, abstractmethod | |||||
| from collections import OrderedDict as OrdDict | |||||
| from typing import Dict, List, Tuple, Union | |||||
| import torch | |||||
| from torch import nn | |||||
| import torch.nn.functional as F | |||||
| from .interpreter import Interpreter | |||||
| from .interpretable import InterpretableModel | |||||
| from ..modules import LAP | |||||
| from ..utils.hooker import Hooker, Hook | |||||
| AttentionDict = Dict[torch.Size, torch.Tensor] | |||||
| class AttentionInterpretableModel(InterpretableModel, ABC): | |||||
| @property | |||||
| @abstractmethod | |||||
| def attention_layers(self) -> Dict[str, List[LAP]]: | |||||
| """ | |||||
| List of attention groups | |||||
| """ | |||||
| class AttentionInterpreter(Interpreter): | |||||
| def __init__(self, model: AttentionInterpretableModel): | |||||
| assert isinstance(model, AttentionInterpretableModel),\ | |||||
| f"Expected AttentionInterpretableModel, got {type(model)}" | |||||
| super().__init__(model) | |||||
| layer_hooks = sum([ | |||||
| [ | |||||
| (layer.attention_layer, self._generate_forward_hook(group_name, layer)) | |||||
| for layer in group_layers | |||||
| ] for group_name, group_layers in model.attention_layers.items() | |||||
| ], []) | |||||
| self._hooker = Hooker(*layer_hooks) | |||||
| self._attention_groups: Dict[str, AttentionDict] = { | |||||
| group_name: {} for group_name in model.attention_layers} | |||||
| self._impact = 1. | |||||
| self._impact_decay = .8 | |||||
| self._input_size: torch.Size = None | |||||
| self._device = None | |||||
| def _generate_forward_hook(self, group_name: str, layer: LAP) -> Hook: | |||||
| def forward_hook(_: nn.Module, __: Tuple[torch.Tensor, ...], out: torch.Tensor) -> None: | |||||
| processed = out.clone() | |||||
| shape = processed.shape[2:] | |||||
| if shape not in self._attention_groups[group_name]: | |||||
| self._attention_groups[group_name][shape] = torch.zeros_like( | |||||
| processed) | |||||
| self._attention_groups[group_name][shape] += processed | |||||
| return forward_hook | |||||
| def _reset(self) -> None: | |||||
| self._attention_groups = { | |||||
| group_name: {} for group_name in self._model.attention_layers} | |||||
| self._impact = 1. | |||||
| def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor: | |||||
| """ Process current attention level. | |||||
| Args: | |||||
| result (torch.Tensor): Attention map till the previous attention level. (B, 1, H1, W1) | |||||
| attention (torch.Tensor): Current level's attention map. (B, 1, H2, W2) | |||||
| Returns: | |||||
| torch.Tensor: New attention map. (B, 1, H2, W2) | |||||
| """ | |||||
| self._impact *= self._impact_decay | |||||
| # smallest attention layer | |||||
| if result is None: | |||||
| return attention | |||||
| # attention is larger than result | |||||
| scale = torch.ceil( | |||||
| torch.tensor(attention.shape[2:]) / torch.tensor(result.shape[2:]) | |||||
| ).int().tolist() | |||||
| # find maximum of attention in each kernel | |||||
| max_map = F.max_pool2d(attention, kernel_size=scale, stride=scale) | |||||
| max_map = F.interpolate(max_map, attention.shape[2:], mode='nearest') | |||||
| is_any_active = max_map > 0 | |||||
| # interpolate result to attention size | |||||
| result = F.interpolate(result, attention.shape[2:], mode='nearest') | |||||
| # # maximum between result and attention | |||||
| # impacted = torch.max(result, max_map * self._impact) | |||||
| impacted = torch.max(result, attention * self._impact) | |||||
| # when attention is zero, we use zero | |||||
| impacted[attention == 0] = 0 | |||||
| # impacted = torch.where(attention > 0, impacted, 0) | |||||
| # when result is non-zero, we will use impacted | |||||
| impacted[result == 0] = 0 | |||||
| # impacted = torch.where(result > 0, impacted, 0) | |||||
| # if max_map is zero, we use result, else we use impacted | |||||
| result[is_any_active] = impacted[is_any_active] | |||||
| # result = torch.where(is_any_active, impacted, result) | |||||
| # mask = (max_map > 0) & (result > 0) | |||||
| # result[mask] = impacted[mask] | |||||
| # result[mask & (attention == 0)] = 0 | |||||
| return result | |||||
| def _process_attentions(self) -> Dict[str, torch.Tensor]: | |||||
| sorted_attentions = { | |||||
| group_name: OrdDict(sorted(group_attention.items(), | |||||
| key=lambda pair: pair[0])) | |||||
| for group_name, group_attention in self._attention_groups.items() | |||||
| } | |||||
| results: Dict[str, torch.Tensor] = {group_name: None for group_name in sorted_attentions} | |||||
| for group_name, group_attention in sorted_attentions.items(): | |||||
| self._impact = 1.0 | |||||
| # from smallest to largest | |||||
| for attention in group_attention.values(): | |||||
| attention = (attention - 0.5).relu() | |||||
| results[group_name] = self._process_attention(results[group_name], attention) | |||||
| interpretations = {} | |||||
| for group_name, group_result in results.items(): | |||||
| group_result = F.interpolate(group_result, self._input_size, mode='bilinear', align_corners=True) | |||||
| interpretations[group_name] = group_result | |||||
| ''' | |||||
| if group_result.shape[1] == 1: | |||||
| interpretations[group_name] = group_result | |||||
| else: | |||||
| interpretations.update({ | |||||
| f"{group_name}_{i}": group_result[:, i:i+1] | |||||
| for i in range(group_result.shape[1]) | |||||
| }) | |||||
| ''' | |||||
| return interpretations | |||||
| def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
| self._reset() | |||||
| self._batch_size = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[0] | |||||
| self._input_size = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[2:] | |||||
| self._device = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].device | |||||
| with self._hooker: | |||||
| self._model(**inputs) | |||||
| return self._process_attentions() |
| from typing import Dict | |||||
| from collections import OrderedDict as OrdDict | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from .attention_interpreter import AttentionInterpreter | |||||
| AttentionDict = Dict[torch.Size, torch.Tensor] | |||||
| class AttentionInterpreterSmoothIntegrator(AttentionInterpreter): | |||||
| def _process_attentions(self) -> Dict[str, torch.Tensor]: | |||||
| sorted_attentions = { | |||||
| group_name: OrdDict(sorted(group_attention.items(), | |||||
| key=lambda pair: pair[0])) | |||||
| for group_name, group_attention in self._attention_groups.items() | |||||
| } | |||||
| results: Dict[str, torch.Tensor] = {group_name: None for group_name in sorted_attentions} | |||||
| for group_name, group_attention in sorted_attentions.items(): | |||||
| self._impact = 1.0 | |||||
| # from smallest to largest | |||||
| sum_ = 0 | |||||
| for attention in group_attention.values(): | |||||
| if torch.is_tensor(sum_): | |||||
| sum_ = F.interpolate(sum_, size=attention.shape[-2:]) | |||||
| sum_ += attention | |||||
| attention = (attention - 0.5).relu() | |||||
| results[group_name] = self._process_attention(results[group_name], attention) | |||||
| results[group_name] = torch.where(results[group_name] > 0, sum_, sum_ / (2 * len(group_attention))) | |||||
| interpretations = {} | |||||
| for group_name, group_result in results.items(): | |||||
| group_result = F.interpolate(group_result, self._input_size, mode='bilinear', align_corners=True) | |||||
| interpretations[group_name] = group_result | |||||
| ''' | |||||
| if group_result.shape[1] == 1: | |||||
| interpretations[group_name] = group_result | |||||
| else: | |||||
| interpretations.update({ | |||||
| f"{group_name}_{i}": group_result[:, i:i+1] | |||||
| for i in range(group_result.shape[1]) | |||||
| }) | |||||
| ''' | |||||
| return interpretations |
| import torch | |||||
| import torch.nn.functional as F | |||||
| from .attention_interpreter import AttentionInterpreter, AttentionInterpretableModel | |||||
| class AttentionSumInterpreter(AttentionInterpreter): | |||||
| def __init__(self, model: AttentionInterpretableModel): | |||||
| assert isinstance(model, AttentionInterpretableModel),\ | |||||
| f"Expected AttentionInterpretableModel, got {type(model)}" | |||||
| super().__init__(model) | |||||
| def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor: | |||||
| """ Process current attention level. | |||||
| Args: | |||||
| result (torch.Tensor): Attention map till the previous attention level. (B, 1, H1, W1) | |||||
| attention (torch.Tensor): Current level's attention map. (B, 1, H2, W2) | |||||
| Returns: | |||||
| torch.Tensor: New attention map. (B, 1, H2, W2) | |||||
| """ | |||||
| self._impact *= self._impact_decay | |||||
| # smallest attention layer | |||||
| if result is None: | |||||
| return attention | |||||
| # attention is larger than result | |||||
| result = F.interpolate(result, attention.shape[2:], mode='bilinear', align_corners=True) | |||||
| return result + attention |
| from typing import List, TYPE_CHECKING, Tuple | |||||
| import numpy as np | |||||
| from skimage import measure | |||||
| from .utils import process_interpretations | |||||
| if TYPE_CHECKING: | |||||
| from ..configs.base_config import BaseConfig | |||||
| from . import Interpreter | |||||
| class BinaryInterpretationEvaluator2D: | |||||
| def __init__(self, n_samples: int, config: 'BaseConfig'): | |||||
| self._config = config | |||||
| self._n_samples = n_samples | |||||
| self._min_intersection_threshold = config.acceptable_min_intersection_threshold | |||||
| self._normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||||
| self._normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||||
| self._tk_normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||||
| self._tk_normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||||
| def reset(self): | |||||
| self._normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||||
| self._normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||||
| self._tk_normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||||
| self._tk_normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||||
| def update_summaries( | |||||
| self, | |||||
| m_interpretations: np.ndarray, ground_truth_interpretations: np.ndarray, | |||||
| net_preds: np.ndarray, ground_truth_labels: np.ndarray, | |||||
| batch_inds: np.ndarray, interpreter: 'Interpreter' | |||||
| ) -> None: | |||||
| assert len(ground_truth_interpretations.shape) == 3, f'GT interpretations must have a shape of BxWxH but it is {ground_truth_interpretations.shape}' | |||||
| assert len(m_interpretations.shape) == 3, f'Model interpretations must have a shape of BxWxH but it is {m_interpretations.shape}' | |||||
| # skipping samples without interpretations! | |||||
| has_interpretation_mask = np.logical_not(np.any(np.isnan(ground_truth_interpretations), axis=(1, 2))) | |||||
| m_interpretations = m_interpretations[has_interpretation_mask] | |||||
| ground_truths = ground_truth_interpretations[has_interpretation_mask] | |||||
| net_preds[has_interpretation_mask] | |||||
| # finding class labels | |||||
| if len(net_preds.shape) == 1: | |||||
| net_preds = (net_preds >= 0.5).astype(int) | |||||
| elif net_preds.shape[1] == 1: | |||||
| net_preds = (net_preds[:, 0] >= 0.5).astype(int) | |||||
| else: | |||||
| net_preds = net_preds.argmax(axis=-1) | |||||
| ground_truth_labels = ground_truth_labels[has_interpretation_mask] | |||||
| batch_inds = batch_inds[has_interpretation_mask] | |||||
| # Checking shapes | |||||
| if net_preds.shape == ground_truth_labels.shape: | |||||
| net_preds = np.round(net_preds, 0).astype(int) | |||||
| else: | |||||
| net_preds = np.argmax(net_preds, axis=1) | |||||
| # calculating soundness | |||||
| soundnesses = (net_preds == ground_truth_labels).astype(int) | |||||
| c_interpretations = np.clip(m_interpretations, 0, np.amax(m_interpretations)) | |||||
| b_interpretations = np.stack(tuple([ | |||||
| process_interpretations(m_interpretations[ind][None, ...], self._config, interpreter) | |||||
| for ind in range(len(m_interpretations)) | |||||
| ]), axis=0)[:, 0, ...] | |||||
| b_interpretations = (b_interpretations > 0).astype(bool) | |||||
| ground_truths = (ground_truths >= 0.5).astype(bool) #making sure values are 0 and 1 even if resize has been applied | |||||
| assert ground_truths.shape[-2:] == b_interpretations.shape[-2:], f'Ground truth and model interpretations must have the same shape, found {ground_truths.shape[-2:]} and {b_interpretations.shape[-2:]}' | |||||
| norm_factor = 1.0 * b_interpretations.shape[1] * b_interpretations.shape[2] | |||||
| np.add.at(self._normed_intersection_per_soundness, soundnesses, | |||||
| np.sum(b_interpretations & ground_truths, axis=(1, 2)) * 1.0 / norm_factor) | |||||
| np.add.at(self._normed_union_per_soundness, soundnesses, | |||||
| np.sum(b_interpretations | ground_truths, axis=(1, 2)) * 1.0 / norm_factor) | |||||
| for i in range(len(b_interpretations)): | |||||
| has_nonzero_captured_bbs = False | |||||
| has_nonzero_captured_bbs_by_topk = False | |||||
| s = soundnesses[i] | |||||
| org_labels = measure.label(ground_truths[i, :, :]) | |||||
| check_labels = measure.label(b_interpretations[i, :, :]) | |||||
| # keeping topK interpretations with k = n_GT! = finding a threshold by quantile! calculating quantile by GT | |||||
| n_on_gt = np.sum(ground_truths[i]) | |||||
| q = (1 + n_on_gt) * 1.0 / (ground_truths.shape[-1] * ground_truths.shape[-2]) | |||||
| # 1 is added because we have > in thresholding not >= | |||||
| if q < 1: | |||||
| tints = c_interpretations[i] | |||||
| th = max(0, np.quantile(tints.reshape(-1), 1 - q)) | |||||
| tints = (tints > th) | |||||
| else: | |||||
| tints = (c_interpretations[i] > 0) | |||||
| # TOPK METRICS | |||||
| tk_intersection = np.sum(tints & ground_truths[i]) | |||||
| tk_union = np.sum(tints | ground_truths[i]) | |||||
| self._tk_normed_intersection_per_soundness[s] += tk_intersection * 1.0 / norm_factor | |||||
| self._tk_normed_union_per_soundness[s] += tk_union * 1.0 / norm_factor | |||||
| @staticmethod | |||||
| def get_titles_of_evaluation_metrics() -> List[str]: | |||||
| return ['S-IOU', 'S-TK-IOU', | |||||
| 'M-IOU', 'M-TK-IOU', | |||||
| 'A-IOU', 'A-TK-IOU'] | |||||
| @staticmethod | |||||
| def _get_eval_metrics(normed_intersection, normed_union, title, | |||||
| tk_normed_intersection, tk_normed_union) -> \ | |||||
| Tuple[str, str, str, str, str]: | |||||
| iou = (1e-6 + normed_intersection) / (1e-6 + normed_union) | |||||
| tk_iou = (1e-6 + tk_normed_intersection) / (1e-6 + tk_normed_union) | |||||
| return '%.4f' % (iou * 100,), '%.4f' % (tk_iou * 100,) | |||||
| def get_values_of_evaluation_metrics(self) -> List[str]: | |||||
| return \ | |||||
| list(self._get_eval_metrics( | |||||
| self._normed_intersection_per_soundness[1], | |||||
| self._normed_union_per_soundness[1], | |||||
| 'Sounds', | |||||
| self._tk_normed_intersection_per_soundness[1], | |||||
| self._tk_normed_union_per_soundness[1], | |||||
| )) + \ | |||||
| list(self._get_eval_metrics( | |||||
| self._normed_intersection_per_soundness[0], | |||||
| self._normed_union_per_soundness[0], | |||||
| 'Mistakes', | |||||
| self._tk_normed_intersection_per_soundness[0], | |||||
| self._tk_normed_union_per_soundness[0], | |||||
| )) + \ | |||||
| list(self._get_eval_metrics( | |||||
| sum(self._normed_intersection_per_soundness), | |||||
| sum(self._normed_union_per_soundness), | |||||
| 'All', | |||||
| sum(self._tk_normed_intersection_per_soundness), | |||||
| sum(self._tk_normed_union_per_soundness), | |||||
| )) | |||||
| def print_summaries(self): | |||||
| titles = self.get_titles_of_evaluation_metrics() | |||||
| vals = self.get_values_of_evaluation_metrics() | |||||
| nc = 2 | |||||
| for r in range(3): | |||||
| print(', '.join(['%s: %s' % (titles[nc * r + i], vals[nc * r + i]) for i in range(nc)])) |
| from typing import Dict, Union, Tuple | |||||
| import torch | |||||
| from . import Interpreter | |||||
| from ..data.dataflow import DataFlow | |||||
| from ..models.model import ModelIO | |||||
| from ..data.data_loader import DataLoader | |||||
| class InterpretationDataFlow(DataFlow[Tuple[ModelIO, Dict[str, torch.Tensor]]]): | |||||
| def __init__(self, interpreter: Interpreter, | |||||
| dataloader: DataLoader, | |||||
| phase: str, | |||||
| device: torch.device, | |||||
| dtype: torch.dtype, | |||||
| print_debug_info: bool, | |||||
| label_for_interpreter: Union[int, None], | |||||
| give_gt_as_label: bool): | |||||
| super().__init__(interpreter._model, dataloader, phase, device, | |||||
| dtype, print_debug_info=print_debug_info) | |||||
| self._interpreter = interpreter | |||||
| self._label_for_interpreter = label_for_interpreter | |||||
| self._give_gt_as_label = give_gt_as_label | |||||
| self._sample_labels = dataloader.get_samples_labels() | |||||
| @staticmethod | |||||
| def _detach_output(output: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |||||
| return { | |||||
| name: o.detach() for name, o in output.items() | |||||
| } | |||||
| def _final_stage(self, model_input: ModelIO) -> Tuple[ModelIO, Dict[str, torch.Tensor]]: | |||||
| if self._give_gt_as_label: | |||||
| return model_input, self._detach_output(self._interpreter.interpret( | |||||
| torch.Tensor(self._sample_labels[ | |||||
| self._dataloader.get_current_batch_sample_indices()]), | |||||
| **model_input)) | |||||
| return model_input, self._detach_output(self._interpreter.interpret( | |||||
| self._label_for_interpreter, **model_input | |||||
| )) |
| """ | |||||
| Guided Grad Cam Interpretation | |||||
| """ | |||||
| from typing import Dict, Union | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from captum import attr | |||||
| from . import Interpreter, InterpretableModel, InterpretableWrapper | |||||
| class DeepLift(Interpreter): # pylint: disable=too-few-public-methods | |||||
| """ | |||||
| Produces class activation map | |||||
| """ | |||||
| def __init__(self, model: InterpretableModel): | |||||
| super().__init__(model) | |||||
| self._model_wrapper = InterpretableWrapper(model) | |||||
| self._interpreter = attr.DeepLift(self._model_wrapper) | |||||
| self.__B = None | |||||
| def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
| ''' | |||||
| interprets given input and target class | |||||
| ''' | |||||
| B = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[0] | |||||
| if self.__B is None: | |||||
| self.__B = B | |||||
| if B < self.__B: | |||||
| padding = self.__B - B | |||||
| inputs = { | |||||
| k: F.pad(v, (*([0] * (v.ndim * 2 - 1)), padding)) | |||||
| for k, v in inputs.items() | |||||
| if torch.is_tensor(v) and v.shape[0] == B | |||||
| } | |||||
| if torch.is_tensor(labels) and labels.shape[0] == B: | |||||
| labels = F.pad(labels, (0, padding)) | |||||
| labels = self._get_target(labels, **inputs) | |||||
| separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs) | |||||
| result = self._interpreter.attribute( | |||||
| separated_inputs.inputs, | |||||
| target=labels, | |||||
| additional_forward_args=separated_inputs.additional_inputs)[0] | |||||
| result = result[:B] | |||||
| return { | |||||
| 'default': result | |||||
| } # TODO: must return and support tuple |
| from typing import Dict, Union | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from captum import attr | |||||
| from . import Interpreter, CamInterpretableModel, InterpretableWrapper | |||||
| class GradCam(Interpreter): | |||||
| def __init__(self, model: CamInterpretableModel): | |||||
| super().__init__(model) | |||||
| self.__model_wrapper = InterpretableWrapper(model) | |||||
| self.__interpreters = [ | |||||
| attr.LayerGradCam(self.__model_wrapper, conv) | |||||
| for conv in model.target_conv_layers | |||||
| ] | |||||
| def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
| inp_shape = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[2:] | |||||
| labels = self._get_target(labels, **inputs) | |||||
| separated_inputs = self.__model_wrapper.convert_inputs_to_args(**inputs) | |||||
| gradcams = [interpreter.attribute( | |||||
| separated_inputs.inputs, | |||||
| target=labels, | |||||
| additional_forward_args=separated_inputs.additional_inputs) | |||||
| for interpreter in self.__interpreters] | |||||
| gradcams = [ | |||||
| cam if torch.is_tensor(cam) else cam[0] | |||||
| for cam in gradcams | |||||
| ] | |||||
| gradcams = torch.stack(gradcams).sum(dim=0) | |||||
| gradcams = F.interpolate(gradcams, size=inp_shape, mode='bilinear', align_corners=True) | |||||
| return { | |||||
| 'default': gradcams, | |||||
| } # TODO: must return and support tuple |
| """ | |||||
| Guided Backprop Interpretation | |||||
| """ | |||||
| from typing import Dict, Union | |||||
| import torch | |||||
| from captum import attr | |||||
| from . import Interpreter, InterpretableModel, InterpretableWrapper | |||||
| class GuidedBackprop(Interpreter): # pylint: disable=too-few-public-methods | |||||
| """ | |||||
| Produces class activation map | |||||
| """ | |||||
| def __init__(self, model: InterpretableModel): | |||||
| super().__init__(model) | |||||
| self._model_wrapper = InterpretableWrapper(model) | |||||
| self._interpreter = attr.GuidedBackprop(self._model_wrapper) | |||||
| def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
| ''' | |||||
| interprets given input and target class | |||||
| ''' | |||||
| labels = self._get_target(labels, **inputs) | |||||
| separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs) | |||||
| return { | |||||
| 'default': self._interpreter.attribute( | |||||
| separated_inputs.inputs, | |||||
| target=labels, | |||||
| additional_forward_args=separated_inputs.additional_inputs)[0] | |||||
| } # TODO: must return and support tuple |
| """ | |||||
| Guided Grad Cam Interpretation | |||||
| """ | |||||
| from typing import Dict, Union | |||||
| import torch | |||||
| from captum import attr | |||||
| from . import Interpreter, CamInterpretableModel, InterpretableWrapper | |||||
| class GuidedGradCam(Interpreter): | |||||
| """ Produces class activation map """ | |||||
| def __init__(self, model: CamInterpretableModel): | |||||
| super().__init__(model) | |||||
| self._model_wrapper = InterpretableWrapper(model) | |||||
| self._interpreters = [ | |||||
| attr.GuidedGradCam(self._model_wrapper, conv) | |||||
| for conv in model.target_conv_layers | |||||
| ] | |||||
| def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
| """ Interprets given input and target class | |||||
| Args: | |||||
| y (Union[int, torch.Tensor]): target class | |||||
| **inputs (torch.Tensor): model inputs | |||||
| Returns: | |||||
| Dict[str, torch.Tensor]: Interpretation results | |||||
| """ | |||||
| labels = self._get_target(labels, **inputs) | |||||
| separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs) | |||||
| gradcams = [interpreter.attribute( | |||||
| separated_inputs.inputs, | |||||
| target=labels, | |||||
| additional_forward_args=separated_inputs.additional_inputs)[0] | |||||
| for interpreter in self._interpreters] | |||||
| gradcams = torch.stack(gradcams).sum(dim=0) | |||||
| return { | |||||
| 'default': gradcams, | |||||
| } # TODO: must return and support tuple |
| from typing import Callable, Type | |||||
| import torch | |||||
| from ..models.model import ModelIO | |||||
| from ..utils.hooker import Hook, Hooker | |||||
| from ..data.data_loader import DataLoader | |||||
| from ..data.dataloader_context import DataloaderContext | |||||
| from .interpreter import Interpreter | |||||
| from .attention_interpreter import AttentionInterpretableModel, AttentionInterpreter | |||||
| class ImagenetPredictionInterpreter(Interpreter): | |||||
| def __init__(self, model: AttentionInterpretableModel, k: int, base_interpreter: Type[AttentionInterpreter]): | |||||
| super().__init__(model) | |||||
| self.__base_interpreter = base_interpreter(model) | |||||
| self.__k = k | |||||
| self.__topk_classes: torch.Tensor = None | |||||
| self.__base_interpreter._hooker = Hooker( | |||||
| (model, self._generate_prediction_hook()), | |||||
| *[(attention_layer[-2], hook) | |||||
| for attention_layer, hook in self.__base_interpreter._hooker._layer_hook_pairs]) | |||||
| @staticmethod | |||||
| def standard_factory(k: int, base_interpreter: Type[AttentionInterpreter] = AttentionInterpreter) -> Callable[[AttentionInterpretableModel], 'ImagenetPredictionInterpreter']: | |||||
| return lambda model: ImagenetPredictionInterpreter(model, k, base_interpreter) | |||||
| def _generate_prediction_hook(self) -> Hook: | |||||
| def hook(_, __, output: ModelIO): | |||||
| topk = output['categorical_probability'].detach()\ | |||||
| .topk(self.__k, dim=-1) | |||||
| dataloader: DataLoader = DataloaderContext.instance.dataloader | |||||
| sample_names = dataloader.get_current_batch_samples_names() | |||||
| for name, top_class, top_prob in zip(sample_names, topk.indices, topk.values): | |||||
| print(f'Top classes of ' | |||||
| f'{name}: ' | |||||
| f'{top_class.detach().cpu().numpy().tolist()} - ' | |||||
| f'{top_prob.cpu().numpy().tolist()}', flush=True) | |||||
| self.__topk_classes = topk\ | |||||
| .indices\ | |||||
| .flatten()\ | |||||
| .cpu() | |||||
| return hook | |||||
| def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor: | |||||
| batch_indices = torch.arange(attention.shape[0]).repeat_interleave(self.__k) | |||||
| attention = attention[batch_indices, self.__topk_classes] | |||||
| attention = attention.reshape(-1, self.__k, *attention.shape[-2:]) | |||||
| return self.__base_interpreter._process_attention(result, attention) | |||||
| def interpret(self, labels, **inputs): | |||||
| return self.__base_interpreter.interpret(labels, **inputs) |
| """ | |||||
| An Interpretable Pytorch Model | |||||
| """ | |||||
| from abc import abstractmethod, ABC | |||||
| from typing import Iterable, List | |||||
| import torch | |||||
| from torch import nn | |||||
| from ..models.model import Model | |||||
| class InterpretableModel(Model, ABC): | |||||
| """ | |||||
| An Interpretable Pytorch Model | |||||
| """ | |||||
| @property | |||||
| @abstractmethod | |||||
| def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]: | |||||
| """ | |||||
| Returns: | |||||
| Input module for interpretation | |||||
| """ | |||||
| @abstractmethod | |||||
| def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor: | |||||
| """ | |||||
| A method to get probabilities assigned to all the classes in the model's forward, | |||||
| with shape (B, C), in which B is the batch size and C is number of classes | |||||
| Args: | |||||
| *inputs: Inputs to the model | |||||
| **kwargs: Additional arguments | |||||
| Returns: | |||||
| Tensor of categorical probabilities | |||||
| """ | |||||
| class CamInterpretableModel(InterpretableModel, ABC): | |||||
| """ | |||||
| A model interpretable by gradcam | |||||
| """ | |||||
| @property | |||||
| @abstractmethod | |||||
| def target_conv_layers(self) -> List[nn.Module]: | |||||
| """ | |||||
| Returns: | |||||
| The convolutional layers to be interpreted. The result of | |||||
| the interpretation will be sum of the grad-cam of these layers | |||||
| """ |
| """ | |||||
| An adapter wrapper to adapt our models to Captum interpreters | |||||
| """ | |||||
| import inspect | |||||
| from typing import Iterable, Dict, Tuple | |||||
| from dataclasses import dataclass | |||||
| from itertools import chain | |||||
| import torch | |||||
| from torch import nn | |||||
| from . import InterpretableModel | |||||
| @dataclass | |||||
| class InterpreterArgs: | |||||
| inputs: Tuple[torch.Tensor, ...] | |||||
| additional_inputs: Tuple[torch.Tensor, ...] | |||||
| class InterpretableWrapper(nn.Module): | |||||
| """ | |||||
| An adapter wrapper to adapt our models to Captum interpreters | |||||
| """ | |||||
| def __init__(self, model: InterpretableModel): | |||||
| super().__init__() | |||||
| self._model = model | |||||
| @property | |||||
| def _additional_names(self) -> Iterable[str]: | |||||
| to_be_interpreted_names = self._model.ordered_placeholder_names_to_be_interpreted | |||||
| signature = inspect.signature(self._model.forward) | |||||
| return [name for name in signature.parameters | |||||
| if name not in to_be_interpreted_names | |||||
| and signature.parameters[name].default is not None] | |||||
| def convert_inputs_to_kwargs(self, *args: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
| """ | |||||
| Converts an ordered *args to **kwargs | |||||
| """ | |||||
| to_be_interpreted_names = self._model.ordered_placeholder_names_to_be_interpreted | |||||
| additional_names = self._additional_names | |||||
| inputs = {} | |||||
| for i, name in enumerate(chain(to_be_interpreted_names, additional_names)): | |||||
| inputs[name] = args[i] | |||||
| return inputs | |||||
| def convert_inputs_to_args(self, **kwargs: torch.Tensor) -> InterpreterArgs: | |||||
| """ | |||||
| Converts a **kwargs to ordered *args | |||||
| """ | |||||
| return InterpreterArgs( | |||||
| tuple(kwargs[name] for name in self._model.ordered_placeholder_names_to_be_interpreted | |||||
| if name in kwargs), | |||||
| tuple(kwargs[name] for name in self._additional_names | |||||
| if name in kwargs) | |||||
| ) | |||||
| def forward(self, *args: torch.Tensor) -> torch.Tensor: | |||||
| """ | |||||
| Forwards the model | |||||
| """ | |||||
| inputs = self.convert_inputs_to_kwargs(*args) | |||||
| return self._model.get_categorical_probabilities(**inputs) |
| from typing import Dict, Union, Tuple | |||||
| import torch | |||||
| from .interpreter import Interpreter | |||||
| from ..data.dataflow import DataFlow | |||||
| from ..data.data_loader import DataLoader | |||||
| class InterpretationDataFlow(DataFlow[Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]]): | |||||
| def __init__(self, interpreter: Interpreter, | |||||
| dataloader: DataLoader, | |||||
| device: torch.device, | |||||
| print_debug_info: bool, | |||||
| label_for_interpreter: Union[int, None], | |||||
| give_gt_as_label: bool): | |||||
| super().__init__(interpreter._model, dataloader, device, print_debug_info=print_debug_info) | |||||
| self._interpreter = interpreter | |||||
| self._label_for_interpreter = label_for_interpreter | |||||
| self._give_gt_as_label = give_gt_as_label | |||||
| self._sample_labels = dataloader.get_samples_labels() | |||||
| @staticmethod | |||||
| def _detach_output(output: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |||||
| if output is None: | |||||
| return None | |||||
| return { | |||||
| name: o.detach() for name, o in output.items() | |||||
| } | |||||
| def _final_stage(self, model_input: Dict[str, torch.Tensor]) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: | |||||
| if self._give_gt_as_label: | |||||
| return model_input, self._detach_output(self._interpreter.interpret( | |||||
| torch.Tensor(self._sample_labels[ | |||||
| self._dataloader.get_current_batch_sample_indices()]), | |||||
| **model_input)) | |||||
| return model_input, self._detach_output(self._interpreter.interpret( | |||||
| self._label_for_interpreter, **model_input | |||||
| )) |
| from typing import TYPE_CHECKING | |||||
| import traceback | |||||
| from time import time | |||||
| from ..data.data_loader import DataLoader | |||||
| from .interpretable import InterpretableModel | |||||
| from .interpreter_maker import create_interpreter | |||||
| from .interpreter import Interpreter | |||||
| from .interpretation_dataflow import InterpretationDataFlow | |||||
| from .binary_interpretation_evaluator_2d import BinaryInterpretationEvaluator2D | |||||
| if TYPE_CHECKING: | |||||
| from ..configs.base_config import BaseConfig | |||||
| class InterpretingEvalRunner: | |||||
| def __init__(self, conf: 'BaseConfig', model: InterpretableModel): | |||||
| self.conf = conf | |||||
| self.model = model | |||||
| def evaluate(self): | |||||
| interpreter = create_interpreter( | |||||
| self.conf.interpretation_method, | |||||
| self.model) | |||||
| for test_group_info in self.conf.samples_dir.split(','): | |||||
| try: | |||||
| print('>> Evaluating interpretations for %s' % test_group_info, flush=True) | |||||
| t1 = time() | |||||
| test_data_loader = \ | |||||
| DataLoader(self.conf, test_group_info, 'test') | |||||
| evaluator = BinaryInterpretationEvaluator2D( | |||||
| test_data_loader.get_number_of_samples(), | |||||
| self.conf | |||||
| ) | |||||
| self._eval_interpretations(test_data_loader, interpreter, evaluator) | |||||
| evaluator.print_summaries() | |||||
| print('Evaluating Interpretations was done in %.2f secs.' % (time() - t1,), flush=True) | |||||
| except Exception as e: | |||||
| print('Problem in %s' % test_group_info, flush=True) | |||||
| track = traceback.format_exc() | |||||
| print(track, flush=True) | |||||
| def _eval_interpretations(self, | |||||
| data_loader: DataLoader, interpreter: Interpreter, | |||||
| evaluator: BinaryInterpretationEvaluator2D) -> None: | |||||
| """ CAUTION: Resets the data_loader, | |||||
| Iterates over the samples (as much as and how data_loader specifies), | |||||
| finds the interpretations based on the specified interpretation method | |||||
| and saves the results in the received save_dir!""" | |||||
| if not isinstance(self.model, InterpretableModel): | |||||
| raise Exception('Model has not implemented the requirements of the InterpretableModel') | |||||
| # Running the model in evaluation mode | |||||
| self.model.eval() | |||||
| # initiating variables for running evaluations | |||||
| evaluator.reset() | |||||
| label_for_interpreter = self.conf.class_label_for_interpretation | |||||
| give_gt_as_label = label_for_interpreter is None and not self.conf.interpret_predictions_vs_gt | |||||
| dataflow = InterpretationDataFlow(interpreter, | |||||
| data_loader, | |||||
| self.conf.device, | |||||
| False, | |||||
| label_for_interpreter, | |||||
| give_gt_as_label) | |||||
| n_interpreted_samples = 0 | |||||
| max_n_samples = self.conf.n_interpretation_samples | |||||
| with dataflow: | |||||
| for model_input, interpretation_output in dataflow.iterate(): | |||||
| n_batch = model_input[ | |||||
| self.model.ordered_placeholder_names_to_be_interpreted[0]].shape[0] | |||||
| n_samples_to_save = n_batch | |||||
| if max_n_samples is not None: | |||||
| n_samples_to_save = min( | |||||
| max_n_samples - n_interpreted_samples, | |||||
| n_batch) | |||||
| model_outputs = self.model(**model_input) | |||||
| model_preds = model_outputs[ | |||||
| self.conf.prediction_key_in_model_output_dict].detach().cpu().numpy() | |||||
| if self.conf.interpretation_tag_to_evaluate: | |||||
| interpretations = interpretation_output[ | |||||
| self.conf.interpretation_tag_to_evaluate | |||||
| ].detach() | |||||
| else: | |||||
| interpretations = list(interpretation_output.values())[0].detach() | |||||
| interpretations = interpretations.detach().cpu().numpy() | |||||
| evaluator.update_summaries( | |||||
| interpretations[:n_samples_to_save, 0], | |||||
| data_loader.get_current_batch_samples_interpretations()[:n_samples_to_save, 0].cpu().numpy(), | |||||
| model_preds[:n_samples_to_save], | |||||
| data_loader.get_current_batch_samples_labels()[:n_samples_to_save], | |||||
| data_loader.get_current_batch_sample_indices()[:n_samples_to_save], | |||||
| interpreter | |||||
| ) | |||||
| del interpretation_output | |||||
| del model_outputs | |||||
| # if the number of interpreted samples has reached the limit, break | |||||
| n_interpreted_samples += n_batch | |||||
| if max_n_samples is not None and n_interpreted_samples >= max_n_samples: | |||||
| break |
| """ | |||||
| An abstract class to apply different interpretation algorithms on torch models. | |||||
| """ | |||||
| from abc import abstractmethod, ABC | |||||
| from typing import Dict, List, Union | |||||
| import torch | |||||
| from torch import nn | |||||
| import numpy as np | |||||
| from . import InterpretableModel | |||||
| class Interpreter(ABC): # pylint: disable=too-few-public-methods | |||||
| """ | |||||
| An abstract class to apply different interpretation algorithms on torch models. | |||||
| """ | |||||
| def __init__(self, model: InterpretableModel): | |||||
| self._model = model | |||||
| def _get_all_leaf_children(self, module: nn.Module = None) -> List[nn.Module]: | |||||
| """ | |||||
| Gets all leaf modules recursively | |||||
| """ | |||||
| if module is None: | |||||
| module = self._model | |||||
| children: List[nn.Module] = [] | |||||
| for child in module.children(): | |||||
| if len(list(child.children())) == 0: | |||||
| children.append(child) | |||||
| else: | |||||
| children += self._get_all_leaf_children(child) | |||||
| return children | |||||
| @staticmethod | |||||
| def _get_one_hot_output(output: torch.Tensor, | |||||
| target: Union[int, torch.Tensor, np.ndarray] = None) -> torch.Tensor: | |||||
| batch_size = output.shape[0] | |||||
| if target is None: | |||||
| target = output.argmax(dim=1) | |||||
| one_hot_output = torch.zeros_like(output) | |||||
| one_hot_output[torch.arange(batch_size), target] = 1 | |||||
| return one_hot_output | |||||
| def _get_target(self, target: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Union[int, torch.Tensor]: | |||||
| if target is not None: | |||||
| return target | |||||
| return self._model.get_categorical_probabilities(**inputs).argmax(dim=1) | |||||
| @abstractmethod | |||||
| def interpret(self, labels: Union[int, torch.Tensor, np.ndarray], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
| ''' | |||||
| An abstract method to interpret given input and target class | |||||
| Returns a dictionary of interpretation results. | |||||
| ''' | |||||
| def dynamic_threshold(self, x: np.ndarray) -> np.ndarray: | |||||
| """ | |||||
| A function to dynamically threshold one sample. | |||||
| """ | |||||
| return (x - (x.mean() + x.std())).clip(min=0) |
| """ | |||||
| Creates an interpreter object and returns it | |||||
| """ | |||||
| from typing import Callable | |||||
| from . import InterpretableModel, Interpreter | |||||
| class InterpretationType: | |||||
| GuidedBackprop = 'GuidedBackprop' | |||||
| GuidedGradCam = 'GuidedGradCam' | |||||
| GradCam = 'GradCam' | |||||
| RelCam = 'RelCam' | |||||
| DeepLift = 'DeepLift' | |||||
| Attention = 'Attention' | |||||
| AttentionSmooth = 'AttentionSmooth' | |||||
| AttentionSum = 'AttentionSum' | |||||
| ImagenetAttention = 'ImagenetAttention' | |||||
| GT = 'GT' | |||||
| InterpreterMaker = Callable[[InterpretableModel], Interpreter] | |||||
| def create_interpreter(interpretation_method: str, model: InterpretableModel) -> Interpreter: | |||||
| """ | |||||
| Creates an interpreter object and returns it | |||||
| :param interpretation_method: The method name | |||||
| :param model: The model to be interpreted | |||||
| :return: an interpreter object that can run the model and interpret its output | |||||
| """ | |||||
| if interpretation_method == InterpretationType.GuidedBackprop: | |||||
| from .guided_backprop import GuidedBackprop as interpreter_maker | |||||
| elif interpretation_method == InterpretationType.GuidedGradCam: | |||||
| from .guided_gradcam import GuidedGradCam as interpreter_maker | |||||
| elif interpretation_method == InterpretationType.GradCam: | |||||
| from .gradcam import GradCam as interpreter_maker | |||||
| elif interpretation_method == InterpretationType.RelCam: | |||||
| from .relcam import RelCamInterpreter as interpreter_maker | |||||
| elif interpretation_method == InterpretationType.DeepLift: | |||||
| from .deep_lift import DeepLift as interpreter_maker | |||||
| elif interpretation_method == InterpretationType.Attention: | |||||
| from .attention_interpreter import AttentionInterpreter as interpreter_maker | |||||
| elif interpretation_method == InterpretationType.AttentionSmooth: | |||||
| from .attention_interpreter_smooth_integrator import AttentionInterpreterSmoothIntegrator as interpreter_maker | |||||
| elif interpretation_method == InterpretationType.AttentionSum: | |||||
| from .attention_sum_interpreter import AttentionSumInterpreter as interpreter_maker | |||||
| elif interpretation_method == InterpretationType.ImagenetAttention: | |||||
| from .imagenet_attention_interpreter import ImagenetPredictionInterpreter as interpreter_maker | |||||
| else: | |||||
| raise Exception('Unknown interpretation method ', interpretation_method) | |||||
| return interpreter_maker(model) | |||||
| from typing import Dict, TYPE_CHECKING | |||||
| from os import makedirs, path | |||||
| import traceback | |||||
| from time import time | |||||
| import warnings | |||||
| import imageio | |||||
| import torch | |||||
| import numpy as np | |||||
| from ..data.data_loader import DataLoader | |||||
| from .interpretable import InterpretableModel | |||||
| from .interpreter_maker import create_interpreter | |||||
| from .interpreter import Interpreter | |||||
| from .interpretation_dataflow import InterpretationDataFlow | |||||
| from .utils import overlay_interpretation | |||||
| if TYPE_CHECKING: | |||||
| from ..configs.base_config import BaseConfig | |||||
| class InterpretingRunner: | |||||
| def __init__(self, conf: 'BaseConfig', model: InterpretableModel): | |||||
| self.conf = conf | |||||
| self.model = model | |||||
| def interpret(self): | |||||
| interpreter = create_interpreter( | |||||
| self.conf.interpretation_method, | |||||
| self.model) | |||||
| for test_group_info in self.conf.samples_dir.split(','): | |||||
| try: | |||||
| print('>> Finding interpretations for %s' % test_group_info, flush=True) | |||||
| t1 = time() | |||||
| labels_to_use = self.conf.mapped_labels_to_use | |||||
| if labels_to_use is None: | |||||
| labels_to_use = 'All' | |||||
| else: | |||||
| labels_to_use = ','.join([str(x) for x in self.conf.mapped_labels_to_use]) | |||||
| report_dir = self.conf.get_sample_group_specific_report_dir( | |||||
| test_group_info, | |||||
| extra_subdir=f'{self.conf.interpretation_method}-C,{labels_to_use}-cut,{self.conf.cut_threshold}-glob,{self.conf.global_threshold}') | |||||
| makedirs(report_dir, exist_ok=True) | |||||
| # writing the whole config | |||||
| f = open(report_dir + '/conf_info.txt', 'w') | |||||
| f.write(str(self.conf) + '\n') | |||||
| f.close() | |||||
| test_data_loader = \ | |||||
| DataLoader(self.conf, test_group_info, 'test') | |||||
| self._interpret(report_dir, test_data_loader, interpreter) | |||||
| print('Interpretations were saved in %s' % report_dir) | |||||
| print('Interpreting was done in %.2f secs.' % (time() - t1,), flush=True) | |||||
| except Exception as e: | |||||
| print('Problem in %s' % test_group_info, flush=True) | |||||
| track = traceback.format_exc() | |||||
| print(track, flush=True) | |||||
| def _interpret(self, save_dir: str, data_loader: DataLoader, interpreter: Interpreter) -> None: | |||||
| """ Iterates over the samples (as much as and how data_loader specifies), | |||||
| finds the interpretations based on the specified interpretation method | |||||
| and saves the results in the received save_dir!""" | |||||
| makedirs(save_dir, exist_ok=True) | |||||
| if not isinstance(self.model, InterpretableModel): | |||||
| raise Exception('Model has not implemented the requirements of the InterpretableModel') | |||||
| # Running the model in evaluation mode | |||||
| self.model.eval() | |||||
| # initiating variables for running evaluations | |||||
| label_for_interpreter = self.conf.class_label_for_interpretation | |||||
| give_gt_as_label = label_for_interpreter is None and not self.conf.interpret_predictions_vs_gt | |||||
| dataflow = InterpretationDataFlow(interpreter, | |||||
| data_loader, | |||||
| self.conf.device, | |||||
| False, | |||||
| label_for_interpreter, | |||||
| give_gt_as_label) | |||||
| n_interpreted_samples = 0 | |||||
| max_n_samples = self.conf.n_interpretation_samples | |||||
| with dataflow: | |||||
| for model_input, interpretation_output in dataflow.iterate(): | |||||
| n_batch = model_input[ | |||||
| self.model.ordered_placeholder_names_to_be_interpreted[0]].shape[0] | |||||
| n_samples_to_save = n_batch | |||||
| if max_n_samples is not None: | |||||
| n_samples_to_save = min( | |||||
| max_n_samples - n_interpreted_samples, | |||||
| n_batch) | |||||
| self._save_interpretations_of_batch( | |||||
| model_input[self.model.ordered_placeholder_names_to_be_interpreted[0]], | |||||
| interpretation_output, save_dir, data_loader, n_samples_to_save, | |||||
| interpreter) | |||||
| del interpretation_output | |||||
| # if the number of interpreted samples has reached the limit, break | |||||
| n_interpreted_samples += n_batch | |||||
| if max_n_samples is not None and n_interpreted_samples >= max_n_samples: | |||||
| break | |||||
| def _save_interpretations_of_batch(self, model_input: torch.Tensor, interpreter_output: Dict[str, torch.Tensor], | |||||
| save_dir: str, data_loader: DataLoader, | |||||
| n_samples_to_save: int, interpreter: Interpreter) -> None: | |||||
| """ | |||||
| Receives the output of the interpreter and saves the interpretations in the received directory | |||||
| in a file named as the sample name. | |||||
| The behaviour can be overwritten in children if extra stuff are required. | |||||
| :param interpreter_output: | |||||
| :param save_dir: | |||||
| :return: None | |||||
| """ | |||||
| batch_samples_names = data_loader.get_current_batch_samples_names() | |||||
| save_dirs = [self.conf.get_save_dir_for_sample( | |||||
| save_dir, batch_samples_names[bi].replace('../', '')) | |||||
| for bi in range(len(batch_samples_names))] | |||||
| for sd in save_dirs: | |||||
| makedirs(path.dirname(sd), exist_ok=True) | |||||
| interpreter_output = {name: output.cpu().numpy() for name, output in interpreter_output.items()} | |||||
| """ Make inputs grayscale """ | |||||
| model_input = model_input.mean(dim=1).cpu().numpy() | |||||
| def save(bis): | |||||
| for bi in bis: | |||||
| for name, output in interpreter_output.items(): | |||||
| output = output[bi] | |||||
| filename = f'{save_dirs[bi]}_{name}' | |||||
| if self.conf.dynamic_threshold: | |||||
| output = interpreter.dynamic_threshold(output) | |||||
| if not self.conf.skip_raw: | |||||
| np.save(filename, output) | |||||
| else: | |||||
| warnings.warn('Skipping raw interpretations') | |||||
| if not self.conf.skip_overlay: | |||||
| if output.shape[0] > 3: | |||||
| output = output[:3] | |||||
| overlayed = overlay_interpretation(model_input[bi][np.newaxis, ...], output, self.conf) | |||||
| imageio.imwrite(f'{filename}_overlay.png', overlayed) | |||||
| else: | |||||
| warnings.warn('Skipping overlayed interpretations') | |||||
| save(np.arange(min(len(model_input), n_samples_to_save))) |
| from .interpreter import RelCamInterpreter |
| from typing import Dict, Union | |||||
| import warnings | |||||
| import numpy as np | |||||
| import torch.nn.functional as F | |||||
| import torch | |||||
| from ..interpreter import Interpreter | |||||
| from ..interpretable import CamInterpretableModel | |||||
| from .relprop import RPProvider | |||||
| class RelCamInterpreter(Interpreter): | |||||
| def __init__(self, model: CamInterpretableModel): | |||||
| super().__init__(model) | |||||
| self.__targets = model.target_conv_layers | |||||
| for name, module in model.named_modules(): | |||||
| if not RPProvider.propable(module): | |||||
| warnings.warn(f"Module {name} of type {type(module)} is not propable! Hope you know what you are doing!") | |||||
| continue | |||||
| RPProvider.create(module) | |||||
| @staticmethod | |||||
| def __normalize(x: torch.Tensor) -> torch.Tensor: | |||||
| fx = x.flatten(1) | |||||
| minx = fx.min(dim=1).values[:, None, None, None] | |||||
| maxx = fx.max(dim=1).values[:, None, None, None] | |||||
| return (x - minx) / (1e-6 + maxx - minx) | |||||
| def interpret(self, labels: Union[int, torch.Tensor, np.ndarray], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
| x_shape = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape | |||||
| with RPProvider.capture(self.__targets) as Rs: | |||||
| z = self._model.get_categorical_probabilities(**inputs) | |||||
| one_hot_y = self._get_one_hot_output(z, labels) | |||||
| RPProvider.get(self._model)(one_hot_y) | |||||
| result = list(Rs.values()) | |||||
| relcam = torch.stack(result).sum(dim=0)\ | |||||
| if len(result) > 1\ | |||||
| else result[0] | |||||
| relcam = self.__normalize(relcam) | |||||
| relcam = F.interpolate(relcam, size=x_shape[2:], mode='bilinear', align_corners=False) | |||||
| return { | |||||
| 'default': relcam, | |||||
| } |
| from torch import nn | |||||
| import torch | |||||
| class Add(nn.Module): | |||||
| def forward(self, inputs): | |||||
| return torch.add(*inputs) | |||||
| class Multiply(nn.Module): | |||||
| def forward(self, inputs): | |||||
| return torch.mul(*inputs) | |||||
| class Cat(nn.Module): | |||||
| def forward(self, inputs, dim): | |||||
| self.__setattr__('dim', dim) | |||||
| return torch.cat(inputs, dim) |
| from contextlib import ExitStack, contextmanager | |||||
| from typing import Callable, Dict, Generic, Iterator, List, Tuple, Type, TypeVar | |||||
| from uuid import UUID, uuid4 | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from torch import nn | |||||
| import torchvision | |||||
| from . import modules as M | |||||
| TModule = TypeVar('TModule', bound=nn.Module) | |||||
| TType = TypeVar('TType', bound=type) | |||||
| def safe_divide(a, b): | |||||
| return a / (b + b.eq(0).type(b.type()) * 1e-9) * b.ne(0).type(b.type()) | |||||
| class classdict(Dict[type, TType], Generic[TType]): | |||||
| def __nearest_available_resolution(self, __k: Type[TModule]) -> Type[TModule]: | |||||
| return next((r for r in __k.mro() if dict.__contains__(self, r)), None) | |||||
| def __contains__(self, __k: Type[TModule]) -> bool: | |||||
| return self.__nearest_available_resolution(__k) is not None | |||||
| def __getitem__(self, __k: Type[TModule]) -> TType: | |||||
| r = self.__nearest_available_resolution(__k) | |||||
| if r is None: | |||||
| raise KeyError(f"{__k} not found!") | |||||
| return super().__getitem__(r) | |||||
| class RPProvider: | |||||
| __props: Dict[Type[TModule], Type['RelProp[TModule]']] = classdict() | |||||
| __instances: Dict[TModule, 'RelProp[TModule]'] = {} | |||||
| __target_results: Dict[nn.Module, torch.Tensor] = {} | |||||
| @classmethod | |||||
| def register(cls, *module_types: Type[TModule]): | |||||
| def decorator(prop_cls): | |||||
| cls.__props.update({ | |||||
| module_type: prop_cls | |||||
| for module_type in module_types | |||||
| }) | |||||
| return prop_cls | |||||
| return decorator | |||||
| @classmethod | |||||
| def create(cls, module: TModule): | |||||
| cls.__instances[module] = prop = cls.__props[type(module)](module) | |||||
| return prop | |||||
| @classmethod | |||||
| def __hook(cls, prop: 'RelProp[TModule]', R: torch.Tensor) -> None: | |||||
| r_weight = torch.mean(R, dim=(2, 3), keepdim=True) | |||||
| r_cam = prop.X * r_weight | |||||
| r_cam = torch.sum(r_cam, dim=1, keepdim=True) | |||||
| cls.__target_results[prop.module] = r_cam | |||||
| @classmethod | |||||
| @contextmanager | |||||
| def capture(cls, target_layers: List[nn.Module]) -> Iterator[Dict[nn.Module, torch.Tensor]]: | |||||
| with ExitStack() as stack: | |||||
| cls.__target_results.clear() | |||||
| [stack.enter_context(prop.hook_module()) | |||||
| for prop in cls.__instances.values()] | |||||
| [stack.enter_context(cls.get(target).register_hook(cls.__hook)) | |||||
| for target in target_layers] | |||||
| yield cls.__target_results | |||||
| @classmethod | |||||
| def propable(cls, module: TModule) -> bool: | |||||
| return type(module) in cls.__props | |||||
| @classmethod | |||||
| def get(cls, module: TModule) -> 'RelProp[TModule]': | |||||
| return cls.__instances[module] | |||||
| @RPProvider.register(nn.Identity, nn.ReLU, nn.LeakyReLU, nn.Dropout, nn.Sigmoid) | |||||
| class RelProp(Generic[TModule]): | |||||
| Hook = Callable[['RelProp', torch.Tensor], None] | |||||
| def __forward_hook(self, _, input: Tuple[torch.Tensor, ...], output: torch.Tensor): | |||||
| if type(input[0]) in (list, tuple): | |||||
| self.X = [] | |||||
| for i in input[0]: | |||||
| x = i.detach() | |||||
| x.requires_grad = True | |||||
| self.X.append(x) | |||||
| else: | |||||
| self.X = input[0].detach() | |||||
| self.X.requires_grad = True | |||||
| self.Y = output | |||||
| def __init__(self, module: TModule) -> None: | |||||
| self.module = module | |||||
| self.__hooks: Dict[UUID, RelProp.Hook] = {} | |||||
| @contextmanager | |||||
| def hook_module(self): | |||||
| handle = self.module.register_forward_hook(self.__forward_hook) | |||||
| try: | |||||
| yield | |||||
| finally: | |||||
| handle.remove() | |||||
| @contextmanager | |||||
| def register_hook(self, hook: 'RelProp.Hook'): | |||||
| uuid = uuid4() | |||||
| self.__hooks[uuid] = hook | |||||
| try: | |||||
| yield | |||||
| finally: | |||||
| self.__hooks.pop(uuid) | |||||
| def grad(self, Z, X, S): | |||||
| C = torch.autograd.grad(Z, X, S, retain_graph=True) | |||||
| return C | |||||
| def __call__(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| R = self.rel(R, alpha=alpha) | |||||
| [hook(self, R) for hook in self.__hooks.values()] | |||||
| return R | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| return R | |||||
| @RPProvider.register(nn.MaxPool2d, nn.AvgPool2d, M.Add, torchvision.transforms.transforms.Normalize) | |||||
| class RelPropSimple(RelProp[TModule]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| Z = self.module.forward(self.X) | |||||
| S = safe_divide(R, Z) | |||||
| C = self.grad(Z, self.X, S) | |||||
| if torch.is_tensor(self.X) == False: | |||||
| outputs = [] | |||||
| outputs.append(self.X[0] * C[0]) | |||||
| outputs.append(self.X[1] * C[1]) | |||||
| else: | |||||
| outputs = self.X * C[0] | |||||
| return outputs | |||||
| @RPProvider.register(nn.Flatten) | |||||
| class RelPropFlatten(RelProp[TModule]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| return R.reshape_as(self.X) | |||||
| @RPProvider.register(nn.AdaptiveAvgPool2d) | |||||
| class AdaptiveAvgPool2dRelProp(RelProp[nn.AdaptiveAvgPool2d]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| px = torch.clamp(self.X, min=0) | |||||
| def f(x1): | |||||
| Z1 = F.adaptive_avg_pool2d(x1, self.module.output_size) | |||||
| S1 = safe_divide(R, Z1) | |||||
| C1 = x1 * self.grad(Z1, x1, S1)[0] | |||||
| return C1 | |||||
| activator_relevances = f(px) | |||||
| out = activator_relevances | |||||
| return out | |||||
| @RPProvider.register(nn.ZeroPad2d) | |||||
| class ZeroPad2dRelProp(RelProp[nn.ZeroPad2d]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| Z = self.module.forward(self.X) | |||||
| S = safe_divide(R, Z) | |||||
| C = self.grad(Z, self.X, S) | |||||
| outputs = self.X * C[0] | |||||
| return outputs | |||||
| @RPProvider.register(M.Multiply) | |||||
| class MultiplyRelProp(RelProp[M.Multiply]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| x0 = torch.clamp(self.X[0], min=0) | |||||
| x1 = torch.clamp(self.X[1], min=0) | |||||
| x = [x0, x1] | |||||
| Z = self.module.forward(x) | |||||
| S = safe_divide(R, Z) | |||||
| C = self.grad(Z, x, S) | |||||
| outputs = [] | |||||
| outputs.append(x[0] * C[0]) | |||||
| outputs.append(x[1] * C[1]) | |||||
| return outputs | |||||
| @RPProvider.register(M.Cat) | |||||
| class CatRelProp(RelProp[M.Cat]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| Z = self.module.forward(self.X, self.module.dim) | |||||
| S = safe_divide(R, Z) | |||||
| C = self.grad(Z, self.X, S) | |||||
| outputs = [] | |||||
| for x, c in zip(self.X, C): | |||||
| outputs.append(x * c) | |||||
| return outputs | |||||
| @RPProvider.register(nn.Sequential) | |||||
| class SequentialRelProp(RelProp[nn.Sequential]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| for m in reversed(self.module): | |||||
| R = RPProvider.get(m)(R, alpha=alpha) | |||||
| return R | |||||
| @RPProvider.register(nn.BatchNorm2d) | |||||
| class BatchNorm2dRelProp(RelProp[nn.BatchNorm2d]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| X = self.X | |||||
| beta = 1 - alpha | |||||
| weight = self.module.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / ( | |||||
| (self.module.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.module.eps).pow(0.5)) | |||||
| Z = X * weight + 1e-9 | |||||
| S = R / Z | |||||
| Ca = S * weight | |||||
| R = self.X * (Ca) | |||||
| return R | |||||
| @RPProvider.register(nn.Linear) | |||||
| class LinearRelProp(RelProp[nn.Linear]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| beta = alpha - 1 | |||||
| pw = torch.clamp(self.module.weight, min=0) | |||||
| nw = torch.clamp(self.module.weight, max=0) | |||||
| px = torch.clamp(self.X, min=0) | |||||
| nx = torch.clamp(self.X, max=0) | |||||
| def f(w1, w2, x1, x2): | |||||
| Z1 = F.linear(x1, w1) | |||||
| Z2 = F.linear(x2, w2) | |||||
| Z = Z1 + Z2 | |||||
| S = safe_divide(R, Z) | |||||
| C1 = x1 * self.grad(Z1, x1, S)[0] | |||||
| C2 = x2 * self.grad(Z2, x2, S)[0] | |||||
| return C1 + C2 | |||||
| activator_relevances = f(pw, nw, px, nx) | |||||
| inhibitor_relevances = f(nw, pw, px, nx) | |||||
| out = alpha * activator_relevances - beta*inhibitor_relevances | |||||
| return out | |||||
| @RPProvider.register(nn.Conv2d) | |||||
| class Conv2dRelProp(RelProp[nn.Conv2d]): | |||||
| def gradprop2(self, DY, weight): | |||||
| Z = self.module.forward(self.X) | |||||
| output_padding = self.X.size()[2] - ( | |||||
| (Z.size()[2] - 1) * self.module.stride[0] - 2 * self.module.padding[0] + self.module.kernel_size[0]) | |||||
| return F.conv_transpose2d(DY, weight, stride=self.module.stride, padding=self.module.padding, output_padding=output_padding) | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| if self.X.shape[1] == 3: | |||||
| pw = torch.clamp(self.module.weight, min=0) | |||||
| nw = torch.clamp(self.module.weight, max=0) | |||||
| X = self.X | |||||
| L = self.X * 0 + \ | |||||
| torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, | |||||
| keepdim=True)[0] | |||||
| H = self.X * 0 + \ | |||||
| torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, | |||||
| keepdim=True)[0] | |||||
| Za = torch.conv2d(X, self.module.weight, bias=None, stride=self.module.stride, padding=self.module.padding) - \ | |||||
| torch.conv2d(L, pw, bias=None, stride=self.module.stride, padding=self.module.padding) - \ | |||||
| torch.conv2d(H, nw, bias=None, stride=self.module.stride, | |||||
| padding=self.module.padding) + 1e-9 | |||||
| S = R / Za | |||||
| C = X * self.gradprop2(S, self.module.weight) - L * \ | |||||
| self.gradprop2(S, pw) - H * self.gradprop2(S, nw) | |||||
| R = C | |||||
| else: | |||||
| beta = alpha - 1 | |||||
| pw = torch.clamp(self.module.weight, min=0) | |||||
| nw = torch.clamp(self.module.weight, max=0) | |||||
| px = torch.clamp(self.X, min=0) | |||||
| nx = torch.clamp(self.X, max=0) | |||||
| def f(w1, w2, x1, x2): | |||||
| Z1 = F.conv2d(x1, w1, bias=self.module.bias, stride=self.module.stride, | |||||
| padding=self.module.padding, groups=self.module.groups) | |||||
| Z2 = F.conv2d(x2, w2, bias=self.module.bias, stride=self.module.stride, | |||||
| padding=self.module.padding, groups=self.module.groups) | |||||
| Z = Z1 + Z2 | |||||
| S = safe_divide(R, Z) | |||||
| C1 = x1 * self.grad(Z1, x1, S)[0] | |||||
| C2 = x2 * self.grad(Z2, x2, S)[0] | |||||
| return C1 + C2 | |||||
| activator_relevances = f(pw, nw, px, nx) | |||||
| inhibitor_relevances = f(nw, pw, px, nx) | |||||
| R = alpha * activator_relevances - beta * inhibitor_relevances | |||||
| return R | |||||
| @RPProvider.register(nn.ConvTranspose2d) | |||||
| class ConvTranspose2dRelProp(RelProp[nn.ConvTranspose2d]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| pw = torch.clamp(self.module.weight, min=0) | |||||
| px = torch.clamp(self.X, min=0) | |||||
| def f(w1, x1): | |||||
| Z1 = F.conv_transpose2d(x1, w1, bias=None, stride=self.module.stride, | |||||
| padding=self.module.padding, output_padding=self.module.output_padding) | |||||
| S1 = safe_divide(R, Z1) | |||||
| C1 = x1 * self.grad(Z1, x1, S1)[0] | |||||
| return C1 | |||||
| activator_relevances = f(pw, px) | |||||
| R = activator_relevances | |||||
| return R |
| from typing import TYPE_CHECKING, Optional | |||||
| import numpy as np | |||||
| import cv2 | |||||
| from skimage.morphology import dilation, erosion | |||||
| if TYPE_CHECKING: | |||||
| from ..configs.base_config import BaseConfig | |||||
| from . import Interpreter | |||||
| def overlay_interpretation(image: np.ndarray, int_map: np.ndarray, config: 'BaseConfig') -> np.ndarray: | |||||
| """ | |||||
| image (np.ndarray): The image to overlay interpretations over. Of shape C W H. The numbers are assumed to be in range [0, 1] | |||||
| int_map (np.ndarray): The map containing interpretation scores. | |||||
| config ('Config'): An object containing required information about the configurations. | |||||
| Returns: | |||||
| (np.ndarray): The overlayed image | |||||
| """ | |||||
| int_map = process_interpretations(int_map, config) # C W H | |||||
| int_map = np.moveaxis(int_map, 0, -1) # W H C | |||||
| # if there are more than 3 channels, use the first 3! | |||||
| if int_map.shape[-1] > 3: | |||||
| int_map = int_map[..., :3] | |||||
| # norm by max to make sure! | |||||
| int_map = int_map * 1.0 / (1e-6 + np.amax(int_map)) | |||||
| alpha = 0.5 | |||||
| if np.amin(image) < 0: | |||||
| image += 0.5 | |||||
| image = np.moveaxis(image, 0, -1) # W H C | |||||
| image = image * 255 | |||||
| o_color_by_c = { | |||||
| 1: [255, 0, 0], | |||||
| 2: [255, 255, 0], | |||||
| 3: [255, 255, 255], | |||||
| } | |||||
| o_color = np.asarray(o_color_by_c.get(int_map.shape[-1], None)) | |||||
| if o_color is None: | |||||
| raise Exception("Trying to overlay more than 3 maps on the image.") | |||||
| # if int map has two channels, append one zeros to the end! | |||||
| if int_map.shape[-1] == 2: | |||||
| int_map = np.concatenate((int_map, np.zeros((int_map.shape[-3], int_map.shape[-2], 1))), axis=-1) | |||||
| overlayed = (1 - int_map) * image + \ | |||||
| int_map * (1 - alpha) * image + \ | |||||
| int_map * alpha * o_color[np.newaxis, np.newaxis, :] | |||||
| overlayed = np.round(overlayed).astype(np.uint8) | |||||
| return overlayed | |||||
| def process_interpretations(int_map: np.ndarray, config: 'BaseConfig', interpreter: Optional['Interpreter'] = None) -> np.ndarray: | |||||
| """ | |||||
| int_map (np.ndarray): The map of interpretation scores. Is assumed to have a shape of C ... | |||||
| config ('Config'): An object containing the information about the required configurations | |||||
| Returns: | |||||
| np.ndarray: Processed interpretation maps, with numbers in the range of [0, 1] | |||||
| """ | |||||
| if interpreter and config.dynamic_threshold: | |||||
| int_map = interpreter.dynamic_threshold(int_map) | |||||
| if not config.global_threshold: | |||||
| # Treating map as negative=reverse effect; discarding negatives | |||||
| int_map[int_map < 0] = 0 | |||||
| qval = np.amax(int_map) | |||||
| if qval > 0: | |||||
| int_map[int_map > qval] = qval | |||||
| int_map = int_map / qval | |||||
| # applying cut threshold! | |||||
| int_map[int_map <= config.cut_threshold] = config.cut_threshold | |||||
| int_map -= config.cut_threshold | |||||
| kw = int(np.round(0.1 * int_map.shape[-1])) | |||||
| mv = np.amax(int_map) | |||||
| int_map /= (mv + 1e-6) | |||||
| # dilation and erosion to make continuous objects and erase noises | |||||
| for c in range(int_map.shape[0]): | |||||
| int_map[c] = dilation(int_map[c], cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kw + 3, kw + 3))) | |||||
| int_map[c] = erosion(int_map[c], cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kw, kw))) | |||||
| int_map[c] = erosion(int_map[c], cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))) | |||||
| int_map[c] = cv2.blur(np.round(255 * int_map[c]).astype(np.uint8), (5, 5)) * 1.0 / 255.0 | |||||
| return int_map # [0, 1] | |||||
| from typing import List, TYPE_CHECKING, Dict | |||||
| import numpy as np | |||||
| import torch | |||||
| from ..data.data_loader import DataLoader | |||||
| from ..models.model import Model | |||||
| from ..model_evaluation.evaluator import Evaluator | |||||
| if TYPE_CHECKING: | |||||
| from ..configs.base_config import BaseConfig | |||||
| class BinaryEvaluator(Evaluator): | |||||
| def __init__(self, model: Model, data_loader: DataLoader, conf: 'BaseConfig'): | |||||
| super(BinaryEvaluator, self).__init__(model, data_loader, conf) | |||||
| # for summarized information | |||||
| self.tp: int = 0 | |||||
| self.tn: int = 0 | |||||
| self.fp: int = 0 | |||||
| self.fn: int = 0 | |||||
| self.avg_loss: float = 0 | |||||
| self.n_received_samples: int = 0 | |||||
| def reset(self): | |||||
| self.tp = 0 | |||||
| self.tn = 0 | |||||
| self.fp = 0 | |||||
| self.fn = 0 | |||||
| self.avg_loss = 0 | |||||
| self.n_received_samples = 0 | |||||
| @property | |||||
| def _prob_key(self) -> str: | |||||
| return 'positive_class_probability' | |||||
| @property | |||||
| def _loss_key(self) -> str: | |||||
| return 'loss' | |||||
| def _get_current_batch_gt(self) -> np.ndarray: | |||||
| return self.data_loader.get_current_batch_samples_labels() | |||||
| def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None: | |||||
| assert self._prob_key in model_output, \ | |||||
| f"model's output must contain {self._prob_key}" | |||||
| gt = self._get_current_batch_gt() | |||||
| prediction = (model_output[self._prob_key].cpu().numpy() >= 0.5).astype(int) | |||||
| ntp = int(np.sum(np.logical_and(gt == 1, prediction == 1))) | |||||
| ntn = int(np.sum(np.logical_and(gt == 0, prediction == 0))) | |||||
| nfp = int(np.sum(np.logical_and(gt == 0, prediction == 1))) | |||||
| nfn = int(np.sum(np.logical_and(gt == 1, prediction == 0))) | |||||
| self.tp += ntp | |||||
| self.tn += ntn | |||||
| self.fp += nfp | |||||
| self.fn += nfn | |||||
| new_n = self.n_received_samples + len(gt) | |||||
| self.avg_loss = self.avg_loss * (float(self.n_received_samples) / new_n) + \ | |||||
| model_output.get(self._loss_key, 0.0) * (float(len(gt)) / new_n) | |||||
| self.n_received_samples = new_n | |||||
| def get_titles_of_evaluation_metrics(self) -> List[str]: | |||||
| return ['Loss', 'Acc', 'Sens', 'Spec', 'BAcc', 'N'] | |||||
| def get_values_of_evaluation_metrics(self) -> List[str]: | |||||
| p = self.tp + self.fn | |||||
| n = self.tn + self.fp | |||||
| if p + n > 0: | |||||
| accuracy = (self.tp + self.tn) * 100.0 / (n + p) | |||||
| else: | |||||
| accuracy = -1 | |||||
| if p > 0: | |||||
| sensitivity = 100.0 * self.tp / p | |||||
| else: | |||||
| sensitivity = -1 | |||||
| if n > 0: | |||||
| specificity = 100.0 * self.tn / max(n, 1) | |||||
| else: | |||||
| specificity = -1 | |||||
| if sensitivity > -1 and specificity > -1: | |||||
| avg_ss = 0.5 * (sensitivity + specificity) | |||||
| elif sensitivity > -1: | |||||
| avg_ss = sensitivity | |||||
| else: | |||||
| avg_ss = specificity | |||||
| return ['%.4f' % self.avg_loss, '%.2f' % accuracy, '%.2f' % sensitivity, | |||||
| '%.2f' % specificity, '%.2f' % avg_ss, str(self.n_received_samples)] | |||||
| from functools import partial | |||||
| from typing import List, TYPE_CHECKING, Dict, Callable | |||||
| import numpy as np | |||||
| from torch import Tensor | |||||
| from ..data.data_loader import DataLoader | |||||
| from ..models.model import Model | |||||
| from ..model_evaluation.evaluator import Evaluator | |||||
| if TYPE_CHECKING: | |||||
| from ..configs.base_config import BaseConfig | |||||
| class BinFaithfulnessEvaluator(Evaluator): | |||||
| def __init__(self, kw_prefix: str, gt_kw, model: Model, data_loader: DataLoader, conf: 'BaseConfig'): | |||||
| super(BinFaithfulnessEvaluator, self).__init__(model, data_loader, conf) | |||||
| self._tp_by_kw: Dict[str, int] = dict() | |||||
| self._fp_by_kw: Dict[str, int] = dict() | |||||
| self._tn_by_kw: Dict[str, int] = dict() | |||||
| self._fn_by_kw: Dict[str, int] = dict() | |||||
| self._kw_prefix = kw_prefix | |||||
| self._gt_kw = gt_kw | |||||
| def reset(self): | |||||
| self._tp_by_kw: Dict[str, int] = dict() | |||||
| self._fp_by_kw: Dict[str, int] = dict() | |||||
| self._tn_by_kw: Dict[str, int] = dict() | |||||
| self._fn_by_kw: Dict[str, int] = dict() | |||||
| def update_summaries_based_on_model_output(self, model_output: Dict[str, Tensor]) -> None: | |||||
| gt = (model_output[self._gt_kw].detach().cpu().numpy() >= 0.5).astype(int) | |||||
| # looking for prefixes | |||||
| for k in model_output.keys(): | |||||
| if k.startswith(self._kw_prefix): | |||||
| if k not in self._tp_by_kw: | |||||
| self._tp_by_kw[k] = 0 | |||||
| self._tn_by_kw[k] = 0 | |||||
| self._fp_by_kw[k] = 0 | |||||
| self._fn_by_kw[k] = 0 | |||||
| pred = (model_output[k].cpu().numpy() >= 0.5).astype(int) | |||||
| self._tp_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 1)) | |||||
| self._fp_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 1)) | |||||
| self._tn_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 0)) | |||||
| self._fn_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 0)) | |||||
| def get_titles_of_evaluation_metrics(self) -> List[str]: | |||||
| return [f'Loyalty_{k}_{metric}' for k in self._tp_by_kw.keys() for metric in ['Acc', 'Sens', 'Spec', 'AvgSS']] | |||||
| def _get_values_of_evaluation_metrics(self, kw) -> List[str]: | |||||
| tp = self._tp_by_kw[kw] | |||||
| tn = self._tn_by_kw[kw] | |||||
| fp = self._fp_by_kw[kw] | |||||
| fn = self._fn_by_kw[kw] | |||||
| p = tp + fn | |||||
| n = tn + fp | |||||
| if p + n > 0: | |||||
| accuracy = (tp + tn) * 100.0 / (n + p) | |||||
| else: | |||||
| accuracy = -1 | |||||
| if p > 0: | |||||
| sensitivity = 100.0 * tp / p | |||||
| else: | |||||
| sensitivity = -1 | |||||
| if n > 0: | |||||
| specificity = 100.0 * tn / max(n, 1) | |||||
| else: | |||||
| specificity = -1 | |||||
| if sensitivity > -1 and specificity > -1: | |||||
| avg_ss = 0.5 * (sensitivity + specificity) | |||||
| elif sensitivity > -1: | |||||
| avg_ss = sensitivity | |||||
| else: | |||||
| avg_ss = specificity | |||||
| return ['%.2f' % accuracy, '%.2f' % sensitivity, | |||||
| '%.2f' % specificity, '%.2f' % avg_ss] | |||||
| def get_values_of_evaluation_metrics(self) -> List[str]: | |||||
| return [ | |||||
| v | |||||
| for k in self._tp_by_kw.keys() | |||||
| for v in self._get_values_of_evaluation_metrics(k)] | |||||
| @classmethod | |||||
| def standard_creator(cls, prefix_kw: str, pred_kw: str = 'positive_class_probability') -> Callable[[Model, DataLoader, 'BaseConfig'], 'BinFaithfulnessEvaluator']: | |||||
| return partial(BinFaithfulnessEvaluator, prefix_kw, pred_kw) | |||||
| def print_evaluation_metrics(self, title: str) -> None: | |||||
| """ For more readable printing! """ | |||||
| print(f'{title}:') | |||||
| for k in self._tp_by_kw.keys(): | |||||
| print( | |||||
| f'\t{k}:: ' + | |||||
| ', '.join([f'{m_name}: {m_val}' | |||||
| for m_name, m_val in zip(['Acc', 'Sens', 'Spec', 'AvgSS'], self._get_values_of_evaluation_metrics(k))])) |
| from typing import List, TYPE_CHECKING, Dict, Callable | |||||
| from functools import partial | |||||
| import numpy as np | |||||
| from torch import Tensor | |||||
| from ..data.data_loader import DataLoader | |||||
| from ..models.model import Model | |||||
| from ..model_evaluation.evaluator import Evaluator | |||||
| if TYPE_CHECKING: | |||||
| from ..configs.base_config import BaseConfig | |||||
| class BinForetellerEvaluator(Evaluator): | |||||
| def __init__(self, kw_prefix: str, model: Model, data_loader: DataLoader, conf: 'BaseConfig'): | |||||
| super(BinForetellerEvaluator, self).__init__(model, data_loader, conf) | |||||
| self._tp_by_kw: Dict[str, int] = dict() | |||||
| self._fp_by_kw: Dict[str, int] = dict() | |||||
| self._tn_by_kw: Dict[str, int] = dict() | |||||
| self._fn_by_kw: Dict[str, int] = dict() | |||||
| self._kw_prefix = kw_prefix | |||||
| def reset(self): | |||||
| self._tp_by_kw: Dict[str, int] = dict() | |||||
| self._fp_by_kw: Dict[str, int] = dict() | |||||
| self._tn_by_kw: Dict[str, int] = dict() | |||||
| self._fn_by_kw: Dict[str, int] = dict() | |||||
| def update_summaries_based_on_model_output(self, model_output: Dict[str, Tensor]) -> None: | |||||
| gt = self.data_loader.get_current_batch_samples_labels() | |||||
| # looking for prefixes | |||||
| for k in model_output.keys(): | |||||
| if k.startswith(self._kw_prefix): | |||||
| if k not in self._tp_by_kw: | |||||
| self._tp_by_kw[k] = 0 | |||||
| self._tn_by_kw[k] = 0 | |||||
| self._fp_by_kw[k] = 0 | |||||
| self._fn_by_kw[k] = 0 | |||||
| pred = (model_output[k].cpu().numpy() >= 0.5).astype(int) | |||||
| self._tp_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 1)) | |||||
| self._fp_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 1)) | |||||
| self._tn_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 0)) | |||||
| self._fn_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 0)) | |||||
| def get_titles_of_evaluation_metrics(self) -> List[str]: | |||||
| return [f'{k}_{metric}' for k in self._tp_by_kw.keys() for metric in ['Acc', 'Sens', 'Spec', 'AvgSS']] | |||||
| def _get_values_of_evaluation_metrics(self, kw) -> List[str]: | |||||
| tp = self._tp_by_kw[kw] | |||||
| tn = self._tn_by_kw[kw] | |||||
| fp = self._fp_by_kw[kw] | |||||
| fn = self._fn_by_kw[kw] | |||||
| p = tp + fn | |||||
| n = tn + fp | |||||
| if p + n > 0: | |||||
| accuracy = (tp + tn) * 100.0 / (n + p) | |||||
| else: | |||||
| accuracy = -1 | |||||
| if p > 0: | |||||
| sensitivity = 100.0 * tp / p | |||||
| else: | |||||
| sensitivity = -1 | |||||
| if n > 0: | |||||
| specificity = 100.0 * tn / max(n, 1) | |||||
| else: | |||||
| specificity = -1 | |||||
| if sensitivity > -1 and specificity > -1: | |||||
| avg_ss = 0.5 * (sensitivity + specificity) | |||||
| elif sensitivity > -1: | |||||
| avg_ss = sensitivity | |||||
| else: | |||||
| avg_ss = specificity | |||||
| return ['%.2f' % accuracy, '%.2f' % sensitivity, | |||||
| '%.2f' % specificity, '%.2f' % avg_ss] | |||||
| def get_values_of_evaluation_metrics(self) -> List[str]: | |||||
| return [ | |||||
| v | |||||
| for k in self._tp_by_kw.keys() | |||||
| for v in self._get_values_of_evaluation_metrics(k)] | |||||
| @classmethod | |||||
| def standard_creator(cls, prefix_kw: str) -> Callable[[Model, DataLoader, 'BaseConfig'], 'BinForetellerEvaluator']: | |||||
| return partial(BinForetellerEvaluator, prefix_kw) | |||||
| def print_evaluation_metrics(self, title: str) -> None: | |||||
| """ For more readable printing! """ | |||||
| print(f'{title}:') | |||||
| for k in self._tp_by_kw.keys(): | |||||
| print( | |||||
| f'\t{k}:: ' + | |||||
| ', '.join([f'{m_name}: {m_val}' | |||||
| for m_name, m_val in zip(['Acc', 'Sens', 'Spec', 'AvgSS'], self._get_values_of_evaluation_metrics(k))])) |
| from typing import TYPE_CHECKING, Type | |||||
| import numpy as np | |||||
| from ..data.data_loader import DataLoader | |||||
| from .binary_evaluator import BinaryEvaluator | |||||
| from ..models.model import Model | |||||
| if TYPE_CHECKING: | |||||
| from ..configs.base_config import BaseConfig | |||||
| from ..data.content_loaders.content_loader import ContentLoader | |||||
| class TagBinaryEvaluator(BinaryEvaluator): | |||||
| def __init__(self, model: Model, data_loader: DataLoader, conf: 'BaseConfig', tag: str, cl_type: Type[ContentLoader]): | |||||
| super().__init__(model, data_loader, conf) | |||||
| self._tag = tag | |||||
| self._content_loader = data_loader.get_content_loader_of_interest(cl_type) | |||||
| @property | |||||
| def _prob_key(self) -> str: | |||||
| return f'{self._tag}_positive_class_probability' | |||||
| @property | |||||
| def _gt_key(self) -> str: | |||||
| return self._tag | |||||
| @property | |||||
| def _loss_key(self) -> str: | |||||
| return f'{self._tag}_loss' | |||||
| def _get_current_batch_gt(self) -> np.ndarray: | |||||
| return self._content_loader.get_placeholder_name_to_fill_function_dict()[self._gt_key]( | |||||
| self.data_loader.get_current_batch_sample_indices(), None | |||||
| ) |
| """ The evaluator """ | |||||
| from abc import abstractmethod, ABC | |||||
| from typing import List, TYPE_CHECKING, Dict | |||||
| import torch | |||||
| from tqdm import tqdm | |||||
| from ..data.dataflow import DataFlow | |||||
| if TYPE_CHECKING: | |||||
| from ..models.model import Model | |||||
| from ..data.data_loader import DataLoader | |||||
| from ..configs.base_config import BaseConfig | |||||
| class Evaluator(ABC): | |||||
| """ The evaluator """ | |||||
| def __init__(self, model: 'Model', data_loader: 'DataLoader', conf: 'BaseConfig'): | |||||
| """ Model is a predictor, data_loader is a subclass of type data_loading.NormalDataLoader which contains information about the samples and how to | |||||
| iterate over them. """ | |||||
| self.model = model | |||||
| self.data_loader = data_loader | |||||
| self.conf = conf | |||||
| self.dataflow = DataFlow[Dict[str, torch.Tensor]](model, data_loader, conf.device) | |||||
| def evaluate(self, max_iters: int = None): | |||||
| """ CAUTION: Resets the data_loader, | |||||
| Iterates over the samples (as much as and how data_loader specifies), | |||||
| calculates the overall evaluation requirements and prints them. | |||||
| Title specified the title of the string for printing evaluation metrics. | |||||
| classes_to_use specifies the labels of the samples to do the function on them, | |||||
| None means all""" | |||||
| # setting in dataflow | |||||
| max_iters = float('inf') if max_iters is None else max_iters | |||||
| with torch.no_grad(): | |||||
| # Running the model in evaluation mode | |||||
| self.model.eval() | |||||
| # initiating variables for running evaluations | |||||
| self.reset() | |||||
| with self.dataflow, tqdm(enumerate(self.dataflow.iterate())) as pbar: | |||||
| for iters, model_output in pbar: | |||||
| self.update_summaries_based_on_model_output(model_output) | |||||
| del model_output | |||||
| pbar.set_description(self.get_evaluation_metrics()) | |||||
| if iters + 1 >= max_iters: | |||||
| break | |||||
| @abstractmethod | |||||
| def update_summaries_based_on_model_output( | |||||
| self, model_output: Dict[str, torch.Tensor]) -> None: | |||||
| """ Updates the inner variables responsible of keeping a summary over data | |||||
| based on the new outputs of the model, so when needed evaluation metrics | |||||
| can be calculated based on these summaries. Mostly used in train and validation phase | |||||
| or eval phase in which we only need evaluation metrics.""" | |||||
| @abstractmethod | |||||
| def reset(self): | |||||
| """ Resets the held information for a new evaluation round!""" | |||||
| @abstractmethod | |||||
| def get_titles_of_evaluation_metrics(self) -> List[str]: | |||||
| """ Returns a list of strings containing the titles of the evaluation metrics""" | |||||
| @abstractmethod | |||||
| def get_values_of_evaluation_metrics(self) -> List[str]: | |||||
| """ Returns a list of values for the calculated evaluation metrics, | |||||
| converted to string with the desired format!""" | |||||
| def get_evaluation_metrics(self) -> None: | |||||
| return ', '.join(["%s: %s" % (eval_title, eval_value) for (eval_title, eval_value) in | |||||
| zip( | |||||
| self.get_titles_of_evaluation_metrics(), | |||||
| self.get_values_of_evaluation_metrics())]) | |||||
| def print_evaluation_metrics(self, title: str) -> None: | |||||
| """ Prints the values of the evaluation metrics""" | |||||
| print("%s: %s" % (title, self.get_evaluation_metrics()), flush=True) |
| from typing import List, TYPE_CHECKING, OrderedDict as OrdDict, Dict | |||||
| from collections import OrderedDict | |||||
| import torch | |||||
| from ..data.data_loader import DataLoader | |||||
| from ..models.model import Model | |||||
| from ..model_evaluation.evaluator import Evaluator | |||||
| if TYPE_CHECKING: | |||||
| from ..configs.base_config import BaseConfig | |||||
| class LossEvaluator(Evaluator): | |||||
| def __init__(self, model: Model, data_loader: DataLoader, conf: 'BaseConfig'): | |||||
| super(LossEvaluator, self).__init__(model, data_loader, conf) | |||||
| self.phase = conf.phase_type | |||||
| # for summarized information | |||||
| self.avg_loss: float = 0 | |||||
| self._avg_other_losses: OrdDict[str, float] = OrderedDict() | |||||
| self.n_received_samples: int = 0 | |||||
| self._n_received_samples_other_losses: OrdDict[str, int] = OrderedDict() | |||||
| def reset(self): | |||||
| self.avg_loss = 0 | |||||
| self._avg_other_losses = OrderedDict() | |||||
| self.n_received_samples = 0 | |||||
| self._n_received_samples_other_losses = OrderedDict() | |||||
| def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None: | |||||
| n_batch = self.data_loader.get_current_batch_size() | |||||
| new_n = self.n_received_samples + n_batch | |||||
| self.avg_loss = self.avg_loss * (float(self.n_received_samples) / new_n) + \ | |||||
| model_output.get('loss', 0.0) * (float(n_batch) / new_n) | |||||
| for kw in model_output.keys(): | |||||
| if 'loss' in kw and kw != 'loss': | |||||
| old_avg = self._avg_other_losses.get(kw, 0.0) | |||||
| old_n = self._n_received_samples_other_losses.get(kw, 0) | |||||
| new_n = old_n + n_batch | |||||
| self._avg_other_losses[kw] = \ | |||||
| old_avg * (float(old_n) / new_n) + \ | |||||
| model_output[kw].detach().cpu() * (float(n_batch) / new_n) | |||||
| self._n_received_samples_other_losses[kw] = new_n | |||||
| def get_titles_of_evaluation_metrics(self) -> List[str]: | |||||
| return ['Loss'] + list(self._avg_other_losses.keys()) | |||||
| def get_values_of_evaluation_metrics(self) -> List[str]: | |||||
| return ['%.4f' % self.avg_loss] + ['%.4f' % loss for loss in self._avg_other_losses.values()] |
| from typing import Any, Callable, List, Dict, TYPE_CHECKING, Optional | |||||
| from functools import partial | |||||
| import numpy as np | |||||
| import torch | |||||
| from sklearn.metrics import confusion_matrix | |||||
| from ..data.data_loader import DataLoader | |||||
| from ..models.model import Model | |||||
| from ..model_evaluation.evaluator import Evaluator | |||||
| if TYPE_CHECKING: | |||||
| from ..configs.base_config import BaseConfig | |||||
| def _precision_score(cm: np.ndarray): | |||||
| numinator = np.diag(cm) | |||||
| denominator = cm.sum(axis=0) | |||||
| return (numinator[denominator != 0] / denominator[denominator != 0]).mean() | |||||
| def _recall_score(cm: np.ndarray): | |||||
| numinator = np.diag(cm) | |||||
| denominator = cm.sum(axis=1) | |||||
| return (numinator[denominator != 0] / denominator[denominator != 0]).mean() | |||||
| def _accuracy_score(cm: np.ndarray): | |||||
| return np.diag(cm).sum() / cm.sum() | |||||
| class MulticlassEvaluator(Evaluator): | |||||
| @classmethod | |||||
| def standard_creator(cls, class_probability_key: str = 'categorical_probability', | |||||
| include_top5: bool = False) -> Callable[[Model, DataLoader, 'BaseConfig'], 'MulticlassEvaluator']: | |||||
| return partial(MulticlassEvaluator, class_probability_key=class_probability_key, include_top5=include_top5) | |||||
| def __init__(self, model: Model, | |||||
| data_loader: DataLoader, | |||||
| conf: 'BaseConfig', | |||||
| class_probability_key: str = 'categorical_probability', | |||||
| include_top5: bool = False): | |||||
| super().__init__(model, data_loader, conf) | |||||
| # for summarized information | |||||
| self._avg_loss: float = 0 | |||||
| self._n_received_samples: int = 0 | |||||
| self._trues = {} | |||||
| self._falses = {} | |||||
| self._top5_trues: int = 0 | |||||
| self._top5_falses: int = 0 | |||||
| self._num_iters = 0 | |||||
| self._predictions: Dict[int, np.ndarray] = {} | |||||
| self._prediction_weights: Dict[int, np.ndarray] = {} | |||||
| self._sample_details: Dict[str, Dict[str, Any]] = {} | |||||
| self._cm: Optional[np.ndarray] = None | |||||
| self._c = None | |||||
| self._class_probability_key = class_probability_key | |||||
| self._include_top5 = include_top5 | |||||
| def reset(self): | |||||
| # for summarized information | |||||
| self._avg_loss: float = 0 | |||||
| self._n_received_samples: int = 0 | |||||
| self._trues = {} | |||||
| self._falses = {} | |||||
| self._top5_trues: int = 0 | |||||
| self._top5_falses: int = 0 | |||||
| self._num_iters = 0 | |||||
| self._cm: Optional[np.ndarray] = None | |||||
| def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None: | |||||
| assert self._class_probability_key in model_output, \ | |||||
| f"model's output dictionary must contain {self._class_probability_key}" | |||||
| prediction: np.ndarray = model_output[self._class_probability_key]\ | |||||
| .cpu().numpy() # B C | |||||
| c = prediction.shape[1] | |||||
| ground_truth = self.data_loader.get_current_batch_samples_labels() # B | |||||
| if self._include_top5: | |||||
| top5_p = prediction.argsort(axis=1)[:, -5:] | |||||
| trues = (top5_p == ground_truth[:, None]).any(axis=1).sum() | |||||
| self._top5_trues += trues | |||||
| self._top5_falses += len(ground_truth) - trues | |||||
| prediction = prediction.argmax(axis=1) # B | |||||
| for i in range(c): | |||||
| nt = int(np.sum(np.logical_and(ground_truth == i, prediction == i))) | |||||
| nf = int(np.sum(np.logical_and(ground_truth == i, prediction != i))) | |||||
| if i not in self._trues: | |||||
| self._trues[i] = 0 | |||||
| self._falses[i] = 0 | |||||
| self._trues[i] += nt | |||||
| self._falses[i] += nf | |||||
| self._cm = (0 if self._cm is None else self._cm)\ | |||||
| + confusion_matrix(ground_truth, prediction, | |||||
| labels=np.arange(c)).astype(float) | |||||
| self._num_iters += 1 | |||||
| new_n = self._n_received_samples + len(ground_truth) | |||||
| self._avg_loss = self._avg_loss * (float(self._n_received_samples) / new_n) + \ | |||||
| model_output.get('loss', 0.0) * (float(len(ground_truth)) / new_n) | |||||
| self._n_received_samples = new_n | |||||
| def get_titles_of_evaluation_metrics(self) -> List[str]: | |||||
| return ['Loss', 'Accuracy', 'Precision', 'Recall'] + \ | |||||
| (['Top5Acc'] if self._include_top5 else []) | |||||
| def get_values_of_evaluation_metrics(self) -> List[str]: | |||||
| accuracy = _accuracy_score(self._cm) * 100 | |||||
| precision = _precision_score(self._cm) * 100 | |||||
| recall = _recall_score(self._cm) * 100 | |||||
| top5_acc = self._top5_trues * 100.0 / (self._top5_trues + self._top5_falses)\ | |||||
| if self._include_top5 else None | |||||
| return [f'{self._avg_loss:.4e}', f'{accuracy:8.4f}', f'{precision:8.4f}', f'{recall:8.4f}'] + \ | |||||
| ([f'{top5_acc:8.4f}'] if self._include_top5 else []) |
| from typing import List, TYPE_CHECKING, Dict, OrderedDict, Type | |||||
| from collections import OrderedDict as ODict | |||||
| import torch | |||||
| from ..data.data_loader import DataLoader | |||||
| from ..models.model import Model | |||||
| from ..model_evaluation.evaluator import Evaluator | |||||
| if TYPE_CHECKING: | |||||
| from ..configs.base_config import BaseConfig | |||||
| # If you want your model to be analyzed from different points of view by different evaluators, you can use this holder! | |||||
| class MultiEvaluatorEvaluator(Evaluator): | |||||
| def __init__( | |||||
| self, model: Model, data_loader: DataLoader, conf: 'BaseConfig', | |||||
| evaluators_cls_by_name: OrderedDict[str, Type[Evaluator]]): | |||||
| """ | |||||
| evaluators_cls_by_name: The key is an arbitrary name to call the instance, | |||||
| cls is the constructor | |||||
| """ | |||||
| super(MultiEvaluatorEvaluator, self).__init__(model, data_loader, conf) | |||||
| # Making all evaluators! | |||||
| self._evaluators_by_name: OrderedDict[str, Evaluator] = ODict() | |||||
| for eval_name, eval_cls in evaluators_cls_by_name.items(): | |||||
| self._evaluators_by_name[eval_name] = eval_cls(model, data_loader, conf) | |||||
| def reset(self): | |||||
| for evaluator in self._evaluators_by_name.values(): | |||||
| evaluator.reset() | |||||
| def get_titles_of_evaluation_metrics(self) -> List[str]: | |||||
| titles = [] | |||||
| for eval_name, evaluator in self._evaluators_by_name.items(): | |||||
| titles += [eval_name + '_' + t for t in evaluator.get_titles_of_evaluation_metrics()] | |||||
| return titles | |||||
| def get_values_of_evaluation_metrics(self) -> List[str]: | |||||
| metrics = [] | |||||
| for evaluator in self._evaluators_by_name.values(): | |||||
| metrics += evaluator.get_values_of_evaluation_metrics() | |||||
| return metrics | |||||
| def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None: | |||||
| for evaluator in self._evaluators_by_name.values(): | |||||
| evaluator.update_summaries_based_on_model_output(model_output) | |||||
| @staticmethod | |||||
| def create_standard_multi_evaluator_evaluator_maker(evaluators_cls_by_name: OrderedDict[str, Type[Evaluator]]): | |||||
| """ For making a constructor, consistent with the known Evaluator""" | |||||
| def typical_maker(model: Model, data_loader: DataLoader, conf: 'BaseConfig') -> MultiEvaluatorEvaluator: | |||||
| return MultiEvaluatorEvaluator(model, data_loader, conf, evaluators_cls_by_name) | |||||
| return typical_maker | |||||
| def print_evaluation_metrics(self, title: str) -> None: | |||||
| """ For more readable printing! """ | |||||
| print(f'{title}:') | |||||
| for e_name, e_obj in self._evaluators_by_name.items(): | |||||
| e_obj.print_evaluation_metrics('\t' + e_name) |
| import typing | |||||
| from typing import Dict, Iterable | |||||
| from collections import OrderedDict | |||||
| import torch | |||||
| from torch.nn import functional as F | |||||
| import torchvision | |||||
| from ..lap_inception import LAPInception | |||||
| class CelebALAPInception(LAPInception): | |||||
| def __init__(self, tag:str, aux_weight: float, pool_factory, adaptive_pool_factory): | |||||
| super().__init__(aux_weight, n_classes=1, pool_factory=pool_factory, adaptive_pool_factory=adaptive_pool_factory) | |||||
| self._tag = tag | |||||
| @property | |||||
| def additional_kwargs(self) -> typing.OrderedDict[str, bool]: | |||||
| r""" Returns a dictionary from additional `kwargs` names to their optionality """ | |||||
| return OrderedDict({ | |||||
| f'{self._tag}': True, | |||||
| }) | |||||
| def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
| # x: B 3 224 224 | |||||
| if self.training: | |||||
| out, aux = torchvision.models.Inception3.forward(self, x) # B 1 | |||||
| out, aux = out.flatten(), aux.flatten() # B | |||||
| else: | |||||
| out = torchvision.models.Inception3.forward(self, x).flatten() # B | |||||
| aux = None | |||||
| output = dict() | |||||
| output['positive_class_probability'] = out | |||||
| if f'{self._tag}' not in gts: | |||||
| return output | |||||
| gt = gts[f'{self._tag}'] | |||||
| r""" Class weighted loss """ | |||||
| loss = torch.mean(torch.stack(tuple( | |||||
| F.binary_cross_entropy(out[gt == i], gt[gt == i]) for i in gt.unique() | |||||
| ))) | |||||
| output['loss'] = loss | |||||
| return output | |||||
| """ INTERPRETATION """ | |||||
| @property | |||||
| def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]: | |||||
| """ | |||||
| :return: input module for interpretation | |||||
| """ | |||||
| return ['x'] | |||||
| from collections import OrderedDict | |||||
| from typing import Dict, List | |||||
| import typing | |||||
| import torch | |||||
| from ...modules import LAP, AdaptiveLAP | |||||
| from ..lap_resnet import LapBasicBlock, LapResNet, PoolFactory | |||||
| def pool_factory(channels, sigmoid_scale=1.0): | |||||
| return LAP(channels, | |||||
| sigmoid_scale=sigmoid_scale, | |||||
| n_attention_heads=3, | |||||
| discriminative_attention=True) | |||||
| def adaptive_pool_factory(channels, sigmoid_scale=1.0): | |||||
| return AdaptiveLAP(channels, | |||||
| sigmoid_scale=sigmoid_scale, | |||||
| n_attention_heads=3, | |||||
| discriminative_attention=True) | |||||
| class CelebALAPResNet18(LapResNet): | |||||
| def __init__(self, tag: str, | |||||
| pool_factory: PoolFactory = pool_factory, | |||||
| adaptive_factory: PoolFactory = adaptive_pool_factory, | |||||
| sigmoid_scale: float = 1.0): | |||||
| super().__init__(LapBasicBlock, [2, 2, 2, 2], | |||||
| pool_factory=pool_factory, | |||||
| adaptive_factory=adaptive_factory, | |||||
| sigmoid_scale=sigmoid_scale, | |||||
| binary=True) | |||||
| self._tag = tag | |||||
| @property | |||||
| def additional_kwargs(self) -> typing.OrderedDict[str, bool]: | |||||
| r""" Returns a dictionary from additional `kwargs` names to their optionality """ | |||||
| return OrderedDict({ | |||||
| f'{self._tag}': True, | |||||
| }) | |||||
| def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
| y = gts[f'{self._tag}'] | |||||
| return super().forward(x, y) | |||||
| @property | |||||
| def attention_layers(self) -> Dict[str, List[LAP]]: | |||||
| res = super().attention_layers | |||||
| res.pop('4_overall') | |||||
| return res |
| from collections import OrderedDict | |||||
| from typing import Dict | |||||
| import typing | |||||
| import torch | |||||
| from ..tv_resnet import BasicBlock, ResNet | |||||
| class CelebAORGResNet18(ResNet): | |||||
| def __init__(self, tag: str): | |||||
| super().__init__(BasicBlock, [2, 2, 2, 2], binary=True) | |||||
| self._tag = tag | |||||
| @property | |||||
| def additional_kwargs(self) -> typing.OrderedDict[str, bool]: | |||||
| r""" Returns a dictionary from additional `kwargs` names to their optionality """ | |||||
| return OrderedDict({ | |||||
| f'{self._tag}': True, | |||||
| }) | |||||
| def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
| y = gts[f'{self._tag}'] | |||||
| return super().forward(x, y) |
| import typing | |||||
| from typing import Dict, List, Iterable | |||||
| from collections import OrderedDict | |||||
| import torch | |||||
| from torch import nn | |||||
| from torch.nn import functional as F | |||||
| import torchvision | |||||
| from ..tv_inception import Inception3 | |||||
| class CelebAORGInception(Inception3): | |||||
| def __init__(self, tag: str, aux_weight: float): | |||||
| super().__init__(aux_weight, n_classes=1) | |||||
| self._tag = tag | |||||
| @property | |||||
| def additional_kwargs(self) -> typing.OrderedDict[str, bool]: | |||||
| r""" Returns a dictionary from additional `kwargs` names to their optionality """ | |||||
| return OrderedDict({ | |||||
| f'{self._tag}': True, | |||||
| }) | |||||
| def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
| if self.training: | |||||
| out, aux = torchvision.models.Inception3.forward(self, x) # B 1 | |||||
| out, aux = out.flatten(), aux.flatten() # B | |||||
| else: | |||||
| out = torchvision.models.Inception3.forward(self, x).flatten() # B | |||||
| aux = None | |||||
| output = dict() | |||||
| output['positive_class_probability'] = out | |||||
| if f'{self._tag}' not in gts: | |||||
| return output | |||||
| gt = gts[f'{self._tag}'] | |||||
| r""" Class weighted loss """ | |||||
| loss = torch.mean(torch.stack(tuple( | |||||
| F.binary_cross_entropy(out[gt == i], gt[gt == i]) for i in gt.unique() | |||||
| ))) | |||||
| output['loss'] = loss | |||||
| return output | |||||
| @property | |||||
| def target_conv_layers(self) -> List[nn.Module]: | |||||
| return [ | |||||
| self.Mixed_7c.branch1x1, | |||||
| self.Mixed_7c.branch3x3_2a, self.Mixed_7c.branch3x3_2b, | |||||
| self.Mixed_7c.branch3x3dbl_3a, self.Mixed_7c.branch3x3dbl_3b, | |||||
| self.Mixed_7c.branch_pool | |||||
| ] | |||||
| @property | |||||
| def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]: | |||||
| return ['x'] | |||||
| def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor: | |||||
| p = self.forward(*inputs, **kwargs)['positive_class_probability'] | |||||
| return torch.stack([1 - p, p], dim=1) |
| from typing import Dict, List | |||||
| import torch | |||||
| from torchvision import transforms | |||||
| from ...modules import LAP, AdaptiveLAP, ScaledSigmoid | |||||
| from ..lap_resnet import LapBottleneck, LapResNet, PoolFactory | |||||
| def pool_factory(channels, sigmoid_scale=1.0): | |||||
| return LAP(channels, | |||||
| hidden_channels=[1000], | |||||
| sigmoid_scale=sigmoid_scale, | |||||
| hidden_activation=ScaledSigmoid.get_factory(sigmoid_scale)) | |||||
| class ImagenetLAPResNet50(LapResNet): | |||||
| def __init__(self, pool_factory: PoolFactory = pool_factory, sigmoid_scale: float = 1.0): | |||||
| super().__init__(LapBottleneck, [3, 4, 6, 3], | |||||
| pool_factory=pool_factory, | |||||
| sigmoid_scale=sigmoid_scale, | |||||
| lap_positions=[4]) | |||||
| self.normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |||||
| def forward(self, x: torch.Tensor, y: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
| x = self.normalize(x) | |||||
| return super().forward(x, y) | |||||
| @property | |||||
| def attention_layers(self) -> Dict[str, List[LAP]]: | |||||
| return { | |||||
| '2_layer4': super().attention_layers['2_layer4'], | |||||
| } |
| from typing import List, Optional, Callable, Any, Iterable, Dict | |||||
| import torch | |||||
| from torch import Tensor, nn | |||||
| import torchvision | |||||
| from ..modules.lap import LAP | |||||
| from ..interpreting.attention_interpreter import AttentionInterpretableModel | |||||
| from ..interpreting.interpretable import CamInterpretableModel | |||||
| from ..interpreting.relcam.relprop import RPProvider, RelProp | |||||
| from ..interpreting.relcam import modules as M | |||||
| class InceptionB(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| pool_factory, | |||||
| conv_block: Optional[Callable[..., nn.Module]] = None | |||||
| ) -> None: | |||||
| super(InceptionB, self).__init__() | |||||
| if conv_block is None: | |||||
| conv_block = BasicConv2d | |||||
| #self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) | |||||
| self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=1, padding=1) | |||||
| self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) | |||||
| self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) | |||||
| #self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2) | |||||
| self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=1, padding=1) | |||||
| self.pool = pool_factory(384 + 96 + in_channels) | |||||
| self.cat = M.Cat() | |||||
| def _forward(self, x: Tensor) -> List[Tensor]: | |||||
| branch3x3 = self.branch3x3(x) | |||||
| branch3x3dbl = self.branch3x3dbl_1(x) | |||||
| branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |||||
| branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) | |||||
| outputs = [branch3x3, branch3x3dbl, x] | |||||
| return outputs | |||||
| def forward(self, x: Tensor) -> Tensor: | |||||
| outputs = self._forward(x) | |||||
| return self.pool(self.cat(outputs, 1)) | |||||
| @RPProvider.register(InceptionB) | |||||
| class InceptionBRelProp(RelProp[InceptionB]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| R = RPProvider.get(self.module.pool)(R, alpha=alpha) | |||||
| branch3x3, branch3x3dbl, x1 = RPProvider.get(self.module.cat)(R, alpha=alpha) | |||||
| branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_3)(branch3x3dbl, alpha=alpha) | |||||
| branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha) | |||||
| x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha) | |||||
| x3 = RPProvider.get(self.module.branch3x3)(branch3x3, alpha=alpha) | |||||
| return x1 + x2 + x3 | |||||
| class InceptionD(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| pool_factory, | |||||
| conv_block: Optional[Callable[..., nn.Module]] = None | |||||
| ) -> None: | |||||
| super(InceptionD, self).__init__() | |||||
| if conv_block is None: | |||||
| conv_block = BasicConv2d | |||||
| self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) | |||||
| #self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) | |||||
| self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=1, padding=1) | |||||
| #self.branch3x3_2_stride = get_pooler(pooler_cls, 320, 2, pooler_hidden_layers) | |||||
| self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1) | |||||
| self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3)) | |||||
| self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0)) | |||||
| #self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2) | |||||
| self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=1, padding=1) | |||||
| #self.branch7x7x3_4_stride = get_pooler(pooler_cls, 192, 2, pooler_hidden_layers) | |||||
| self.pool = pool_factory(320 + 192 + in_channels) | |||||
| self.cat = M.Cat() | |||||
| def _forward(self, x: Tensor) -> List[Tensor]: | |||||
| branch3x3 = self.branch3x3_1(x) | |||||
| branch3x3 = self.branch3x3_2(branch3x3) | |||||
| branch7x7x3 = self.branch7x7x3_1(x) | |||||
| branch7x7x3 = self.branch7x7x3_2(branch7x7x3) | |||||
| branch7x7x3 = self.branch7x7x3_3(branch7x7x3) | |||||
| branch7x7x3 = self.branch7x7x3_4(branch7x7x3) | |||||
| outputs = [branch3x3, branch7x7x3, x] | |||||
| return outputs | |||||
| def forward(self, x: Tensor) -> Tensor: | |||||
| outputs = self._forward(x) | |||||
| return self.pool(self.cat(outputs, 1)) | |||||
| @RPProvider.register(InceptionD) | |||||
| class InceptionDRelProp(RelProp[InceptionD]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| R = RPProvider.get(self.module.pool)(R, alpha=alpha) | |||||
| branch3x3, branch7x7x3, x1 = RPProvider.get(self.module.cat)(R, alpha=alpha) | |||||
| branch7x7x3 = RPProvider.get(self.module.branch7x7x3_4)(branch7x7x3, alpha=alpha) | |||||
| branch7x7x3 = RPProvider.get(self.module.branch7x7x3_3)(branch7x7x3, alpha=alpha) | |||||
| branch7x7x3 = RPProvider.get(self.module.branch7x7x3_2)(branch7x7x3, alpha=alpha) | |||||
| x2 = RPProvider.get(self.module.branch7x7x3_1)(branch7x7x3, alpha=alpha) | |||||
| branch3x3 = RPProvider.get(self.module.branch3x3_2)(branch3x3, alpha=alpha) | |||||
| x3 = RPProvider.get(self.module.branch3x3_1)(branch3x3, alpha=alpha) | |||||
| return x1 + x2 + x3 | |||||
| def inception_b_maker(pool_factory): | |||||
| return ( | |||||
| lambda in_channels, conv_block=None: | |||||
| InceptionB(in_channels, pool_factory, conv_block)) | |||||
| def inception_d_maker(pool_factory): | |||||
| return ( | |||||
| lambda in_channels, conv_block=None: | |||||
| InceptionD(in_channels, pool_factory, conv_block)) | |||||
| class InceptionA(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| pool_features: int, | |||||
| conv_block: Optional[Callable[..., nn.Module]] = None, | |||||
| ) -> None: | |||||
| super(InceptionA, self).__init__() | |||||
| if conv_block is None: | |||||
| conv_block = BasicConv2d | |||||
| self.branch1x1 = conv_block(in_channels, 64, kernel_size=1) | |||||
| self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1) | |||||
| self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2) | |||||
| self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) | |||||
| self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) | |||||
| self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1) | |||||
| self.avg_pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1) | |||||
| self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1) | |||||
| self.cat = M.Cat() | |||||
| def _forward(self, x: Tensor) -> List[Tensor]: | |||||
| branch1x1 = self.branch1x1(x) | |||||
| branch5x5 = self.branch5x5_1(x) | |||||
| branch5x5 = self.branch5x5_2(branch5x5) | |||||
| branch3x3dbl = self.branch3x3dbl_1(x) | |||||
| branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |||||
| branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) | |||||
| branch_pool = self.avg_pool(x) | |||||
| branch_pool = self.branch_pool(branch_pool) | |||||
| outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] | |||||
| return outputs | |||||
| def forward(self, x: Tensor) -> Tensor: | |||||
| outputs = self._forward(x) | |||||
| return self.cat(outputs, 1) | |||||
| @RPProvider.register(InceptionA) | |||||
| class InceptionARelProp(RelProp[InceptionA]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| branch1x1, branch5x5, branch3x3dbl, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha) | |||||
| branch_pool = RPProvider.get(self.module.branch_pool)(branch_pool, alpha=alpha) | |||||
| x1 = RPProvider.get(self.module.avg_pool)(branch_pool, alpha=alpha) | |||||
| branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_3)(branch3x3dbl, alpha=alpha) | |||||
| branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha) | |||||
| x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha) | |||||
| branch5x5 = RPProvider.get(self.module.branch5x5_2)(branch5x5, alpha=alpha) | |||||
| x3 = RPProvider.get(self.module.branch5x5_1)(branch5x5, alpha=alpha) | |||||
| x4 = RPProvider.get(self.module.branch1x1)(branch1x1, alpha=alpha) | |||||
| return x1 + x2 + x3 + x4 | |||||
| class InceptionC(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| channels_7x7: int, | |||||
| conv_block: Optional[Callable[..., nn.Module]] = None | |||||
| ) -> None: | |||||
| super(InceptionC, self).__init__() | |||||
| if conv_block is None: | |||||
| conv_block = BasicConv2d | |||||
| self.branch1x1 = conv_block(in_channels, 192, kernel_size=1) | |||||
| c7 = channels_7x7 | |||||
| self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1) | |||||
| self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) | |||||
| self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0)) | |||||
| self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1) | |||||
| self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) | |||||
| self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) | |||||
| self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) | |||||
| self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3)) | |||||
| self.avg_pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1) | |||||
| self.branch_pool = conv_block(in_channels, 192, kernel_size=1) | |||||
| self.cat = M.Cat() | |||||
| def _forward(self, x: Tensor) -> List[Tensor]: | |||||
| branch1x1 = self.branch1x1(x) | |||||
| branch7x7 = self.branch7x7_1(x) | |||||
| branch7x7 = self.branch7x7_2(branch7x7) | |||||
| branch7x7 = self.branch7x7_3(branch7x7) | |||||
| branch7x7dbl = self.branch7x7dbl_1(x) | |||||
| branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) | |||||
| branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) | |||||
| branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) | |||||
| branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) | |||||
| branch_pool = self.avg_pool(x) | |||||
| branch_pool = self.branch_pool(branch_pool) | |||||
| outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] | |||||
| return outputs | |||||
| def forward(self, x: Tensor) -> Tensor: | |||||
| outputs = self._forward(x) | |||||
| return self.cat(outputs, 1) | |||||
| @RPProvider.register(InceptionC) | |||||
| class InceptionCRelProp(RelProp[InceptionC]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| branch1x1, branch7x7, branch7x7dbl, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha) | |||||
| branch_pool = RPProvider.get(self.module.branch_pool)(branch_pool, alpha=alpha) | |||||
| x1 = RPProvider.get(self.module.avg_pool)(branch_pool, alpha=alpha) | |||||
| branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_5)(branch7x7dbl, alpha=alpha) | |||||
| branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_4)(branch7x7dbl, alpha=alpha) | |||||
| branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_3)(branch7x7dbl, alpha=alpha) | |||||
| branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_2)(branch7x7dbl, alpha=alpha) | |||||
| x2 = RPProvider.get(self.module.branch7x7dbl_1)(branch7x7dbl, alpha=alpha) | |||||
| branch7x7 = RPProvider.get(self.module.branch7x7_3)(branch7x7, alpha=alpha) | |||||
| branch7x7 = RPProvider.get(self.module.branch7x7_2)(branch7x7, alpha=alpha) | |||||
| x3 = RPProvider.get(self.module.branch7x7_1)(branch7x7, alpha=alpha) | |||||
| x4 = RPProvider.get(self.module.branch1x1)(branch1x1, alpha=alpha) | |||||
| return x1 + x2 + x3 + x4 | |||||
| class InceptionE(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| conv_block: Optional[Callable[..., nn.Module]] = None | |||||
| ) -> None: | |||||
| super(InceptionE, self).__init__() | |||||
| if conv_block is None: | |||||
| conv_block = BasicConv2d | |||||
| self.branch1x1 = conv_block(in_channels, 320, kernel_size=1) | |||||
| self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1) | |||||
| self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) | |||||
| self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) | |||||
| self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1) | |||||
| self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1) | |||||
| self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) | |||||
| self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) | |||||
| self.avg_pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1) | |||||
| self.branch_pool = conv_block(in_channels, 192, kernel_size=1) | |||||
| self.cat1 = M.Cat() | |||||
| self.cat2 = M.Cat() | |||||
| self.cat3 = M.Cat() | |||||
| def _forward(self, x: Tensor) -> List[Tensor]: | |||||
| branch1x1 = self.branch1x1(x) | |||||
| branch3x3 = self.branch3x3_1(x) | |||||
| branch3x3 = [ | |||||
| self.branch3x3_2a(branch3x3), | |||||
| self.branch3x3_2b(branch3x3), | |||||
| ] | |||||
| branch3x3 = self.cat1(branch3x3, 1) | |||||
| branch3x3dbl = self.branch3x3dbl_1(x) | |||||
| branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |||||
| branch3x3dbl = [ | |||||
| self.branch3x3dbl_3a(branch3x3dbl), | |||||
| self.branch3x3dbl_3b(branch3x3dbl), | |||||
| ] | |||||
| branch3x3dbl = self.cat2(branch3x3dbl, 1) | |||||
| branch_pool = self.avg_pool(x) | |||||
| branch_pool = self.branch_pool(branch_pool) | |||||
| outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] | |||||
| return outputs | |||||
| def forward(self, x: Tensor) -> Tensor: | |||||
| outputs = self._forward(x) | |||||
| return self.cat3(outputs, 1) | |||||
| @RPProvider.register(InceptionE) | |||||
| class InceptionERelProp(RelProp[InceptionE]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| branch1x1, branch3x3, branch3x3dbl, branch_pool = RPProvider.get(self.module.cat3)(R, alpha=alpha) | |||||
| branch_pool = RPProvider.get(self.module.branch_pool)(branch_pool, alpha=alpha) | |||||
| x1 = RPProvider.get(self.module.avg_pool)(branch_pool, alpha=alpha) | |||||
| branch3x3dbl_3a, branch3x3dbl_3b = RPProvider.get(self.module.cat2)(branch3x3dbl, alpha=alpha) | |||||
| branch3x3dbl_1 = RPProvider.get(self.module.branch3x3dbl_3a)(branch3x3dbl_3a, alpha=alpha) | |||||
| branch3x3dbl_2 = RPProvider.get(self.module.branch3x3dbl_3b)(branch3x3dbl_3b, alpha=alpha) | |||||
| branch3x3dbl = branch3x3dbl_1 + branch3x3dbl_2 | |||||
| branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha) | |||||
| x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha) | |||||
| branch3x3_2a, branch3x3_2b = RPProvider.get(self.module.cat1)(branch3x3, alpha=alpha) | |||||
| branch3x3_1 = RPProvider.get(self.module.branch3x3_2a)(branch3x3_2a, alpha=alpha) | |||||
| branch3x3_2 = RPProvider.get(self.module.branch3x3_2b)(branch3x3_2b, alpha=alpha) | |||||
| branch3x3 = branch3x3_1 + branch3x3_2 | |||||
| x3 = RPProvider.get(self.module.branch3x3_1)(branch3x3, alpha=alpha) | |||||
| x4 = RPProvider.get(self.module.branch1x1)(branch1x1, alpha=alpha) | |||||
| return x1 + x2 + x3 + x4 | |||||
| class InceptionAux(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| num_classes: int, | |||||
| conv_block: Optional[Callable[..., nn.Module]] = None | |||||
| ) -> None: | |||||
| super(InceptionAux, self).__init__() | |||||
| if conv_block is None: | |||||
| conv_block = BasicConv2d | |||||
| self.avgpool1 = nn.AvgPool2d(kernel_size=5, stride=3) | |||||
| self.conv0 = conv_block(in_channels, 128, kernel_size=1) | |||||
| self.conv1 = conv_block(128, 768, kernel_size=5, padding=2) | |||||
| self.conv1.stddev = 0.01 # type: ignore[assignment] | |||||
| self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1)) | |||||
| self.flatten = nn.Flatten() | |||||
| self.fc = nn.Linear(768, num_classes) | |||||
| self.fc.stddev = 0.001 # type: ignore[assignment] | |||||
| def forward(self, x: Tensor) -> Tensor: | |||||
| # N x 768 x 17 x 17 | |||||
| x = self.avgpool1(x) | |||||
| # N x 768 x 5 x 5 | |||||
| x = self.conv0(x) | |||||
| # N x 128 x 5 x 5 | |||||
| x = self.conv1(x) | |||||
| # N x 768 x 1 x 1 | |||||
| # Adaptive average pooling | |||||
| x = self.avgpool2(x) | |||||
| # N x 768 x 1 x 1 | |||||
| x = self.flatten(x) | |||||
| # N x 768 | |||||
| x = self.fc(x) | |||||
| # N x 1000 | |||||
| return x | |||||
| @RPProvider.register(InceptionAux) | |||||
| class InceptionAuxRelProp(RelProp[InceptionAux]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| R = RPProvider.get(self.module.fc)(R, alpha=alpha) | |||||
| R = RPProvider.get(self.module.flatten)(R, alpha=alpha) | |||||
| R = RPProvider.get(self.module.avgpool2)(R, alpha=alpha) | |||||
| R = RPProvider.get(self.module.conv1)(R, alpha=alpha) | |||||
| R = RPProvider.get(self.module.conv0)(R, alpha=alpha) | |||||
| R = RPProvider.get(self.module.avgpool1)(R, alpha=alpha) | |||||
| return R | |||||
| class BasicConv2d(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| out_channels: int, | |||||
| **kwargs: Any | |||||
| ) -> None: | |||||
| super(BasicConv2d, self).__init__() | |||||
| self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) | |||||
| self.bn = nn.BatchNorm2d(out_channels, eps=0.001) | |||||
| self.act = torch.nn.ReLU() | |||||
| def forward(self, x: Tensor) -> Tensor: | |||||
| x = self.conv(x) | |||||
| x = self.bn(x) | |||||
| return self.act(x) | |||||
| @RPProvider.register(BasicConv2d) | |||||
| class BasicConv2dRelProp(RelProp[BasicConv2d]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| R = RPProvider.get(self.module.act)(R, alpha=alpha) | |||||
| R = RPProvider.get(self.module.bn)(R, alpha=alpha) | |||||
| R = RPProvider.get(self.module.conv)(R, alpha=alpha) | |||||
| return R | |||||
| class LAPInception(AttentionInterpretableModel, CamInterpretableModel, torchvision.models.Inception3): | |||||
| def __init__(self, aux_weight: float, n_classes, pool_factory, adaptive_pool_factory): | |||||
| torchvision.models.Inception3.__init__( | |||||
| self, | |||||
| transform_input=False, init_weights=False, | |||||
| inception_blocks=[ | |||||
| BasicConv2d, InceptionA, inception_b_maker(pool_factory), | |||||
| InceptionC, inception_d_maker(pool_factory), InceptionE, InceptionAux | |||||
| ]) | |||||
| self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2, padding=1) | |||||
| self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3, padding=1) | |||||
| self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) | |||||
| self.maxpool1 = pool_factory(64) | |||||
| self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1,) | |||||
| self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3, padding=1) | |||||
| self.maxpool2 = pool_factory(192) | |||||
| if adaptive_pool_factory is not None: | |||||
| self.avgpool = nn.Sequential( | |||||
| adaptive_pool_factory(2048)) | |||||
| self.fc = nn.Sequential( | |||||
| nn.Linear(2048, 1), | |||||
| nn.Sigmoid() | |||||
| ) | |||||
| self.AuxLogits.fc = nn.Sequential( | |||||
| nn.Linear(768, 1), | |||||
| nn.Sigmoid() | |||||
| ) | |||||
| self.aux_weight = aux_weight | |||||
| @property | |||||
| def target_conv_layers(self) -> List[nn.Module]: | |||||
| return [ | |||||
| self.Mixed_7c.branch1x1, | |||||
| self.Mixed_7c.branch3x3_2a, self.Mixed_7c.branch3x3_2b, | |||||
| self.Mixed_7c.branch3x3dbl_3a, self.Mixed_7c.branch3x3dbl_3b, | |||||
| self.Mixed_7c.branch_pool | |||||
| ] | |||||
| @property | |||||
| def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]: | |||||
| return ['x'] | |||||
| def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor: | |||||
| p = self.forward(*inputs, **kwargs)['positive_class_probability'] | |||||
| return torch.stack([1 - p, p], dim=1) | |||||
| @property | |||||
| def attention_layers(self) -> Dict[str, List[LAP]]: | |||||
| attention_groups = { | |||||
| '0_layer2': [self.maxpool2], | |||||
| '1_layer6': [self.Mixed_6a.pool], | |||||
| '2_layer7': [self.Mixed_7a.pool], | |||||
| '3_avgpool': [self.avgpool[0]], | |||||
| '4_all': [ | |||||
| self.maxpool2, | |||||
| self.Mixed_6a.pool, | |||||
| self.Mixed_7a.pool, | |||||
| self.avgpool[0]], | |||||
| } | |||||
| return attention_groups | |||||
| @RPProvider.register(LAPInception) | |||||
| class Inception3Mo4RelProp(RelProp[LAPInception]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| if RPProvider.get(self.module.fc).Y.shape[1] == 1: | |||||
| R = R[:, -1:] | |||||
| R = RPProvider.get(self.module.fc)(R, alpha=alpha) # B 2048 | |||||
| R = R.reshape_as(RPProvider.get(self.module.dropout).Y) # B 2048 1 1 | |||||
| R = RPProvider.get(self.module.dropout)(R, alpha=alpha) # B 2048 1 1 | |||||
| R = RPProvider.get(self.module.avgpool)(R, alpha=alpha) # B 2048 8 8 | |||||
| R = RPProvider.get(self.module.Mixed_7c)(R, alpha=alpha) # B 2048 8 8 | |||||
| R = RPProvider.get(self.module.Mixed_7b)(R, alpha=alpha) # B 1280 8 8 | |||||
| R = RPProvider.get(self.module.Mixed_7a)(R, alpha=alpha) # B 768 17 17 | |||||
| R = RPProvider.get(self.module.Mixed_6e)(R, alpha=alpha) # B 768 17 17 | |||||
| R = RPProvider.get(self.module.Mixed_6d)(R, alpha=alpha) # B 768 17 17 | |||||
| R = RPProvider.get(self.module.Mixed_6c)(R, alpha=alpha) # B 768 17 17 | |||||
| R = RPProvider.get(self.module.Mixed_6b)(R, alpha=alpha) # B 768 17 17 | |||||
| R = RPProvider.get(self.module.Mixed_6a)(R, alpha=alpha) # B 288 35 35 | |||||
| R = RPProvider.get(self.module.Mixed_5d)(R, alpha=alpha) # B 288 35 35 | |||||
| R = RPProvider.get(self.module.Mixed_5c)(R, alpha=alpha) # B 256 35 35 | |||||
| R = RPProvider.get(self.module.Mixed_5b)(R, alpha=alpha) # B 192 35 35 | |||||
| R = RPProvider.get(self.module.maxpool2)(R, alpha=alpha) # B 192 71 71 | |||||
| R = RPProvider.get(self.module.Conv2d_4a_3x3)(R, alpha=alpha) # B 80 73 73 | |||||
| R = RPProvider.get(self.module.Conv2d_3b_1x1)(R, alpha=alpha) # B 64 73 73 | |||||
| R = RPProvider.get(self.module.maxpool1)(R, alpha=alpha) # B 64 147 147 | |||||
| R = RPProvider.get(self.module.Conv2d_2b_3x3)(R, alpha=alpha) # B 32 147 147 | |||||
| R = RPProvider.get(self.module.Conv2d_2a_3x3)(R, alpha=alpha) # B 32 149 149 | |||||
| R = RPProvider.get(self.module.Conv2d_1a_3x3)(R, alpha=alpha) # B 3 299 299 | |||||
| return R |
| from typing import Callable, Dict, Generic, List, Optional, Type, TypeVar | |||||
| from typing_extensions import Protocol | |||||
| import warnings | |||||
| import torch | |||||
| from torch import nn | |||||
| from ..modules import LAP, AdaptiveLAP | |||||
| from .tv_resnet import BasicBlock, Bottleneck, ResNet | |||||
| from ..interpreting.relcam.relprop import RPProvider, RelProp, RelPropSimple | |||||
| from ..interpreting.relcam import modules as M | |||||
| from ..interpreting.attention_interpreter import AttentionInterpretableModel | |||||
| class PoolFactory(Protocol): | |||||
| def __call__(self, channels: int, sigmoid_scale: float = 1.0) -> LAP: ... | |||||
| lap_factory: PoolFactory = \ | |||||
| lambda channels, sigmoid_scale=1.0: \ | |||||
| LAP(channels, sigmoid_scale=sigmoid_scale) | |||||
| adaptive_lap_factory: PoolFactory = \ | |||||
| lambda channels, sigmoid_scale=1.0: \ | |||||
| AdaptiveLAP(channels, sigmoid_scale=sigmoid_scale) | |||||
| class LapBasicBlock(BasicBlock): | |||||
| def __init__(self, pool_factory: PoolFactory, inplanes, planes, stride=1, downsample=None, groups=1, | |||||
| base_width=64, dilation=1, norm_layer=None, sigmoid_scale: float = 1.0): | |||||
| super().__init__(inplanes, planes, stride=stride, downsample=downsample, | |||||
| groups=groups, base_width=base_width, dilation=dilation, | |||||
| norm_layer=norm_layer) | |||||
| self.pool = None | |||||
| if stride != 1: | |||||
| assert downsample is not None | |||||
| self.conv1 = nn.Conv2d(inplanes, planes, 3, padding=1, bias=False) | |||||
| self.downsample[0] = nn.Conv2d(inplanes, planes, 1, bias=False) | |||||
| self.pool = pool_factory(planes * 2, sigmoid_scale=sigmoid_scale) | |||||
| self.relu3 = nn.ReLU() | |||||
| self.cat = M.Cat() | |||||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
| out = self.conv1(x) # B P H W | |||||
| out = self.bn1(out) # B P H W | |||||
| out = self.relu(out) # B P H W | |||||
| if self.downsample is not None: | |||||
| x = self.downsample(x) # B P H W | |||||
| x = self.relu2(x) | |||||
| if self.pool is not None: | |||||
| poolin = self.cat([out, x], 1) # B 2P H W | |||||
| poolout: torch.Tensor = self.pool(poolin) # B 2P H/S W/S | |||||
| out, x = poolout.chunk(2, dim=1) # B P H/S W/S | |||||
| out = self.conv2(out) # B P H/S W/S | |||||
| out = self.bn2(out) # B P H/S W/S | |||||
| out = self.add([out, x]) # B P H/S W/S | |||||
| out = self.relu3(out) # B P H/S W/S | |||||
| return out | |||||
| @RPProvider.register(LapBasicBlock) | |||||
| class BlockRelProp(RelProp[LapBasicBlock]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| out = RPProvider.get(self.module.relu3)(R, alpha=alpha) | |||||
| out, x = RPProvider.get(self.module.add)(out, alpha=alpha) | |||||
| out = RPProvider.get(self.module.bn2)(out, alpha=alpha) | |||||
| out = RPProvider.get(self.module.conv2)(out, alpha=alpha) | |||||
| if self.module.pool is not None: | |||||
| poolout = torch.cat([out, x], dim=1) | |||||
| poolin = RPProvider.get(self.module.pool)(poolout, alpha=alpha) | |||||
| out, x = RPProvider.get(self.module.cat)(poolin, alpha=alpha) | |||||
| if self.module.downsample is not None: | |||||
| x = RPProvider.get(self.module.relu2)(x, alpha=alpha) | |||||
| x = RPProvider.get(self.module.downsample)(x, alpha=alpha) | |||||
| out = RPProvider.get(self.module.relu)(out, alpha=alpha) | |||||
| out = RPProvider.get(self.module.bn1)(out, alpha=alpha) | |||||
| x1 = RPProvider.get(self.module.conv1)(out, alpha=alpha) | |||||
| return x + x1 | |||||
| class LapBottleneck(Bottleneck): | |||||
| def __init__(self, pool_factory: PoolFactory, | |||||
| inplanes: int, | |||||
| planes: int, | |||||
| stride: int = 1, | |||||
| downsample: Optional[nn.Module] = None, | |||||
| groups: int = 1, | |||||
| base_width: int = 64, | |||||
| dilation: int = 1, | |||||
| norm_layer: Optional[Callable[..., nn.Module]] = None, | |||||
| sigmoid_scale: float = 1.0) -> None: | |||||
| super().__init__(inplanes, planes, stride=stride, downsample=downsample, | |||||
| groups=groups, base_width=base_width, dilation=dilation, | |||||
| norm_layer=norm_layer) | |||||
| self.pool = None | |||||
| if stride != 1: | |||||
| assert downsample is not None | |||||
| width = int(planes * (base_width / 64.)) * groups | |||||
| self.conv2 = nn.Conv2d(width, width, 3, padding=1, bias=False) | |||||
| self.downsample[0] = nn.Conv2d(inplanes, planes * self.expansion, 1, bias=False) | |||||
| self.pool = pool_factory(planes * (self.expansion + 1), sigmoid_scale=sigmoid_scale) | |||||
| self.cat = M.Cat() | |||||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
| # x # B I H W | |||||
| out = self.conv1(x) # B P H W | |||||
| out = self.bn1(out) # B P H W | |||||
| out = self.relu(out) # B P H W | |||||
| out = self.conv2(out) # B P H W | |||||
| out = self.bn2(out) # B P H W | |||||
| out = self.relu2(out) # B P H W | |||||
| if self.downsample is not None: | |||||
| x = self.downsample(x) # B 4P H W | |||||
| if self.pool is not None: | |||||
| poolin = self.cat([out, x], 1) # B 5P H W | |||||
| poolout: torch.Tensor = self.pool(poolin) # B 5P H/S W/S | |||||
| out, x = poolout.split( # B P H/S W/S | |||||
| [out.shape[1], x.shape[1]], # B 4P H/S W/S | |||||
| dim=1) | |||||
| out = self.conv3(out) # B 4P H/S W/S | |||||
| out = self.bn3(out) # B 4P H/S W/S | |||||
| out = self.add([out, x]) # B P H/S W/S | |||||
| out = self.relu3(out) # B P H/S W/S | |||||
| return out | |||||
| @RPProvider.register(LapBottleneck) | |||||
| class BlockRP(RelProp[LapBottleneck]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| out = RPProvider.get(self.module.relu3)(R, alpha=alpha) | |||||
| out, x = RPProvider.get(self.module.add)(out, alpha=alpha) | |||||
| out = RPProvider.get(self.module.bn3)(out, alpha=alpha) | |||||
| out = RPProvider.get(self.module.conv3)(out, alpha=alpha) | |||||
| if self.module.pool is not None: | |||||
| poolout = torch.cat([out, x], dim=1) | |||||
| poolin = RPProvider.get(self.module.pool)(poolout, alpha=alpha) | |||||
| out, x = RPProvider.get(self.module.cat)(poolin, alpha=alpha) | |||||
| if self.module.downsample is not None: | |||||
| x = RPProvider.get(self.module.downsample)(x, alpha=alpha) | |||||
| out = RPProvider.get(self.module.relu2)(out, alpha=alpha) | |||||
| out = RPProvider.get(self.module.bn2)(out, alpha=alpha) | |||||
| out = RPProvider.get(self.module.conv2)(out, alpha=alpha) | |||||
| out = RPProvider.get(self.module.relu)(out, alpha=alpha) | |||||
| out = RPProvider.get(self.module.bn1)(out, alpha=alpha) | |||||
| x1 = RPProvider.get(self.module.conv1)(out, alpha=alpha) | |||||
| return x + x1 | |||||
| TLapBlock = TypeVar('TLapBlock', LapBasicBlock, LapBottleneck) | |||||
| class BlockFactory(Generic[TLapBlock]): | |||||
| def __init__(self, block: Type[TLapBlock], pool_factory: PoolFactory, sigmoid_scale: float = 1.0) -> None: | |||||
| self.expansion = block.expansion | |||||
| self._block = block | |||||
| self._pool_factory = pool_factory | |||||
| self._sigmoid_scale = sigmoid_scale | |||||
| def __call__(self, *args, **kwargs) -> TLapBlock: | |||||
| return self._block(self._pool_factory, *args, **kwargs, sigmoid_scale=self._sigmoid_scale) | |||||
| class Stride(nn.Module): | |||||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
| return x[:, :, ::2, ::2] | |||||
| @RPProvider.register(Stride) | |||||
| class StrideRelProp(RelPropSimple[Stride]): | |||||
| pass | |||||
| class LapResNet(ResNet, AttentionInterpretableModel): | |||||
| def __init__(self, | |||||
| block: Type[TLapBlock], | |||||
| layers: List[int], | |||||
| pool_factory: PoolFactory = lap_factory, | |||||
| lap_positions: List[int] = [2, 3, 4], | |||||
| adaptive_factory: PoolFactory = None, | |||||
| sigmoid_scale: float = 1.0, | |||||
| binary: bool = False): | |||||
| super().__init__(BlockFactory(block, pool_factory, sigmoid_scale=sigmoid_scale), layers, binary=binary) | |||||
| if adaptive_factory is not None: | |||||
| self.avgpool = nn.Sequential( | |||||
| adaptive_factory(512, sigmoid_scale=sigmoid_scale), | |||||
| ) | |||||
| for i in range(2, 5): | |||||
| if i not in lap_positions: | |||||
| warnings.warn(f'Putting stride on layer {i}') | |||||
| getattr(self, 'layer{}'.format(i))[0].pool = Stride() | |||||
| @property | |||||
| def attention_layers(self) -> Dict[str, List[LAP]]: | |||||
| """ | |||||
| List of attention groups | |||||
| """ | |||||
| attention_groups = { | |||||
| '0_layer2': [self.layer2[0].pool], | |||||
| '1_layer3': [self.layer3[0].pool], | |||||
| '2_layer4': [self.layer4[0].pool], | |||||
| '3_avgpool': [self.avgpool[0]], | |||||
| '4_overall': [ | |||||
| self.layer2[0].pool, | |||||
| self.layer3[0].pool, | |||||
| self.layer4[0].pool, | |||||
| self.avgpool[0], | |||||
| ] | |||||
| } | |||||
| assert all( | |||||
| all( | |||||
| isinstance(layer, LAP) | |||||
| for layer in attention_group) | |||||
| for attention_group in attention_groups.values()), \ | |||||
| "Only LAP is supported for this interpretation method" | |||||
| return attention_groups | |||||
| def lap_resnet18(pool_factory: PoolFactory = lap_factory, | |||||
| adaptive_factory: PoolFactory = None, | |||||
| sigmoid_scale: float = 1.0, | |||||
| lap_positions: List[int] = [2, 3, 4], | |||||
| binary: bool = False) -> LapResNet: | |||||
| """Constructs a LAP-ResNet-18 model. | |||||
| """ | |||||
| return LapResNet(LapBasicBlock, [2, 2, 2, 2], pool_factory=pool_factory, | |||||
| adaptive_factory=adaptive_factory, sigmoid_scale=sigmoid_scale, | |||||
| lap_positions=lap_positions, binary=binary) | |||||
| def lap_resnet50(pool_factory: PoolFactory = lap_factory, | |||||
| adaptive_factory: PoolFactory = None, | |||||
| sigmoid_scale: float = 1.0, | |||||
| lap_positions: List[int] = [2, 3, 4], | |||||
| binary: bool = False) -> LapResNet: | |||||
| """Constructs a LAP-ResNet-50 model. | |||||
| """ | |||||
| return LapResNet(LapBottleneck, [3, 4, 6, 3], pool_factory=pool_factory, | |||||
| adaptive_factory=adaptive_factory, sigmoid_scale=sigmoid_scale, | |||||
| lap_positions=lap_positions, binary=binary) |
| """ | |||||
| The Model base class | |||||
| """ | |||||
| from collections import OrderedDict | |||||
| import inspect | |||||
| import typing | |||||
| from typing import Any, Dict, List, Tuple | |||||
| import re | |||||
| from abc import ABC | |||||
| import torch | |||||
| from torch.nn import BatchNorm2d, BatchNorm1d, BatchNorm3d | |||||
| ModelIO = Dict[str, torch.Tensor] | |||||
| class Model(torch.nn.Module, ABC): | |||||
| """ | |||||
| The Model base class | |||||
| """ | |||||
| _frozen_bns_list: List[torch.nn.Module] = [] | |||||
| def init_weights_from_other_model(self, pretrained_model_dir=None): | |||||
| """ Initializes the weights from another model with the given address, | |||||
| Default is to set all the parameters with the same name and shape, | |||||
| but can be rewritten in subclasses. | |||||
| The model to load is received via the function get_other_model_to_load_from. | |||||
| The default value is the same model!""" | |||||
| if pretrained_model_dir is None: | |||||
| print('The model was not preinitialized.', flush=True) | |||||
| return | |||||
| other_model_dict = torch.load(pretrained_model_dir) | |||||
| own_state = self.state_dict() | |||||
| cnt = 0 | |||||
| skip_cnt = 0 | |||||
| for name, param in other_model_dict.items(): | |||||
| if isinstance(param, torch.nn.Parameter): | |||||
| # backwards compatibility for serialized parameters | |||||
| param = param.data | |||||
| if name in own_state and own_state[name].data.shape == param.shape: | |||||
| own_state[name].copy_(param) | |||||
| cnt += 1 | |||||
| else: | |||||
| skip_cnt += 1 | |||||
| print(f'{cnt} out of {len(own_state)} parameters were loaded successfully, {skip_cnt} of the given set were skipped.' , flush=True) | |||||
| def freeze_parameters(self, regex_list: List[str] = None) -> None: | |||||
| """ | |||||
| An auxiliary function for freezing the parameters of the model | |||||
| using an inclusion keywords list and exclusion keywords list. | |||||
| :param regex_list: All the parameters matching at least one of the given regexes would be freezed. | |||||
| None means no parameter would be frozen. | |||||
| :return: None | |||||
| """ | |||||
| # One solution for batchnorms is registering a preforward hook | |||||
| # to call eval everytime before running them | |||||
| # But that causes a problem, we can't dynamically change the frozen batchnorms during training | |||||
| # e.g. in layerwise unfreezing. | |||||
| # So it's better to set a dynamic attribute that keeps their list, rather than having an init | |||||
| # Init causes other problems in multi-inheritance | |||||
| if regex_list is None: | |||||
| print('No parameter was freezeed!', flush=True) | |||||
| return | |||||
| regex_list = [re.compile(the_pattern) for the_pattern in regex_list] | |||||
| def check_freeze_conditions(param_name): | |||||
| for the_regex in regex_list: | |||||
| if the_regex.fullmatch(param_name) is not None: | |||||
| return True | |||||
| return False | |||||
| frozen_cnt = 0 | |||||
| total_cnt = 0 | |||||
| frozen_modules_names = dict() | |||||
| for n, p in self.named_parameters(): | |||||
| total_cnt += 1 | |||||
| if check_freeze_conditions(n): | |||||
| p.requires_grad = False | |||||
| frozen_cnt += 1 | |||||
| frozen_modules_names[n[:n.rfind('.')]] = True | |||||
| self._frozen_bns_list = [] | |||||
| for module_name, module in self.named_modules(): | |||||
| if (module_name in frozen_modules_names) and \ | |||||
| (isinstance(module, BatchNorm2d) or | |||||
| isinstance(module, BatchNorm1d) or | |||||
| isinstance(module, BatchNorm3d)): | |||||
| self._frozen_bns_list.append(module) | |||||
| print('********************* NOTICE *********************') | |||||
| print('%d out of %d parameters were frozen!' % (frozen_cnt, total_cnt)) | |||||
| print( | |||||
| '%d BatchNorm layers will be frozen by being kept in the val mode, IF YOU HAVE NOT OVERRIDEN IT IN YOUR MAIN MODEL!' % len( | |||||
| self._frozen_bns_list)) | |||||
| print('***************************************************', flush=True) | |||||
| def train(self, mode: bool = True): | |||||
| super(Model, self).train(mode) | |||||
| for bn in self._frozen_bns_list: | |||||
| bn.train(False) | |||||
| return self | |||||
| @property | |||||
| def additional_kwargs(self) -> typing.OrderedDict[str, bool]: | |||||
| r""" Returns a dictionary from additional `kwargs` names to their optionality """ | |||||
| return OrderedDict({}) | |||||
| def get_forward_required_kws_kwargs_and_defaults(self) -> Tuple[List[str], List[Any], Dict[str, Any]]: | |||||
| """Returns the names of the keyword arguments required in the forward function and the list of the default values for the ones having them. | |||||
| The second list might be shorter than the first and it represents the default values from the end of the arguments. | |||||
| Returns: | |||||
| Tuple[List[str], List[Any], Dict[str, Any]]: The list of args names and default values for the ones having them (from some point to the end) + the dictionary of kwargs | |||||
| """ | |||||
| model_argspec = inspect.getfullargspec(self.forward) | |||||
| model_forward_args = list(model_argspec.args) | |||||
| # skipping self arg | |||||
| model_forward_args = model_forward_args[1:] | |||||
| args_default_values = model_argspec.defaults | |||||
| if args_default_values is None: | |||||
| args_default_values = [] | |||||
| else: | |||||
| args_default_values = list(args_default_values) | |||||
| additional_kwargs = self.additional_kwargs if model_argspec.varkw is not None else {} | |||||
| return model_forward_args, args_default_values, additional_kwargs | |||||
| from typing import Dict | |||||
| import torch | |||||
| from torch.nn import functional as F | |||||
| import torchvision | |||||
| from ..lap_inception import LAPInception | |||||
| class RSNALAPInception(LAPInception): | |||||
| def __init__(self, aux_weight: float, pool_factory, adaptive_pool_factory): | |||||
| super().__init__(aux_weight, n_classes=1, pool_factory=pool_factory, adaptive_pool_factory=adaptive_pool_factory) | |||||
| def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]: | |||||
| x = x.repeat_interleave(3, dim=1) # B 3 224 224 | |||||
| if self.training: | |||||
| out, aux = torchvision.models.Inception3.forward(self, x) # B 1 | |||||
| out, aux = out.flatten(), aux.flatten() # B | |||||
| else: | |||||
| out = torchvision.models.Inception3.forward(self, x).flatten() # B | |||||
| aux = None | |||||
| if y is not None: | |||||
| main_loss = F.binary_cross_entropy(out, y) | |||||
| if aux is not None: | |||||
| aux_loss = F.binary_cross_entropy(aux, y) | |||||
| loss = (main_loss + self.aux_weight * aux_loss) / (1 + self.aux_weight) | |||||
| else: | |||||
| loss = main_loss | |||||
| return { | |||||
| 'positive_class_probability': out, | |||||
| 'loss': loss | |||||
| } | |||||
| return { | |||||
| 'positive_class_probability': out | |||||
| } |
| from typing import Dict | |||||
| import torch | |||||
| from ...modules import LAP, AdaptiveLAP | |||||
| from ..lap_resnet import LapBasicBlock, LapResNet, PoolFactory | |||||
| def get_pool_factory(discriminative_attention=True): | |||||
| def pool_factory(channels, sigmoid_scale=1.0): | |||||
| return LAP(channels, | |||||
| hidden_channels=[8], | |||||
| sigmoid_scale=sigmoid_scale, | |||||
| discriminative_attention=discriminative_attention) | |||||
| return pool_factory | |||||
| def get_adaptive_pool_factory(discriminative_attention=True): | |||||
| def adaptive_pool_factory(channels, sigmoid_scale=1.0): | |||||
| return AdaptiveLAP(channels, | |||||
| sigmoid_scale=sigmoid_scale, | |||||
| discriminative_attention=discriminative_attention) | |||||
| return adaptive_pool_factory | |||||
| class RSNALAPResNet18(LapResNet): | |||||
| def __init__(self, pool_factory: PoolFactory = get_pool_factory(), | |||||
| adaptive_factory: PoolFactory = get_adaptive_pool_factory(), | |||||
| sigmoid_scale: float = 1.0): | |||||
| super().__init__(LapBasicBlock, [2, 2, 2, 2], | |||||
| pool_factory=pool_factory, | |||||
| adaptive_factory=adaptive_factory, | |||||
| sigmoid_scale=sigmoid_scale, | |||||
| binary=True) | |||||
| def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]: | |||||
| x = x.repeat_interleave(3, dim=1) # B 3 224 224 | |||||
| return super().forward(x, y) |
| from typing import Dict | |||||
| import torch | |||||
| from torch.nn import functional as F | |||||
| import torchvision | |||||
| from ..tv_inception import Inception3 | |||||
| class RSNAORGInception(Inception3): | |||||
| def __init__(self, aux_weight: float): | |||||
| super().__init__(aux_weight, n_classes=1) | |||||
| def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]: | |||||
| x = x.repeat_interleave(3, dim=1) # B 3 224 224 | |||||
| if self.training: | |||||
| out, aux = torchvision.models.Inception3.forward(self, x) # B 1 | |||||
| out, aux = out.flatten(), aux.flatten() # B | |||||
| else: | |||||
| out = torchvision.models.Inception3.forward(self, x).flatten() # B | |||||
| aux = None | |||||
| if y is not None: | |||||
| main_loss = F.binary_cross_entropy(out, y) | |||||
| if aux is not None: | |||||
| aux_loss = F.binary_cross_entropy(aux, y) | |||||
| loss = (main_loss + self.aux_weight * aux_loss) / (1 + self.aux_weight) | |||||
| else: | |||||
| loss = main_loss | |||||
| return { | |||||
| 'positive_class_probability': out, | |||||
| 'loss': loss | |||||
| } | |||||
| return { | |||||
| 'positive_class_probability': out | |||||
| } |
| from typing import Dict | |||||
| import torch | |||||
| from ..tv_resnet import BasicBlock, ResNet | |||||
| class RSNAORGResNet18(ResNet): | |||||
| def __init__(self): | |||||
| super().__init__(BasicBlock, [2, 2, 2, 2], binary=True) | |||||
| def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]: | |||||
| x = x.repeat_interleave(3, dim=1) # B 3 224 224 | |||||
| return super().forward(x, y) |
| from typing import Iterable, Optional, Callable, List | |||||
| import torch | |||||
| from torch import nn, Tensor | |||||
| import torchvision | |||||
| from ..interpreting.interpretable import CamInterpretableModel | |||||
| from ..interpreting.relcam.relprop import RPProvider, RelProp | |||||
| from ..interpreting.relcam import modules as M | |||||
| from .lap_inception import ( | |||||
| InceptionA, | |||||
| InceptionC, | |||||
| InceptionE, | |||||
| BasicConv2d, | |||||
| InceptionAux as BaseInceptionAux, | |||||
| ) | |||||
| class InceptionB(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| conv_block: Optional[Callable[..., nn.Module]] = None | |||||
| ) -> None: | |||||
| super(InceptionB, self).__init__() | |||||
| if conv_block is None: | |||||
| conv_block = BasicConv2d | |||||
| self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) | |||||
| self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) | |||||
| self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) | |||||
| self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2) | |||||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2) | |||||
| self.cat = M.Cat() | |||||
| def _forward(self, x: Tensor) -> List[Tensor]: | |||||
| branch3x3 = self.branch3x3(x) | |||||
| branch3x3dbl = self.branch3x3dbl_1(x) | |||||
| branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |||||
| branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) | |||||
| branch_pool = self.maxpool(x) | |||||
| outputs = [branch3x3, branch3x3dbl, branch_pool] | |||||
| return outputs | |||||
| def forward(self, x: Tensor) -> Tensor: | |||||
| outputs = self._forward(x) | |||||
| return self.cat(outputs, 1) | |||||
| @RPProvider.register(InceptionB) | |||||
| class InceptionBRelProp(RelProp[InceptionB]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| branch3x3, branch3x3dbl, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha) | |||||
| x1 = RPProvider.get(self.module.maxpool)(branch_pool, alpha=alpha) | |||||
| branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_3)(branch3x3dbl, alpha=alpha) | |||||
| branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha) | |||||
| x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha) | |||||
| x3 = RPProvider.get(self.module.branch3x3)(branch3x3, alpha=alpha) | |||||
| return x1 + x2 + x3 | |||||
| class InceptionD(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| conv_block: Optional[Callable[..., nn.Module]] = None | |||||
| ) -> None: | |||||
| super(InceptionD, self).__init__() | |||||
| if conv_block is None: | |||||
| conv_block = BasicConv2d | |||||
| self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) | |||||
| self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) | |||||
| self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1) | |||||
| self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3)) | |||||
| self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0)) | |||||
| self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2) | |||||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2) | |||||
| self.cat = M.Cat() | |||||
| def _forward(self, x: Tensor) -> List[Tensor]: | |||||
| branch3x3 = self.branch3x3_1(x) | |||||
| branch3x3 = self.branch3x3_2(branch3x3) | |||||
| branch7x7x3 = self.branch7x7x3_1(x) | |||||
| branch7x7x3 = self.branch7x7x3_2(branch7x7x3) | |||||
| branch7x7x3 = self.branch7x7x3_3(branch7x7x3) | |||||
| branch7x7x3 = self.branch7x7x3_4(branch7x7x3) | |||||
| branch_pool = self.maxpool(x) | |||||
| outputs = [branch3x3, branch7x7x3, branch_pool] | |||||
| return outputs | |||||
| def forward(self, x: Tensor) -> Tensor: | |||||
| outputs = self._forward(x) | |||||
| return self.cat(outputs, 1) | |||||
| @RPProvider.register(InceptionD) | |||||
| class InceptionDRelProp(RelProp[InceptionD]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| branch3x3, branch7x7x3, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha) | |||||
| x1 = RPProvider.get(self.module.maxpool)(branch_pool, alpha=alpha) | |||||
| branch7x7x3 = RPProvider.get(self.module.branch7x7x3_4)(branch7x7x3, alpha=alpha) | |||||
| branch7x7x3 = RPProvider.get(self.module.branch7x7x3_3)(branch7x7x3, alpha=alpha) | |||||
| branch7x7x3 = RPProvider.get(self.module.branch7x7x3_2)(branch7x7x3, alpha=alpha) | |||||
| x2 = RPProvider.get(self.module.branch7x7x3_1)(branch7x7x3, alpha=alpha) | |||||
| branch3x3 = RPProvider.get(self.module.branch3x3_2)(branch3x3, alpha=alpha) | |||||
| x3 = RPProvider.get(self.module.branch3x3_1)(branch3x3, alpha=alpha) | |||||
| return x1 + x2 + x3 | |||||
| class InceptionAux(BaseInceptionAux): | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| num_classes: int, | |||||
| conv_block: Optional[Callable[..., nn.Module]] = None | |||||
| ) -> None: | |||||
| super().__init__(in_channels, num_classes, conv_block=conv_block) | |||||
| if conv_block is None: | |||||
| conv_block = BasicConv2d | |||||
| self.conv1 = conv_block(128, 768, kernel_size=5) | |||||
| self.conv1.stddev = 0.01 # type: ignore[assignment] | |||||
| class Inception3(CamInterpretableModel, torchvision.models.Inception3): | |||||
| def __init__(self, aux_weight: float, n_classes=1): | |||||
| torchvision.models.Inception3.__init__(self, | |||||
| transform_input=False, init_weights=False, | |||||
| inception_blocks = [ | |||||
| BasicConv2d, InceptionA, InceptionB, InceptionC, | |||||
| InceptionD, InceptionE, InceptionAux | |||||
| ]) | |||||
| self.fc = nn.Sequential( | |||||
| nn.Linear(2048, n_classes), | |||||
| nn.Sigmoid() | |||||
| ) | |||||
| self.AuxLogits.fc = nn.Sequential( | |||||
| nn.Linear(768, n_classes), | |||||
| nn.Sigmoid() | |||||
| ) | |||||
| self.aux_weight = aux_weight | |||||
| @property | |||||
| def target_conv_layers(self) -> List[nn.Module]: | |||||
| return [ | |||||
| self.Mixed_7c.branch1x1, | |||||
| self.Mixed_7c.branch3x3_2a, self.Mixed_7c.branch3x3_2b, | |||||
| self.Mixed_7c.branch3x3dbl_3a, self.Mixed_7c.branch3x3dbl_3b, | |||||
| self.Mixed_7c.branch_pool | |||||
| ] | |||||
| @property | |||||
| def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]: | |||||
| return ['x'] | |||||
| def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor: | |||||
| p = self.forward(*inputs, **kwargs)['positive_class_probability'] | |||||
| return torch.stack([1 - p, p], dim=1) | |||||
| @RPProvider.register(Inception3) | |||||
| class Inception3Mo4RelProp(RelProp[Inception3]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| if RPProvider.get(self.module.fc).Y.shape[1] == 1: | |||||
| R = R[:, -1:] | |||||
| R = RPProvider.get(self.module.fc)(R, alpha=alpha) # B 2048 | |||||
| R = R.reshape_as(RPProvider.get(self.module.dropout).Y) # B 2048 1 1 | |||||
| R = RPProvider.get(self.module.dropout)(R, alpha=alpha) # B 2048 1 1 | |||||
| R = RPProvider.get(self.module.avgpool)(R, alpha=alpha) # B 2048 8 8 | |||||
| R = RPProvider.get(self.module.Mixed_7c)(R, alpha=alpha) # B 2048 8 8 | |||||
| R = RPProvider.get(self.module.Mixed_7b)(R, alpha=alpha) # B 1280 8 8 | |||||
| R = RPProvider.get(self.module.Mixed_7a)(R, alpha=alpha) # B 768 17 17 | |||||
| R = RPProvider.get(self.module.Mixed_6e)(R, alpha=alpha) # B 768 17 17 | |||||
| R = RPProvider.get(self.module.Mixed_6d)(R, alpha=alpha) # B 768 17 17 | |||||
| R = RPProvider.get(self.module.Mixed_6c)(R, alpha=alpha) # B 768 17 17 | |||||
| R = RPProvider.get(self.module.Mixed_6b)(R, alpha=alpha) # B 768 17 17 | |||||
| R = RPProvider.get(self.module.Mixed_6a)(R, alpha=alpha) # B 288 35 35 | |||||
| R = RPProvider.get(self.module.Mixed_5d)(R, alpha=alpha) # B 288 35 35 | |||||
| R = RPProvider.get(self.module.Mixed_5c)(R, alpha=alpha) # B 256 35 35 | |||||
| R = RPProvider.get(self.module.Mixed_5b)(R, alpha=alpha) # B 192 35 35 | |||||
| R = RPProvider.get(self.module.maxpool2)(R, alpha=alpha) # B 192 71 71 | |||||
| R = RPProvider.get(self.module.Conv2d_4a_3x3)(R, alpha=alpha) # B 80 73 73 | |||||
| R = RPProvider.get(self.module.Conv2d_3b_1x1)(R, alpha=alpha) # B 64 73 73 | |||||
| R = RPProvider.get(self.module.maxpool1)(R, alpha=alpha) # B 64 147 147 | |||||
| R = RPProvider.get(self.module.Conv2d_2b_3x3)(R, alpha=alpha) # B 32 147 147 | |||||
| R = RPProvider.get(self.module.Conv2d_2a_3x3)(R, alpha=alpha) # B 32 149 149 | |||||
| R = RPProvider.get(self.module.Conv2d_1a_3x3)(R, alpha=alpha) # B 3 299 299 | |||||
| return R |
| from typing import Dict, Iterable, Type, Any, Callable, Union, List, Optional | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from torch import Tensor | |||||
| from torch.nn import functional as F | |||||
| from ..interpreting.interpretable import CamInterpretableModel | |||||
| from ..interpreting.relcam.relprop import RPProvider, RelProp | |||||
| from ..interpreting.relcam import modules as M | |||||
| __all__ = [ | |||||
| "ResNet", | |||||
| "resnet18", | |||||
| "resnet34", | |||||
| "resnet50", | |||||
| "resnet101", | |||||
| "resnet152", | |||||
| "resnext50_32x4d", | |||||
| "resnext101_32x8d", | |||||
| "wide_resnet50_2", | |||||
| "wide_resnet101_2", | |||||
| ] | |||||
| def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: | |||||
| """3x3 convolution with padding""" | |||||
| return nn.Conv2d( | |||||
| in_planes, | |||||
| out_planes, | |||||
| kernel_size=3, | |||||
| stride=stride, | |||||
| padding=dilation, | |||||
| groups=groups, | |||||
| bias=False, | |||||
| dilation=dilation, | |||||
| ) | |||||
| def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: | |||||
| """1x1 convolution""" | |||||
| return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |||||
| class BasicBlock(nn.Module): | |||||
| expansion: int = 1 | |||||
| def __init__( | |||||
| self, | |||||
| inplanes: int, | |||||
| planes: int, | |||||
| stride: int = 1, | |||||
| downsample: Optional[nn.Module] = None, | |||||
| groups: int = 1, | |||||
| base_width: int = 64, | |||||
| dilation: int = 1, | |||||
| norm_layer: Optional[Callable[..., nn.Module]] = None, | |||||
| ) -> None: | |||||
| super().__init__() | |||||
| if norm_layer is None: | |||||
| norm_layer = nn.BatchNorm2d | |||||
| if groups != 1 or base_width != 64: | |||||
| raise ValueError("BasicBlock only supports groups=1 and base_width=64") | |||||
| if dilation > 1: | |||||
| raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | |||||
| # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | |||||
| self.conv1 = conv3x3(inplanes, planes, stride) | |||||
| self.bn1 = norm_layer(planes) | |||||
| self.relu = nn.ReLU() | |||||
| self.relu2 = nn.ReLU() | |||||
| self.conv2 = conv3x3(planes, planes) | |||||
| self.bn2 = norm_layer(planes) | |||||
| self.downsample = downsample | |||||
| self.stride = stride | |||||
| self.add = M.Add() | |||||
| def forward(self, x: Tensor) -> Tensor: | |||||
| out = self.conv1(x) | |||||
| out = self.bn1(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv2(out) | |||||
| out = self.bn2(out) | |||||
| if self.downsample is not None: | |||||
| x = self.downsample(x) | |||||
| out = self.add([out, x]) | |||||
| out = self.relu2(out) | |||||
| return out | |||||
| @RPProvider.register(BasicBlock) | |||||
| class BasicBlockRelProp(RelProp[BasicBlock]): | |||||
| def rel(self, R, alpha): | |||||
| out = RPProvider.get(self.module.relu2)(R, alpha) | |||||
| out, x = RPProvider.get(self.module.add)(out, alpha) | |||||
| if self.module.downsample is not None: | |||||
| x = RPProvider.get(self.module.downsample)(x, alpha) | |||||
| out = RPProvider.get(self.module.bn2)(out, alpha) | |||||
| out = RPProvider.get(self.module.conv2)(out, alpha) | |||||
| out = RPProvider.get(self.module.relu)(out, alpha) | |||||
| out = RPProvider.get(self.module.bn1)(out, alpha) | |||||
| x1 = RPProvider.get(self.module.conv1)(out, alpha) | |||||
| return x + x1 | |||||
| class Bottleneck(nn.Module): | |||||
| # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) | |||||
| # while original implementation places the stride at the first 1x1 convolution(self.conv1) | |||||
| # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. | |||||
| # This variant is also known as ResNet V1.5 and improves accuracy according to | |||||
| # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. | |||||
| expansion: int = 4 | |||||
| def __init__( | |||||
| self, | |||||
| inplanes: int, | |||||
| planes: int, | |||||
| stride: int = 1, | |||||
| downsample: Optional[nn.Module] = None, | |||||
| groups: int = 1, | |||||
| base_width: int = 64, | |||||
| dilation: int = 1, | |||||
| norm_layer: Optional[Callable[..., nn.Module]] = None, | |||||
| ) -> None: | |||||
| super().__init__() | |||||
| if norm_layer is None: | |||||
| norm_layer = nn.BatchNorm2d | |||||
| width = int(planes * (base_width / 64.0)) * groups | |||||
| # Both self.conv2 and self.downsample layers downsample the input when stride != 1 | |||||
| self.conv1 = conv1x1(inplanes, width) | |||||
| self.bn1 = norm_layer(width) | |||||
| self.conv2 = conv3x3(width, width, stride, groups, dilation) | |||||
| self.bn2 = norm_layer(width) | |||||
| self.conv3 = conv1x1(width, planes * self.expansion) | |||||
| self.bn3 = norm_layer(planes * self.expansion) | |||||
| self.relu = nn.ReLU() | |||||
| self.relu2 = nn.ReLU() | |||||
| self.relu3 = nn.ReLU() | |||||
| self.downsample = downsample | |||||
| self.stride = stride | |||||
| self.add = M.Add() | |||||
| def forward(self, x: Tensor) -> Tensor: | |||||
| out = self.conv1(x) | |||||
| out = self.bn1(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv2(out) | |||||
| out = self.bn2(out) | |||||
| out = self.relu2(out) | |||||
| out = self.conv3(out) | |||||
| out = self.bn3(out) | |||||
| if self.downsample is not None: | |||||
| x = self.downsample(x) | |||||
| out = self.add([out, x]) | |||||
| out = self.relu3(out) | |||||
| return out | |||||
| @RPProvider.register(Bottleneck) | |||||
| class BottleneckRelProp(RelProp[Bottleneck]): | |||||
| def rel(self, R, alpha): | |||||
| out = RPProvider.get(self.module.relu3)(R, alpha) | |||||
| out, x = RPProvider.get(self.module.add)(out, alpha) | |||||
| if self.downsample is not None: | |||||
| x = RPProvider.get(self.module.downsample)(x, alpha) | |||||
| out = RPProvider.get(self.module.bn3)(out, alpha) | |||||
| out = RPProvider.get(self.module.conv3)(out, alpha) | |||||
| out = RPProvider.get(self.module.relu2)(out, alpha) | |||||
| out = RPProvider.get(self.module.bn2)(out, alpha) | |||||
| out = RPProvider.get(self.module.conv2)(out, alpha) | |||||
| out = RPProvider.get(self.module.relu)(out, alpha) | |||||
| out = RPProvider.get(self.module.bn1)(out, alpha) | |||||
| x1 = RPProvider.get(self.module.conv1)(out, alpha) | |||||
| return x + x1 | |||||
| class ResNet(CamInterpretableModel): | |||||
| def __init__( | |||||
| self, | |||||
| block: Type[Union[BasicBlock, Bottleneck]], | |||||
| layers: List[int], | |||||
| num_classes: int = 1000, | |||||
| zero_init_residual: bool = False, | |||||
| groups: int = 1, | |||||
| width_per_group: int = 64, | |||||
| replace_stride_with_dilation: Optional[List[bool]] = None, | |||||
| norm_layer: Optional[Callable[..., nn.Module]] = None, | |||||
| binary: bool = False, | |||||
| ) -> None: | |||||
| super().__init__() | |||||
| if norm_layer is None: | |||||
| norm_layer = nn.BatchNorm2d | |||||
| self._norm_layer = norm_layer | |||||
| self.inplanes = 64 | |||||
| self.dilation = 1 | |||||
| if replace_stride_with_dilation is None: | |||||
| # each element in the tuple indicates if we should replace | |||||
| # the 2x2 stride with a dilated convolution instead | |||||
| replace_stride_with_dilation = [False, False, False] | |||||
| if len(replace_stride_with_dilation) != 3: | |||||
| raise ValueError( | |||||
| "replace_stride_with_dilation should be None " | |||||
| f"or a 3-element tuple, got {replace_stride_with_dilation}" | |||||
| ) | |||||
| self.groups = groups | |||||
| self.base_width = width_per_group | |||||
| self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) | |||||
| self.bn1 = norm_layer(self.inplanes) | |||||
| self.relu = nn.ReLU() | |||||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||||
| self.layer1 = self._make_layer(block, 64, layers[0]) | |||||
| self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) | |||||
| self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) | |||||
| self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) | |||||
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |||||
| self._binary = binary | |||||
| if binary: | |||||
| self.fc = nn.Sequential( | |||||
| nn.Linear(512 * block.expansion, 1), | |||||
| nn.Sigmoid(), | |||||
| nn.Flatten(0) | |||||
| ) | |||||
| else: | |||||
| self.fc = nn.Linear(512 * block.expansion, num_classes) | |||||
| for m in self.modules(): | |||||
| if isinstance(m, nn.Conv2d): | |||||
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |||||
| elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | |||||
| nn.init.constant_(m.weight, 1) | |||||
| nn.init.constant_(m.bias, 0) | |||||
| # Zero-initialize the last BN in each residual branch, | |||||
| # so that the residual branch starts with zeros, and each residual block behaves like an identity. | |||||
| # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | |||||
| if zero_init_residual: | |||||
| for m in self.modules(): | |||||
| if isinstance(m, Bottleneck): | |||||
| nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] | |||||
| elif isinstance(m, BasicBlock): | |||||
| nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] | |||||
| def _make_layer( | |||||
| self, | |||||
| block: Type[Union[BasicBlock, Bottleneck]], | |||||
| planes: int, | |||||
| blocks: int, | |||||
| stride: int = 1, | |||||
| dilate: bool = False, | |||||
| ) -> nn.Sequential: | |||||
| norm_layer = self._norm_layer | |||||
| downsample = None | |||||
| previous_dilation = self.dilation | |||||
| if dilate: | |||||
| self.dilation *= stride | |||||
| stride = 1 | |||||
| if stride != 1 or self.inplanes != planes * block.expansion: | |||||
| downsample = nn.Sequential( | |||||
| conv1x1(self.inplanes, planes * block.expansion, stride), | |||||
| norm_layer(planes * block.expansion), | |||||
| ) | |||||
| layers = [] | |||||
| layers.append( | |||||
| block( | |||||
| self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer | |||||
| ) | |||||
| ) | |||||
| self.inplanes = planes * block.expansion | |||||
| for _ in range(1, blocks): | |||||
| layers.append( | |||||
| block( | |||||
| self.inplanes, | |||||
| planes, | |||||
| groups=self.groups, | |||||
| base_width=self.base_width, | |||||
| dilation=self.dilation, | |||||
| norm_layer=norm_layer, | |||||
| ) | |||||
| ) | |||||
| return nn.Sequential(*layers) | |||||
| def _forward_impl(self, x: Tensor) -> Tensor: | |||||
| # See note [TorchScript super()] | |||||
| x = self.conv1(x) | |||||
| x = self.bn1(x) | |||||
| x = self.relu(x) | |||||
| x = self.maxpool(x) | |||||
| x = self.layer1(x) | |||||
| x = self.layer2(x) | |||||
| x = self.layer3(x) | |||||
| x = self.layer4(x) | |||||
| x = self.avgpool(x) | |||||
| x = torch.flatten(x, 1) | |||||
| x = self.fc(x) | |||||
| return x | |||||
| def forward(self, x: Tensor, y: Tensor) -> Dict[str, Tensor]: | |||||
| p = self._forward_impl(x) | |||||
| if self._binary: | |||||
| return dict( | |||||
| positive_class_probability=p, | |||||
| loss=F.binary_cross_entropy(p, y), | |||||
| ) | |||||
| return dict( | |||||
| categorical_probability=p.softmax(dim=1), | |||||
| loss=F.cross_entropy(p, y.long()), | |||||
| ) | |||||
| @property | |||||
| def target_conv_layers(self) -> List[nn.Module]: | |||||
| """ | |||||
| Returns: | |||||
| The convolutional layers to be interpreted. The result of | |||||
| the interpretation will be sum of the grad-cam of these layers | |||||
| """ | |||||
| return [self.layer4] | |||||
| def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor: | |||||
| """ | |||||
| A method to get probabilities assigned to all the classes in the model's forward, | |||||
| with shape (B, C), in which B is the batch size and C is number of classes | |||||
| Args: | |||||
| *inputs: Inputs to the model | |||||
| **kwargs: Additional arguments | |||||
| Returns: | |||||
| Tensor of categorical probabilities | |||||
| """ | |||||
| if self._binary: | |||||
| p = self.forward(*inputs, **kwargs)['positive_class_probability'] | |||||
| return torch.stack([1 - p, p], dim=1) | |||||
| return self.forward(*inputs, **kwargs)['categorical_probability'] | |||||
| @property | |||||
| def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]: | |||||
| """ | |||||
| Returns: | |||||
| Input module for interpretation | |||||
| """ | |||||
| return ['x'] | |||||
| @RPProvider.register(ResNet) | |||||
| class ResNetRelProp(RelProp[ResNet]): | |||||
| def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||||
| if RPProvider.get(self.module.fc).Y.ndim == 1: | |||||
| R = R[:, -1] | |||||
| R = RPProvider.get(self.module.fc)(R, alpha=alpha) | |||||
| R = R.reshape_as(RPProvider.get(self.module.avgpool).Y) | |||||
| R = RPProvider.get(self.module.avgpool)(R, alpha=alpha) | |||||
| R = RPProvider.get(self.module.layer4)(R, alpha=alpha) | |||||
| R = RPProvider.get(self.module.layer3)(R, alpha=alpha) | |||||
| R = RPProvider.get(self.module.layer2)(R, alpha=alpha) | |||||
| R = RPProvider.get(self.module.layer1)(R, alpha=alpha) | |||||
| R = RPProvider.get(self.module.maxpool)(R, alpha=alpha) | |||||
| R = RPProvider.get(self.module.relu)(R, alpha=alpha) | |||||
| R = RPProvider.get(self.module.bn1)(R, alpha=alpha) | |||||
| R = RPProvider.get(self.module.conv1)(R, alpha=alpha) | |||||
| return R | |||||
| def _resnet( | |||||
| arch: str, | |||||
| block: Type[Union[BasicBlock, Bottleneck]], | |||||
| layers: List[int], | |||||
| pretrained: bool, | |||||
| progress: bool, | |||||
| **kwargs: Any, | |||||
| ) -> ResNet: | |||||
| model = ResNet(block, layers, **kwargs) | |||||
| #if pretrained: | |||||
| #state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) | |||||
| #model.load_state_dict(state_dict) | |||||
| return model | |||||
| def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||||
| r"""ResNet-18 model from | |||||
| `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. | |||||
| Args: | |||||
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |||||
| progress (bool): If True, displays a progress bar of the download to stderr | |||||
| """ | |||||
| return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) | |||||
| def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||||
| r"""ResNet-34 model from | |||||
| `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. | |||||
| Args: | |||||
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |||||
| progress (bool): If True, displays a progress bar of the download to stderr | |||||
| """ | |||||
| return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) | |||||
| def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||||
| r"""ResNet-50 model from | |||||
| `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. | |||||
| Args: | |||||
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |||||
| progress (bool): If True, displays a progress bar of the download to stderr | |||||
| """ | |||||
| return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) | |||||
| def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||||
| r"""ResNet-101 model from | |||||
| `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. | |||||
| Args: | |||||
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |||||
| progress (bool): If True, displays a progress bar of the download to stderr | |||||
| """ | |||||
| return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) | |||||
| def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||||
| r"""ResNet-152 model from | |||||
| `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. | |||||
| Args: | |||||
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |||||
| progress (bool): If True, displays a progress bar of the download to stderr | |||||
| """ | |||||
| return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) | |||||
| def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||||
| r"""ResNeXt-50 32x4d model from | |||||
| `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_. | |||||
| Args: | |||||
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |||||
| progress (bool): If True, displays a progress bar of the download to stderr | |||||
| """ | |||||
| kwargs["groups"] = 32 | |||||
| kwargs["width_per_group"] = 4 | |||||
| return _resnet("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) | |||||
| def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||||
| r"""ResNeXt-101 32x8d model from | |||||
| `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_. | |||||
| Args: | |||||
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |||||
| progress (bool): If True, displays a progress bar of the download to stderr | |||||
| """ | |||||
| kwargs["groups"] = 32 | |||||
| kwargs["width_per_group"] = 8 | |||||
| return _resnet("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) | |||||
| def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||||
| r"""Wide ResNet-50-2 model from | |||||
| `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_. | |||||
| The model is the same as ResNet except for the bottleneck number of channels | |||||
| which is twice larger in every block. The number of channels in outer 1x1 | |||||
| convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 | |||||
| channels, and in Wide ResNet-50-2 has 2048-1024-2048. | |||||
| Args: | |||||
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |||||
| progress (bool): If True, displays a progress bar of the download to stderr | |||||
| """ | |||||
| kwargs["width_per_group"] = 64 * 2 | |||||
| return _resnet("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) | |||||
| def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||||
| r"""Wide ResNet-101-2 model from | |||||
| `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_. | |||||
| The model is the same as ResNet except for the bottleneck number of channels | |||||
| which is twice larger in every block. The number of channels in outer 1x1 | |||||
| convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 | |||||
| channels, and in Wide ResNet-50-2 has 2048-1024-2048. | |||||
| Args: | |||||
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |||||
| progress (bool): If True, displays a progress bar of the download to stderr | |||||
| """ | |||||
| kwargs["width_per_group"] = 64 * 2 | |||||
| return _resnet("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) |
| from .gaussianmax import GaussianMax2d | |||||
| from .activated_conv import ActivatedConv2d | |||||
| from .scaled_sigmoid import ScaledSigmoid | |||||
| from .lap import LAP | |||||
| from .adaptive_lap import AdaptiveLAP |