@@ -0,0 +1,313 @@ | |||
# Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,pycharm+all,macos | |||
# Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode,pycharm+all,macos | |||
### macOS ### | |||
# General | |||
.DS_Store | |||
.AppleDouble | |||
.LSOverride | |||
# Icon must end with two \r | |||
Icon | |||
# Thumbnails | |||
._* | |||
# Files that might appear in the root of a volume | |||
.DocumentRevisions-V100 | |||
.fseventsd | |||
.Spotlight-V100 | |||
.TemporaryItems | |||
.Trashes | |||
.VolumeIcon.icns | |||
.com.apple.timemachine.donotpresent | |||
# Directories potentially created on remote AFP share | |||
.AppleDB | |||
.AppleDesktop | |||
Network Trash Folder | |||
Temporary Items | |||
.apdisk | |||
### macOS Patch ### | |||
# iCloud generated files | |||
*.icloud | |||
### PyCharm+all ### | |||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider | |||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 | |||
# User-specific stuff | |||
.idea/**/workspace.xml | |||
.idea/**/tasks.xml | |||
.idea/**/usage.statistics.xml | |||
.idea/**/dictionaries | |||
.idea/**/shelf | |||
# AWS User-specific | |||
.idea/**/aws.xml | |||
# Generated files | |||
.idea/**/contentModel.xml | |||
# Sensitive or high-churn files | |||
.idea/**/dataSources/ | |||
.idea/**/dataSources.ids | |||
.idea/**/dataSources.local.xml | |||
.idea/**/sqlDataSources.xml | |||
.idea/**/dynamic.xml | |||
.idea/**/uiDesigner.xml | |||
.idea/**/dbnavigator.xml | |||
# Gradle | |||
.idea/**/gradle.xml | |||
.idea/**/libraries | |||
# Gradle and Maven with auto-import | |||
# When using Gradle or Maven with auto-import, you should exclude module files, | |||
# since they will be recreated, and may cause churn. Uncomment if using | |||
# auto-import. | |||
# .idea/artifacts | |||
# .idea/compiler.xml | |||
# .idea/jarRepositories.xml | |||
# .idea/modules.xml | |||
# .idea/*.iml | |||
# .idea/modules | |||
# *.iml | |||
# *.ipr | |||
# CMake | |||
cmake-build-*/ | |||
# Mongo Explorer plugin | |||
.idea/**/mongoSettings.xml | |||
# File-based project format | |||
*.iws | |||
# IntelliJ | |||
out/ | |||
# mpeltonen/sbt-idea plugin | |||
.idea_modules/ | |||
# JIRA plugin | |||
atlassian-ide-plugin.xml | |||
# Cursive Clojure plugin | |||
.idea/replstate.xml | |||
# SonarLint plugin | |||
.idea/sonarlint/ | |||
# Crashlytics plugin (for Android Studio and IntelliJ) | |||
com_crashlytics_export_strings.xml | |||
crashlytics.properties | |||
crashlytics-build.properties | |||
fabric.properties | |||
# Editor-based Rest Client | |||
.idea/httpRequests | |||
# Android studio 3.1+ serialized cache file | |||
.idea/caches/build_file_checksums.ser | |||
### PyCharm+all Patch ### | |||
# Ignore everything but code style settings and run configurations | |||
# that are supposed to be shared within teams. | |||
.idea/* | |||
!.idea/codeStyles | |||
!.idea/runConfigurations | |||
### Python ### | |||
# Byte-compiled / optimized / DLL files | |||
__pycache__/ | |||
*.py[cod] | |||
*$py.class | |||
# C extensions | |||
*.so | |||
# Distribution / packaging | |||
.Python | |||
build/ | |||
develop-eggs/ | |||
dist/ | |||
downloads/ | |||
eggs/ | |||
.eggs/ | |||
lib/ | |||
lib64/ | |||
parts/ | |||
sdist/ | |||
var/ | |||
wheels/ | |||
share/python-wheels/ | |||
*.egg-info/ | |||
.installed.cfg | |||
*.egg | |||
MANIFEST | |||
# PyInstaller | |||
# Usually these files are written by a python script from a template | |||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | |||
*.manifest | |||
*.spec | |||
# Installer logs | |||
pip-log.txt | |||
pip-delete-this-directory.txt | |||
# Unit test / coverage reports | |||
htmlcov/ | |||
.tox/ | |||
.nox/ | |||
.coverage | |||
.coverage.* | |||
.cache | |||
nosetests.xml | |||
coverage.xml | |||
*.cover | |||
*.py,cover | |||
.hypothesis/ | |||
.pytest_cache/ | |||
cover/ | |||
# Translations | |||
*.mo | |||
*.pot | |||
# Django stuff: | |||
*.log | |||
local_settings.py | |||
db.sqlite3 | |||
db.sqlite3-journal | |||
# Flask stuff: | |||
instance/ | |||
.webassets-cache | |||
# Scrapy stuff: | |||
.scrapy | |||
# Sphinx documentation | |||
docs/_build/ | |||
# PyBuilder | |||
.pybuilder/ | |||
target/ | |||
# Jupyter Notebook | |||
.ipynb_checkpoints | |||
# IPython | |||
profile_default/ | |||
ipython_config.py | |||
# pyenv | |||
# For a library or package, you might want to ignore these files since the code is | |||
# intended to run in multiple environments; otherwise, check them in: | |||
# .python-version | |||
# pipenv | |||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | |||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | |||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | |||
# install all needed dependencies. | |||
#Pipfile.lock | |||
# poetry | |||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. | |||
# This is especially recommended for binary packages to ensure reproducibility, and is more | |||
# commonly ignored for libraries. | |||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control | |||
#poetry.lock | |||
# pdm | |||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. | |||
#pdm.lock | |||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it | |||
# in version control. | |||
# https://pdm.fming.dev/#use-with-ide | |||
.pdm.toml | |||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm | |||
__pypackages__/ | |||
# Celery stuff | |||
celerybeat-schedule | |||
celerybeat.pid | |||
# SageMath parsed files | |||
*.sage.py | |||
# Environments | |||
.env | |||
.venv | |||
env/ | |||
venv/ | |||
ENV/ | |||
env.bak/ | |||
venv.bak/ | |||
# Spyder project settings | |||
.spyderproject | |||
.spyproject | |||
# Rope project settings | |||
.ropeproject | |||
# mkdocs documentation | |||
/site | |||
# mypy | |||
.mypy_cache/ | |||
.dmypy.json | |||
dmypy.json | |||
# Pyre type checker | |||
.pyre/ | |||
# pytype static type analyzer | |||
.pytype/ | |||
# Cython debug symbols | |||
cython_debug/ | |||
# PyCharm | |||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can | |||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore | |||
# and can be added to the global gitignore or merged into this file. For a more nuclear | |||
# option (not recommended) you can uncomment the following to ignore the entire idea folder. | |||
#.idea/ | |||
### Python Patch ### | |||
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration | |||
poetry.toml | |||
# ruff | |||
.ruff_cache/ | |||
### VisualStudioCode ### | |||
.vscode/* | |||
!.vscode/settings.json | |||
!.vscode/tasks.json | |||
!.vscode/launch.json | |||
!.vscode/extensions.json | |||
!.vscode/*.code-snippets | |||
# Local History for Visual Studio Code | |||
.history/ | |||
# Built Visual Studio Code Extensions | |||
*.vsix | |||
### VisualStudioCode Patch ### | |||
# Ignore all local history of files | |||
.history | |||
.ionide | |||
# End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,pycharm+all,macos |
@@ -0,0 +1,218 @@ | |||
# LAP | |||
An Attention-Based Module for Faithful Interpretation and Knowledge Injection in Convolutional Neural Networks | |||
In this supplementary material, we have provided a pdf containing more details for some of the sections mentioned in the paper, and the code of our work. | |||
# Data Preparation | |||
## Metadata | |||
First, download the metadata from [here](https://mega.nz/file/MqhXiLgC#NK1vd9ZLGU-3a182x-BXfvGCSuqHgq9hflI6tddh4Wc) | |||
## RSNA | |||
## CelebA | |||
1. Download the dataset from [here](https://www.kaggle.com/jessicali9530/celeba-dataset) | |||
2. Extract the data | |||
```bash | |||
unzip -q -j celeba-dataset.zip '**/*.jpg' -d data/celeba | |||
``` | |||
3. Extract `celeba.tsv` from [this](#metadata) metadata file to `dataset_metadata/celeba.tsv` | |||
## Imagenet | |||
1. Download the ImageNet ILSVRC2012 dataset from [here](https://www.image-net.org/challenges/LSVRC/index.php). | |||
1. Extract `ILSVRC2012_img_train_t3.tar` to `data/imagenet/tars`. | |||
1. Run the following command to extract the train images per class: | |||
```bash | |||
python data_preparation/imagenet/extract_train.py -s data/imagenet/tars/ -n <num-threads> | |||
``` | |||
1. Run the following command to extract the validation images per class (replace directory with the directory containing `ILSVRC2012_img_val.tar`, and extract `imagenet_val_maps.pklz` from [this](#metadata) metadata file): | |||
```bash | |||
python data_preparation/imagenet/extract_val.py -s <directory> -m <imagenet_val_maps.pklz> | |||
``` | |||
1. download the localizations bounding-boxes from [here](https://www.kaggle.com/c/imagenet-object-localization-challenge/data). Put `LOC_train_solution.csv` and `LOC_val_solution.csv` in `data/imagenet/extracted`. | |||
# Model checkpoints | |||
Download all the checkpoints from [here](https://mega.nz/file/16pWWA5L#b1SwLEv-YO5lL9wLVTcgykn2LmUBHm1exPFNB0JBoZo). | |||
# Run | |||
## Evaluate the checkpoints | |||
### RSNA | |||
#### Vanilla ResNet 18 | |||
##### Evaluate the model performance on the test set | |||
```bash | |||
python evaluate.py rsna.org_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.org_resnet.pth | |||
``` | |||
#### Vanilla Inception V3 | |||
##### Evaluate the model performance on the test set | |||
```bash | |||
python evaluate.py rsna.org_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.org_inception.pth | |||
``` | |||
#### WS ResNet 18 | |||
##### Evaluate the model performance on the test set | |||
```bash | |||
python evaluate.py rsna.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_resnet.pth | |||
``` | |||
##### Generate the LAP interpretations for positive samples | |||
```bash | |||
python interpret.py rsna.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir rsna/ws_resnet_interpretations | |||
``` | |||
##### Generate the LAP interpretations for negative samples | |||
```bash | |||
python interpret.py rsna.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir rsna/ws_resnet_interpretations | |||
``` | |||
#### WS Inception V3 | |||
##### Evaluate the model performance on the test set | |||
```bash | |||
python evaluate.py rsna.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_inception.pth | |||
``` | |||
##### Generate the LAP interpretations for positive samples | |||
```bash | |||
python interpret.py rsna.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir rsna/ws_inception_interpretations | |||
``` | |||
##### Generate the LAP interpretations for negative samples | |||
```bash | |||
python interpret.py rsna.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir rsna/ws_inception_interpretations | |||
``` | |||
#### BB ResNet 18 | |||
##### Evaluate the model performance on the test set | |||
```bash | |||
python evaluate.py rsna.bb_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.bb_resnet.pth | |||
``` | |||
##### Generate the LAP interpretations for positive samples | |||
```bash | |||
python interpret.py rsna.bb_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.bb_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir rsna/bb_resnet_interpretations | |||
``` | |||
##### Generate the LAP interpretations for negative samples | |||
```bash | |||
python interpret.py rsna.bb_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.bb_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir rsna/bb_resnet_interpretations | |||
``` | |||
#### BB Inception V3 | |||
##### Evaluate the model performance on the test set | |||
```bash | |||
python evaluate.py rsna.bb_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.bb_inception.pth | |||
``` | |||
##### Generate the LAP interpretations for positive samples | |||
```bash | |||
python interpret.py rsna.bb_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.bb_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir rsna/bb_inception_interpretations | |||
``` | |||
##### Generate the LAP interpretations for negative samples | |||
```bash | |||
python interpret.py rsna.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir rsna/ws_inception_interpretations | |||
``` | |||
### CelebA | |||
#### Vanilla ResNet 18 | |||
##### Evaluate the model performance on the test set | |||
```bash | |||
python evaluate.py celeba.org_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.org_resnet.pth | |||
``` | |||
#### Vanilla Inception V3 | |||
##### Evaluate the model performance on the test set | |||
```bash | |||
python evaluate.py celeba.org_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.org_inception.pth | |||
``` | |||
#### WS ResNet 18 | |||
##### Evaluate the model performance on the test set | |||
```bash | |||
python evaluate.py celeba.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_resnet.pth | |||
``` | |||
##### Generate the LAP interpretations for positive samples | |||
```bash | |||
python interpret.py celeba.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir celeba/ws_resnet_interpretations | |||
``` | |||
##### Generate the LAP interpretations for negative samples | |||
```bash | |||
python interpret.py celeba.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir celeba/ws_resnet_interpretations | |||
``` | |||
#### WS Inception V3 | |||
##### Evaluate the model performance on the test set | |||
```bash | |||
python evaluate.py celeba.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_inception.pth | |||
``` | |||
##### Generate the LAP interpretations for positive samples | |||
```bash | |||
python interpret.py celeba.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir celeba/ws_inception_interpretations | |||
``` | |||
##### Generate the LAP interpretations for negative samples | |||
```bash | |||
python interpret.py celeba.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir celeba/ws_inception_interpretations | |||
``` | |||
### ImageNet | |||
#### WS ResNet 50 (Fine-tuned) | |||
##### Evaluate the model performance on the validation set | |||
```bash | |||
python evaluate.py imagenet.lap_resnet50_ft --device cuda:0 --samples-dir val --final-model-dir checkpoints/imagenet.ws_resnet_ft.pth | |||
``` |
@@ -0,0 +1,32 @@ | |||
import argparse | |||
import glob | |||
import os | |||
import tarfile | |||
from multiprocessing import Pool | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('-s', dest='source', help='Class tars directory', required=True) | |||
parser.add_argument('-t', dest='target', help='train set directory', default='data/imagenet/extracted') | |||
parser.add_argument('-n', dest='num_threads', help='number of threads', default=1, type=int) | |||
args = parser.parse_args() | |||
class_tars = glob.glob(os.path.join(args.source, '*.tar')) | |||
assert len(class_tars) == 1000, f"class_tars length: {len(class_tars)}" | |||
def extract(class_tar): | |||
filename = os.path.basename(class_tar) | |||
class_name = filename.replace('.tar', '') | |||
print('Extract ' + os.path.basename(class_tar)) | |||
class_fname = os.path.join(args.target, class_name) | |||
os.makedirs(class_fname, exist_ok=True) | |||
with tarfile.open(class_tar) as f: | |||
f.extractall(class_fname) | |||
pool = Pool(args.num_threads) | |||
pool.map(extract, class_tars) | |||
pool.close() |
@@ -0,0 +1,59 @@ | |||
"""Prepare the ImageNet dataset""" | |||
# import torch | |||
import os | |||
import argparse | |||
import tarfile | |||
import pickle | |||
import gzip | |||
# import subprocess | |||
# from tqdm import tqdm | |||
# from mxnet.gluon.utils import check_sha1 | |||
# from gluoncv.utils import download, makedirs | |||
_VAL_TAR = 'ILSVRC2012_img_val.tar' | |||
def parse_args(): | |||
parser = argparse.ArgumentParser( | |||
description='Setup the ImageNet dataset.', | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
parser.add_argument('-s', dest='download_dir', required=True, | |||
help="The directory that contains downloaded tar files") | |||
parser.add_argument('-m', dest='mapping', required=True, | |||
help="The mapping file for validation set") | |||
parser.add_argument('--target-dir', default='data/imagenet/extracted', | |||
help="The directory to store extracted images") | |||
args = parser.parse_args() | |||
return args | |||
def check_file(filename): | |||
if not os.path.exists(filename): | |||
raise ValueError('File not found: '+filename) | |||
def extract_val(tar_fname, target_dir, val_maps_file): | |||
os.makedirs(target_dir) | |||
print('Extracting ' + tar_fname) | |||
with tarfile.open(tar_fname) as tar: | |||
tar.extractall(target_dir) | |||
# build rec file before images are moved into subfolders | |||
# move images to proper subfolders | |||
with gzip.open(val_maps_file, 'rb') as f: | |||
dirs, mappings = pickle.load(f) | |||
for d in dirs: | |||
os.makedirs(os.path.join(target_dir, d)) | |||
for m in mappings: | |||
os.rename(os.path.join(target_dir, m[0]), os.path.join(target_dir, m[1], m[0])) | |||
def main(): | |||
args = parse_args() | |||
target_dir = os.path.expanduser(args.target_dir) | |||
if os.path.exists(target_dir): | |||
raise ValueError('Target dir ['+target_dir+'] exists. Remove it first') | |||
download_dir = os.path.expanduser(args.download_dir) | |||
val_tar_fname = os.path.join(download_dir, _VAL_TAR) | |||
check_file(val_tar_fname) | |||
extract_val(val_tar_fname, os.path.join(target_dir, 'val'), args.mapping) | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,41 @@ | |||
from typing import List, Tuple | |||
import argparse | |||
from os import makedirs, path, listdir | |||
import numpy as np | |||
import cv2 | |||
from multiprocessing import Pool | |||
def resize_image(img_dir: str, img_save_dir: str, res: int) -> None: | |||
img = cv2.imread(img_dir) | |||
img = cv2.resize(img, dsize=(res, res)) | |||
cv2.imwrite(img_save_dir, img) | |||
if __name__ == '__main__': | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('kaggle_dataset_dir', type=str, help='download the dataset from https://www.kaggle.com/c/rsna-pneumonia-detection-challenge/data, extract it, and pass its path is this argument') | |||
parser.add_argument('resolution', type=int, help='The resolution required for your model, 224 for resnet, 299 for vanilla inception, 256 for modified inception') | |||
parser.add_argument('cores', type=int, help='The number of cores for multiprocessing.') | |||
args = parser.parse_args() | |||
save_dir = f'data/RSNA-Kaggle_R{args.resolution}' | |||
makedirs(save_dir, exist_ok=True) | |||
kaggle_dataset_dir = args.kaggle_dataset_dir | |||
assert path.exists(kaggle_dataset_dir), f'{kaggle_dataset_dir} does not exist!' | |||
# reading rsna images names | |||
rsna_imgs_path = path.join(kaggle_dataset_dir, 'stage_2_train_images') | |||
assert path.exists(rsna_imgs_path), 'Make sure there is a folder named stage_2_train_images in the passed kaggle_directory!' | |||
imgs_names = np.asarray(listdir(rsna_imgs_path)) | |||
imgs_src = np.vectorize(lambda x: path.join(rsna_imgs_path, x))(imgs_names) | |||
imgs_dst = np.vectorize(lambda x: path.join(save_dir, x))(imgs_names) | |||
pool = Pool(args.cores) | |||
pool.starmap(resize_image, zip(imgs_src, imgs_dst, np.full((len(imgs_src),), args.resolution))) | |||
pool.close() | |||
@@ -0,0 +1,33 @@ | |||
import traceback | |||
from torchlap.experiments._entry_loader import get_entry, PhaseType | |||
from torchlap.data.data_loader import DataLoader, RunType | |||
from torchlap.experiments._model_loading import load_model | |||
from torchlap.utils.hooker import Hooker | |||
def main() -> None: | |||
ep = get_entry(PhaseType.EVAL) | |||
conf = ep.conf | |||
model = ep.model | |||
model = load_model(model, conf) | |||
model.eval() | |||
for test_group_info in conf.samples_dir.split(','): | |||
try: | |||
print('') | |||
print('>> Running evaluations for %s' % test_group_info, flush=True) | |||
test_data_loader = DataLoader(conf, test_group_info, RunType.TEST) | |||
evaluator = conf.evaluator_cls(model, test_data_loader, conf) | |||
with Hooker(*conf.hooks_by_phase[conf.phase_type]): | |||
evaluator.evaluate() | |||
except Exception: | |||
print('Problem in %s' % test_group_info, flush=True) | |||
track = traceback.format_exc() | |||
print(track, flush=True) | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,21 @@ | |||
from torchlap.experiments._entry_loader import get_entry, PhaseType | |||
from torchlap.experiments._model_loading import load_model | |||
from torchlap.interpreting.interpretation_evaluation_runner import InterpretingEvalRunner | |||
from torchlap.utils.hooker import Hooker | |||
def main() -> None: | |||
ep = get_entry(PhaseType.EVALINTERPRET) | |||
conf = ep.conf | |||
model = ep.model | |||
model = load_model(model, conf) | |||
model.eval() | |||
runner = InterpretingEvalRunner(conf, model) | |||
with Hooker(*conf.hooks_by_phase[conf.phase_type]): | |||
runner.evaluate() | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,21 @@ | |||
from torchlap.experiments._entry_loader import get_entry, PhaseType | |||
from torchlap.experiments._model_loading import load_model | |||
from torchlap.interpreting.interpreting_runner import InterpretingRunner | |||
from torchlap.utils.hooker import Hooker | |||
def main() -> None: | |||
ep = get_entry(PhaseType.INTERPRET) | |||
conf = ep.conf | |||
model = ep.model | |||
model = load_model(model, conf) | |||
model.eval() | |||
runner = InterpretingRunner(conf, model) | |||
with Hooker(*conf.hooks_by_phase[conf.phase_type]): | |||
runner.interpret() | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,8 @@ | |||
numpy | |||
pandas | |||
torch>=1.7.0 | |||
torchvision>=0.8.1 | |||
scikit-image>=0.14.3 | |||
scikit-learn>=0.19.1 | |||
captum>=0.3.0 | |||
opencv-python>=4.4.0.46 |
@@ -0,0 +1 @@ | |||
__version__ = 'master' |
@@ -0,0 +1,304 @@ | |||
from typing import TYPE_CHECKING, Union, Dict, Tuple, Type, List, Callable, Iterator, Optional | |||
import os | |||
import random | |||
import numpy as np | |||
from sys import stderr | |||
from enum import Enum | |||
import torch | |||
from ..interpreting.interpreter_maker import InterpretationType | |||
from ..data.batch_choosing.class_balanced_shuffled_sequential import ClassBalancedShuffledSequentialBatchChooser | |||
from ..data.batch_choosing.sequential import SequentialBatchChooser | |||
from ..utils.hooker import HookPair | |||
if TYPE_CHECKING: | |||
from ..model_evaluation.evaluator import Evaluator | |||
from ..data.batch_choosing.batch_chooser import BatchChooser | |||
from ..data.content_loaders.content_loader import ContentLoader | |||
class PhaseType(Enum): | |||
TRAIN = 'train' | |||
EVAL = 'eval' | |||
INTERPRET = 'interpret' | |||
EVALINTERPRET = 'evalinterpret' | |||
class BaseConfig: | |||
def __init__(self, | |||
try_name: str, try_num: int, data_separation: str, | |||
input_size: int, | |||
phase_type: PhaseType, | |||
content_loader_cls: Type['ContentLoader'], | |||
evaluator_cls: Type['Evaluator']) -> None: | |||
# -------------------- ESSENTIAL CONFIG -------------------- | |||
self.data_separation = data_separation | |||
self.try_name = try_name | |||
self.try_num = try_num | |||
self.phase_type = phase_type | |||
# classes | |||
self.evaluator_cls = evaluator_cls | |||
self.content_loader_cls = content_loader_cls | |||
self.train_batch_chooser_cls: Type['BatchChooser'] = ClassBalancedShuffledSequentialBatchChooser | |||
self.eval_batch_chooser_cls: Type['BatchChooser'] = SequentialBatchChooser | |||
# Setups | |||
self.input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = input_size | |||
self.batch_size = 64 | |||
self.random_seed: int = 17 | |||
self.samples_dir: Union[str, None] = None | |||
# Saving and loading directories! | |||
self.save_dir: Union[str, None] = None | |||
self.report_dir: Union[str, None] = None | |||
self.final_model_dir: Union[str, None] = None | |||
self.epoch: Union[str, None] = None | |||
# Device setting | |||
self.dev_name = None | |||
self.device = None | |||
# -------------------- ESSENTIAL CONFIG -------------------- | |||
# -------------------- TRAINING CONFIG -------------------- | |||
self.pretrained_model_file = None | |||
self.big_batch_size: int = 1 | |||
self.max_epochs: int = 10 | |||
self.iters_per_epoch: Union[int, None] = None | |||
self.val_iters_per_epoch: Union[int, None] = None | |||
# cleaning log while training | |||
self.keep_best_and_last_epochs_only = False | |||
# auto-stopping the train | |||
self.n_unsuccessful_epochs_to_stop: Union[None, int] = 100 | |||
self.different_augmentation_per_batch = True | |||
self.augmentations_dict: Union[Dict[str, torch.nn.Module], None] = None | |||
self.title_of_reference_metric_to_choose_best_epoch = 'Loss' | |||
self.operator_to_decide_on_improvement_of_val_reference_metric = '<=' | |||
# for dataloader | |||
self.mapped_labels_to_use: Union[List[int], None] = None # None means all, a list of ints= the ones to consider | |||
# hooking | |||
self.hooks_by_phase: Dict[PhaseType, List[HookPair]] = { | |||
phasetype: [] for phasetype in PhaseType | |||
} | |||
# -------------------- TRAINING CONFIG -------------------- | |||
# ------------------ OPTIMIZATION CONFIG ------------------ | |||
self.init_lr: float = 1e-4 | |||
self.lr_decay: float = 1e-6 | |||
self.freezing_regexes: Union[None, List[str]] = None | |||
self.optimizer_creator: Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer] = get_adam_creator(self.init_lr, self.lr_decay) | |||
# ------------------ OPTIMIZATION CONFIG ------------------ | |||
# ----------------- INTERPRETATION CONFIG ----------------- | |||
self.save_by_file_name = False | |||
self.interpretation_method = InterpretationType.GuidedBackprop | |||
self.class_label_for_interpretation: Union[int, None] = None | |||
self.interpret_predictions_vs_gt: bool = True | |||
# saving interpretations: | |||
self.skip_overlay = False | |||
self.skip_raw = False | |||
# FOR INTERPRETATION EVALUATION | |||
# A threshold to cut all interpretations lower than | |||
self.cut_threshold: float = 0 | |||
# whether the threshold is applied to un-normalized interpretation map | |||
self.global_threshold: bool = False | |||
# whether to use dynamic threshold | |||
self.dynamic_threshold: bool = False | |||
self.prediction_key_in_model_output_dict = 'positive_class_probability' | |||
# for when interpretation masks as continuous objects are compared with GT | |||
self.acceptable_min_intersection_threshold: float = 0.5 | |||
# the key for interpretation to be evaluated, None = the first one available | |||
self.interpretation_tag_to_evaluate = None | |||
self.n_interpretation_samples: Optional[int] = None | |||
# ----------------- INTERPRETATION CONFIG ----------------- | |||
def __repr__(self): | |||
""" | |||
Represent config as string | |||
""" | |||
return str(self.__dict__) | |||
def _update_based_on_args(self, args) -> None: | |||
""" Updates the values based on the received arguments from the input! """ | |||
args_dict = args.__dict__ | |||
for arg in args_dict: | |||
if arg in self.__dict__ and args_dict[arg] is not None: | |||
setattr(self, arg, args_dict[arg]) | |||
def update(self, args): | |||
""" | |||
updates the config from args | |||
""" | |||
# dirty piece of code for repetition | |||
self._update_based_on_args(args) | |||
self._update_based_on_args(args) | |||
self.dev_name, self.device = _get_device(args) | |||
self._set_random_seeds() | |||
self._set_save_dir() | |||
self._update_report_dir() | |||
def _set_random_seeds(self): | |||
# Set the seed for hash based operations in python | |||
os.environ['PYTHONHASHSEED'] = '0' | |||
torch.manual_seed(self.random_seed) | |||
torch.cuda.manual_seed_all(self.random_seed) | |||
np.random.seed(self.random_seed) | |||
random.seed(self.random_seed) | |||
if self.phase_type != PhaseType.TRAIN: | |||
torch.backends.cudnn.deterministic = True | |||
torch.backends.cudnn.benchmark = False | |||
else: | |||
torch.backends.cudnn.deterministic = False | |||
torch.backends.cudnn.benchmark = True | |||
# Set the numpy seed | |||
np.random.seed(self.random_seed) | |||
def _set_save_dir(self): | |||
if self.save_dir is None: | |||
self.save_dir = f"../Results/{self.try_num}_{self.try_name}" | |||
os.makedirs(self.save_dir, exist_ok=True) | |||
def _update_report_dir(self): | |||
if self.report_dir is None: | |||
self.report_dir = self.save_dir + '/' + self.phase_type.value | |||
def get_sample_group_specific_report_dir(self, samples_specification, extra_subdir=None) -> str: | |||
if extra_subdir is None: | |||
extra_subdir = '' | |||
else: | |||
extra_subdir = '/' + extra_subdir | |||
# Making sure of keeping version information! | |||
if self.final_model_dir is not None: | |||
report_dir = '%s/%s%s' % ( | |||
self.report_dir, | |||
self.final_model_dir[self.final_model_dir.rfind('/') + 1:self.final_model_dir.rfind('.')], | |||
extra_subdir) | |||
else: | |||
epoch = self.epoch | |||
if epoch is None: | |||
epoch = 'best' | |||
report_dir = '%s/epoch,%s%s' % ( | |||
self.report_dir, | |||
epoch, extra_subdir) | |||
# adding sample info | |||
if samples_specification not in ['train', 'test', 'val'] and not os.path.isfile(samples_specification): | |||
if ':' in samples_specification: | |||
base_info, _ = tuple(samples_specification.split(':')) | |||
else: | |||
base_info = samples_specification | |||
if not os.path.isdir(base_info): | |||
report_dir = report_dir + '/' + \ | |||
samples_specification.replace(':', '_').replace('../', '') | |||
else: | |||
report_dir = report_dir + '/' + samples_specification | |||
return self.get_unique_save_dir(report_dir) | |||
@staticmethod | |||
def get_unique_save_dir(sd): | |||
if os.path.exists(sd): | |||
report_num = 2 | |||
base_sd = sd | |||
while os.path.exists(sd): | |||
sd = base_sd + '_%d' % report_num | |||
report_num += 1 | |||
return sd | |||
def get_save_dir_for_sample(self, save_dir, sample_name): | |||
if self.save_by_file_name and '/' in sample_name: | |||
return save_dir + '/' + sample_name[sample_name.rfind('/') + 1:] | |||
else: | |||
return save_dir + '/' + sample_name | |||
def _get_device(args): | |||
""" cuda num, -1 means parallel, -2 means cpu, no cuda means cpu""" | |||
dev_name = args.device | |||
if 'cuda' in dev_name: | |||
gpu_num = 0 | |||
if ':' in dev_name: | |||
gpu_num = int(dev_name.split(':')[1]) | |||
if gpu_num >= torch.cuda.device_count(): | |||
print('No %s, using CPU instead!' % dev_name, file=stderr) | |||
dev_name = 'cpu' | |||
print('@ Running on CPU @', flush=True) | |||
if dev_name == 'cuda': | |||
dev_name = None | |||
the_device = torch.device('cuda') | |||
print('@ Running on all %d GPUs @' % torch.cuda.device_count(), flush=True) | |||
elif 'cuda:' in dev_name: | |||
the_device = torch.device(dev_name) | |||
print('@ Running on GPU:%d @' % int(dev_name.split(':')[1]), flush=True) | |||
else: | |||
dev_name = 'cpu' | |||
the_device = torch.device(dev_name) | |||
print('@ Running on CPU @', flush=True) | |||
return dev_name, the_device | |||
def get_adam_creator(init_lr: float = 1e-4, lr_decay: float = 1e-6)\ | |||
-> Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer]: | |||
def create_adam(params: Iterator[torch.nn.Parameter]) -> torch.optim.Optimizer: | |||
return torch.optim.Adam(params, lr=init_lr, weight_decay=lr_decay) | |||
return create_adam | |||
def get_sgd_creator(init_lr: float = 1e-4, lr_decay: float = 1e-6)\ | |||
-> Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer]: | |||
def create_sgd(params: Iterator[torch.nn.Parameter]) -> torch.optim.Optimizer: | |||
return torch.optim.SGD(params, lr=init_lr, weight_decay=lr_decay) | |||
return create_sgd | |||
@@ -0,0 +1,35 @@ | |||
from .base_config import BaseConfig, PhaseType | |||
from typing import Union | |||
from torch import nn | |||
from torchvision import transforms | |||
from ..data.content_loaders.celeba_loader import CelebALoader, CelebATag | |||
from ..model_evaluation.binary_evaluator import BinaryEvaluator | |||
class CelebAConfigs(BaseConfig): | |||
def __init__(self, | |||
try_name: str, try_num: int, input_size, phase_type: PhaseType) -> None: | |||
super().__init__(try_name, try_num, 'default', input_size, phase_type, CelebALoader, BinaryEvaluator) | |||
# replaced configs! | |||
self.batch_size = 64 | |||
self.iters_per_epoch = 200 | |||
self.tags = [tag.name for tag in CelebATag if tag.name.endswith('Tag')] | |||
self.main_tag: Union[CelebATag, None] = None | |||
self.data_root: str = 'data/celeba' | |||
self.dataset_metadata: str = 'dataset_metadata/celeba.tsv' | |||
self.title_of_reference_metric_to_choose_best_epoch = 'AvgSS' | |||
self.operator_to_decide_on_improvement_of_val_reference_metric = '>=' | |||
self.keep_best_and_last_epochs_only = True | |||
# augmentation | |||
self.augmentations_dict = { | |||
'x': nn.Sequential( | |||
transforms.RandomRotation(45), | |||
transforms.RandomAffine(0, shear=0.2, scale=(.8, 1.2)), | |||
transforms.RandomHorizontalFlip(), | |||
transforms.ColorJitter(brightness=.5, contrast=.5, saturation=.5), | |||
transforms.RandomPerspective()) | |||
} | |||
@@ -0,0 +1,17 @@ | |||
from .base_config import BaseConfig, PhaseType | |||
from ..data.content_loaders.imagenet_loader import ImagenetLoader | |||
from ..model_evaluation.multiclass_evaluator import MulticlassEvaluator | |||
class ImagenetConfigs(BaseConfig): | |||
def __init__(self, try_name: str, try_num: int, input_size, phase_type: PhaseType) -> None: | |||
super().__init__(try_name, try_num, 'data/imagenet/extracted', input_size, phase_type, ImagenetLoader, MulticlassEvaluator.standard_creator('categorical_probability', include_top5=True)) | |||
# replaced configs! | |||
self.batch_size = 64 | |||
self.big_batch_size = 4 | |||
self.max_epochs = 10 | |||
self.title_of_reference_metric_to_choose_best_epoch = 'Accuracy' | |||
self.operator_to_decide_on_improvement_of_val_reference_metric = '>=' | |||
self.keep_best_and_last_epochs_only = True |
@@ -0,0 +1,47 @@ | |||
from .base_config import BaseConfig, PhaseType | |||
from typing import Dict, List, Optional | |||
from torch import nn | |||
from torchvision import transforms | |||
from ..utils.random_brightness_augmentation import ClippedBrightnessAugment | |||
from ..data.content_loaders.rsna_loader import RSNALoader | |||
from ..model_evaluation.binary_evaluator import BinaryEvaluator | |||
class RSNAConfigs(BaseConfig): | |||
def __init__(self, try_name: str, try_num: int, input_size, phase_type: PhaseType) -> None: | |||
super().__init__(try_name, try_num, f'dataset_metadata/RSNA/DataSeparation_R{input_size}', input_size, phase_type, RSNALoader, BinaryEvaluator) | |||
# replaced configs! | |||
self.batch_size = 64 | |||
self.max_epochs = 200 | |||
self.label_map_dict: Dict[str, int] = { | |||
'healthy': 0, 'pneumonia': 1 | |||
} | |||
self.augmentations_dict = { | |||
'x': nn.Sequential( | |||
transforms.RandomRotation(45), | |||
transforms.RandomAffine(0, shear=0.4), | |||
transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4)), | |||
ClippedBrightnessAugment(0.5, 1.5, 0, 1)), | |||
'infection': nn.Sequential( | |||
transforms.RandomRotation(45), | |||
transforms.RandomAffine(0, shear=0.4), | |||
transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4))), | |||
'interpretations': nn.Sequential( | |||
transforms.RandomRotation(45), | |||
transforms.RandomAffine(0, shear=0.4), | |||
transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4))), | |||
} | |||
self.title_of_reference_metric_to_choose_best_epoch = 'BAcc' | |||
self.operator_to_decide_on_improvement_of_val_reference_metric = '>=' | |||
self.keep_best_and_last_epochs_only = True | |||
self.infection_map_size = self.input_size | |||
self.receptive_field_radii: List[int] = [] | |||
self.infections_bb_dir: Optional[str] = 'data/RSNA/infection_bounding_boxs.tsv' |
@@ -0,0 +1,86 @@ | |||
from abc import ABC, abstractmethod | |||
from functools import partial | |||
from typing import Dict, List, Tuple | |||
import torch | |||
from torch import nn | |||
from ..configs.base_config import BaseConfig, PhaseType | |||
class AuxLoss(ABC): | |||
def __init__(self, title: str, model: nn.Module, | |||
layer_by_name: Dict[str, nn.Module], | |||
loss_weight: float): | |||
"""This class is used for calculating loss based on the middle layers of the network. | |||
Args: | |||
title (str): A unique title, the required inputs would be kept by this title in the output dictionary. Also the loss term would be added as title_loss | |||
model (nn.Module): The base model | |||
named_layers_list (List[Tuple[str, torch.nn.Module]]): A list of pairs, layer name and the layer. The names used for one instance of this class should be unique. | |||
loss_weight (float): The weight of the loss in the final loss. The loss term would be automatically added to the model's loss in model_output dictionary by this factor. | |||
""" | |||
self._title = title | |||
self._model = model | |||
self._loss_weight = loss_weight | |||
self._layers_names = layer_by_name.keys() | |||
self._layers_outputs_by_name: Dict[str, torch.Tensor] = dict() | |||
self.hook_pairs = [ | |||
(module, partial(self._save_layer_val_hook, name)) | |||
for name, module in layer_by_name.items() | |||
] + [(model, self._add_loss)] | |||
def _save_layer_val_hook(self, layer_name, _, __, layer_out): | |||
self._layers_outputs_by_name[layer_name] = layer_out | |||
def configure(self, config: BaseConfig) -> None: | |||
"""Must be called to add the necessary configs and hooks so the loss will be calculated! | |||
Args: | |||
config (BaseConfig): The base config! | |||
""" | |||
for phase in PhaseType: | |||
config.hooks_by_phase[phase] += self.hook_pairs | |||
def _add_loss(self, _, __, model_out: Dict[str, torch.Tensor]): | |||
""" A hook for calculating and adding loss the the model's output dictionary | |||
Args: | |||
_ ([type]): Model, useless! (just the hook format) | |||
__ ([type]): Model input, also useless! (just the hook format) | |||
model_out (Dict[str, torch.Tensor]): The output dictionary of the model | |||
""" | |||
# popping values of tensors | |||
layers_values = [] | |||
for layer_name in self._layers_names: | |||
assert layer_name in self._layers_outputs_by_name, f'The output of {layer_name} was not found. It seems the forward function has not been called.' | |||
layers_values.append((layer_name, self._layers_outputs_by_name.pop(layer_name))) | |||
loss = self._calculate_loss(layers_values, model_out) | |||
# adding loss to output dictionary | |||
assert f'{self._title}_loss' not in model_out, f'Trying to add {self._title}_loss to model\'s output multiple times!' | |||
model_out[f'{self._title}_loss'] = loss.clone() | |||
# adding loss to the main loss | |||
model_out['loss'] = model_out.get('loss', torch.zeros([], dtype=loss.dtype, device=loss.device)) + self._loss_weight * loss | |||
return model_out | |||
@abstractmethod | |||
def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_output: Dict[str, torch.Tensor]) -> torch.Tensor: | |||
"""Calculates loss based on the output of the layers given in the constructor. | |||
Args: | |||
layers_values (List[Tuple[str, torch.Tensor]]): List of layers names and their output value | |||
model_output (Dict[str, torch.Tensor]): The output dictionary of the model, if something else is required for the loss calculation! or it can be used to add stuff into the output! | |||
Returns: | |||
torch.Tensor: The loss | |||
""" |
@@ -0,0 +1,123 @@ | |||
from typing import List, Tuple, Dict, Union | |||
import math | |||
import torch | |||
from torch.nn import functional as F | |||
from ..data.data_loader import DataLoader | |||
from ..data.dataloader_context import DataloaderContext | |||
from .aux_loss import AuxLoss | |||
class BBLoss(AuxLoss): | |||
def __init__( | |||
self, title: str, model: torch.nn.Module, | |||
layer_by_name: Dict[str, torch.nn.Module], | |||
loss_weight: float, | |||
inf_mask_kw: str, | |||
layers_loss_weights: Union[float, List[float]], | |||
bb_fill_ratio: float): | |||
super().__init__(title, model, layer_by_name, loss_weight) | |||
self._inf_mask_kw = inf_mask_kw | |||
self._bb_fill_ratio = bb_fill_ratio | |||
if isinstance(layers_loss_weights, list): | |||
assert len(layers_loss_weights) == len(layer_by_name), f'There should be as many weights as layers, one per layer. Expected {len(layer_by_name)} found {len(layers_loss_weights)}' | |||
self._layers_loss_weights = layers_loss_weights | |||
else: | |||
self._layers_loss_weights = [layers_loss_weights for _ in range(len(layer_by_name))] | |||
def _calculate_loss( | |||
self, | |||
layers_values: List[Tuple[str, torch.Tensor]], | |||
model_output: Dict[str, torch.Tensor]) -> torch.Tensor: | |||
# gathering extra information | |||
dl: DataLoader = DataloaderContext.instance.dataloader | |||
xray_y = torch.from_numpy(dl.get_current_batch_samples_labels()).to(dl._device) | |||
infection_mask = dl.get_current_batch_data(keyword=self._inf_mask_kw).to(dl._device) | |||
xray_y = xray_y.bool() | |||
infection_mask = infection_mask[xray_y] | |||
PB, _, H, W = infection_mask.shape | |||
loss = torch.zeros([], dtype=torch.float32, | |||
device=xray_y.device, requires_grad=True) | |||
for li, (ln, lo) in enumerate(layers_values): | |||
out = F.interpolate(lo, | |||
size=(H, W), | |||
mode='bilinear', | |||
align_corners=True) | |||
out = out.flatten(start_dim=1) | |||
neg_out = out[~xray_y] | |||
pos_out = out[xray_y] | |||
NB = neg_out.shape[0] | |||
infection_mask = infection_mask.flatten(start_dim=1) | |||
pos_out_infection = torch.where(infection_mask < 1, | |||
torch.zeros_like(pos_out, requires_grad=True), | |||
pos_out) | |||
neg_losses = [] | |||
pos_losses = [] | |||
if PB > 0: | |||
bb_area = infection_mask.sum(dim=-1) # B | |||
k = (self._bb_fill_ratio * bb_area.quantile(q=0.5))\ | |||
.ceil()\ | |||
.long() | |||
# Top positive in bb pixels must be positive | |||
pos_infection_topk, pos_infection_indices = pos_out_infection\ | |||
.topk(k, dim=-1, sorted=False) | |||
pos_infection_batch_index = torch.arange(PB)\ | |||
.to(pos_infection_indices.device)\ | |||
.unsqueeze(1)\ | |||
.repeat_interleave(k, dim=1) | |||
pos_weight = infection_mask[pos_infection_batch_index, pos_infection_indices].floor() # make non ones to be zero | |||
pos_losses.append(F.binary_cross_entropy( | |||
pos_infection_topk, | |||
torch.ones_like(pos_infection_topk), | |||
pos_weight | |||
)) | |||
if (infection_mask == 0).any(): | |||
# All positive out bb pixels must be negative | |||
pos_out_non_infection = pos_out[infection_mask == 0] | |||
neg_losses.append(F.binary_cross_entropy( | |||
pos_out_non_infection, | |||
torch.zeros_like(pos_out_non_infection) | |||
)) | |||
if NB > 0: | |||
if PB > 0: | |||
# Top negative pixels must be negative | |||
neg_k = int(math.ceil(PB * k * 1.0 / NB)) | |||
neg_out_topk = neg_out.topk(neg_k, dim=-1, sorted=False)[0] | |||
neg_losses.append(F.binary_cross_entropy( | |||
neg_out_topk, | |||
torch.zeros_like(neg_out_topk) | |||
)) | |||
else: | |||
# All negative pixels must be negative | |||
neg_losses.append(F.binary_cross_entropy( | |||
neg_out, | |||
torch.zeros_like(neg_out) | |||
)) | |||
losses = [] | |||
if len(neg_losses) > 0: | |||
losses.append(torch.stack(neg_losses).mean()) | |||
if len(pos_losses) > 0: | |||
losses.append(torch.stack(pos_losses).mean()) | |||
l_loss = torch.stack(losses).mean() | |||
model_output[f'{self._title}_{ln}_loss'] = l_loss | |||
loss = loss + self._layers_loss_weights[li] * l_loss | |||
return loss |
@@ -0,0 +1,152 @@ | |||
from typing import Dict, Union, List, Tuple | |||
import numpy as np | |||
import torch | |||
from torch.nn import functional as F | |||
from .aux_loss import AuxLoss | |||
from ..data.dataloader_context import DataloaderContext | |||
from ..data.data_loader import DataLoader | |||
class PoolConcordanceLossCalculator(AuxLoss): | |||
def __init__(self, | |||
title: str, model: torch.nn.Module, | |||
layer_by_name: Dict[str, torch.nn.Module], | |||
loss_weight: float, weights: Union[float, List[float]], | |||
diff_thresholds: Union[float, List[float]], | |||
labels_by_channel: Dict[int, List[int]]): | |||
super().__init__(title, model, layer_by_name, loss_weight) | |||
# at least two pools are needed | |||
assert len(layer_by_name) > 1, 'At least two pool layers are required to calculate this loss!' | |||
self._title = title | |||
self._model = model | |||
self._loss_prefix = f'{title}_JS_loss_' | |||
self._labels_by_channel = labels_by_channel | |||
self._pool_names = list(layer_by_name.keys()) | |||
self._weights = weights if isinstance(weights, list) else \ | |||
[weights for _ in range(len(layer_by_name) - 1)] | |||
if isinstance(weights, list): | |||
assert len(weights) == len(layer_by_name) - 1, 'Weights must have a length of pool_layers -1' | |||
self._diff_thresholds = diff_thresholds if isinstance(diff_thresholds, list) else \ | |||
[diff_thresholds for _ in range(len(layer_by_name) - 1)] | |||
if isinstance(diff_thresholds, list): | |||
assert len(diff_thresholds) == len(layer_by_name) - 1, 'Diff thresholds must have a length of pool_layers -1' | |||
def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_output: Dict[str, torch.Tensor]) -> torch.Tensor: | |||
# reading pools from the output and deleting them | |||
pool_vals = [x[1] for x in layers_values] | |||
# getting samples labels | |||
dl: DataLoader = DataloaderContext.instance.dataloader | |||
labels = dl.get_current_batch_samples_labels() | |||
labels = torch.from_numpy(labels).to(dl._device) | |||
# sorting by shape from biggest to smallest | |||
p_shapes = np.asarray([p.shape[-1] for p in pool_vals]) | |||
sorted_inds = np.argsort(-1 * p_shapes) | |||
pool_vals = [pool_vals[i] for i in sorted_inds] | |||
pool_names = [self._pool_names[i] for i in sorted_inds] | |||
loss = torch.zeros([], dtype=torch.float32, device=labels.device, requires_grad=True) | |||
# for each pair of pools, calculate loss! | |||
for i in range(len(pool_names) - 1): | |||
p_loss = self._cal_pool_pair_loss(pool_vals[i], pool_vals[i + 1], self._diff_thresholds[i], labels) | |||
assert (f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}') not in model_output, 'Trying to add ' + (f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}') + ' to model output multiple times' | |||
model_output[f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}'] = p_loss.clone() | |||
loss = loss + self._weights[i] * p_loss | |||
return loss | |||
def _cal_pool_pair_loss(self, p1: torch.Tensor, p2: torch.Tensor, diff_threshold: float, labels: torch.Tensor) -> torch.Tensor: | |||
# down-sampling by max-pool till reaching the same shape as p2! | |||
if p1.shape[-1] > p2.shape[-1]: | |||
p1 = F.adaptive_max_pool2d(p1, p2.shape[-2:]) | |||
# jensen shannon loss, for each channel -> in the related class! | |||
loss = torch.tensor(0.0, requires_grad=True, device=p1.device) | |||
for channel, r_labels in self._labels_by_channel.items(): | |||
on_mask = self._get_inclusion_mask(labels, r_labels) | |||
ip1 = p1[on_mask, channel, ...] | |||
ip2 = p2[on_mask, channel, ...] | |||
if torch.numel(ip1) > 0: | |||
if ip1.shape != ip2.shape: | |||
print(f'Problem in shape of concordance loss in {self._title}') | |||
loss = loss + jensen_shannon_divergence( | |||
ip1[:, None, ...], ip2[:, None, ...], diff_threshold) | |||
return loss | |||
def _get_inclusion_mask( | |||
self, | |||
samples_labels: torch.Tensor, | |||
desired_labels: List[int]) -> torch.Tensor: | |||
with torch.no_grad(): | |||
inclusion_mask = torch.stack([samples_labels == l for l in desired_labels], dim=0) | |||
aggregation = torch.sum(inclusion_mask.float(), dim=0) | |||
return torch.greater(aggregation, 0) | |||
def jensen_shannon_divergence(p1: torch.Tensor, p2: torch.Tensor, diff_threshold: float) -> torch.Tensor: | |||
""" | |||
Calculates the jensen shannon loss between two distributions p1 and p2 | |||
Args: | |||
p1 (torch.Tensor): A tensor of shape B C ... in range [0, 1] that is the probabilities for each neuron in the 1st distribution. | |||
p2 (torch.Tensor): A tensor of the same shape as src, in range [0, 1] that is the probabilities for each neuron in the 2nd distribution. | |||
diff_threshold (float): Threshold between p1 and p2 to decide whether to consider one pixel in loss | |||
Returns: | |||
torch.Tensor: The calculated loss. | |||
""" | |||
assert p1.shape == p2.shape, 'The tensors must have the same shape' | |||
assert 0 <= diff_threshold < 1, 'The difference threshold should be in range [0, 1)' | |||
# Reshaping tensors | |||
p1 = torch.transpose(p1, 0, 1).flatten(1).transpose(0, 1) | |||
p2 = torch.transpose(p2, 0, 1).flatten(1).transpose(0, 1) | |||
# if binary, append class 0! | |||
if p1.shape[1] == 1: | |||
p1 = torch.cat([1 - p1, p1], dim=1) | |||
p2 = torch.cat([1 - p2, p2], dim=1) | |||
with torch.no_grad(): | |||
mask = torch.abs(p1 - p2).detach() >= diff_threshold | |||
mask = mask.max(dim=-1)[0] | |||
# to make sure error does not result in log(negative)! | |||
lp1 = torch.log(torch.maximum(p1 + 1e-4, 1e-6 + torch.zeros_like(p1))) | |||
lp2 = torch.log(torch.maximum(p2 + 1e-4, 1e-6 + torch.zeros_like(p2))) | |||
loss = 0.5 * ( | |||
_smean(mask, torch.sum(p1 * (lp1 - lp2), dim=-1)) + | |||
_smean(mask, torch.sum(p2 * (lp2 - lp1), dim=-1)) | |||
) | |||
return loss | |||
def _smean(mask: torch.Tensor, val: torch.Tensor) -> torch.Tensor: | |||
if torch.any(mask.bool()): | |||
return torch.mean(val[mask]) | |||
else: | |||
return torch.tensor(0.0, requires_grad=True, device=mask.device) |
@@ -0,0 +1,221 @@ | |||
from typing import Dict, List, Tuple, Optional | |||
import numpy as np | |||
import torch | |||
from torch import nn | |||
from torch.nn import functional as F | |||
from ..data.data_loader import DataLoader | |||
from ..data.dataloader_context import DataloaderContext | |||
from .aux_loss import AuxLoss | |||
class DiscriminativeWeaklySupervisedLoss(AuxLoss): | |||
def __init__( | |||
self, title: str, model: nn.Module, | |||
att_score_layer: nn.Module, | |||
loss_weight: float, | |||
neg_ratio: float, pos_ratio_range: Tuple[float, float], | |||
on_labels_by_channel: Dict[int, List[int]], | |||
discr_score_layer: nn.Module = None, | |||
w_attention_in_ordering: float = 1, | |||
w_discr_in_ordering: float = 1): | |||
""" | |||
Calculates binary weakly supervised score for an attention layer | |||
with extra discrimination head. | |||
Args: | |||
title (str): The title of the loss, must be unique! | |||
model (Model): the base model, so the output can be modified | |||
att_score_layer (torch.nn.Module): A layer that gives attention score (B C ...) | |||
loss_weight (float): The weight of the loss | |||
neg_ratio (float, optional): top ratio to apply loss for negative samples. Defaults to 0.1. | |||
pos_ratio_range (Tuple[float, float], optional): low and top ratios to apply loss for positive samples. Defaults to (0.033, 0.278). Calculated by distribution of positive bounding boxes. | |||
on_labels_by_channel (Dict[int, List[int]]): The dictionary that specifies the samples related to which labels should be on in each channel. | |||
w_attention_in_ordering (float): The weight of the attention score used in ordering the pixels. | |||
w_discr_in_ordering (float): The weight of the reference score used in ordering the pixels. | |||
discr_score_layer (torch.nn.Module): A layer that gives discriminative score (B C ...) | |||
""" | |||
layers = dict( | |||
att=att_score_layer, | |||
) | |||
if discr_score_layer is not None: | |||
layers['discr'] = discr_score_layer | |||
super().__init__(title, model, layers, loss_weight) | |||
self._has_discr = discr_score_layer is not None | |||
self._neg_ratio = neg_ratio | |||
self._pos_ratio_range = pos_ratio_range | |||
self._on_labels_by_channel = on_labels_by_channel | |||
self._w_attention_in_ordering = w_attention_in_ordering | |||
self._w_discr_in_ordering = w_discr_in_ordering | |||
def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_out: Dict[str, torch.Tensor]) -> torch.Tensor: | |||
discriminative_scores = None | |||
probabilities = None | |||
for ln, lv in layers_values: | |||
if ln == 'att': | |||
probabilities = lv | |||
else: | |||
discriminative_scores = lv | |||
dl: DataLoader = DataloaderContext.instance.dataloader | |||
labels = dl.get_current_batch_samples_labels() | |||
if discriminative_scores is not None: | |||
discrimination_loss = self._calculate_discrimination_loss( | |||
labels, discriminative_scores) | |||
assert (self._title + '_discr_loss') not in model_out, 'Trying to add ' + (self._title + '_discr_loss') + ' to model output multiple times' | |||
model_out[self._title + '_discr_loss'] = discrimination_loss.clone() | |||
else: | |||
discrimination_loss = torch.zeros([], requires_grad=True, device=labels.device) | |||
attention_loss = self._calculate_attention_loss( | |||
labels, probabilities, (discriminative_scores if discriminative_scores is not None else probabilities)) | |||
assert (self._title + '_ws_loss') not in model_out, 'Trying to add ' + (self._title + '_ws_loss') + ' to model output multiple times' | |||
model_out[self._title + '_ws_loss'] = attention_loss.clone() | |||
loss = self._loss_weight * (discrimination_loss + attention_loss) | |||
return loss | |||
def _calculate_discrimination_loss( | |||
self, | |||
samples_labels: torch.Tensor, | |||
discrimination_scores: torch.Tensor) -> torch.Tensor: | |||
losses = [] | |||
for channel, labels in self._on_labels_by_channel.items(): | |||
on_mask = self._get_inclusion_mask(samples_labels, labels, discrimination_scores.device) | |||
on_ps = discrimination_scores[on_mask, channel, ...] | |||
off_ps = discrimination_scores[torch.logical_not(on_mask), channel, ...] | |||
if torch.numel(on_ps) > 0: | |||
losses.append(self._cal_loss(1, True, on_ps)) | |||
if torch.numel(off_ps) > 0: | |||
losses.append(self._cal_loss(1, False, off_ps)) | |||
return torch.mean(torch.stack(losses)) | |||
def _calculate_attention_loss( | |||
self, | |||
samples_labels: torch.Tensor, | |||
attention_scores: torch.Tensor, | |||
discrimination_scores: torch.Tensor) -> torch.Tensor: | |||
losses = [] | |||
for channel, labels in self._on_labels_by_channel.items(): | |||
on_mask = self._get_inclusion_mask(samples_labels, labels, discrimination_scores.device) | |||
on_atts = attention_scores[on_mask, channel, ...] | |||
on_discr = discrimination_scores[on_mask, channel, ...].detach() | |||
off_atts = attention_scores[torch.logical_not(on_mask), channel, ...] | |||
off_discr = discrimination_scores[torch.logical_not(on_mask), channel, ...].detach() | |||
neg_losses = [] | |||
pos_losses = [] | |||
# loss injection to the model | |||
if torch.numel(off_atts) > 0 and self._neg_ratio > 0: | |||
neg_losses.append(self._cal_loss( | |||
self._neg_ratio, False, off_atts, off_discr, largest=True | |||
)) | |||
if torch.numel(on_atts) > 0: | |||
# Calculate positive top k to be positive | |||
if self._pos_ratio_range[0] > 0: | |||
pos_losses.append(self._cal_loss( | |||
self._pos_ratio_range[0], True, on_atts, on_discr, True | |||
)) | |||
# Calculate positive bottom k to be negative | |||
if self._pos_ratio_range[1] < 1: | |||
neg_losses.append(self._cal_loss( | |||
1 - self._pos_ratio_range[1], False, on_atts, on_discr, False | |||
)) | |||
if len(neg_losses) > 0: | |||
losses.append(torch.stack(neg_losses).mean()) | |||
if len(pos_losses) > 0: | |||
losses.append(torch.stack(pos_losses).mean()) | |||
return torch.stack(losses).mean() | |||
def _get_inclusion_mask( | |||
self, | |||
samples_labels: np.ndarray, | |||
desired_labels: List[int], device: torch.device) -> torch.Tensor: | |||
with torch.no_grad(): | |||
samples_labels = torch.from_numpy(samples_labels).to(device) | |||
inclusion_mask = torch.stack([samples_labels == l for l in desired_labels], dim=0) | |||
aggregation = torch.sum(inclusion_mask.float(), dim=0) | |||
return torch.greater(aggregation, 0) | |||
def _cal_loss( | |||
self, | |||
ratio: float, positive_label: bool, | |||
att_scores: torch.Tensor, | |||
discr_scores: Optional[torch.Tensor] = None, | |||
largest: bool = True): | |||
if ratio == 1: | |||
ps = att_scores | |||
else: | |||
k = np.ceil( | |||
ratio * att_scores.shape[-1] * att_scores.shape[-2]).astype(int) | |||
ps = self._get_topk(att_scores, discr_scores, k, largest=largest) | |||
ps = ps.flatten() | |||
if positive_label: | |||
gt = torch.ones_like(ps) | |||
else: | |||
gt = torch.zeros_like(ps) | |||
return F.binary_cross_entropy(ps, gt) | |||
def _get_topk(self, att_scores: torch.Tensor, discr_scores: torch.Tensor, | |||
k: int, dim=-1, largest=True, return_indices=False) -> torch.Tensor: | |||
scores = self._pixels_scores(att_scores, discr_scores) | |||
b = att_scores.shape[0] | |||
top_inds = (scores.flatten(1)).topk(k, dim=dim, largest=largest, sorted=False).indices | |||
# B K | |||
ret_val = att_scores.flatten(1)[ | |||
torch.repeat_interleave( | |||
torch.arange(b, device=att_scores.device), k).reshape(b, k), | |||
top_inds] # B K | |||
if not return_indices: | |||
return ret_val | |||
else: | |||
return ret_val, top_inds | |||
def _pixels_scores(self, attention_scores: torch.Tensor, discr_scores: torch.Tensor) -> torch.Tensor: | |||
return self._w_attention_in_ordering * attention_scores + self._w_discr_in_ordering * discr_scores | |||
@@ -0,0 +1,61 @@ | |||
""" | |||
The BatchChooser | |||
""" | |||
from abc import abstractmethod | |||
from typing import List, TYPE_CHECKING | |||
import numpy as np | |||
if TYPE_CHECKING: | |||
from ..data_loader import DataLoader | |||
class BatchChooser: | |||
""" | |||
The BatchChooser | |||
""" | |||
def __init__(self, data_loader: 'DataLoader', batch_size: int): | |||
""" Receives as input a data_loader which contains the information about the samples and the desired batch size. """ | |||
self.data_loader = data_loader | |||
self.batch_size = batch_size | |||
self.current_batch_sample_indices: np.ndarray = np.asarray([], dtype=int) | |||
self.completed_iteration = False | |||
self.class_samples_indices = self.data_loader.get_class_sample_indices() | |||
self._classes_labels: List[int] = list(self.class_samples_indices.keys()) | |||
def prepare_next_batch(self): | |||
""" | |||
Prepares the next batch based on its strategy | |||
""" | |||
if self.finished_iteration(): | |||
self.reset() | |||
self.current_batch_sample_indices = \ | |||
self.get_next_batch_sample_indices().astype(int) | |||
if len(self.current_batch_sample_indices) == 0: | |||
self.completed_iteration = True | |||
return | |||
@abstractmethod | |||
def get_next_batch_sample_indices(self) -> np.ndarray: | |||
pass | |||
def get_current_batch_sample_indices(self) -> np.ndarray: | |||
""" Returns a list of indices of the samples chosen for the current batch. """ | |||
return self.current_batch_sample_indices | |||
def finished_iteration(self): | |||
""" Returns True if iteration is finished over all the slices of all the samples, False otherwise""" | |||
return self.completed_iteration | |||
def reset(self): | |||
""" Resets the sample iterator. """ | |||
self.completed_iteration = False | |||
self.current_batch_sample_indices = np.asarray([], dtype=int) | |||
def get_current_batch_size(self) -> int: | |||
return len(self.current_batch_sample_indices) |
@@ -0,0 +1,86 @@ | |||
from typing import TYPE_CHECKING | |||
from .batch_chooser import BatchChooser | |||
import numpy as np | |||
if TYPE_CHECKING: | |||
from ..data_loader import DataLoader | |||
class ClassBalancedShuffledSequentialBatchChooser(BatchChooser): | |||
def __init__(self, data_loader: 'DataLoader', batch_size: int): | |||
super().__init__(data_loader, batch_size) | |||
self._cursor = 0 | |||
# just copying to prevent any dependency | |||
for k, v in self.class_samples_indices.items(): | |||
self.class_samples_indices[k] = np.copy(v) | |||
np.random.shuffle(self.class_samples_indices[k]) | |||
self.classes_cursors = np.zeros(len(self.class_samples_indices), dtype=int) | |||
def get_next_batch_sample_indices(self): | |||
""" Returns a list containing the indices of the samples chosen for the next batch.""" | |||
n_samples_per_class = self._get_n_samples_per_class() | |||
def sample_for_each_class(class_index): | |||
class_samples = self.class_samples_indices[ | |||
self._classes_labels[class_index]] | |||
if self.classes_cursors[class_index] + \ | |||
n_samples_per_class[class_index] <= \ | |||
len(class_samples): | |||
ret_val = np.copy(class_samples[ | |||
self.classes_cursors[class_index]: | |||
self.classes_cursors[class_index] + n_samples_per_class[class_index] | |||
]) | |||
self.classes_cursors[class_index] += n_samples_per_class[class_index] | |||
else: | |||
ret_val = np.copy(class_samples)[ | |||
self.classes_cursors[class_index]:] | |||
np.random.shuffle(class_samples) | |||
self.classes_cursors[class_index] = \ | |||
n_samples_per_class[class_index] - len(ret_val) | |||
ret_val = np.concatenate(( | |||
ret_val, | |||
np.copy(class_samples | |||
[: n_samples_per_class[class_index] - len(ret_val)] | |||
)), axis=0) | |||
if self.classes_cursors[class_index] == len(class_samples): | |||
np.random.shuffle(class_samples) | |||
self.classes_cursors[class_index] = 0 | |||
return ret_val | |||
chosen_sample_indices = np.concatenate(tuple([ | |||
sample_for_each_class(ci) for ci in range(len(self._classes_labels)) | |||
]), axis=0) | |||
return chosen_sample_indices | |||
def _get_n_samples_per_class(self): | |||
# determining the number of samples per class | |||
n_samples_per_class = np.full(len(self.class_samples_indices), | |||
int(self.batch_size // len(self.class_samples_indices))) | |||
remaining_samles_cnt = self.batch_size - np.sum(n_samples_per_class) | |||
if remaining_samles_cnt: | |||
if self._cursor + remaining_samles_cnt >= len(self._classes_labels): | |||
n_samples_per_class[self._cursor: len(self._classes_labels)] += 1 | |||
remaining_samles_cnt -= (len(self._classes_labels) - self._cursor) | |||
self._cursor = 0 | |||
if remaining_samles_cnt > 0: | |||
n_samples_per_class[self._cursor: self._cursor + remaining_samles_cnt] += 1 | |||
self._cursor += remaining_samles_cnt | |||
return n_samples_per_class | |||
@@ -0,0 +1,38 @@ | |||
from typing import TYPE_CHECKING | |||
from .batch_chooser import BatchChooser | |||
import numpy as np | |||
if TYPE_CHECKING: | |||
from ..data_loader import DataLoader | |||
class SequentialBatchChooser(BatchChooser): | |||
def __init__(self, data_loader: 'DataLoader', batch_size: int): | |||
super().__init__(data_loader, batch_size) | |||
self.cursor = 0 | |||
self.desired_samples_inds = None | |||
def reset(self): | |||
super(SequentialBatchChooser, self).reset() | |||
self.cursor = 0 | |||
self.desired_samples_inds = None | |||
def get_next_batch_sample_indices(self): | |||
""" Returns a list containing the indices of the samples chosen for the next batch.""" | |||
if self.desired_samples_inds is None: | |||
self.desired_samples_inds = \ | |||
np.arange(self.data_loader.get_number_of_samples()) | |||
next_cursor = min( | |||
len(self.desired_samples_inds), | |||
self.cursor + self.batch_size) | |||
next_sample_inds = self.desired_samples_inds[ | |||
self.cursor:next_cursor] | |||
self.cursor = next_cursor | |||
return next_sample_inds | |||
@@ -0,0 +1,154 @@ | |||
from enum import Enum | |||
import os | |||
import glob | |||
from typing import Callable, TYPE_CHECKING, Union | |||
import imageio | |||
import cv2 | |||
import numpy as np | |||
import pandas as pd | |||
from .content_loader import ContentLoader | |||
if TYPE_CHECKING: | |||
from ...configs.celeba_configs import CelebAConfigs | |||
class CelebATag(Enum): | |||
FiveOClockShadowTag = '5_o_Clock_Shadow' | |||
ArchedEyebrowsTag = 'Arched_Eyebrows' | |||
AttractiveTag = 'Attractive' | |||
BagsUnderEyesTag = 'Bags_Under_Eyes' | |||
BaldTag = 'Bald' | |||
BangsTag = 'Bangs' | |||
BigLipsTag = 'Big_Lips' | |||
BigNoseTag = 'Big_Nose' | |||
BlackHairTag = 'Black_Hair' | |||
BlondHairTag = 'Blond_Hair' | |||
BlurryTag = 'Blurry' | |||
BrownHairTag = 'Brown_Hair' | |||
BushyEyebrowsTag = 'Bushy_Eyebrows' | |||
ChubbyTag = 'Chubby' | |||
DoubleChinTag = 'Double_Chin' | |||
EyeglassesTag = 'Eyeglasses' | |||
GoateeTag = 'Goatee' | |||
GrayHairTag = 'Gray_Hair' | |||
HighCheekbonesTag = 'High_Cheekbones' | |||
MaleTag = 'Male' | |||
MouthSlightlyOpenTag = 'Mouth_Slightly_Open' | |||
MustacheTag = 'Mustache' | |||
NarrowEyesTag = 'Narrow_Eyes' | |||
NoBeardTag = 'No_Beard' | |||
OvalFaceTag = 'Oval_Face' | |||
PaleSkinTag = 'Pale_Skin' | |||
PointyNoseTag = 'Pointy_Nose' | |||
RecedingHairlineTag = 'Receding_Hairline' | |||
RosyCheeksTag = 'Rosy_Cheeks' | |||
SideburnsTag = 'Sideburns' | |||
SmilingTag = 'Smiling' | |||
StraightHairTag = 'Straight_Hair' | |||
WavyHairTag = 'Wavy_Hair' | |||
WearingEarringsTag = 'Wearing_Earrings' | |||
WearingHatTag = 'Wearing_Hat' | |||
WearingLipstickTag = 'Wearing_Lipstick' | |||
WearingNecklaceTag = 'Wearing_Necklace' | |||
WearingNecktieTag = 'Wearing_Necktie' | |||
YoungTag = 'Young' | |||
BoundingBoxX = 'x_1' | |||
BoundingBoxY = 'y_1' | |||
BoundingBoxW = 'width' | |||
BoundingBoxH = 'height' | |||
Partition = 'partition' | |||
LeftEyeX = 'lefteye_x' | |||
LeftEyeY = 'lefteye_y' | |||
RightEyeX = 'righteye_x' | |||
RightEyeY = 'righteye_y' | |||
NoseX = 'nose_x' | |||
NoseY = 'nose_y' | |||
LeftMouthX = 'leftmouth_x' | |||
LeftMouthY = 'leftmouth_y' | |||
RightMouthX = 'rightmouth_x' | |||
RightMouthY = 'rightmouth_y' | |||
class SampleField(Enum): | |||
ImageId = 'image_id' | |||
Specification = 'specification' | |||
class CelebALoader(ContentLoader): | |||
def __init__(self, conf: 'CelebAConfigs', data_specification: str): | |||
""" read all directories for scans and annotations, which are split by DataSplitter.py | |||
And then, keeps only those samples which are used for 'usage' """ | |||
super().__init__(conf, data_specification) | |||
self.conf = conf | |||
self._datasep = True | |||
self._samples_metadata = self._load_metadata(data_specification) | |||
self._data_root = conf.data_root | |||
self._warning_count = 0 | |||
def _load_metadata(self, data_specification: str) -> pd.DataFrame: | |||
if '/' in data_specification: | |||
self._datasep = False | |||
return pd.DataFrame({ | |||
SampleField.ImageId.value: glob.glob(os.path.join(data_specification, '*.jpg')) | |||
}) | |||
metadata: pd.DataFrame = pd.read_csv(self.conf.dataset_metadata, sep='\t') | |||
metadata = metadata[metadata[SampleField.Specification.value] == data_specification] | |||
metadata = metadata.drop(SampleField.Specification.value, axis=1) | |||
metadata = metadata.reset_index().drop('index', axis=1) | |||
return metadata | |||
def get_samples_names(self): | |||
return self._samples_metadata[SampleField.ImageId.value].values | |||
def get_samples_labels(self): | |||
r""" Dummy. Because we have multiple labels """ | |||
return np.ones((len(self._samples_metadata),))\ | |||
if self.conf.main_tag is None\ | |||
else self._samples_metadata[self.conf.main_tag.value].values | |||
def drop_samples(self, drop_mask: np.ndarray) -> None: | |||
self._samples_metadata = self._samples_metadata[np.logical_not(drop_mask)] | |||
def get_placeholder_name_to_fill_function_dict(self): | |||
""" Returns a dictionary of the placeholders' names (the ones this content loader supports) | |||
to the functions used for filling them. The functions must receive as input data_loader, | |||
which is an object of class data_loader that contains information about the current batch | |||
(e.g. the indices of the samples, or if the sample has many elements the indices of the chosen | |||
elements) and return an array per placeholder name according to the receives batch information. | |||
IMPORTANT: Better to use a fixed prefix in the names of the placeholders to become clear which content loader | |||
they belong to! Some sort of having a mark :))!""" | |||
return { | |||
'x': self._get_x, | |||
**{ | |||
tag.name: self._generate_tag_getter(tag) for tag in CelebATag | |||
} | |||
} | |||
def _get_x(self, samples_inds: np.ndarray)\ | |||
-> np.ndarray: | |||
images = tuple(self._read_image(index) for index in samples_inds) | |||
images = np.stack(images, axis=0) # B I I 3 | |||
images = images.transpose((0, 3, 1, 2)) # B 3 I I | |||
images = images.astype(float) / 255. | |||
return images | |||
def _read_image(self, sample_index: int) -> np.ndarray: | |||
filepath = self._samples_metadata.iloc[sample_index][SampleField.ImageId.value] | |||
if self._datasep: | |||
filepath = os.path.join(self._data_root, filepath) | |||
image = imageio.imread(filepath) # H W 3 | |||
image = cv2.resize(image, (self.conf.input_size, self.conf.input_size)) # I I 3 | |||
return image | |||
def _generate_tag_getter(self, tag: CelebATag) -> Callable[[np.ndarray, Union[None, np.ndarray]], np.ndarray]: | |||
def get_tag(samples_inds: np.ndarray) -> np.ndarray: | |||
return self._samples_metadata.iloc[samples_inds][tag.value].values if self._datasep else np.zeros(len(samples_inds)) | |||
return get_tag |
@@ -0,0 +1,94 @@ | |||
""" | |||
A class for loading one content type needed for the run. | |||
""" | |||
from abc import abstractmethod, ABC | |||
from typing import TYPE_CHECKING, Dict, Callable, Union, List | |||
import numpy as np | |||
if TYPE_CHECKING: | |||
from .. import BaseConfig | |||
LoadFunction = Callable[[np.ndarray, Union[None, np.ndarray]], np.ndarray] | |||
class ContentLoader(ABC): | |||
""" A class for loading one content type needed for the run. """ | |||
def __init__(self, conf: 'BaseConfig', data_specification: str): | |||
""" Receives as input conf which is a dictionary containing configurations. | |||
This dictionary can also be used to pass fixed addresses for easing the usage! | |||
prefix_name is the str which all the variables that must be filled with this class have | |||
this prefix in their names so it would be clear that this class should fill it. | |||
data_specification is the string that specifies data and where it is! e.g. train, test, val""" | |||
self.conf = conf | |||
self.data_specification = data_specification | |||
self._fill_func_by_var_name: Union[None, Dict[str, LoadFunction]] = None | |||
@abstractmethod | |||
def get_samples_names(self): | |||
""" Returns a list containing names of all the samples of the content loader, | |||
each sample must owns a unique ID, and this function returns all this IDs. | |||
The order of the list must always be the same during one run. | |||
For example, this function can return an ID column of a table for TableLoader | |||
or the dir of images as ID for ImageLoader""" | |||
@abstractmethod | |||
def get_samples_labels(self): | |||
""" Returns list of labels of the whole samples. | |||
The order of the list must always be the same during one run.""" | |||
@abstractmethod | |||
def get_placeholder_name_to_fill_function_dict(self) -> Dict[str, LoadFunction]: | |||
""" Returns a dictionary of the placeholders' names (the ones this content loader supports) | |||
to the functions used for filling them. The functions must receive as input | |||
batch_samples_inds and batch_samples_elements_inds which defines the current batch, | |||
and return an array per placeholder name according to the receives batch information. | |||
IMPORTANT: Better to use a fixed prefix in the names of the placeholders to become clear which content loader | |||
they belong to! Some sort of having a mark :))!""" | |||
def fill_placeholders(self, | |||
keys: List[str], | |||
samples_inds: np.ndarray)\ | |||
-> Dict[str, Union[None, np.ndarray]]: | |||
""" Receives as input placeholders which is a dictionary of the placeholders' | |||
names to a torch tensor for filling it to feed the model with and | |||
samples_inds and samples_elements_inds, which contain | |||
information about the current batch. Fills placeholders based on the | |||
function dictionary received in get_placeholder_name_to_fill_function_dict.""" | |||
if self._fill_func_by_var_name is None: | |||
self._fill_func_by_var_name = self.get_placeholder_name_to_fill_function_dict() | |||
# Filling all the placeholders in the received dictionary! | |||
placeholders = dict() | |||
for placeholder_name in keys: | |||
if placeholder_name in self._fill_func_by_var_name: | |||
placeholder_v = self._fill_func_by_var_name[placeholder_name]( | |||
samples_inds) | |||
if placeholder_v is None: | |||
raise Exception('None value for key %s' % placeholder_name) | |||
placeholders[placeholder_name] = placeholder_v | |||
else: | |||
raise Exception(f'Unknown key for content loader: {placeholder_name}') | |||
return placeholders | |||
def get_batch_true_interpretations(self, samples_inds: np.ndarray, | |||
samples_elements_inds: Union[None, np.ndarray, List[np.ndarray]])\ | |||
-> np.ndarray: | |||
""" | |||
Receives the indices of samples and their elements (if elemented) in the current batch, | |||
returns the true interpretations as expected (Bounding box or a boolean mask) | |||
""" | |||
raise NotImplementedError('If you want to evaluate interpretations, you must implement this function in your content loader.') | |||
def drop_samples(self, drop_mask: np.ndarray) -> None: | |||
""" | |||
Receives a boolean drop mask and drops the samples whose mask are true, | |||
From now on the indices of samples are based on the new samples set (after elimination) | |||
""" | |||
raise NotImplementedError('If you want to filter samples by mapped labels, you should implement this function in your content loader.') |
@@ -0,0 +1,146 @@ | |||
import os | |||
from typing import TYPE_CHECKING, Dict, List, Tuple | |||
import pickle | |||
from PIL import Image | |||
from enum import Enum | |||
import pandas as pd | |||
import numpy as np | |||
import torch | |||
from torchvision import transforms | |||
from torchvision.datasets.folder import default_loader, find_classes, make_dataset, IMG_EXTENSIONS | |||
from .content_loader import ContentLoader | |||
from ...utils.bb_generator import generate_bb_map | |||
if TYPE_CHECKING: | |||
from ...configs.imagenet_configs import ImagenetConfigs | |||
def _get_image_transform(data_specification: str, input_size: int) -> transforms.Compose: | |||
if data_specification == 'train': | |||
return transforms.Compose([ | |||
transforms.RandomResizedCrop(input_size), | |||
transforms.RandomHorizontalFlip(), | |||
transforms.ToTensor(), | |||
]) | |||
if data_specification in ['test', 'val']: | |||
return transforms.Compose([ | |||
transforms.Resize(256), | |||
transforms.CenterCrop(input_size), | |||
transforms.ToTensor(), | |||
]) | |||
raise ValueError('Unknown data specification: {}'.format(data_specification)) | |||
def load_samples_cache(cache_name: str, imagenet_dir: str) -> List[Tuple[str, int]]: | |||
cache_path = os.path.join('.cache/', cache_name) | |||
print(cache_path) | |||
if os.path.isfile(cache_path): | |||
print('Loading cached samples from {}'.format(cache_path)) | |||
with open(cache_path, 'rb') as f: | |||
return pickle.load(f) | |||
print('Creating cache for {}'.format(cache_name)) | |||
os.makedirs(os.path.dirname(cache_path), exist_ok=True) | |||
_, class_to_idx = find_classes(imagenet_dir) | |||
samples = make_dataset(imagenet_dir, class_to_idx, IMG_EXTENSIONS) | |||
with open(cache_path, 'wb') as f: | |||
pickle.dump(samples, f) | |||
return samples | |||
class BBoxField(Enum): | |||
ImageId = 'ImageId' | |||
PredictionString = 'PredictionString' | |||
def load_bbox_df(imagenet_root: str, data_specification: str) -> pd.DataFrame: | |||
return pd.read_csv(os.path.join(imagenet_root, f'LOC_{data_specification}_solution.csv')) | |||
class ImagenetLoader(ContentLoader): | |||
def __init__(self, conf: 'ImagenetConfigs', data_specification: str): | |||
super().__init__(conf, data_specification) | |||
imagenet_dir = os.path.join(conf.data_separation, data_specification) | |||
self.__samples = load_samples_cache(f'imagenet.{data_specification}', imagenet_dir) | |||
self.__transform = _get_image_transform(data_specification, conf.input_size) | |||
self.__bboxes = load_bbox_df(conf.data_separation, data_specification) | |||
self.__rng_states: Dict[int, Tuple[str, torch.Tensor]] = {} | |||
def get_samples_names(self): | |||
''' sample names must be unique, they can be either scan_names or scan_dirs. | |||
Decided to put scan_names. No difference''' | |||
return np.array([path for path, _ in self.__samples]) | |||
def get_samples_labels(self): | |||
return np.array([label for _, label in self.__samples]) | |||
def drop_samples(self, drop_mask: np.ndarray) -> None: | |||
self.__samples = [sample for sample, drop in zip(self.__samples, drop_mask) if not drop] | |||
def get_placeholder_name_to_fill_function_dict(self): | |||
""" Returns a dictionary of the placeholders' names (the ones this content loader supports) | |||
to the functions used for filling them. The functions must receive as input data_loader, | |||
which is an object of class data_loader that contains information about the current batch | |||
(e.g. the indices of the samples, or if the sample has many elements the indices of the chosen | |||
elements) and return an array per placeholder name according to the receives batch information. | |||
IMPORTANT: Better to use a fixed prefix in the names of the placeholders to become clear which content loader | |||
they belong to! Some sort of having a mark :))!""" | |||
return { | |||
'x': self.__get_x, | |||
'y': self.__get_y, | |||
'bbox': self.__get_bbox, | |||
'interpretations': self.__get_bbox, | |||
} | |||
def __get_x(self, samples_inds: np.ndarray)\ | |||
-> np.ndarray: | |||
def get_image(index) -> torch.Tensor: | |||
path, _ = self.__samples[index] | |||
sample = default_loader(path) | |||
self.__handle_random_state(index, 'x') | |||
return self.__transform(sample) | |||
return torch.stack( | |||
tuple([get_image(index) for index in samples_inds]), | |||
dim=0) | |||
def __get_y(self, samples_inds: np.ndarray)\ | |||
-> np.ndarray: | |||
return self.get_samples_labels()[samples_inds] | |||
def __handle_random_state(self, idx: int, label: str) -> None: | |||
if idx not in self.__rng_states or self.__rng_states[idx][0] == label: | |||
self.__rng_states[idx] = (label, torch.get_rng_state()) | |||
else: | |||
torch.set_rng_state(self.__rng_states[idx][1]) | |||
def __get_bbox(self, sample_inds: np.ndarray)\ | |||
-> np.ndarray: | |||
def extract_bb(prediction_string: str) -> np.ndarray: | |||
splitted = prediction_string.split() | |||
n = len(splitted) // 5 | |||
return np.array([[float(b) for b in splitted[5 * i + 1: 5 * i + 5]] for i in range(n)])\ | |||
.reshape((-1, 2, 2)) | |||
def make_bb(index) -> torch.Tensor: | |||
path, _ = self.__samples[index] | |||
image_id = os.path.basename(path).split('.')[0] | |||
image_size = np.array(Image.open(path).size)[::-1] # 2 | |||
bboxes = self.__bboxes[self.__bboxes[BBoxField.ImageId.value] == image_id][BBoxField.PredictionString.value] | |||
if len(bboxes) == 0: | |||
return torch.zeros(1, self.conf.inp_size, self.conf.inp_size) * np.nan | |||
bboxes = bboxes.values[0] | |||
bboxes = extract_bb(bboxes)[..., ::-1] / image_size # N 2 2 | |||
start_points = bboxes[:, 0] # N 2 | |||
end_points = bboxes[:, 1] # N 2 | |||
bb_map = generate_bb_map(start_points, end_points, tuple(image_size)) | |||
bb_map = Image.fromarray(bb_map) | |||
self.__handle_random_state(index, 'bb') | |||
return self.__transform(bb_map) | |||
return torch.stack([make_bb(i) for i in sample_inds], dim=0) |
@@ -0,0 +1,329 @@ | |||
from typing import Dict, List, TYPE_CHECKING, Optional, Tuple | |||
from os import path, listdir | |||
from sys import stderr | |||
from functools import partial | |||
import cv2 | |||
import numpy as np | |||
import pandas as pd | |||
from .content_loader import ContentLoader | |||
from ...utils.bb_generator import generate_bb_map | |||
if TYPE_CHECKING: | |||
from ...configs.rsna_configs import RSNAConfigs | |||
class RSNALoader(ContentLoader): | |||
def __init__(self, conf: 'RSNAConfigs', data_specification: str): | |||
""" read all directories for scans and annotations, which are split by DataSplitter.py | |||
And then, keeps only those samples which are used for 'usage' """ | |||
super().__init__(conf, data_specification) | |||
self.conf = conf | |||
self.img_size = conf.input_size | |||
self.infection_map_size = conf.infection_map_size | |||
self.samples_info = None | |||
self.load_samples(data_specification) | |||
self.loaded_images = None | |||
self._infections_bbs: Optional[Dict[str, Tuple[List[Tuple[float, float], List[Tuple[float, float]]]]]] = None | |||
self._load_infections() | |||
def load_samples(self, data_specification): | |||
if ':' in data_specification: | |||
data_specification, filter_str = data_specification.split(':') | |||
else: | |||
filter_str = None | |||
if data_specification in ['train', 'val', 'test']: | |||
self.samples_info = \ | |||
pd.read_csv('%s/%s.txt' % | |||
( | |||
self.conf.data_separation, data_specification), sep='\t', header=0) | |||
self.samples_info['sample'] = self.samples_info['sample'].apply(lambda s: s[len('../'):]) | |||
self.samples_info['Interpretation_dir'] = self.samples_info['Interpretation_dir'].apply(lambda s: s[len('../'):]) | |||
elif path.isfile(data_specification) and path.exists(data_specification): | |||
self.samples_info = pd.read_csv(data_specification, sep='\t', header=0) | |||
self.samples_info['sample'] = self.samples_info['sample'].apply(lambda s: s[len('../'):]) | |||
self.samples_info['Interpretation_dir'] = self.samples_info['Interpretation_dir'].apply(lambda s: s[len('../'):]) | |||
elif path.isdir(data_specification) and path.exists(data_specification): | |||
samples_names = np.asarray( | |||
[data_specification + '/' + x for x in listdir(data_specification)]) | |||
print('%d samples discovered in %s' % (len(samples_names), data_specification)) | |||
if filter_str is not None: | |||
pass_filter_str = np.vectorize(lambda x: '/%s/' % filter_str in x) | |||
samples_names = samples_names[pass_filter_str(samples_names)] | |||
print('%d samples remained after filtering' % len(samples_names)) | |||
self.samples_info = pd.DataFrame({ | |||
'label': np.full(len(samples_names), -1, dtype=int), | |||
'sample': samples_names, | |||
'view': np.full(len(samples_names), 'Unknown', dtype=np.object), | |||
'dataset': np.full(len(samples_names), 'unknown', dtype=np.object), | |||
'Interpretation_dir': np.full(len(samples_names), '', dtype=np.object), | |||
}) | |||
return | |||
else: | |||
print('Please implement this part!', flush=True) | |||
org_cnt = len(self.samples_info) | |||
# filtering by filter string! | |||
if filter_str is not None: | |||
pass_filter_str = np.vectorize(lambda x: '/%s/' % filter_str in x) | |||
self.samples_info = self.samples_info[pass_filter_str(self.samples_info['sample'].values)] | |||
print( | |||
'%d of %d samples remained after filtering for %s!' % (len(self.samples_info), org_cnt, data_specification), | |||
flush=True) | |||
org_cnt = len(self.samples_info) | |||
# label mapping/ordering! | |||
label_mapping_dict = self.conf.label_map_dict | |||
v_map_label = np.vectorize(lambda l: label_mapping_dict.get(l.lower(), -1)) | |||
self.samples_info['label'] = v_map_label(self.samples_info['label'].values) | |||
# filtering unmapped labels | |||
self.samples_info = self.samples_info[self.samples_info['label'].values != -1] | |||
print('%d out of %d samples remained after label mapping!' % | |||
(len(self.samples_info), org_cnt), flush=True) | |||
# counting the labels! | |||
flattened_labels = self.samples_info['label'].values | |||
u_labels = np.unique(flattened_labels) | |||
labels_cnt = np.zeros(len(u_labels)) | |||
np.add.at(labels_cnt, np.searchsorted(u_labels, flattened_labels, side='left'), 1) | |||
print('[%s]' % ', '.join(['class-%s: %d' % (str(u_labels[i]), labels_cnt[i]) for i in range(len(labels_cnt))]), | |||
flush=True) | |||
def _load_infections(self) -> None: | |||
""" | |||
If a file address is given as infections_bb_dir in configs | |||
(containing imgs and their bounding boxes as a str) | |||
the information related to bounding boxes will be taken, | |||
otherwise it is expected from the data separation to have a column | |||
indicating the address of the infection mask | |||
""" | |||
if self.conf.infections_bb_dir is not None: | |||
bbs_info = pd.read_csv(self.conf.infections_bb_dir, sep='\t', header=0) | |||
imgs_dirs = np.vectorize(self._get_id_by_path)(bbs_info['img_dir'].values) | |||
imgs_bbs = [ | |||
[] if '-' not in str(im_bbs) else [ | |||
tuple([np.clip(float(x), 0, 1) for x in im_bb.split('-')]) | |||
for im_bb in im_bbs.split(',') | |||
] | |||
for im_bbs in bbs_info['img_bb'].values] | |||
# reformatting to tuple of list of starts and list of ends | |||
imgs_bbs_starts = [[ | |||
(r1, c1) for r1, c1, _, _ in img_bbs] | |||
for img_bbs in imgs_bbs] | |||
imgs_bbs_ends = [[ | |||
(r2, c2) for _, _, r2, c2 in img_bbs] | |||
for img_bbs in imgs_bbs] | |||
imgs_bbs = [ | |||
(imgs_bbs_starts[i], imgs_bbs_ends[i]) | |||
for i in range(len(imgs_bbs_starts))] | |||
self._infections_bbs = dict(zip(imgs_dirs, imgs_bbs)) | |||
def get_samples_names(self): | |||
return self.samples_info['sample'].values | |||
def get_samples_labels(self): | |||
return self.samples_info['label'].values | |||
def drop_samples(self, drop_mask: np.asarray) -> None: | |||
""" Keep_mask is a bool array for keeping samples """ | |||
self.samples_info = self.samples_info[np.logical_not(drop_mask)] | |||
def get_placeholder_name_to_fill_function_dict(self): | |||
ret_val = { | |||
'x': self.get_batch_scaled_images, | |||
'y': self.get_batch_label, | |||
'infection': partial(self.get_batch_extended_bbs_map_from_bbs_file, 0) if | |||
self._infections_bbs is not None else | |||
self.get_batch_bbs_map_from_mask_file, | |||
'interpretations': partial(self.get_batch_extended_bbs_map_from_bbs_file, 0) if | |||
self._infections_bbs is not None else | |||
self.get_batch_bbs_map_from_mask_file, | |||
'has_bb': self.get_batch_has_bb_mask_from_bb_file if | |||
self._infections_bbs is not None else | |||
self.get_batch_has_bb_mask_from_mask_file | |||
} | |||
# if infection radii is not empty, it is expected to have a dictionary | |||
if len(self.conf.receptive_field_radii) > 0: | |||
assert self._infections_bbs is not None, \ | |||
"When having receptive field radii, " \ | |||
"you should provide bounding boxes as a file in config.infections_bb_dir" | |||
ret_val.update({ | |||
f'ex_infection_{r}': partial(self.get_batch_extended_bbs_map_from_bbs_file, r) | |||
for r in self.conf.receptive_field_radii | |||
}) | |||
return ret_val | |||
def get_batch_has_bb_mask_from_mask_file(self, samples_inds: np.ndarray) \ | |||
-> np.ndarray: | |||
assert 'Interpretation_dir' in self.samples_info.columns, \ | |||
'If you have not specified infections_bb_dirin your config, ' \ | |||
'you need to have Interpretation_dir column in your data separation' | |||
def has_inf(si): | |||
if self.samples_info.iloc[si]['label'] == 0: | |||
return True | |||
int_dir = self.samples_info.iloc[si]['Interpretation_dir'] | |||
return str(int_dir) != 'nan' | |||
return np.vectorize(has_inf)(samples_inds) | |||
def get_batch_has_bb_mask_from_bb_file(self, samples_inds: np.ndarray) \ | |||
-> np.ndarray: | |||
def has_inf(si): | |||
if self.samples_info.iloc[si]['label'] == 0: | |||
return True | |||
sample_key = self.samples_info.iloc[si]['sample'] | |||
sample_key = self._get_id_by_path(sample_key) | |||
return sample_key in self._infections_bbs and \ | |||
len(self._infections_bbs[sample_key][0]) > 0 | |||
return np.vectorize(has_inf)(samples_inds) | |||
def get_batch_bbs_map_from_mask_file(self, samples_inds: np.ndarray)\ | |||
-> np.ndarray: | |||
def read_interpretation(im_ind): | |||
# for healthy, return full 0 | |||
if self.samples_info.iloc[im_ind]['label'] == 0: | |||
return np.full((1, self.infection_map_size, self.infection_map_size), 0, dtype=np.float) | |||
int_dir = self.samples_info.iloc[im_ind]['Interpretation_dir'] | |||
# if it does not exist, return full -1 | |||
if str(int_dir) == 'nan': | |||
return np.full((1, self.infection_map_size, self.infection_map_size), -1, dtype=np.float) | |||
if 'npy' in int_dir: | |||
interpretation = np.load(int_dir) | |||
else: | |||
interpretation = np.load(int_dir)['arr_0'] | |||
interpretation = interpretation.astype(float) | |||
if interpretation.shape != (self.infection_map_size, self.infection_map_size): | |||
interpretation = (cv2.resize(np.round(255 * interpretation, 0), | |||
dsize=(self.infection_map_size, self.infection_map_size)) >= 128).astype(float) | |||
return interpretation[np.newaxis, :, :] | |||
batch_interpretations = np.stack(tuple([read_interpretation(si) | |||
for si in samples_inds]), axis=0) | |||
return batch_interpretations | |||
@staticmethod | |||
def _get_id_by_path(path: str) -> str: | |||
return path[path.index('png_images/'):] | |||
def get_batch_extended_bbs_map_from_bbs_file(self, radius, samples_inds: np.ndarray) -> np.ndarray: | |||
def make_map(im_ind): | |||
if self.samples_info.iloc[im_ind]['label'] == 0: | |||
return np.full((1, self.infection_map_size, self.infection_map_size), 0, dtype=np.float) | |||
sample_key = self.samples_info.iloc[im_ind]['sample'] | |||
sample_key = self._get_id_by_path(sample_key) | |||
if sample_key in self._infections_bbs and \ | |||
len(self._infections_bbs[sample_key][0]) > 0: | |||
bbs_info = self._infections_bbs[sample_key] | |||
start_points, end_points = bbs_info | |||
mask = generate_bb_map(start_points, end_points, | |||
(self.infection_map_size, self.infection_map_size), radius) | |||
return mask[np.newaxis, :, :] | |||
else: | |||
return np.full((1, self.infection_map_size, self.infection_map_size), -1, dtype=np.float) | |||
batch_interpretations = np.stack(tuple([make_map(si) | |||
for si in samples_inds]), axis=0) | |||
return batch_interpretations | |||
def get_batch_scaled_images(self, | |||
samples_inds: np.ndarray) \ | |||
-> np.ndarray: | |||
def read_img(im_ind): | |||
im = cv2.imread(self.samples_info.iloc[im_ind]['sample']) | |||
if im is None: | |||
print(self.samples_info.iloc[im_ind]['sample'] + ' is missing!', file=stderr) | |||
raise Exception('Missing image') | |||
if len(list(im.shape)) == 3: | |||
ret_val = im[:, :, 0] | |||
else: | |||
ret_val = im | |||
if ret_val.shape != (self.img_size, self.img_size): | |||
ret_val = cv2.resize(ret_val, dsize=(self.img_size, self.img_size)) | |||
return ret_val[np.newaxis, :, :] | |||
batch_imgs = np.stack(tuple([read_img(si) | |||
for si in samples_inds]), axis=0) | |||
batch_imgs = batch_imgs.astype(np.float32) / 255 | |||
return batch_imgs | |||
def get_batch_label(self, | |||
samples_inds: np.ndarray, | |||
) \ | |||
-> np.ndarray: | |||
return self.samples_info['label'].iloc[samples_inds].values.astype(np.float32) | |||
def get_batch_true_interpretations(self, samples_inds: np.ndarray) \ | |||
-> np.ndarray: | |||
def read_interpretation(im_ind): | |||
# for healthy, return full 0 | |||
if self.samples_info.iloc[im_ind]['label'] == 0: | |||
return np.full((1, self.img_size, self.img_size), 0, dtype=np.float) | |||
int_dir = self.samples_info.iloc[im_ind]['Interpretation_dir'] | |||
if str(int_dir) == 'nan': | |||
return np.full((1, self.img_size, self.img_size), np.nan, dtype=np.float) | |||
if 'npy' in int_dir: | |||
interpretation = np.load(int_dir) | |||
else: | |||
interpretation = np.load(int_dir)['arr_0'].astype(int) | |||
if interpretation.shape != (self.img_size, self.img_size): | |||
interpretation = (cv2.resize(np.round(255 * interpretation, 0).astype(np.uint8), | |||
dsize=(self.img_size, self.img_size)) >= 128).astype(float) | |||
return interpretation[np.newaxis, :, :] | |||
batch_interpretations = np.stack(tuple([read_interpretation(si) | |||
for si in samples_inds]), axis=0) | |||
return batch_interpretations |
@@ -0,0 +1,246 @@ | |||
""" A class for loading the whole data needed for the run! """ | |||
from typing import Dict, List, Union, TYPE_CHECKING | |||
import random | |||
import numpy as np | |||
import torch | |||
from .batch_choosing.batch_chooser import BatchChooser | |||
from .content_loaders.content_loader import ContentLoader | |||
if TYPE_CHECKING: | |||
from ..configs.base_config import BaseConfig | |||
class RunType: | |||
TRAIN = 'train' | |||
VAL = 'val' | |||
TEST = 'test' | |||
class DataLoader(): | |||
""" A class for loading the whole data needed for the run! """ | |||
def __init__(self, conf: 'BaseConfig', data_specification: str, run_type: RunType): | |||
""" Conf is a dictionary containing configurations, | |||
sample specification is a string specifying the samples e.g. address of the CTs | |||
run type is one of the strings train/val/test specifying the mode that the data_loader | |||
will be used.""" | |||
self.conf = conf | |||
self._device = conf.device | |||
self.data_specification = data_specification | |||
self.run_type = run_type | |||
self._content_loader: ContentLoader = self.conf.content_loader_cls(self.conf, data_specification) | |||
self.samples_names: np.ndarray = self._content_loader.get_samples_names() | |||
self.samples_labels = self._content_loader.get_samples_labels() | |||
keep_mask = self.get_samples_keep_mask() | |||
if keep_mask is not None: | |||
drop_mask = np.logical_not(keep_mask) | |||
self._content_loader.drop_samples(drop_mask) | |||
self.samples_names = self.samples_names[keep_mask] | |||
self.samples_labels = self.samples_labels[keep_mask] | |||
self.class_samples_indices = dict() | |||
print('%d samples' % len(self.samples_names), flush=True) | |||
for i in range(len(self.samples_names)): | |||
if self.samples_labels[i] not in self.class_samples_indices: | |||
self.class_samples_indices[self.samples_labels[i]] = [] | |||
self.class_samples_indices[self.samples_labels[i]].append(i) | |||
for c in self.class_samples_indices.keys(): | |||
self.class_samples_indices[c] = np.asarray(self.class_samples_indices[c]) | |||
# for augmentationa | |||
# Only do augmentations in the training phase | |||
self._different_augmentation_per_batch = conf.different_augmentation_per_batch | |||
self._augmentations_dict: Union[None, Dict[str, torch.nn.Module]] = None | |||
if run_type == RunType.TRAIN: | |||
self._augmentations_dict = conf.augmentations_dict | |||
# for preventing reprocessing | |||
self._processed_data_dictionary: Dict[str, torch.Tensor] = dict() | |||
# for preserving same augmentations in one batch | |||
self._transformations_seeds_dict: Dict[torch.nn.Module, Union[int, np.ndarray]] = dict() | |||
if run_type == RunType.TRAIN: | |||
self._batch_chooser: BatchChooser = self.conf.train_batch_chooser_cls(self, conf.batch_size) | |||
else: | |||
self._batch_chooser: BatchChooser = self.conf.eval_batch_chooser_cls(self, conf.batch_size) | |||
def get_samples_names(self) -> np.ndarray: | |||
""" Returns the names of the samples based on the first content loader. | |||
IMPORTANT: These all must be the same for all of the content loaders.""" | |||
return self.samples_names | |||
def get_samples_labels(self): | |||
""" Returns the labels of the samples in a numpy array. """ | |||
return self.samples_labels | |||
def get_number_of_samples(self): | |||
""" Returns the number of samples loaded. """ | |||
return len(self.samples_names) | |||
def get_class_sample_indices(self): | |||
""" Returns a dictionary, containing lists of samples indices belonging to each class label.""" | |||
return self.class_samples_indices | |||
def prepare_next_batch(self): | |||
# Resetting information | |||
self._processed_data_dictionary = dict() | |||
self._transformations_seeds_dict = dict() | |||
self._batch_chooser.prepare_next_batch() | |||
def finished_iteration(self) -> bool: | |||
return self._batch_chooser.finished_iteration() | |||
def fill_placeholders(self, placeholders_dict: Dict[str, torch.Tensor]) -> None: | |||
""" Receives as input a dictionary of placeholders and fills them using all the content loaders.""" | |||
missed_placeholders_names: List[str] | |||
if len(self._processed_data_dictionary) == 0: | |||
missed_placeholders_names = list(placeholders_dict.keys()) | |||
else: | |||
missed_placeholders_names = [] | |||
for k, v in placeholders_dict.items(): | |||
if k in self._processed_data_dictionary: | |||
placeholders_dict[k] = self._processed_data_dictionary[k] | |||
else: | |||
missed_placeholders_names.append(k) | |||
if len(missed_placeholders_names) == 0: | |||
return | |||
new_batch_info = self._content_loader.fill_placeholders( | |||
missed_placeholders_names, self._batch_chooser.get_current_batch_sample_indices()) | |||
# filling the unfilled ones | |||
if len(missed_placeholders_names) > 0: | |||
# filling all missed keys! | |||
with torch.no_grad(): | |||
for k in missed_placeholders_names: | |||
self._fill_placeholder(placeholders_dict[k], new_batch_info[k]) | |||
if self._augmentations_dict is not None and k in self._augmentations_dict: | |||
placeholders_dict[k] = self._apply_augmentation(k, placeholders_dict[k]) | |||
# Keeping a copy of data | |||
self._processed_data_dictionary[k] = placeholders_dict[k] | |||
def get_current_batch_data(self, keyword: str) -> torch.Tensor: | |||
""" Builds placeholder and retrieves the requirement from loaders and returns it! | |||
Args: | |||
keyword (str): The keyword to load for the current batch. | |||
Returns: | |||
torch.Tensor: The information related to the keyword for the current batch. | |||
""" | |||
if keyword in self._processed_data_dictionary: | |||
return self._processed_data_dictionary[keyword] | |||
else: | |||
placeholders_dict = {keyword: create_empty_placeholder(self._device)} | |||
self.fill_placeholders(placeholders_dict) | |||
return placeholders_dict[keyword] | |||
def get_current_batch_sample_indices(self) -> np.ndarray: | |||
""" Returns a list of indices of the samples chosen for the current batch. """ | |||
return self._batch_chooser.get_current_batch_sample_indices() | |||
def get_max_class_samples_num(self): | |||
return max([len(x) for x in self.class_samples_indices.values()]) | |||
def get_classes_num(self): | |||
return len(self.class_samples_indices) | |||
def get_samples_keep_mask(self) -> Union[None, np.ndarray]: | |||
if self.conf.mapped_labels_to_use is None: | |||
return None | |||
existing_dict = dict(zip(self.conf.mapped_labels_to_use, self.conf.mapped_labels_to_use)) | |||
keep_mask = np.vectorize(lambda x: x in existing_dict)(self.samples_labels) | |||
print(f'Kept {np.sum(keep_mask)} out of {len(keep_mask)} samples after considering the labels of interest.') | |||
return keep_mask | |||
def get_current_batch_size(self) -> int: | |||
return self._batch_chooser.get_current_batch_size() | |||
def get_current_batch_samples_names(self) -> np.ndarray: | |||
return self.samples_names[self._batch_chooser.get_current_batch_sample_indices()] | |||
def get_current_batch_samples_labels(self) -> np.ndarray: | |||
return self.samples_labels[self._batch_chooser.get_current_batch_sample_indices()] | |||
def get_current_batch_samples_interpretations(self) -> np.ndarray: | |||
return self.get_current_batch_data('interpretations') | |||
@staticmethod | |||
def _fill_placeholder(placeholder: torch.Tensor, val: np.ndarray): | |||
""" Fills the torch placeholder with the given numpy value, | |||
If shape mismatches, resizes the placeholder so the data would fit. """ | |||
# Resize if the shape mismatches | |||
if list(placeholder.shape) != list(val.shape): | |||
placeholder.resize_(*tuple(val.shape)) | |||
# feeding the value | |||
placeholder.copy_(torch.Tensor(val)) | |||
def _apply_augmentation(self, var_name: str, var_val: torch.Tensor) -> torch.Tensor: | |||
""" This function would be called for the variables we are sure they need augmentation | |||
(they are presented in augmentation dictionary), applies the specified augmentation, | |||
makes sure all augmentations be the same on the same batch elements, | |||
and returns the augmented data""" | |||
def run_single_aug(aug, inp): | |||
if type(aug) in self._transformations_seeds_dict: | |||
seed = self._transformations_seeds_dict[type(aug)] | |||
else: | |||
# setting a seed | |||
seed = np.random.randint(2147483647) # make a seed with numpy generator | |||
self._transformations_seeds_dict[type(aug)] = seed | |||
random.seed(seed) | |||
np.random.seed(seed) | |||
torch.manual_seed(seed) | |||
if self._different_augmentation_per_batch: | |||
ret_val = torch.stack([aug(inp[bi]) for bi in range(inp.shape[0])], dim=0) | |||
else: | |||
ret_val = aug(inp) | |||
return ret_val | |||
field_transformation = self._augmentations_dict[var_name] | |||
if not isinstance(field_transformation, torch.nn.Sequential): | |||
return run_single_aug(field_transformation, var_val) | |||
else: | |||
aug_val = var_val | |||
for single_transform in field_transformation.children(): | |||
aug_val = run_single_aug( | |||
single_transform, | |||
aug_val) | |||
return aug_val | |||
def create_empty_placeholder(device: torch.device) -> torch.Tensor: | |||
""" Create empty placeholder with shape (1) in the given device. | |||
Args: | |||
device (torch.device): Device to create the placeholder on. | |||
Returns: | |||
torch.Tensor: Empty placeholder. | |||
""" | |||
return torch.zeros(1, dtype=torch.float32, | |||
device=device, | |||
requires_grad=False) |
@@ -0,0 +1,98 @@ | |||
""" | |||
Runs flow of data, from loading to forwarding the model. | |||
""" | |||
from typing import Iterator, Union, TypeVar, Generic, Dict | |||
import torch | |||
from ..data.data_loader import DataLoader, create_empty_placeholder | |||
from ..data.dataloader_context import DataloaderContext | |||
from ..models.model import Model | |||
TReturn = TypeVar('TReturn') | |||
class DataFlow(Generic[TReturn]): | |||
""" | |||
Runs flow of data, from loading to forwarding the model. | |||
""" | |||
def __init__(self, model: 'Model', dataloader: 'DataLoader', | |||
device: torch.device, print_debug_info:bool = False): | |||
self._model = model | |||
self._device = device | |||
self._dataloader = dataloader | |||
self._placeholders: Union[None, Dict[str, torch.Tensor]] = None | |||
self._print_debug_info = print_debug_info | |||
def iterate(self) -> Iterator[TReturn]: | |||
""" | |||
Iterates on data and forwards the model | |||
""" | |||
if self._placeholders is None: | |||
raise Exception("Please use `with` statement before calling iterate method:\n" + | |||
"with dataflow:\n" + | |||
" for model_output in dataflow.iterate():\n" + | |||
" pass\n") | |||
DataloaderContext.instance.dataloader = self._dataloader | |||
while True: | |||
self._dataloader.prepare_next_batch() | |||
if self._print_debug_info: | |||
print('> Next Batch:') | |||
print('\t' + '\n\t'.join(self._dataloader.get_current_batch_samples_names())) | |||
# check if the iteration is done | |||
if self._dataloader.finished_iteration(): | |||
break | |||
# Filling the placeholders for running the model | |||
self._dataloader.fill_placeholders(self._placeholders) | |||
# Running the model | |||
yield self._final_stage(self._placeholders) | |||
def __enter__(self): | |||
# self._dataloader.reset() | |||
self._placeholders = self._build_placeholders() | |||
def __exit__(self, exc_type, exc_value, traceback): | |||
for name in self._placeholders: | |||
self._placeholders[name] = None | |||
self._placeholders = None | |||
def _final_stage(self, model_input: Dict[str, torch.Tensor]) -> TReturn: | |||
return self._model(**model_input) | |||
def _build_placeholders(self) -> Dict[str, torch.Tensor]: | |||
""" Creates placeholders for feeding the model. | |||
Returns a dictionary containing placeholders. | |||
The placeholders are created based on the names | |||
of the input variables of the model's forward method. | |||
The variables initialized as None are assumed to be | |||
labels used for training phase only, and won't be | |||
added in phases other than train.""" | |||
model_forward_args, args_default_values, additional_kwargs = self._model.get_forward_required_kws_kwargs_and_defaults() | |||
# Checking which args have None as default | |||
args_default_values = ['Mandatory' for _ in | |||
range(len(model_forward_args) - len(args_default_values))] + \ | |||
args_default_values | |||
optional_args = [x is None for x in args_default_values] | |||
placeholders = dict() | |||
for argname in model_forward_args: | |||
placeholders[argname] = create_empty_placeholder(self._device) | |||
for argname in additional_kwargs: | |||
placeholders[argname] = create_empty_placeholder(self._device) | |||
return placeholders | |||
@@ -0,0 +1,28 @@ | |||
from typing import TYPE_CHECKING | |||
if TYPE_CHECKING: | |||
from .data_loader import DataLoader | |||
class classproperty(property): | |||
def __get__(self, obj, objtype=None): | |||
return super(classproperty, self).__get__(objtype) | |||
def __set__(self, obj, value): | |||
super(classproperty, self).__set__(type(obj), value) | |||
def __delete__(self, obj): | |||
super(classproperty, self).__delete__(type(obj)) | |||
class DataloaderContext: | |||
_instance: 'DataloaderContext' = None | |||
def __init__(self) -> None: | |||
self.dataloader: 'DataLoader' = None | |||
@classproperty | |||
def instance(cls) -> 'DataloaderContext': | |||
if cls._instance is None: | |||
cls._instance = DataloaderContext() | |||
return cls._instance |
@@ -0,0 +1,10 @@ | |||
from typing import Type | |||
from .entrypoint import BaseEntrypoint | |||
from ..configs.base_config import PhaseType | |||
from importlib import import_module | |||
import sys | |||
def get_entry(phase_type: PhaseType): | |||
EntryPoint: Type[BaseEntrypoint] = import_module(f'torchlap.experiments.{sys.argv[1]}').EntryPoint | |||
return EntryPoint(phase_type) |
@@ -0,0 +1,68 @@ | |||
from os import path | |||
from typing import TYPE_CHECKING | |||
import torch | |||
from ..models.model import Model | |||
if TYPE_CHECKING: | |||
from ..configs.base_config import BaseConfig | |||
def load_model(the_model: Model, conf: 'BaseConfig') -> Model: | |||
dev_name = conf.dev_name | |||
if conf.final_model_dir is not None: | |||
load_dir = conf.final_model_dir | |||
if not path.exists(load_dir): | |||
raise Exception('No path %s' % load_dir) | |||
else: | |||
load_dir = conf.save_dir | |||
print('>>> loading from: ' + load_dir, flush=True) | |||
if not path.exists(load_dir): | |||
raise Exception('Problem in loading the model. %s does not exist!' % load_dir) | |||
map_location = None | |||
if dev_name == 'cpu': | |||
map_location = 'cpu' | |||
elif dev_name is not None and ':' in dev_name: | |||
map_location = dev_name | |||
the_model.to(conf.device) | |||
if path.isfile(load_dir): | |||
if not path.exists(load_dir): | |||
raise Exception('Problem in loading the model. %s does not exist!' % load_dir) | |||
the_model.load_state_dict(torch.load(load_dir, map_location=map_location)) | |||
print('Loaded the model at %s' % load_dir, flush=True) | |||
else: | |||
if conf.epoch is not None: | |||
epoch = int(conf.epoch) | |||
elif path.exists(load_dir + '/GeneralInfo'): | |||
checkpoint = torch.load(load_dir + '/GeneralInfo', map_location=map_location) | |||
epoch = checkpoint['best_val_epoch'] | |||
print(f'Epoch has not been specified in the config, ' | |||
f'using {epoch} which is best_val_epoch instead') | |||
else: | |||
raise Exception('Either epoch or pt dir (final_model) must be given as GeneralInfo was not found') | |||
if not path.exists('%s/%d' % (load_dir, epoch)): | |||
raise Exception('Problem in loading the model. %s/%d does not exist!' % | |||
(load_dir, epoch)) | |||
checkpoint = torch.load('%s/%d' % (load_dir, epoch), | |||
map_location=map_location) | |||
# backward compatibility | |||
if 'model_state_dict' in checkpoint: | |||
checkpoint = checkpoint['model_state_dict'] | |||
the_model.load_state_dict(checkpoint) | |||
print(f'Loaded the model from epoch {epoch} of {load_dir}', flush=True) | |||
the_model.to(conf.device) | |||
return the_model |
@@ -0,0 +1,20 @@ | |||
from typing import Tuple | |||
from ...configs.celeba_configs import CelebAConfigs, CelebATag | |||
from ...models.model import Model | |||
from ..entrypoint import BaseEntrypoint | |||
from ...models.celeba.single_tag_org_inception import CelebAORGInception | |||
class EntryPoint(BaseEntrypoint): | |||
def _get_conf_model(self) -> Tuple[CelebAConfigs, Model]: | |||
conf = CelebAConfigs('CelebA_Inception_Org', 8, 299, self.phase_type) | |||
conf.max_epochs = 12 | |||
conf.main_tag = CelebATag.SmilingTag | |||
conf.tags = [conf.main_tag.name] | |||
model = CelebAORGInception(conf.main_tag.name, 0.4) | |||
return conf, model | |||
@@ -0,0 +1,20 @@ | |||
from typing import Tuple | |||
from ...configs.celeba_configs import CelebAConfigs, CelebATag | |||
from ...models.model import Model | |||
from ..entrypoint import BaseEntrypoint | |||
from ...models.celeba.org_resnet import CelebAORGResNet18 | |||
class EntryPoint(BaseEntrypoint): | |||
def _get_conf_model(self) -> Tuple[CelebAConfigs, Model]: | |||
conf = CelebAConfigs('CelebA_ResNet_Org', 9, 224, self.phase_type) | |||
conf.max_epochs = 12 | |||
conf.main_tag = CelebATag.SmilingTag | |||
conf.tags = [conf.main_tag.name] | |||
model = CelebAORGResNet18(conf.main_tag.name) | |||
return conf, model | |||
@@ -0,0 +1,143 @@ | |||
from typing import Tuple | |||
from collections import OrderedDict | |||
from ...configs.celeba_configs import CelebAConfigs, CelebATag | |||
from ...models.model import Model | |||
from ..entrypoint import BaseEntrypoint | |||
from ...models.celeba.lap_inception import CelebALAPInception | |||
from ...modules.lap import LAP | |||
from ...modules.adaptive_lap import AdaptiveLAP | |||
from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator | |||
from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss | |||
from ...utils.aux_output import AuxOutput | |||
from ...utils.output_modifier import OutputModifier | |||
from ...model_evaluation.binary_evaluator import BinaryEvaluator | |||
from ...model_evaluation.binary_fortelling import BinForetellerEvaluator | |||
from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator | |||
from ...model_evaluation.loss_evaluator import LossEvaluator | |||
from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator | |||
def lap_factory(channel): | |||
return LAP( | |||
channel, 2, 2, hidden_channels=[8], n_attention_heads=3, | |||
sigmoid_scale=0.1, discriminative_attention=True) | |||
def adaptive_lap_factory(channel): | |||
return AdaptiveLAP(channel, [], 0.1, | |||
n_attention_heads=3, discriminative_attention=True) | |||
class EntryPoint(BaseEntrypoint): | |||
def __init__(self, phase_type) -> None: | |||
self.active_min_ratio: float = 0.02 | |||
self.active_max_ratio: float = 0.4 | |||
self.inactive_ratio: float = 0.01 | |||
self.common_max_ratio = 0.02 | |||
super().__init__(phase_type) | |||
def _get_conf_model(self) -> Tuple[CelebAConfigs, Model]: | |||
conf = CelebAConfigs('CelebA_Inception_WS', 10, 256, self.phase_type) | |||
conf.max_epochs = 12 | |||
conf.main_tag = CelebATag.SmilingTag | |||
conf.tags = [conf.main_tag.name] | |||
model = CelebALAPInception(conf.main_tag.name, 0.4, lap_factory, adaptive_lap_factory) | |||
# aux loss for free head | |||
fws_losses = [ | |||
DiscriminativeWeaklySupervisedLoss( | |||
title, model, att_score_layer, 0.025 / 3, 0, | |||
(0, self.common_max_ratio), | |||
{2: [0, 1]}, discr_score_layer, | |||
w_attention_in_ordering=1, w_discr_in_ordering=0) | |||
for title, att_score_layer, discr_score_layer in [ | |||
('FPool2', model.maxpool2.attention_layer, model.maxpool2.discrimination_layer), | |||
('FPool6', model.Mixed_6a.pool.attention_layer, model.Mixed_6a.pool.discrimination_layer), | |||
('FPool7', model.Mixed_7a.pool.attention_layer, model.Mixed_7a.pool.discrimination_layer), | |||
('FPoolG', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), | |||
] | |||
] | |||
for fws_loss in fws_losses: | |||
fws_loss.configure(conf) | |||
# weakly supervised losses based on discriminative head and attention head | |||
ws_losses = [ | |||
DiscriminativeWeaklySupervisedLoss( | |||
title, model, att_score_layer, 0.025 * 2 / 3, | |||
self.inactive_ratio, | |||
(self.active_min_ratio, self.active_max_ratio), {0: [0], 1: [1]}, | |||
discr_score_layer, | |||
w_attention_in_ordering=0.2, w_discr_in_ordering=1) | |||
for title, att_score_layer, discr_score_layer in [ | |||
('Pool2', model.maxpool2.attention_layer, model.maxpool2.discrimination_layer), | |||
('Pool6', model.Mixed_6a.pool.attention_layer, model.Mixed_6a.pool.discrimination_layer), | |||
('Pool7', model.Mixed_7a.pool.attention_layer, model.Mixed_7a.pool.discrimination_layer), | |||
('PoolG', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), | |||
] | |||
] | |||
for ws_loss in ws_losses: | |||
ws_loss.configure(conf) | |||
# concordance loss for attention | |||
concordance_loss = PoolConcordanceLossCalculator( | |||
'AC', model, OrderedDict([ | |||
('att2', model.maxpool2.attention_layer), | |||
('att6', model.Mixed_6a.pool.attention_layer), | |||
('att7', model.Mixed_7a.pool.attention_layer), | |||
('attG', model.avgpool[0].attention_layer), | |||
]), loss_weight=1, weights=0.1 / 4, diff_thresholds=0, | |||
labels_by_channel={0: [0], 1: [1]}) | |||
concordance_loss.configure(conf) | |||
# concordance loss for discrimination head | |||
concordance_loss2 = PoolConcordanceLossCalculator( | |||
'DC', model, OrderedDict([ | |||
('D-att2', model.maxpool2.discrimination_layer), | |||
('D-att6', model.Mixed_6a.pool.discrimination_layer), | |||
('D-att7', model.Mixed_7a.pool.discrimination_layer), | |||
('D-attG', model.avgpool[0].discrimination_layer), | |||
]), loss_weight=1, weights=0.05 / 4, diff_thresholds=0, | |||
labels_by_channel={0: [0], 1: [1]}) | |||
concordance_loss2.configure(conf) | |||
conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ | |||
'b': BinaryEvaluator, | |||
'l': LossEvaluator, | |||
'f': BinForetellerEvaluator.standard_creator('foretell'), | |||
'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), | |||
})) | |||
conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' | |||
################################### | |||
########### Foreteller ############ | |||
################################### | |||
aux = AuxOutput(model, dict( | |||
foretell_pool2=model.maxpool2.attention_layer, | |||
foretell_pool6=model.Mixed_6a.pool.attention_layer, | |||
foretell_pool7=model.Mixed_7a.pool.attention_layer, | |||
foretell_avgpool=model.avgpool[0].attention_layer, | |||
)) | |||
aux.configure(conf) | |||
output_modifier = OutputModifier(model, | |||
lambda x: (x - 0.5).relu().flatten(2).sum(dim=2)[:, :2].argmax(dim=1), | |||
'foretell_pool2', | |||
'foretell_pool6', | |||
'foretell_pool7', | |||
'foretell_avgpool', | |||
) | |||
output_modifier.configure(conf) | |||
return conf, model | |||
@@ -0,0 +1,131 @@ | |||
from typing import Tuple | |||
from collections import OrderedDict | |||
from ...configs.celeba_configs import CelebAConfigs, CelebATag | |||
from ...models.model import Model | |||
from ..entrypoint import BaseEntrypoint | |||
from ...models.celeba.lap_resnet import CelebALAPResNet18 | |||
from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator | |||
from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss | |||
from ...utils.aux_output import AuxOutput | |||
from ...utils.output_modifier import OutputModifier | |||
from ...model_evaluation.binary_evaluator import BinaryEvaluator | |||
from ...model_evaluation.binary_fortelling import BinForetellerEvaluator | |||
from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator | |||
from ...model_evaluation.loss_evaluator import LossEvaluator | |||
from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator | |||
class EntryPoint(BaseEntrypoint): | |||
def __init__(self, phase_type) -> None: | |||
self.active_min_ratio: float = 0.02 | |||
self.active_max_ratio: float = 0.4 | |||
self.inactive_ratio: float = 0.01 | |||
self.common_max_ratio = 0.02 | |||
super().__init__(phase_type) | |||
def _get_conf_model(self) -> Tuple[CelebAConfigs, Model]: | |||
conf = CelebAConfigs('CelebA_ResNet_WS', 11, 224, self.phase_type) | |||
conf.max_epochs = 12 | |||
conf.main_tag = CelebATag.SmilingTag | |||
conf.tags = [conf.main_tag.name] | |||
model = CelebALAPResNet18(conf.main_tag.name, sigmoid_scale=0.1) | |||
# aux loss for free head | |||
fws_losses = [ | |||
DiscriminativeWeaklySupervisedLoss( | |||
title, model, att_score_layer, 0.025 / 3, 0, | |||
(0, self.common_max_ratio), | |||
{2: [0, 1]}, discr_score_layer, | |||
w_attention_in_ordering=1, w_discr_in_ordering=0) | |||
for title, att_score_layer, discr_score_layer in [ | |||
('Fatt2', model.layer2[0].pool.attention_layer, model.layer2[0].pool.discrimination_layer), | |||
('Fatt3', model.layer3[0].pool.attention_layer, model.layer3[0].pool.discrimination_layer), | |||
('Fatt4', model.layer4[0].pool.attention_layer, model.layer4[0].pool.discrimination_layer), | |||
('Fatt5', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), | |||
] | |||
] | |||
for fws_loss in fws_losses: | |||
fws_loss.configure(conf) | |||
# weakly supervised losses based on discriminative head and attention head | |||
ws_losses = [ | |||
DiscriminativeWeaklySupervisedLoss( | |||
title, model, att_score_layer, 0.025 * 2 / 3, | |||
self.inactive_ratio, | |||
(self.active_min_ratio, self.active_max_ratio), {0: [0], 1: [1]}, | |||
discr_score_layer, | |||
w_attention_in_ordering=0.2, w_discr_in_ordering=1) | |||
for title, att_score_layer, discr_score_layer in [ | |||
('att2', model.layer2[0].pool.attention_layer, model.layer2[0].pool.discrimination_layer), | |||
('att3', model.layer3[0].pool.attention_layer, model.layer3[0].pool.discrimination_layer), | |||
('att4', model.layer4[0].pool.attention_layer, model.layer4[0].pool.discrimination_layer), | |||
('att5', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), | |||
] | |||
] | |||
for ws_loss in ws_losses: | |||
ws_loss.configure(conf) | |||
# concordance loss for attention | |||
concordance_loss = PoolConcordanceLossCalculator( | |||
'AC', model, OrderedDict([ | |||
('att2', model.layer2[0].pool.attention_layer), | |||
('att3', model.layer3[0].pool.attention_layer), | |||
('att4', model.layer4[0].pool.attention_layer), | |||
('att5', model.avgpool[0].attention_layer), | |||
]), loss_weight=1, weights=0.1 / 4, diff_thresholds=0, | |||
labels_by_channel={0: [0], 1: [1]}) | |||
concordance_loss.configure(conf) | |||
# concordance loss for discrimination head | |||
concordance_loss2 = PoolConcordanceLossCalculator( | |||
'DC', model, OrderedDict([ | |||
('att2', model.layer2[0].pool.discrimination_layer), | |||
('att3', model.layer3[0].pool.discrimination_layer), | |||
('att4', model.layer4[0].pool.discrimination_layer), | |||
('att5', model.avgpool[0].discrimination_layer), | |||
]), loss_weight=1, weights=0.05 / 4, diff_thresholds=0, | |||
labels_by_channel={0: [0], 1: [1]}) | |||
concordance_loss2.configure(conf) | |||
conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ | |||
'b': BinaryEvaluator, | |||
'l': LossEvaluator, | |||
'f': BinForetellerEvaluator.standard_creator('foretell'), | |||
'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), | |||
})) | |||
conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' | |||
################################### | |||
########### Foreteller ############ | |||
################################### | |||
aux = AuxOutput(model, dict( | |||
foretell_pool2=model.layer2[0].pool.attention_layer, | |||
foretell_pool3=model.layer3[0].pool.attention_layer, | |||
foretell_pool4=model.layer4[0].pool.attention_layer, | |||
foretell_avgpool=model.avgpool[0].attention_layer, | |||
)) | |||
aux.configure(conf) | |||
output_modifier = OutputModifier(model, | |||
lambda x: (x - 0.5).relu().flatten(2).sum(dim=2)[:, :2].argmax(dim=1), | |||
'foretell_pool2', | |||
'foretell_pool3', | |||
'foretell_pool4', | |||
'foretell_avgpool', | |||
) | |||
output_modifier.configure(conf) | |||
return conf, model | |||
@@ -0,0 +1,113 @@ | |||
import sys | |||
from typing import Tuple | |||
import argparse | |||
import os | |||
from abc import ABC, abstractmethod | |||
from ..models.model import Model | |||
from ..configs.base_config import BaseConfig, PhaseType | |||
class BaseEntrypoint(ABC): | |||
"""Base class for all entrypoints. | |||
""" | |||
description = '' | |||
def __init__(self, phase_type: PhaseType) -> None: | |||
super().__init__() | |||
self.phase_type = phase_type | |||
self.conf, self.model = self._get_conf_model() | |||
self.parser = self._create_parser(sys.argv[0], sys.argv[1]) | |||
self.conf.update(self.parser.parse_args(sys.argv[2:])) | |||
def _create_parser(self, command: str, entrypoint_name: str) -> argparse.ArgumentParser: | |||
parser = argparse.ArgumentParser( | |||
prog=f'{os.path.basename(command)} {entrypoint_name}', | |||
description=self.description or None | |||
) | |||
parser.add_argument('--device', default='cpu', type=str, | |||
help='Analysis will run over device, use cpu, cuda:#, cuda (for all gpus)') | |||
parser.add_argument('--samples-dir', default=None, type=str, | |||
help='The directory/dataGroup to be evaluated. In the case of dataGroup can be train/test/val. In the case of directory must contain 0 and 1 subdirectories. Use:FilterName to do over samples containing /FilterName/ in their path') | |||
parser.add_argument('--final-model-dir', default=None, type=str, | |||
help='The directory to load the model from, when not given will be calculated!') | |||
parser.add_argument('--save-dir', default=None, type=str, | |||
help='The directory to save the model from, when not given will be calculated!') | |||
parser.add_argument('--report-dir', default=None, type=str, | |||
help='The dir to save reports per slice per sample in.') | |||
parser.add_argument('--epoch', default=None, type=str, | |||
help='The epoch to load.') | |||
parser.add_argument('--try-name', default=None, type=str, | |||
help='The run name specifying what run is doing') | |||
parser.add_argument('--try-num', default=None, type=int, | |||
help='The try number to load') | |||
parser.add_argument('--data-separation', default=None, type=str, | |||
help='The data_separation to be used.') | |||
parser.add_argument('--batch-size', default=None, type=int, | |||
help='The batch size to be used.') | |||
parser.add_argument('--pretrained-model-file', default=None, type=str, | |||
help='Address of .pt pretrained model') | |||
parser.add_argument('--max-epochs', default=None, type=int, | |||
help='The maximum epochs of training!') | |||
parser.add_argument('--big-batch-size', default=None, type=int, | |||
help='The big batch size (iteration per optimization)!') | |||
parser.add_argument('--iters-per-epoch', default=None, type=int, | |||
help='The number of big batches per epoch!') | |||
parser.add_argument('--interpretation-method', default=None, type=str, | |||
help='The method used for interpreting the results!') | |||
parser.add_argument('--cut-threshold', default=None, type=float, | |||
help='The threshold for cutting interpretations!') | |||
parser.add_argument('--global-threshold', action='store_true', | |||
help='Whether the given cut threshold must be applied global or to the relative values!') | |||
parser.add_argument('--dynamic-threshold', action='store_true', | |||
help='Whether to use dynamic threshold in interpretation!') | |||
parser.add_argument('--class-label-for-interpretation', default=None, type=int, | |||
help='The class label we want to explain why it has been chosen. None means the decision of the model!') | |||
parser.add_argument('--interpret-predictions-vs-gt', default='1', type=(lambda x: x == '1'), | |||
help='If the class_label for_interpretation is None this will be considered. If 1, interpretations would be done for the predicted label, otherwise for the ground truth.') | |||
parser.add_argument('--mapped-labels-to-use', default=None, type=(lambda x: [int(y) for y in x.split(',')]), | |||
help='The labels to do the analysis on them only (comma separated), default is all the labels.') | |||
parser.add_argument('--skip-overlay', action='store_true', default=None, | |||
help='Passing this flag prevents the interpretation phase from storing overlay images.') | |||
parser.add_argument('--skip-raw', action='store_true', default=None, | |||
help='Passing this flag prevents the interpretation phase from storing the raw interpretation values.') | |||
parser.add_argument('--overlay-only', action='store_true', | |||
help='Passing this flag makes the interpretation phase to store just overlay images ' | |||
'and not the `.npy` files.') | |||
parser.add_argument('--save-by-file-name', action='store_true', | |||
help='Saves sample-specific files by their file names only, not the whole path.') | |||
parser.add_argument('--n-interpretation-samples', default=None, type=int, | |||
help='The max number of samples to be interpreted.') | |||
parser.add_argument('--interpretation-tag-to-evaluate', default=None, type=str, | |||
help='The tag to be used as interpretation for evaluation phase, otherwise the first one will be used.') | |||
return parser | |||
@abstractmethod | |||
def _get_conf_model(self) -> Tuple[BaseConfig, Model]: | |||
pass |
@@ -0,0 +1,17 @@ | |||
from typing import Tuple | |||
from torchlap.configs.imagenet_configs import ImagenetConfigs | |||
from ...models.model import Model | |||
from ..entrypoint import BaseEntrypoint | |||
from ...models.imagenet.lap_resnet import ImagenetLAPResNet50 | |||
class EntryPoint(BaseEntrypoint): | |||
def _get_conf_model(self) -> Tuple[ImagenetConfigs, Model]: | |||
config = ImagenetConfigs('ImagenetFT', 2, 224, self.phase_type) | |||
model = ImagenetLAPResNet50(sigmoid_scale=0.1) | |||
config.freezing_regexes = [ | |||
r'(?!^layer4\..*$)(?!^fc\..*$)(^.*$)' | |||
] | |||
return config, model |
@@ -0,0 +1,17 @@ | |||
from typing import Tuple | |||
from torchlap.configs.imagenet_configs import ImagenetConfigs | |||
from ...models.model import Model | |||
from ..entrypoint import BaseEntrypoint | |||
from ...models.imagenet.lap_resnet import ImagenetLAPResNet50 | |||
class EntryPoint(BaseEntrypoint): | |||
def _get_conf_model(self) -> Tuple[ImagenetConfigs, Model]: | |||
config = ImagenetConfigs('ImagenetNFT', 1, 224, self.phase_type) | |||
model = ImagenetLAPResNet50(sigmoid_scale=0.1) | |||
config.freezing_regexes = [ | |||
r'(?!^.*\.pool\..*$)(^.*$)' | |||
] | |||
return config, model |
@@ -0,0 +1,77 @@ | |||
from collections import OrderedDict | |||
from typing import Tuple | |||
from ...configs.rsna_configs import RSNAConfigs | |||
from ...models.model import Model | |||
from ..entrypoint import BaseEntrypoint | |||
from ...models.rsna.lap_inception import RSNALAPInception | |||
from ...modules.lap import LAP | |||
from ...modules.adaptive_lap import AdaptiveLAP | |||
from ...criteria.bb_supervised import BBLoss | |||
from ...utils.aux_output import AuxOutput | |||
from ...utils.output_modifier import OutputModifier | |||
from ...model_evaluation.binary_evaluator import BinaryEvaluator | |||
from ...model_evaluation.binary_fortelling import BinForetellerEvaluator | |||
from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator | |||
from ...model_evaluation.loss_evaluator import LossEvaluator | |||
from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator | |||
def lap_factory(channel): | |||
return LAP( | |||
channel, 2, 2, hidden_channels=[8], | |||
sigmoid_scale=0.1, discriminative_attention=False) | |||
def adaptive_lap_factory(channel): | |||
return AdaptiveLAP(channel, [], 0.1, discriminative_attention=False) | |||
class EntryPoint(BaseEntrypoint): | |||
def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: | |||
conf = RSNAConfigs('RSNA_Inception_BB', 6, 256, self.phase_type) | |||
model = RSNALAPInception(0.4, lap_factory, adaptive_lap_factory) | |||
bb_losses = BBLoss( | |||
'BB', model, { | |||
'Pool2': model.maxpool2.attention_layer, | |||
'Pool6': model.Mixed_6a.pool.attention_layer, | |||
'Pool7': model.Mixed_7a.pool.attention_layer, | |||
'PoolG': model.avgpool[0].attention_layer, | |||
} | |||
, 1, | |||
'interpretations', 1, 0.5) | |||
bb_losses.configure(conf) | |||
conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ | |||
'b': BinaryEvaluator, | |||
'l': LossEvaluator, | |||
'f': BinForetellerEvaluator.standard_creator('foretell'), | |||
'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), | |||
})) | |||
conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' | |||
################################### | |||
########### Foreteller ############ | |||
################################### | |||
aux = AuxOutput(model, dict( | |||
foretell_pool2=model.maxpool2.attention_layer, | |||
foretell_pool6=model.Mixed_6a.pool.attention_layer, | |||
foretell_pool7=model.Mixed_7a.pool.attention_layer, | |||
foretell_avgpool=model.avgpool[0].attention_layer, | |||
)) | |||
aux.configure(conf) | |||
output_modifier = OutputModifier(model, | |||
lambda x: x.flatten(1).max(dim=-1).values, | |||
'foretell_pool2', | |||
'foretell_pool6', | |||
'foretell_pool7', | |||
'foretell_avgpool', | |||
) | |||
output_modifier.configure(conf) | |||
return conf, model | |||
@@ -0,0 +1,69 @@ | |||
from collections import OrderedDict | |||
from typing import Tuple | |||
from ...utils.output_modifier import OutputModifier | |||
from ...utils.aux_output import AuxOutput | |||
from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator | |||
from ...model_evaluation.binary_evaluator import BinaryEvaluator | |||
from ...model_evaluation.binary_fortelling import BinForetellerEvaluator | |||
from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator | |||
from ...model_evaluation.loss_evaluator import LossEvaluator | |||
from ...criteria.bb_supervised import BBLoss | |||
from ...configs.rsna_configs import RSNAConfigs | |||
from ...models.model import Model | |||
from ..entrypoint import BaseEntrypoint | |||
from ...models.rsna.lap_resnet import RSNALAPResNet18, get_adaptive_pool_factory, get_pool_factory | |||
class EntryPoint(BaseEntrypoint): | |||
def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: | |||
config = RSNAConfigs('RSNA_ResNet_BB', 1, 224, self.phase_type) | |||
model = RSNALAPResNet18( | |||
pool_factory=get_pool_factory(discriminative_attention=False), | |||
adaptive_factory=get_adaptive_pool_factory(discriminative_attention=False), | |||
sigmoid_scale=0.1, | |||
) | |||
bb_losses = BBLoss( | |||
'BB', model, { | |||
'att2': model.layer2[0].pool.attention_layer, | |||
'att3': model.layer3[0].pool.attention_layer, | |||
'att4': model.layer4[0].pool.attention_layer, | |||
'att5': model.avgpool[0].attention_layer, | |||
} | |||
, 1, | |||
'interpretations', 1, 0.5) | |||
bb_losses.configure(config) | |||
config.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ | |||
'b': BinaryEvaluator, | |||
'l': LossEvaluator, | |||
'f': BinForetellerEvaluator.standard_creator('foretell'), | |||
'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), | |||
})) | |||
config.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' | |||
################################### | |||
########### Foreteller ############ | |||
################################### | |||
aux = AuxOutput(model, dict( | |||
foretell_pool2=model.layer2[0].pool.attention_layer, | |||
foretell_pool3=model.layer3[0].pool.attention_layer, | |||
foretell_pool4=model.layer4[0].pool.attention_layer, | |||
foretell_avgpool=model.avgpool[0].attention_layer, | |||
)) | |||
aux.configure(config) | |||
output_modifier = OutputModifier(model, | |||
lambda x: x.flatten(1).max(dim=-1).values, | |||
'foretell_pool2', | |||
'foretell_pool3', | |||
'foretell_pool4', | |||
'foretell_avgpool', | |||
) | |||
output_modifier.configure(config) | |||
return config, model |
@@ -0,0 +1,14 @@ | |||
from typing import Tuple | |||
from ...configs.rsna_configs import RSNAConfigs | |||
from ...models.model import Model | |||
from ..entrypoint import BaseEntrypoint | |||
from ...models.rsna.org_inception import RSNAORGInception | |||
class EntryPoint(BaseEntrypoint): | |||
def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: | |||
config = RSNAConfigs('RSNA_Inception_Org', 2, 299, self.phase_type) | |||
model = RSNAORGInception(0.4) | |||
return config, model | |||
@@ -0,0 +1,14 @@ | |||
from typing import Tuple | |||
from ...configs.rsna_configs import RSNAConfigs | |||
from ...models.model import Model | |||
from ..entrypoint import BaseEntrypoint | |||
from ...models.rsna.org_resnet import RSNAORGResNet18 | |||
class EntryPoint(BaseEntrypoint): | |||
def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: | |||
config = RSNAConfigs('RSNA_ResNet_Org', 1, 224, self.phase_type) | |||
model = RSNAORGResNet18() | |||
return config, model |
@@ -0,0 +1,117 @@ | |||
from typing import Tuple | |||
from collections import OrderedDict | |||
from ...configs.rsna_configs import RSNAConfigs | |||
from ...models.model import Model | |||
from ..entrypoint import BaseEntrypoint | |||
from ...models.rsna.lap_inception import RSNALAPInception | |||
from ...modules.lap import LAP | |||
from ...modules.adaptive_lap import AdaptiveLAP | |||
from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator | |||
from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss | |||
from ...utils.aux_output import AuxOutput | |||
from ...utils.output_modifier import OutputModifier | |||
from ...model_evaluation.binary_evaluator import BinaryEvaluator | |||
from ...model_evaluation.binary_fortelling import BinForetellerEvaluator | |||
from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator | |||
from ...model_evaluation.loss_evaluator import LossEvaluator | |||
from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator | |||
def lap_factory(channel): | |||
return LAP( | |||
channel, 2, 2, hidden_channels=[8], | |||
sigmoid_scale=0.1, discriminative_attention=True) | |||
def adaptive_lap_factory(channel): | |||
return AdaptiveLAP(channel, sigmoid_scale=0.1, discriminative_attention=True) | |||
class EntryPoint(BaseEntrypoint): | |||
def __init__(self, phase_type) -> None: | |||
self.active_min_ratio: float = 0.1 | |||
self.active_max_ratio: float = 0.5 | |||
self.inactive_ratio: float = 0.1 | |||
super().__init__(phase_type) | |||
def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: | |||
conf = RSNAConfigs('RSNA_Inception_WS', 4, 256, self.phase_type) | |||
model = RSNALAPInception(0.4, lap_factory, adaptive_lap_factory) | |||
# weakly supervised losses based on discriminative head and attention head | |||
ws_losses = [ | |||
DiscriminativeWeaklySupervisedLoss( | |||
title, model, att_score_layer, 0.25, | |||
self.inactive_ratio, | |||
(self.active_min_ratio, self.active_max_ratio), {0: [1]}, | |||
discr_score_layer, | |||
w_attention_in_ordering=0.2, w_discr_in_ordering=1) | |||
for title, att_score_layer, discr_score_layer in [ | |||
('Pool2', model.maxpool2.attention_layer, model.maxpool2.discrimination_layer), | |||
('Pool6', model.Mixed_6a.pool.attention_layer, model.Mixed_6a.pool.discrimination_layer), | |||
('Pool7', model.Mixed_7a.pool.attention_layer, model.Mixed_7a.pool.discrimination_layer), | |||
('PoolG', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), | |||
] | |||
] | |||
for ws_loss in ws_losses: | |||
ws_loss.configure(conf) | |||
# concordance loss for attention | |||
concordance_loss = PoolConcordanceLossCalculator( | |||
'AC', model, OrderedDict([ | |||
('att2', model.maxpool2.attention_layer), | |||
('att6', model.Mixed_6a.pool.attention_layer), | |||
('att7', model.Mixed_7a.pool.attention_layer), | |||
('attG', model.avgpool[0].attention_layer), | |||
]), loss_weight=1, weights=1 / 4, diff_thresholds=0.1, | |||
labels_by_channel={0: [1]}) | |||
concordance_loss.configure(conf) | |||
# concordance loss for discrimination head | |||
concordance_loss2 = PoolConcordanceLossCalculator( | |||
'DC', model, OrderedDict([ | |||
('D-att2', model.maxpool2.discrimination_layer), | |||
('D-att6', model.Mixed_6a.pool.discrimination_layer), | |||
('D-att7', model.Mixed_7a.pool.discrimination_layer), | |||
('D-attG', model.avgpool[0].discrimination_layer), | |||
]), loss_weight=1, weights=0.5 / 4, diff_thresholds=0, | |||
labels_by_channel={0: [1]}) | |||
concordance_loss2.configure(conf) | |||
conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ | |||
'b': BinaryEvaluator, | |||
'l': LossEvaluator, | |||
'f': BinForetellerEvaluator.standard_creator('foretell'), | |||
'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), | |||
})) | |||
conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' | |||
################################### | |||
########### Foreteller ############ | |||
################################### | |||
aux = AuxOutput(model, dict( | |||
foretell_pool2=model.maxpool2.attention_layer, | |||
foretell_pool6=model.Mixed_6a.pool.attention_layer, | |||
foretell_pool7=model.Mixed_7a.pool.attention_layer, | |||
foretell_avgpool=model.avgpool[0].attention_layer, | |||
)) | |||
aux.configure(conf) | |||
output_modifier = OutputModifier(model, | |||
lambda x: x.flatten(1).max(dim=-1).values, | |||
'foretell_pool2', | |||
'foretell_pool6', | |||
'foretell_pool7', | |||
'foretell_avgpool', | |||
) | |||
output_modifier.configure(conf) | |||
return conf, model | |||
@@ -0,0 +1,105 @@ | |||
from collections import OrderedDict | |||
from typing import Tuple | |||
from ...utils.output_modifier import OutputModifier | |||
from ...utils.aux_output import AuxOutput | |||
from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator | |||
from ...model_evaluation.binary_evaluator import BinaryEvaluator | |||
from ...model_evaluation.binary_fortelling import BinForetellerEvaluator | |||
from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator | |||
from ...model_evaluation.loss_evaluator import LossEvaluator | |||
from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator | |||
from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss | |||
from ...configs.rsna_configs import RSNAConfigs | |||
from ...models.model import Model | |||
from ..entrypoint import BaseEntrypoint | |||
from ...models.rsna.lap_resnet import RSNALAPResNet18 | |||
class EntryPoint(BaseEntrypoint): | |||
def __init__(self, phase_type) -> None: | |||
self.active_min_ratio: float = 0.1 | |||
self.active_max_ratio: float = 0.5 | |||
self.inactive_ratio: float = 0.1 | |||
super().__init__(phase_type) | |||
def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: | |||
config = RSNAConfigs('RSNA_ResNet_WS', 1, 224, self.phase_type) | |||
model = RSNALAPResNet18() | |||
# weakly supervised losses based on discriminative head and attention head | |||
ws_losses = [ | |||
DiscriminativeWeaklySupervisedLoss( | |||
title, model, att_score_layer, 0.25, | |||
self.inactive_ratio, | |||
(self.active_min_ratio, self.active_max_ratio), {0: [1]}, | |||
discr_score_layer, | |||
w_attention_in_ordering=0.2, w_discr_in_ordering=1) | |||
for title, att_score_layer, discr_score_layer in [ | |||
('att2', model.layer2[0].pool.attention_layer, model.layer2[0].pool.discrimination_layer), | |||
('att3', model.layer3[0].pool.attention_layer, model.layer3[0].pool.discrimination_layer), | |||
('att4', model.layer4[0].pool.attention_layer, model.layer4[0].pool.discrimination_layer), | |||
('att5', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), | |||
] | |||
] | |||
for ws_loss in ws_losses: | |||
ws_loss.configure(config) | |||
# concordance loss for attention | |||
concordance_loss = PoolConcordanceLossCalculator( | |||
'AC', model, OrderedDict([ | |||
('att2', model.layer2[0].pool.attention_layer), | |||
('att3', model.layer3[0].pool.attention_layer), | |||
('att4', model.layer4[0].pool.attention_layer), | |||
('att5', model.avgpool[0].attention_layer), | |||
]), loss_weight=1, weights=1 / 4, diff_thresholds=0.1, | |||
labels_by_channel={0: [1]}) | |||
concordance_loss.configure(config) | |||
# concordance loss for discrimination head | |||
concordance_loss2 = PoolConcordanceLossCalculator( | |||
'DC', model, OrderedDict([ | |||
('att2', model.layer2[0].pool.discrimination_layer), | |||
('att3', model.layer3[0].pool.discrimination_layer), | |||
('att4', model.layer4[0].pool.discrimination_layer), | |||
('att5', model.avgpool[0].discrimination_layer), | |||
]), loss_weight=1, weights=0.5 / 4, diff_thresholds=0, | |||
labels_by_channel={0: [1]}) | |||
concordance_loss2.configure(config) | |||
config.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ | |||
'b': BinaryEvaluator, | |||
'l': LossEvaluator, | |||
'f': BinForetellerEvaluator.standard_creator('foretell'), | |||
'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), | |||
})) | |||
config.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' | |||
################################### | |||
########### Foreteller ############ | |||
################################### | |||
aux = AuxOutput(model, dict( | |||
foretell_pool2=model.layer2[0].pool.attention_layer, | |||
foretell_pool3=model.layer3[0].pool.attention_layer, | |||
foretell_pool4=model.layer4[0].pool.attention_layer, | |||
foretell_avgpool=model.avgpool[0].attention_layer, | |||
)) | |||
aux.configure(config) | |||
output_modifier = OutputModifier(model, | |||
lambda x: x.flatten(1).max(dim=-1).values, | |||
'foretell_pool2', | |||
'foretell_pool3', | |||
'foretell_pool4', | |||
'foretell_avgpool', | |||
) | |||
output_modifier.configure(config) | |||
return config, model |
@@ -0,0 +1,7 @@ | |||
""" | |||
Interpretation Modules | |||
""" | |||
from .interpretable import InterpretableModel, CamInterpretableModel | |||
from .interpreter import Interpreter | |||
from .interpretable_wrapper import InterpretableWrapper | |||
from .interpreter_maker import create_interpreter, InterpretationType |
@@ -0,0 +1,148 @@ | |||
from abc import ABC, abstractmethod | |||
from collections import OrderedDict as OrdDict | |||
from typing import Dict, List, Tuple, Union | |||
import torch | |||
from torch import nn | |||
import torch.nn.functional as F | |||
from .interpreter import Interpreter | |||
from .interpretable import InterpretableModel | |||
from ..modules import LAP | |||
from ..utils.hooker import Hooker, Hook | |||
AttentionDict = Dict[torch.Size, torch.Tensor] | |||
class AttentionInterpretableModel(InterpretableModel, ABC): | |||
@property | |||
@abstractmethod | |||
def attention_layers(self) -> Dict[str, List[LAP]]: | |||
""" | |||
List of attention groups | |||
""" | |||
class AttentionInterpreter(Interpreter): | |||
def __init__(self, model: AttentionInterpretableModel): | |||
assert isinstance(model, AttentionInterpretableModel),\ | |||
f"Expected AttentionInterpretableModel, got {type(model)}" | |||
super().__init__(model) | |||
layer_hooks = sum([ | |||
[ | |||
(layer.attention_layer, self._generate_forward_hook(group_name, layer)) | |||
for layer in group_layers | |||
] for group_name, group_layers in model.attention_layers.items() | |||
], []) | |||
self._hooker = Hooker(*layer_hooks) | |||
self._attention_groups: Dict[str, AttentionDict] = { | |||
group_name: {} for group_name in model.attention_layers} | |||
self._impact = 1. | |||
self._impact_decay = .8 | |||
self._input_size: torch.Size = None | |||
self._device = None | |||
def _generate_forward_hook(self, group_name: str, layer: LAP) -> Hook: | |||
def forward_hook(_: nn.Module, __: Tuple[torch.Tensor, ...], out: torch.Tensor) -> None: | |||
processed = out.clone() | |||
shape = processed.shape[2:] | |||
if shape not in self._attention_groups[group_name]: | |||
self._attention_groups[group_name][shape] = torch.zeros_like( | |||
processed) | |||
self._attention_groups[group_name][shape] += processed | |||
return forward_hook | |||
def _reset(self) -> None: | |||
self._attention_groups = { | |||
group_name: {} for group_name in self._model.attention_layers} | |||
self._impact = 1. | |||
def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor: | |||
""" Process current attention level. | |||
Args: | |||
result (torch.Tensor): Attention map till the previous attention level. (B, 1, H1, W1) | |||
attention (torch.Tensor): Current level's attention map. (B, 1, H2, W2) | |||
Returns: | |||
torch.Tensor: New attention map. (B, 1, H2, W2) | |||
""" | |||
self._impact *= self._impact_decay | |||
# smallest attention layer | |||
if result is None: | |||
return attention | |||
# attention is larger than result | |||
scale = torch.ceil( | |||
torch.tensor(attention.shape[2:]) / torch.tensor(result.shape[2:]) | |||
).int().tolist() | |||
# find maximum of attention in each kernel | |||
max_map = F.max_pool2d(attention, kernel_size=scale, stride=scale) | |||
max_map = F.interpolate(max_map, attention.shape[2:], mode='nearest') | |||
is_any_active = max_map > 0 | |||
# interpolate result to attention size | |||
result = F.interpolate(result, attention.shape[2:], mode='nearest') | |||
# # maximum between result and attention | |||
# impacted = torch.max(result, max_map * self._impact) | |||
impacted = torch.max(result, attention * self._impact) | |||
# when attention is zero, we use zero | |||
impacted[attention == 0] = 0 | |||
# impacted = torch.where(attention > 0, impacted, 0) | |||
# when result is non-zero, we will use impacted | |||
impacted[result == 0] = 0 | |||
# impacted = torch.where(result > 0, impacted, 0) | |||
# if max_map is zero, we use result, else we use impacted | |||
result[is_any_active] = impacted[is_any_active] | |||
# result = torch.where(is_any_active, impacted, result) | |||
# mask = (max_map > 0) & (result > 0) | |||
# result[mask] = impacted[mask] | |||
# result[mask & (attention == 0)] = 0 | |||
return result | |||
def _process_attentions(self) -> Dict[str, torch.Tensor]: | |||
sorted_attentions = { | |||
group_name: OrdDict(sorted(group_attention.items(), | |||
key=lambda pair: pair[0])) | |||
for group_name, group_attention in self._attention_groups.items() | |||
} | |||
results: Dict[str, torch.Tensor] = {group_name: None for group_name in sorted_attentions} | |||
for group_name, group_attention in sorted_attentions.items(): | |||
self._impact = 1.0 | |||
# from smallest to largest | |||
for attention in group_attention.values(): | |||
attention = (attention - 0.5).relu() | |||
results[group_name] = self._process_attention(results[group_name], attention) | |||
interpretations = {} | |||
for group_name, group_result in results.items(): | |||
group_result = F.interpolate(group_result, self._input_size, mode='bilinear', align_corners=True) | |||
interpretations[group_name] = group_result | |||
''' | |||
if group_result.shape[1] == 1: | |||
interpretations[group_name] = group_result | |||
else: | |||
interpretations.update({ | |||
f"{group_name}_{i}": group_result[:, i:i+1] | |||
for i in range(group_result.shape[1]) | |||
}) | |||
''' | |||
return interpretations | |||
def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||
self._reset() | |||
self._batch_size = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[0] | |||
self._input_size = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[2:] | |||
self._device = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].device | |||
with self._hooker: | |||
self._model(**inputs) | |||
return self._process_attentions() |
@@ -0,0 +1,48 @@ | |||
from typing import Dict | |||
from collections import OrderedDict as OrdDict | |||
import torch | |||
import torch.nn.functional as F | |||
from .attention_interpreter import AttentionInterpreter | |||
AttentionDict = Dict[torch.Size, torch.Tensor] | |||
class AttentionInterpreterSmoothIntegrator(AttentionInterpreter): | |||
def _process_attentions(self) -> Dict[str, torch.Tensor]: | |||
sorted_attentions = { | |||
group_name: OrdDict(sorted(group_attention.items(), | |||
key=lambda pair: pair[0])) | |||
for group_name, group_attention in self._attention_groups.items() | |||
} | |||
results: Dict[str, torch.Tensor] = {group_name: None for group_name in sorted_attentions} | |||
for group_name, group_attention in sorted_attentions.items(): | |||
self._impact = 1.0 | |||
# from smallest to largest | |||
sum_ = 0 | |||
for attention in group_attention.values(): | |||
if torch.is_tensor(sum_): | |||
sum_ = F.interpolate(sum_, size=attention.shape[-2:]) | |||
sum_ += attention | |||
attention = (attention - 0.5).relu() | |||
results[group_name] = self._process_attention(results[group_name], attention) | |||
results[group_name] = torch.where(results[group_name] > 0, sum_, sum_ / (2 * len(group_attention))) | |||
interpretations = {} | |||
for group_name, group_result in results.items(): | |||
group_result = F.interpolate(group_result, self._input_size, mode='bilinear', align_corners=True) | |||
interpretations[group_name] = group_result | |||
''' | |||
if group_result.shape[1] == 1: | |||
interpretations[group_name] = group_result | |||
else: | |||
interpretations.update({ | |||
f"{group_name}_{i}": group_result[:, i:i+1] | |||
for i in range(group_result.shape[1]) | |||
}) | |||
''' | |||
return interpretations |
@@ -0,0 +1,31 @@ | |||
import torch | |||
import torch.nn.functional as F | |||
from .attention_interpreter import AttentionInterpreter, AttentionInterpretableModel | |||
class AttentionSumInterpreter(AttentionInterpreter): | |||
def __init__(self, model: AttentionInterpretableModel): | |||
assert isinstance(model, AttentionInterpretableModel),\ | |||
f"Expected AttentionInterpretableModel, got {type(model)}" | |||
super().__init__(model) | |||
def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor: | |||
""" Process current attention level. | |||
Args: | |||
result (torch.Tensor): Attention map till the previous attention level. (B, 1, H1, W1) | |||
attention (torch.Tensor): Current level's attention map. (B, 1, H2, W2) | |||
Returns: | |||
torch.Tensor: New attention map. (B, 1, H2, W2) | |||
""" | |||
self._impact *= self._impact_decay | |||
# smallest attention layer | |||
if result is None: | |||
return attention | |||
# attention is larger than result | |||
result = F.interpolate(result, attention.shape[2:], mode='bilinear', align_corners=True) | |||
return result + attention |
@@ -0,0 +1,166 @@ | |||
from typing import List, TYPE_CHECKING, Tuple | |||
import numpy as np | |||
from skimage import measure | |||
from .utils import process_interpretations | |||
if TYPE_CHECKING: | |||
from ..configs.base_config import BaseConfig | |||
from . import Interpreter | |||
class BinaryInterpretationEvaluator2D: | |||
def __init__(self, n_samples: int, config: 'BaseConfig'): | |||
self._config = config | |||
self._n_samples = n_samples | |||
self._min_intersection_threshold = config.acceptable_min_intersection_threshold | |||
self._normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||
self._normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||
self._tk_normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||
self._tk_normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||
def reset(self): | |||
self._normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||
self._normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||
self._tk_normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||
self._tk_normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) | |||
def update_summaries( | |||
self, | |||
m_interpretations: np.ndarray, ground_truth_interpretations: np.ndarray, | |||
net_preds: np.ndarray, ground_truth_labels: np.ndarray, | |||
batch_inds: np.ndarray, interpreter: 'Interpreter' | |||
) -> None: | |||
assert len(ground_truth_interpretations.shape) == 3, f'GT interpretations must have a shape of BxWxH but it is {ground_truth_interpretations.shape}' | |||
assert len(m_interpretations.shape) == 3, f'Model interpretations must have a shape of BxWxH but it is {m_interpretations.shape}' | |||
# skipping samples without interpretations! | |||
has_interpretation_mask = np.logical_not(np.any(np.isnan(ground_truth_interpretations), axis=(1, 2))) | |||
m_interpretations = m_interpretations[has_interpretation_mask] | |||
ground_truths = ground_truth_interpretations[has_interpretation_mask] | |||
net_preds[has_interpretation_mask] | |||
# finding class labels | |||
if len(net_preds.shape) == 1: | |||
net_preds = (net_preds >= 0.5).astype(int) | |||
elif net_preds.shape[1] == 1: | |||
net_preds = (net_preds[:, 0] >= 0.5).astype(int) | |||
else: | |||
net_preds = net_preds.argmax(axis=-1) | |||
ground_truth_labels = ground_truth_labels[has_interpretation_mask] | |||
batch_inds = batch_inds[has_interpretation_mask] | |||
# Checking shapes | |||
if net_preds.shape == ground_truth_labels.shape: | |||
net_preds = np.round(net_preds, 0).astype(int) | |||
else: | |||
net_preds = np.argmax(net_preds, axis=1) | |||
# calculating soundness | |||
soundnesses = (net_preds == ground_truth_labels).astype(int) | |||
c_interpretations = np.clip(m_interpretations, 0, np.amax(m_interpretations)) | |||
b_interpretations = np.stack(tuple([ | |||
process_interpretations(m_interpretations[ind][None, ...], self._config, interpreter) | |||
for ind in range(len(m_interpretations)) | |||
]), axis=0)[:, 0, ...] | |||
b_interpretations = (b_interpretations > 0).astype(bool) | |||
ground_truths = (ground_truths >= 0.5).astype(bool) #making sure values are 0 and 1 even if resize has been applied | |||
assert ground_truths.shape[-2:] == b_interpretations.shape[-2:], f'Ground truth and model interpretations must have the same shape, found {ground_truths.shape[-2:]} and {b_interpretations.shape[-2:]}' | |||
norm_factor = 1.0 * b_interpretations.shape[1] * b_interpretations.shape[2] | |||
np.add.at(self._normed_intersection_per_soundness, soundnesses, | |||
np.sum(b_interpretations & ground_truths, axis=(1, 2)) * 1.0 / norm_factor) | |||
np.add.at(self._normed_union_per_soundness, soundnesses, | |||
np.sum(b_interpretations | ground_truths, axis=(1, 2)) * 1.0 / norm_factor) | |||
for i in range(len(b_interpretations)): | |||
has_nonzero_captured_bbs = False | |||
has_nonzero_captured_bbs_by_topk = False | |||
s = soundnesses[i] | |||
org_labels = measure.label(ground_truths[i, :, :]) | |||
check_labels = measure.label(b_interpretations[i, :, :]) | |||
# keeping topK interpretations with k = n_GT! = finding a threshold by quantile! calculating quantile by GT | |||
n_on_gt = np.sum(ground_truths[i]) | |||
q = (1 + n_on_gt) * 1.0 / (ground_truths.shape[-1] * ground_truths.shape[-2]) | |||
# 1 is added because we have > in thresholding not >= | |||
if q < 1: | |||
tints = c_interpretations[i] | |||
th = max(0, np.quantile(tints.reshape(-1), 1 - q)) | |||
tints = (tints > th) | |||
else: | |||
tints = (c_interpretations[i] > 0) | |||
# TOPK METRICS | |||
tk_intersection = np.sum(tints & ground_truths[i]) | |||
tk_union = np.sum(tints | ground_truths[i]) | |||
self._tk_normed_intersection_per_soundness[s] += tk_intersection * 1.0 / norm_factor | |||
self._tk_normed_union_per_soundness[s] += tk_union * 1.0 / norm_factor | |||
@staticmethod | |||
def get_titles_of_evaluation_metrics() -> List[str]: | |||
return ['S-IOU', 'S-TK-IOU', | |||
'M-IOU', 'M-TK-IOU', | |||
'A-IOU', 'A-TK-IOU'] | |||
@staticmethod | |||
def _get_eval_metrics(normed_intersection, normed_union, title, | |||
tk_normed_intersection, tk_normed_union) -> \ | |||
Tuple[str, str, str, str, str]: | |||
iou = (1e-6 + normed_intersection) / (1e-6 + normed_union) | |||
tk_iou = (1e-6 + tk_normed_intersection) / (1e-6 + tk_normed_union) | |||
return '%.4f' % (iou * 100,), '%.4f' % (tk_iou * 100,) | |||
def get_values_of_evaluation_metrics(self) -> List[str]: | |||
return \ | |||
list(self._get_eval_metrics( | |||
self._normed_intersection_per_soundness[1], | |||
self._normed_union_per_soundness[1], | |||
'Sounds', | |||
self._tk_normed_intersection_per_soundness[1], | |||
self._tk_normed_union_per_soundness[1], | |||
)) + \ | |||
list(self._get_eval_metrics( | |||
self._normed_intersection_per_soundness[0], | |||
self._normed_union_per_soundness[0], | |||
'Mistakes', | |||
self._tk_normed_intersection_per_soundness[0], | |||
self._tk_normed_union_per_soundness[0], | |||
)) + \ | |||
list(self._get_eval_metrics( | |||
sum(self._normed_intersection_per_soundness), | |||
sum(self._normed_union_per_soundness), | |||
'All', | |||
sum(self._tk_normed_intersection_per_soundness), | |||
sum(self._tk_normed_union_per_soundness), | |||
)) | |||
def print_summaries(self): | |||
titles = self.get_titles_of_evaluation_metrics() | |||
vals = self.get_values_of_evaluation_metrics() | |||
nc = 2 | |||
for r in range(3): | |||
print(', '.join(['%s: %s' % (titles[nc * r + i], vals[nc * r + i]) for i in range(nc)])) |
@@ -0,0 +1,43 @@ | |||
from typing import Dict, Union, Tuple | |||
import torch | |||
from . import Interpreter | |||
from ..data.dataflow import DataFlow | |||
from ..models.model import ModelIO | |||
from ..data.data_loader import DataLoader | |||
class InterpretationDataFlow(DataFlow[Tuple[ModelIO, Dict[str, torch.Tensor]]]): | |||
def __init__(self, interpreter: Interpreter, | |||
dataloader: DataLoader, | |||
phase: str, | |||
device: torch.device, | |||
dtype: torch.dtype, | |||
print_debug_info: bool, | |||
label_for_interpreter: Union[int, None], | |||
give_gt_as_label: bool): | |||
super().__init__(interpreter._model, dataloader, phase, device, | |||
dtype, print_debug_info=print_debug_info) | |||
self._interpreter = interpreter | |||
self._label_for_interpreter = label_for_interpreter | |||
self._give_gt_as_label = give_gt_as_label | |||
self._sample_labels = dataloader.get_samples_labels() | |||
@staticmethod | |||
def _detach_output(output: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |||
return { | |||
name: o.detach() for name, o in output.items() | |||
} | |||
def _final_stage(self, model_input: ModelIO) -> Tuple[ModelIO, Dict[str, torch.Tensor]]: | |||
if self._give_gt_as_label: | |||
return model_input, self._detach_output(self._interpreter.interpret( | |||
torch.Tensor(self._sample_labels[ | |||
self._dataloader.get_current_batch_sample_indices()]), | |||
**model_input)) | |||
return model_input, self._detach_output(self._interpreter.interpret( | |||
self._label_for_interpreter, **model_input | |||
)) |
@@ -0,0 +1,52 @@ | |||
""" | |||
Guided Grad Cam Interpretation | |||
""" | |||
from typing import Dict, Union | |||
import torch | |||
import torch.nn.functional as F | |||
from captum import attr | |||
from . import Interpreter, InterpretableModel, InterpretableWrapper | |||
class DeepLift(Interpreter): # pylint: disable=too-few-public-methods | |||
""" | |||
Produces class activation map | |||
""" | |||
def __init__(self, model: InterpretableModel): | |||
super().__init__(model) | |||
self._model_wrapper = InterpretableWrapper(model) | |||
self._interpreter = attr.DeepLift(self._model_wrapper) | |||
self.__B = None | |||
def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||
''' | |||
interprets given input and target class | |||
''' | |||
B = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[0] | |||
if self.__B is None: | |||
self.__B = B | |||
if B < self.__B: | |||
padding = self.__B - B | |||
inputs = { | |||
k: F.pad(v, (*([0] * (v.ndim * 2 - 1)), padding)) | |||
for k, v in inputs.items() | |||
if torch.is_tensor(v) and v.shape[0] == B | |||
} | |||
if torch.is_tensor(labels) and labels.shape[0] == B: | |||
labels = F.pad(labels, (0, padding)) | |||
labels = self._get_target(labels, **inputs) | |||
separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs) | |||
result = self._interpreter.attribute( | |||
separated_inputs.inputs, | |||
target=labels, | |||
additional_forward_args=separated_inputs.additional_inputs)[0] | |||
result = result[:B] | |||
return { | |||
'default': result | |||
} # TODO: must return and support tuple |
@@ -0,0 +1,38 @@ | |||
from typing import Dict, Union | |||
import torch | |||
import torch.nn.functional as F | |||
from captum import attr | |||
from . import Interpreter, CamInterpretableModel, InterpretableWrapper | |||
class GradCam(Interpreter): | |||
def __init__(self, model: CamInterpretableModel): | |||
super().__init__(model) | |||
self.__model_wrapper = InterpretableWrapper(model) | |||
self.__interpreters = [ | |||
attr.LayerGradCam(self.__model_wrapper, conv) | |||
for conv in model.target_conv_layers | |||
] | |||
def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||
inp_shape = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[2:] | |||
labels = self._get_target(labels, **inputs) | |||
separated_inputs = self.__model_wrapper.convert_inputs_to_args(**inputs) | |||
gradcams = [interpreter.attribute( | |||
separated_inputs.inputs, | |||
target=labels, | |||
additional_forward_args=separated_inputs.additional_inputs) | |||
for interpreter in self.__interpreters] | |||
gradcams = [ | |||
cam if torch.is_tensor(cam) else cam[0] | |||
for cam in gradcams | |||
] | |||
gradcams = torch.stack(gradcams).sum(dim=0) | |||
gradcams = F.interpolate(gradcams, size=inp_shape, mode='bilinear', align_corners=True) | |||
return { | |||
'default': gradcams, | |||
} # TODO: must return and support tuple |
@@ -0,0 +1,35 @@ | |||
""" | |||
Guided Backprop Interpretation | |||
""" | |||
from typing import Dict, Union | |||
import torch | |||
from captum import attr | |||
from . import Interpreter, InterpretableModel, InterpretableWrapper | |||
class GuidedBackprop(Interpreter): # pylint: disable=too-few-public-methods | |||
""" | |||
Produces class activation map | |||
""" | |||
def __init__(self, model: InterpretableModel): | |||
super().__init__(model) | |||
self._model_wrapper = InterpretableWrapper(model) | |||
self._interpreter = attr.GuidedBackprop(self._model_wrapper) | |||
def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||
''' | |||
interprets given input and target class | |||
''' | |||
labels = self._get_target(labels, **inputs) | |||
separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs) | |||
return { | |||
'default': self._interpreter.attribute( | |||
separated_inputs.inputs, | |||
target=labels, | |||
additional_forward_args=separated_inputs.additional_inputs)[0] | |||
} # TODO: must return and support tuple |
@@ -0,0 +1,45 @@ | |||
""" | |||
Guided Grad Cam Interpretation | |||
""" | |||
from typing import Dict, Union | |||
import torch | |||
from captum import attr | |||
from . import Interpreter, CamInterpretableModel, InterpretableWrapper | |||
class GuidedGradCam(Interpreter): | |||
""" Produces class activation map """ | |||
def __init__(self, model: CamInterpretableModel): | |||
super().__init__(model) | |||
self._model_wrapper = InterpretableWrapper(model) | |||
self._interpreters = [ | |||
attr.GuidedGradCam(self._model_wrapper, conv) | |||
for conv in model.target_conv_layers | |||
] | |||
def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||
""" Interprets given input and target class | |||
Args: | |||
y (Union[int, torch.Tensor]): target class | |||
**inputs (torch.Tensor): model inputs | |||
Returns: | |||
Dict[str, torch.Tensor]: Interpretation results | |||
""" | |||
labels = self._get_target(labels, **inputs) | |||
separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs) | |||
gradcams = [interpreter.attribute( | |||
separated_inputs.inputs, | |||
target=labels, | |||
additional_forward_args=separated_inputs.additional_inputs)[0] | |||
for interpreter in self._interpreters] | |||
gradcams = torch.stack(gradcams).sum(dim=0) | |||
return { | |||
'default': gradcams, | |||
} # TODO: must return and support tuple |
@@ -0,0 +1,55 @@ | |||
from typing import Callable, Type | |||
import torch | |||
from ..models.model import ModelIO | |||
from ..utils.hooker import Hook, Hooker | |||
from ..data.data_loader import DataLoader | |||
from ..data.dataloader_context import DataloaderContext | |||
from .interpreter import Interpreter | |||
from .attention_interpreter import AttentionInterpretableModel, AttentionInterpreter | |||
class ImagenetPredictionInterpreter(Interpreter): | |||
def __init__(self, model: AttentionInterpretableModel, k: int, base_interpreter: Type[AttentionInterpreter]): | |||
super().__init__(model) | |||
self.__base_interpreter = base_interpreter(model) | |||
self.__k = k | |||
self.__topk_classes: torch.Tensor = None | |||
self.__base_interpreter._hooker = Hooker( | |||
(model, self._generate_prediction_hook()), | |||
*[(attention_layer[-2], hook) | |||
for attention_layer, hook in self.__base_interpreter._hooker._layer_hook_pairs]) | |||
@staticmethod | |||
def standard_factory(k: int, base_interpreter: Type[AttentionInterpreter] = AttentionInterpreter) -> Callable[[AttentionInterpretableModel], 'ImagenetPredictionInterpreter']: | |||
return lambda model: ImagenetPredictionInterpreter(model, k, base_interpreter) | |||
def _generate_prediction_hook(self) -> Hook: | |||
def hook(_, __, output: ModelIO): | |||
topk = output['categorical_probability'].detach()\ | |||
.topk(self.__k, dim=-1) | |||
dataloader: DataLoader = DataloaderContext.instance.dataloader | |||
sample_names = dataloader.get_current_batch_samples_names() | |||
for name, top_class, top_prob in zip(sample_names, topk.indices, topk.values): | |||
print(f'Top classes of ' | |||
f'{name}: ' | |||
f'{top_class.detach().cpu().numpy().tolist()} - ' | |||
f'{top_prob.cpu().numpy().tolist()}', flush=True) | |||
self.__topk_classes = topk\ | |||
.indices\ | |||
.flatten()\ | |||
.cpu() | |||
return hook | |||
def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor: | |||
batch_indices = torch.arange(attention.shape[0]).repeat_interleave(self.__k) | |||
attention = attention[batch_indices, self.__topk_classes] | |||
attention = attention.reshape(-1, self.__k, *attention.shape[-2:]) | |||
return self.__base_interpreter._process_attention(result, attention) | |||
def interpret(self, labels, **inputs): | |||
return self.__base_interpreter.interpret(labels, **inputs) |
@@ -0,0 +1,53 @@ | |||
""" | |||
An Interpretable Pytorch Model | |||
""" | |||
from abc import abstractmethod, ABC | |||
from typing import Iterable, List | |||
import torch | |||
from torch import nn | |||
from ..models.model import Model | |||
class InterpretableModel(Model, ABC): | |||
""" | |||
An Interpretable Pytorch Model | |||
""" | |||
@property | |||
@abstractmethod | |||
def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]: | |||
""" | |||
Returns: | |||
Input module for interpretation | |||
""" | |||
@abstractmethod | |||
def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor: | |||
""" | |||
A method to get probabilities assigned to all the classes in the model's forward, | |||
with shape (B, C), in which B is the batch size and C is number of classes | |||
Args: | |||
*inputs: Inputs to the model | |||
**kwargs: Additional arguments | |||
Returns: | |||
Tensor of categorical probabilities | |||
""" | |||
class CamInterpretableModel(InterpretableModel, ABC): | |||
""" | |||
A model interpretable by gradcam | |||
""" | |||
@property | |||
@abstractmethod | |||
def target_conv_layers(self) -> List[nn.Module]: | |||
""" | |||
Returns: | |||
The convolutional layers to be interpreted. The result of | |||
the interpretation will be sum of the grad-cam of these layers | |||
""" |
@@ -0,0 +1,69 @@ | |||
""" | |||
An adapter wrapper to adapt our models to Captum interpreters | |||
""" | |||
import inspect | |||
from typing import Iterable, Dict, Tuple | |||
from dataclasses import dataclass | |||
from itertools import chain | |||
import torch | |||
from torch import nn | |||
from . import InterpretableModel | |||
@dataclass | |||
class InterpreterArgs: | |||
inputs: Tuple[torch.Tensor, ...] | |||
additional_inputs: Tuple[torch.Tensor, ...] | |||
class InterpretableWrapper(nn.Module): | |||
""" | |||
An adapter wrapper to adapt our models to Captum interpreters | |||
""" | |||
def __init__(self, model: InterpretableModel): | |||
super().__init__() | |||
self._model = model | |||
@property | |||
def _additional_names(self) -> Iterable[str]: | |||
to_be_interpreted_names = self._model.ordered_placeholder_names_to_be_interpreted | |||
signature = inspect.signature(self._model.forward) | |||
return [name for name in signature.parameters | |||
if name not in to_be_interpreted_names | |||
and signature.parameters[name].default is not None] | |||
def convert_inputs_to_kwargs(self, *args: torch.Tensor) -> Dict[str, torch.Tensor]: | |||
""" | |||
Converts an ordered *args to **kwargs | |||
""" | |||
to_be_interpreted_names = self._model.ordered_placeholder_names_to_be_interpreted | |||
additional_names = self._additional_names | |||
inputs = {} | |||
for i, name in enumerate(chain(to_be_interpreted_names, additional_names)): | |||
inputs[name] = args[i] | |||
return inputs | |||
def convert_inputs_to_args(self, **kwargs: torch.Tensor) -> InterpreterArgs: | |||
""" | |||
Converts a **kwargs to ordered *args | |||
""" | |||
return InterpreterArgs( | |||
tuple(kwargs[name] for name in self._model.ordered_placeholder_names_to_be_interpreted | |||
if name in kwargs), | |||
tuple(kwargs[name] for name in self._additional_names | |||
if name in kwargs) | |||
) | |||
def forward(self, *args: torch.Tensor) -> torch.Tensor: | |||
""" | |||
Forwards the model | |||
""" | |||
inputs = self.convert_inputs_to_kwargs(*args) | |||
return self._model.get_categorical_probabilities(**inputs) |
@@ -0,0 +1,42 @@ | |||
from typing import Dict, Union, Tuple | |||
import torch | |||
from .interpreter import Interpreter | |||
from ..data.dataflow import DataFlow | |||
from ..data.data_loader import DataLoader | |||
class InterpretationDataFlow(DataFlow[Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]]): | |||
def __init__(self, interpreter: Interpreter, | |||
dataloader: DataLoader, | |||
device: torch.device, | |||
print_debug_info: bool, | |||
label_for_interpreter: Union[int, None], | |||
give_gt_as_label: bool): | |||
super().__init__(interpreter._model, dataloader, device, print_debug_info=print_debug_info) | |||
self._interpreter = interpreter | |||
self._label_for_interpreter = label_for_interpreter | |||
self._give_gt_as_label = give_gt_as_label | |||
self._sample_labels = dataloader.get_samples_labels() | |||
@staticmethod | |||
def _detach_output(output: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |||
if output is None: | |||
return None | |||
return { | |||
name: o.detach() for name, o in output.items() | |||
} | |||
def _final_stage(self, model_input: Dict[str, torch.Tensor]) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: | |||
if self._give_gt_as_label: | |||
return model_input, self._detach_output(self._interpreter.interpret( | |||
torch.Tensor(self._sample_labels[ | |||
self._dataloader.get_current_batch_sample_indices()]), | |||
**model_input)) | |||
return model_input, self._detach_output(self._interpreter.interpret( | |||
self._label_for_interpreter, **model_input | |||
)) |
@@ -0,0 +1,123 @@ | |||
from typing import TYPE_CHECKING | |||
import traceback | |||
from time import time | |||
from ..data.data_loader import DataLoader | |||
from .interpretable import InterpretableModel | |||
from .interpreter_maker import create_interpreter | |||
from .interpreter import Interpreter | |||
from .interpretation_dataflow import InterpretationDataFlow | |||
from .binary_interpretation_evaluator_2d import BinaryInterpretationEvaluator2D | |||
if TYPE_CHECKING: | |||
from ..configs.base_config import BaseConfig | |||
class InterpretingEvalRunner: | |||
def __init__(self, conf: 'BaseConfig', model: InterpretableModel): | |||
self.conf = conf | |||
self.model = model | |||
def evaluate(self): | |||
interpreter = create_interpreter( | |||
self.conf.interpretation_method, | |||
self.model) | |||
for test_group_info in self.conf.samples_dir.split(','): | |||
try: | |||
print('>> Evaluating interpretations for %s' % test_group_info, flush=True) | |||
t1 = time() | |||
test_data_loader = \ | |||
DataLoader(self.conf, test_group_info, 'test') | |||
evaluator = BinaryInterpretationEvaluator2D( | |||
test_data_loader.get_number_of_samples(), | |||
self.conf | |||
) | |||
self._eval_interpretations(test_data_loader, interpreter, evaluator) | |||
evaluator.print_summaries() | |||
print('Evaluating Interpretations was done in %.2f secs.' % (time() - t1,), flush=True) | |||
except Exception as e: | |||
print('Problem in %s' % test_group_info, flush=True) | |||
track = traceback.format_exc() | |||
print(track, flush=True) | |||
def _eval_interpretations(self, | |||
data_loader: DataLoader, interpreter: Interpreter, | |||
evaluator: BinaryInterpretationEvaluator2D) -> None: | |||
""" CAUTION: Resets the data_loader, | |||
Iterates over the samples (as much as and how data_loader specifies), | |||
finds the interpretations based on the specified interpretation method | |||
and saves the results in the received save_dir!""" | |||
if not isinstance(self.model, InterpretableModel): | |||
raise Exception('Model has not implemented the requirements of the InterpretableModel') | |||
# Running the model in evaluation mode | |||
self.model.eval() | |||
# initiating variables for running evaluations | |||
evaluator.reset() | |||
label_for_interpreter = self.conf.class_label_for_interpretation | |||
give_gt_as_label = label_for_interpreter is None and not self.conf.interpret_predictions_vs_gt | |||
dataflow = InterpretationDataFlow(interpreter, | |||
data_loader, | |||
self.conf.device, | |||
False, | |||
label_for_interpreter, | |||
give_gt_as_label) | |||
n_interpreted_samples = 0 | |||
max_n_samples = self.conf.n_interpretation_samples | |||
with dataflow: | |||
for model_input, interpretation_output in dataflow.iterate(): | |||
n_batch = model_input[ | |||
self.model.ordered_placeholder_names_to_be_interpreted[0]].shape[0] | |||
n_samples_to_save = n_batch | |||
if max_n_samples is not None: | |||
n_samples_to_save = min( | |||
max_n_samples - n_interpreted_samples, | |||
n_batch) | |||
model_outputs = self.model(**model_input) | |||
model_preds = model_outputs[ | |||
self.conf.prediction_key_in_model_output_dict].detach().cpu().numpy() | |||
if self.conf.interpretation_tag_to_evaluate: | |||
interpretations = interpretation_output[ | |||
self.conf.interpretation_tag_to_evaluate | |||
].detach() | |||
else: | |||
interpretations = list(interpretation_output.values())[0].detach() | |||
interpretations = interpretations.detach().cpu().numpy() | |||
evaluator.update_summaries( | |||
interpretations[:n_samples_to_save, 0], | |||
data_loader.get_current_batch_samples_interpretations()[:n_samples_to_save, 0].cpu().numpy(), | |||
model_preds[:n_samples_to_save], | |||
data_loader.get_current_batch_samples_labels()[:n_samples_to_save], | |||
data_loader.get_current_batch_sample_indices()[:n_samples_to_save], | |||
interpreter | |||
) | |||
del interpretation_output | |||
del model_outputs | |||
# if the number of interpreted samples has reached the limit, break | |||
n_interpreted_samples += n_batch | |||
if max_n_samples is not None and n_interpreted_samples >= max_n_samples: | |||
break |
@@ -0,0 +1,67 @@ | |||
""" | |||
An abstract class to apply different interpretation algorithms on torch models. | |||
""" | |||
from abc import abstractmethod, ABC | |||
from typing import Dict, List, Union | |||
import torch | |||
from torch import nn | |||
import numpy as np | |||
from . import InterpretableModel | |||
class Interpreter(ABC): # pylint: disable=too-few-public-methods | |||
""" | |||
An abstract class to apply different interpretation algorithms on torch models. | |||
""" | |||
def __init__(self, model: InterpretableModel): | |||
self._model = model | |||
def _get_all_leaf_children(self, module: nn.Module = None) -> List[nn.Module]: | |||
""" | |||
Gets all leaf modules recursively | |||
""" | |||
if module is None: | |||
module = self._model | |||
children: List[nn.Module] = [] | |||
for child in module.children(): | |||
if len(list(child.children())) == 0: | |||
children.append(child) | |||
else: | |||
children += self._get_all_leaf_children(child) | |||
return children | |||
@staticmethod | |||
def _get_one_hot_output(output: torch.Tensor, | |||
target: Union[int, torch.Tensor, np.ndarray] = None) -> torch.Tensor: | |||
batch_size = output.shape[0] | |||
if target is None: | |||
target = output.argmax(dim=1) | |||
one_hot_output = torch.zeros_like(output) | |||
one_hot_output[torch.arange(batch_size), target] = 1 | |||
return one_hot_output | |||
def _get_target(self, target: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Union[int, torch.Tensor]: | |||
if target is not None: | |||
return target | |||
return self._model.get_categorical_probabilities(**inputs).argmax(dim=1) | |||
@abstractmethod | |||
def interpret(self, labels: Union[int, torch.Tensor, np.ndarray], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||
''' | |||
An abstract method to interpret given input and target class | |||
Returns a dictionary of interpretation results. | |||
''' | |||
def dynamic_threshold(self, x: np.ndarray) -> np.ndarray: | |||
""" | |||
A function to dynamically threshold one sample. | |||
""" | |||
return (x - (x.mean() + x.std())).clip(min=0) |
@@ -0,0 +1,63 @@ | |||
""" | |||
Creates an interpreter object and returns it | |||
""" | |||
from typing import Callable | |||
from . import InterpretableModel, Interpreter | |||
class InterpretationType: | |||
GuidedBackprop = 'GuidedBackprop' | |||
GuidedGradCam = 'GuidedGradCam' | |||
GradCam = 'GradCam' | |||
RelCam = 'RelCam' | |||
DeepLift = 'DeepLift' | |||
Attention = 'Attention' | |||
AttentionSmooth = 'AttentionSmooth' | |||
AttentionSum = 'AttentionSum' | |||
ImagenetAttention = 'ImagenetAttention' | |||
GT = 'GT' | |||
InterpreterMaker = Callable[[InterpretableModel], Interpreter] | |||
def create_interpreter(interpretation_method: str, model: InterpretableModel) -> Interpreter: | |||
""" | |||
Creates an interpreter object and returns it | |||
:param interpretation_method: The method name | |||
:param model: The model to be interpreted | |||
:return: an interpreter object that can run the model and interpret its output | |||
""" | |||
if interpretation_method == InterpretationType.GuidedBackprop: | |||
from .guided_backprop import GuidedBackprop as interpreter_maker | |||
elif interpretation_method == InterpretationType.GuidedGradCam: | |||
from .guided_gradcam import GuidedGradCam as interpreter_maker | |||
elif interpretation_method == InterpretationType.GradCam: | |||
from .gradcam import GradCam as interpreter_maker | |||
elif interpretation_method == InterpretationType.RelCam: | |||
from .relcam import RelCamInterpreter as interpreter_maker | |||
elif interpretation_method == InterpretationType.DeepLift: | |||
from .deep_lift import DeepLift as interpreter_maker | |||
elif interpretation_method == InterpretationType.Attention: | |||
from .attention_interpreter import AttentionInterpreter as interpreter_maker | |||
elif interpretation_method == InterpretationType.AttentionSmooth: | |||
from .attention_interpreter_smooth_integrator import AttentionInterpreterSmoothIntegrator as interpreter_maker | |||
elif interpretation_method == InterpretationType.AttentionSum: | |||
from .attention_sum_interpreter import AttentionSumInterpreter as interpreter_maker | |||
elif interpretation_method == InterpretationType.ImagenetAttention: | |||
from .imagenet_attention_interpreter import ImagenetPredictionInterpreter as interpreter_maker | |||
else: | |||
raise Exception('Unknown interpretation method ', interpretation_method) | |||
return interpreter_maker(model) | |||
@@ -0,0 +1,162 @@ | |||
from typing import Dict, TYPE_CHECKING | |||
from os import makedirs, path | |||
import traceback | |||
from time import time | |||
import warnings | |||
import imageio | |||
import torch | |||
import numpy as np | |||
from ..data.data_loader import DataLoader | |||
from .interpretable import InterpretableModel | |||
from .interpreter_maker import create_interpreter | |||
from .interpreter import Interpreter | |||
from .interpretation_dataflow import InterpretationDataFlow | |||
from .utils import overlay_interpretation | |||
if TYPE_CHECKING: | |||
from ..configs.base_config import BaseConfig | |||
class InterpretingRunner: | |||
def __init__(self, conf: 'BaseConfig', model: InterpretableModel): | |||
self.conf = conf | |||
self.model = model | |||
def interpret(self): | |||
interpreter = create_interpreter( | |||
self.conf.interpretation_method, | |||
self.model) | |||
for test_group_info in self.conf.samples_dir.split(','): | |||
try: | |||
print('>> Finding interpretations for %s' % test_group_info, flush=True) | |||
t1 = time() | |||
labels_to_use = self.conf.mapped_labels_to_use | |||
if labels_to_use is None: | |||
labels_to_use = 'All' | |||
else: | |||
labels_to_use = ','.join([str(x) for x in self.conf.mapped_labels_to_use]) | |||
report_dir = self.conf.get_sample_group_specific_report_dir( | |||
test_group_info, | |||
extra_subdir=f'{self.conf.interpretation_method}-C,{labels_to_use}-cut,{self.conf.cut_threshold}-glob,{self.conf.global_threshold}') | |||
makedirs(report_dir, exist_ok=True) | |||
# writing the whole config | |||
f = open(report_dir + '/conf_info.txt', 'w') | |||
f.write(str(self.conf) + '\n') | |||
f.close() | |||
test_data_loader = \ | |||
DataLoader(self.conf, test_group_info, 'test') | |||
self._interpret(report_dir, test_data_loader, interpreter) | |||
print('Interpretations were saved in %s' % report_dir) | |||
print('Interpreting was done in %.2f secs.' % (time() - t1,), flush=True) | |||
except Exception as e: | |||
print('Problem in %s' % test_group_info, flush=True) | |||
track = traceback.format_exc() | |||
print(track, flush=True) | |||
def _interpret(self, save_dir: str, data_loader: DataLoader, interpreter: Interpreter) -> None: | |||
""" Iterates over the samples (as much as and how data_loader specifies), | |||
finds the interpretations based on the specified interpretation method | |||
and saves the results in the received save_dir!""" | |||
makedirs(save_dir, exist_ok=True) | |||
if not isinstance(self.model, InterpretableModel): | |||
raise Exception('Model has not implemented the requirements of the InterpretableModel') | |||
# Running the model in evaluation mode | |||
self.model.eval() | |||
# initiating variables for running evaluations | |||
label_for_interpreter = self.conf.class_label_for_interpretation | |||
give_gt_as_label = label_for_interpreter is None and not self.conf.interpret_predictions_vs_gt | |||
dataflow = InterpretationDataFlow(interpreter, | |||
data_loader, | |||
self.conf.device, | |||
False, | |||
label_for_interpreter, | |||
give_gt_as_label) | |||
n_interpreted_samples = 0 | |||
max_n_samples = self.conf.n_interpretation_samples | |||
with dataflow: | |||
for model_input, interpretation_output in dataflow.iterate(): | |||
n_batch = model_input[ | |||
self.model.ordered_placeholder_names_to_be_interpreted[0]].shape[0] | |||
n_samples_to_save = n_batch | |||
if max_n_samples is not None: | |||
n_samples_to_save = min( | |||
max_n_samples - n_interpreted_samples, | |||
n_batch) | |||
self._save_interpretations_of_batch( | |||
model_input[self.model.ordered_placeholder_names_to_be_interpreted[0]], | |||
interpretation_output, save_dir, data_loader, n_samples_to_save, | |||
interpreter) | |||
del interpretation_output | |||
# if the number of interpreted samples has reached the limit, break | |||
n_interpreted_samples += n_batch | |||
if max_n_samples is not None and n_interpreted_samples >= max_n_samples: | |||
break | |||
def _save_interpretations_of_batch(self, model_input: torch.Tensor, interpreter_output: Dict[str, torch.Tensor], | |||
save_dir: str, data_loader: DataLoader, | |||
n_samples_to_save: int, interpreter: Interpreter) -> None: | |||
""" | |||
Receives the output of the interpreter and saves the interpretations in the received directory | |||
in a file named as the sample name. | |||
The behaviour can be overwritten in children if extra stuff are required. | |||
:param interpreter_output: | |||
:param save_dir: | |||
:return: None | |||
""" | |||
batch_samples_names = data_loader.get_current_batch_samples_names() | |||
save_dirs = [self.conf.get_save_dir_for_sample( | |||
save_dir, batch_samples_names[bi].replace('../', '')) | |||
for bi in range(len(batch_samples_names))] | |||
for sd in save_dirs: | |||
makedirs(path.dirname(sd), exist_ok=True) | |||
interpreter_output = {name: output.cpu().numpy() for name, output in interpreter_output.items()} | |||
""" Make inputs grayscale """ | |||
model_input = model_input.mean(dim=1).cpu().numpy() | |||
def save(bis): | |||
for bi in bis: | |||
for name, output in interpreter_output.items(): | |||
output = output[bi] | |||
filename = f'{save_dirs[bi]}_{name}' | |||
if self.conf.dynamic_threshold: | |||
output = interpreter.dynamic_threshold(output) | |||
if not self.conf.skip_raw: | |||
np.save(filename, output) | |||
else: | |||
warnings.warn('Skipping raw interpretations') | |||
if not self.conf.skip_overlay: | |||
if output.shape[0] > 3: | |||
output = output[:3] | |||
overlayed = overlay_interpretation(model_input[bi][np.newaxis, ...], output, self.conf) | |||
imageio.imwrite(f'{filename}_overlay.png', overlayed) | |||
else: | |||
warnings.warn('Skipping overlayed interpretations') | |||
save(np.arange(min(len(model_input), n_samples_to_save))) |
@@ -0,0 +1 @@ | |||
from .interpreter import RelCamInterpreter |
@@ -0,0 +1,44 @@ | |||
from typing import Dict, Union | |||
import warnings | |||
import numpy as np | |||
import torch.nn.functional as F | |||
import torch | |||
from ..interpreter import Interpreter | |||
from ..interpretable import CamInterpretableModel | |||
from .relprop import RPProvider | |||
class RelCamInterpreter(Interpreter): | |||
def __init__(self, model: CamInterpretableModel): | |||
super().__init__(model) | |||
self.__targets = model.target_conv_layers | |||
for name, module in model.named_modules(): | |||
if not RPProvider.propable(module): | |||
warnings.warn(f"Module {name} of type {type(module)} is not propable! Hope you know what you are doing!") | |||
continue | |||
RPProvider.create(module) | |||
@staticmethod | |||
def __normalize(x: torch.Tensor) -> torch.Tensor: | |||
fx = x.flatten(1) | |||
minx = fx.min(dim=1).values[:, None, None, None] | |||
maxx = fx.max(dim=1).values[:, None, None, None] | |||
return (x - minx) / (1e-6 + maxx - minx) | |||
def interpret(self, labels: Union[int, torch.Tensor, np.ndarray], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: | |||
x_shape = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape | |||
with RPProvider.capture(self.__targets) as Rs: | |||
z = self._model.get_categorical_probabilities(**inputs) | |||
one_hot_y = self._get_one_hot_output(z, labels) | |||
RPProvider.get(self._model)(one_hot_y) | |||
result = list(Rs.values()) | |||
relcam = torch.stack(result).sum(dim=0)\ | |||
if len(result) > 1\ | |||
else result[0] | |||
relcam = self.__normalize(relcam) | |||
relcam = F.interpolate(relcam, size=x_shape[2:], mode='bilinear', align_corners=False) | |||
return { | |||
'default': relcam, | |||
} |
@@ -0,0 +1,21 @@ | |||
from torch import nn | |||
import torch | |||
class Add(nn.Module): | |||
def forward(self, inputs): | |||
return torch.add(*inputs) | |||
class Multiply(nn.Module): | |||
def forward(self, inputs): | |||
return torch.mul(*inputs) | |||
class Cat(nn.Module): | |||
def forward(self, inputs, dim): | |||
self.__setattr__('dim', dim) | |||
return torch.cat(inputs, dim) |
@@ -0,0 +1,339 @@ | |||
from contextlib import ExitStack, contextmanager | |||
from typing import Callable, Dict, Generic, Iterator, List, Tuple, Type, TypeVar | |||
from uuid import UUID, uuid4 | |||
import torch | |||
import torch.nn.functional as F | |||
from torch import nn | |||
import torchvision | |||
from . import modules as M | |||
TModule = TypeVar('TModule', bound=nn.Module) | |||
TType = TypeVar('TType', bound=type) | |||
def safe_divide(a, b): | |||
return a / (b + b.eq(0).type(b.type()) * 1e-9) * b.ne(0).type(b.type()) | |||
class classdict(Dict[type, TType], Generic[TType]): | |||
def __nearest_available_resolution(self, __k: Type[TModule]) -> Type[TModule]: | |||
return next((r for r in __k.mro() if dict.__contains__(self, r)), None) | |||
def __contains__(self, __k: Type[TModule]) -> bool: | |||
return self.__nearest_available_resolution(__k) is not None | |||
def __getitem__(self, __k: Type[TModule]) -> TType: | |||
r = self.__nearest_available_resolution(__k) | |||
if r is None: | |||
raise KeyError(f"{__k} not found!") | |||
return super().__getitem__(r) | |||
class RPProvider: | |||
__props: Dict[Type[TModule], Type['RelProp[TModule]']] = classdict() | |||
__instances: Dict[TModule, 'RelProp[TModule]'] = {} | |||
__target_results: Dict[nn.Module, torch.Tensor] = {} | |||
@classmethod | |||
def register(cls, *module_types: Type[TModule]): | |||
def decorator(prop_cls): | |||
cls.__props.update({ | |||
module_type: prop_cls | |||
for module_type in module_types | |||
}) | |||
return prop_cls | |||
return decorator | |||
@classmethod | |||
def create(cls, module: TModule): | |||
cls.__instances[module] = prop = cls.__props[type(module)](module) | |||
return prop | |||
@classmethod | |||
def __hook(cls, prop: 'RelProp[TModule]', R: torch.Tensor) -> None: | |||
r_weight = torch.mean(R, dim=(2, 3), keepdim=True) | |||
r_cam = prop.X * r_weight | |||
r_cam = torch.sum(r_cam, dim=1, keepdim=True) | |||
cls.__target_results[prop.module] = r_cam | |||
@classmethod | |||
@contextmanager | |||
def capture(cls, target_layers: List[nn.Module]) -> Iterator[Dict[nn.Module, torch.Tensor]]: | |||
with ExitStack() as stack: | |||
cls.__target_results.clear() | |||
[stack.enter_context(prop.hook_module()) | |||
for prop in cls.__instances.values()] | |||
[stack.enter_context(cls.get(target).register_hook(cls.__hook)) | |||
for target in target_layers] | |||
yield cls.__target_results | |||
@classmethod | |||
def propable(cls, module: TModule) -> bool: | |||
return type(module) in cls.__props | |||
@classmethod | |||
def get(cls, module: TModule) -> 'RelProp[TModule]': | |||
return cls.__instances[module] | |||
@RPProvider.register(nn.Identity, nn.ReLU, nn.LeakyReLU, nn.Dropout, nn.Sigmoid) | |||
class RelProp(Generic[TModule]): | |||
Hook = Callable[['RelProp', torch.Tensor], None] | |||
def __forward_hook(self, _, input: Tuple[torch.Tensor, ...], output: torch.Tensor): | |||
if type(input[0]) in (list, tuple): | |||
self.X = [] | |||
for i in input[0]: | |||
x = i.detach() | |||
x.requires_grad = True | |||
self.X.append(x) | |||
else: | |||
self.X = input[0].detach() | |||
self.X.requires_grad = True | |||
self.Y = output | |||
def __init__(self, module: TModule) -> None: | |||
self.module = module | |||
self.__hooks: Dict[UUID, RelProp.Hook] = {} | |||
@contextmanager | |||
def hook_module(self): | |||
handle = self.module.register_forward_hook(self.__forward_hook) | |||
try: | |||
yield | |||
finally: | |||
handle.remove() | |||
@contextmanager | |||
def register_hook(self, hook: 'RelProp.Hook'): | |||
uuid = uuid4() | |||
self.__hooks[uuid] = hook | |||
try: | |||
yield | |||
finally: | |||
self.__hooks.pop(uuid) | |||
def grad(self, Z, X, S): | |||
C = torch.autograd.grad(Z, X, S, retain_graph=True) | |||
return C | |||
def __call__(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
R = self.rel(R, alpha=alpha) | |||
[hook(self, R) for hook in self.__hooks.values()] | |||
return R | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
return R | |||
@RPProvider.register(nn.MaxPool2d, nn.AvgPool2d, M.Add, torchvision.transforms.transforms.Normalize) | |||
class RelPropSimple(RelProp[TModule]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
Z = self.module.forward(self.X) | |||
S = safe_divide(R, Z) | |||
C = self.grad(Z, self.X, S) | |||
if torch.is_tensor(self.X) == False: | |||
outputs = [] | |||
outputs.append(self.X[0] * C[0]) | |||
outputs.append(self.X[1] * C[1]) | |||
else: | |||
outputs = self.X * C[0] | |||
return outputs | |||
@RPProvider.register(nn.Flatten) | |||
class RelPropFlatten(RelProp[TModule]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
return R.reshape_as(self.X) | |||
@RPProvider.register(nn.AdaptiveAvgPool2d) | |||
class AdaptiveAvgPool2dRelProp(RelProp[nn.AdaptiveAvgPool2d]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
px = torch.clamp(self.X, min=0) | |||
def f(x1): | |||
Z1 = F.adaptive_avg_pool2d(x1, self.module.output_size) | |||
S1 = safe_divide(R, Z1) | |||
C1 = x1 * self.grad(Z1, x1, S1)[0] | |||
return C1 | |||
activator_relevances = f(px) | |||
out = activator_relevances | |||
return out | |||
@RPProvider.register(nn.ZeroPad2d) | |||
class ZeroPad2dRelProp(RelProp[nn.ZeroPad2d]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
Z = self.module.forward(self.X) | |||
S = safe_divide(R, Z) | |||
C = self.grad(Z, self.X, S) | |||
outputs = self.X * C[0] | |||
return outputs | |||
@RPProvider.register(M.Multiply) | |||
class MultiplyRelProp(RelProp[M.Multiply]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
x0 = torch.clamp(self.X[0], min=0) | |||
x1 = torch.clamp(self.X[1], min=0) | |||
x = [x0, x1] | |||
Z = self.module.forward(x) | |||
S = safe_divide(R, Z) | |||
C = self.grad(Z, x, S) | |||
outputs = [] | |||
outputs.append(x[0] * C[0]) | |||
outputs.append(x[1] * C[1]) | |||
return outputs | |||
@RPProvider.register(M.Cat) | |||
class CatRelProp(RelProp[M.Cat]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
Z = self.module.forward(self.X, self.module.dim) | |||
S = safe_divide(R, Z) | |||
C = self.grad(Z, self.X, S) | |||
outputs = [] | |||
for x, c in zip(self.X, C): | |||
outputs.append(x * c) | |||
return outputs | |||
@RPProvider.register(nn.Sequential) | |||
class SequentialRelProp(RelProp[nn.Sequential]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
for m in reversed(self.module): | |||
R = RPProvider.get(m)(R, alpha=alpha) | |||
return R | |||
@RPProvider.register(nn.BatchNorm2d) | |||
class BatchNorm2dRelProp(RelProp[nn.BatchNorm2d]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
X = self.X | |||
beta = 1 - alpha | |||
weight = self.module.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / ( | |||
(self.module.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.module.eps).pow(0.5)) | |||
Z = X * weight + 1e-9 | |||
S = R / Z | |||
Ca = S * weight | |||
R = self.X * (Ca) | |||
return R | |||
@RPProvider.register(nn.Linear) | |||
class LinearRelProp(RelProp[nn.Linear]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
beta = alpha - 1 | |||
pw = torch.clamp(self.module.weight, min=0) | |||
nw = torch.clamp(self.module.weight, max=0) | |||
px = torch.clamp(self.X, min=0) | |||
nx = torch.clamp(self.X, max=0) | |||
def f(w1, w2, x1, x2): | |||
Z1 = F.linear(x1, w1) | |||
Z2 = F.linear(x2, w2) | |||
Z = Z1 + Z2 | |||
S = safe_divide(R, Z) | |||
C1 = x1 * self.grad(Z1, x1, S)[0] | |||
C2 = x2 * self.grad(Z2, x2, S)[0] | |||
return C1 + C2 | |||
activator_relevances = f(pw, nw, px, nx) | |||
inhibitor_relevances = f(nw, pw, px, nx) | |||
out = alpha * activator_relevances - beta*inhibitor_relevances | |||
return out | |||
@RPProvider.register(nn.Conv2d) | |||
class Conv2dRelProp(RelProp[nn.Conv2d]): | |||
def gradprop2(self, DY, weight): | |||
Z = self.module.forward(self.X) | |||
output_padding = self.X.size()[2] - ( | |||
(Z.size()[2] - 1) * self.module.stride[0] - 2 * self.module.padding[0] + self.module.kernel_size[0]) | |||
return F.conv_transpose2d(DY, weight, stride=self.module.stride, padding=self.module.padding, output_padding=output_padding) | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
if self.X.shape[1] == 3: | |||
pw = torch.clamp(self.module.weight, min=0) | |||
nw = torch.clamp(self.module.weight, max=0) | |||
X = self.X | |||
L = self.X * 0 + \ | |||
torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, | |||
keepdim=True)[0] | |||
H = self.X * 0 + \ | |||
torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, | |||
keepdim=True)[0] | |||
Za = torch.conv2d(X, self.module.weight, bias=None, stride=self.module.stride, padding=self.module.padding) - \ | |||
torch.conv2d(L, pw, bias=None, stride=self.module.stride, padding=self.module.padding) - \ | |||
torch.conv2d(H, nw, bias=None, stride=self.module.stride, | |||
padding=self.module.padding) + 1e-9 | |||
S = R / Za | |||
C = X * self.gradprop2(S, self.module.weight) - L * \ | |||
self.gradprop2(S, pw) - H * self.gradprop2(S, nw) | |||
R = C | |||
else: | |||
beta = alpha - 1 | |||
pw = torch.clamp(self.module.weight, min=0) | |||
nw = torch.clamp(self.module.weight, max=0) | |||
px = torch.clamp(self.X, min=0) | |||
nx = torch.clamp(self.X, max=0) | |||
def f(w1, w2, x1, x2): | |||
Z1 = F.conv2d(x1, w1, bias=self.module.bias, stride=self.module.stride, | |||
padding=self.module.padding, groups=self.module.groups) | |||
Z2 = F.conv2d(x2, w2, bias=self.module.bias, stride=self.module.stride, | |||
padding=self.module.padding, groups=self.module.groups) | |||
Z = Z1 + Z2 | |||
S = safe_divide(R, Z) | |||
C1 = x1 * self.grad(Z1, x1, S)[0] | |||
C2 = x2 * self.grad(Z2, x2, S)[0] | |||
return C1 + C2 | |||
activator_relevances = f(pw, nw, px, nx) | |||
inhibitor_relevances = f(nw, pw, px, nx) | |||
R = alpha * activator_relevances - beta * inhibitor_relevances | |||
return R | |||
@RPProvider.register(nn.ConvTranspose2d) | |||
class ConvTranspose2dRelProp(RelProp[nn.ConvTranspose2d]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
pw = torch.clamp(self.module.weight, min=0) | |||
px = torch.clamp(self.X, min=0) | |||
def f(w1, x1): | |||
Z1 = F.conv_transpose2d(x1, w1, bias=None, stride=self.module.stride, | |||
padding=self.module.padding, output_padding=self.module.output_padding) | |||
S1 = safe_divide(R, Z1) | |||
C1 = x1 * self.grad(Z1, x1, S1)[0] | |||
return C1 | |||
activator_relevances = f(pw, px) | |||
R = activator_relevances | |||
return R |
@@ -0,0 +1,102 @@ | |||
from typing import TYPE_CHECKING, Optional | |||
import numpy as np | |||
import cv2 | |||
from skimage.morphology import dilation, erosion | |||
if TYPE_CHECKING: | |||
from ..configs.base_config import BaseConfig | |||
from . import Interpreter | |||
def overlay_interpretation(image: np.ndarray, int_map: np.ndarray, config: 'BaseConfig') -> np.ndarray: | |||
""" | |||
image (np.ndarray): The image to overlay interpretations over. Of shape C W H. The numbers are assumed to be in range [0, 1] | |||
int_map (np.ndarray): The map containing interpretation scores. | |||
config ('Config'): An object containing required information about the configurations. | |||
Returns: | |||
(np.ndarray): The overlayed image | |||
""" | |||
int_map = process_interpretations(int_map, config) # C W H | |||
int_map = np.moveaxis(int_map, 0, -1) # W H C | |||
# if there are more than 3 channels, use the first 3! | |||
if int_map.shape[-1] > 3: | |||
int_map = int_map[..., :3] | |||
# norm by max to make sure! | |||
int_map = int_map * 1.0 / (1e-6 + np.amax(int_map)) | |||
alpha = 0.5 | |||
if np.amin(image) < 0: | |||
image += 0.5 | |||
image = np.moveaxis(image, 0, -1) # W H C | |||
image = image * 255 | |||
o_color_by_c = { | |||
1: [255, 0, 0], | |||
2: [255, 255, 0], | |||
3: [255, 255, 255], | |||
} | |||
o_color = np.asarray(o_color_by_c.get(int_map.shape[-1], None)) | |||
if o_color is None: | |||
raise Exception("Trying to overlay more than 3 maps on the image.") | |||
# if int map has two channels, append one zeros to the end! | |||
if int_map.shape[-1] == 2: | |||
int_map = np.concatenate((int_map, np.zeros((int_map.shape[-3], int_map.shape[-2], 1))), axis=-1) | |||
overlayed = (1 - int_map) * image + \ | |||
int_map * (1 - alpha) * image + \ | |||
int_map * alpha * o_color[np.newaxis, np.newaxis, :] | |||
overlayed = np.round(overlayed).astype(np.uint8) | |||
return overlayed | |||
def process_interpretations(int_map: np.ndarray, config: 'BaseConfig', interpreter: Optional['Interpreter'] = None) -> np.ndarray: | |||
""" | |||
int_map (np.ndarray): The map of interpretation scores. Is assumed to have a shape of C ... | |||
config ('Config'): An object containing the information about the required configurations | |||
Returns: | |||
np.ndarray: Processed interpretation maps, with numbers in the range of [0, 1] | |||
""" | |||
if interpreter and config.dynamic_threshold: | |||
int_map = interpreter.dynamic_threshold(int_map) | |||
if not config.global_threshold: | |||
# Treating map as negative=reverse effect; discarding negatives | |||
int_map[int_map < 0] = 0 | |||
qval = np.amax(int_map) | |||
if qval > 0: | |||
int_map[int_map > qval] = qval | |||
int_map = int_map / qval | |||
# applying cut threshold! | |||
int_map[int_map <= config.cut_threshold] = config.cut_threshold | |||
int_map -= config.cut_threshold | |||
kw = int(np.round(0.1 * int_map.shape[-1])) | |||
mv = np.amax(int_map) | |||
int_map /= (mv + 1e-6) | |||
# dilation and erosion to make continuous objects and erase noises | |||
for c in range(int_map.shape[0]): | |||
int_map[c] = dilation(int_map[c], cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kw + 3, kw + 3))) | |||
int_map[c] = erosion(int_map[c], cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kw, kw))) | |||
int_map[c] = erosion(int_map[c], cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))) | |||
int_map[c] = cv2.blur(np.round(255 * int_map[c]).astype(np.uint8), (5, 5)) * 1.0 / 255.0 | |||
return int_map # [0, 1] | |||
@@ -0,0 +1,100 @@ | |||
from typing import List, TYPE_CHECKING, Dict | |||
import numpy as np | |||
import torch | |||
from ..data.data_loader import DataLoader | |||
from ..models.model import Model | |||
from ..model_evaluation.evaluator import Evaluator | |||
if TYPE_CHECKING: | |||
from ..configs.base_config import BaseConfig | |||
class BinaryEvaluator(Evaluator): | |||
def __init__(self, model: Model, data_loader: DataLoader, conf: 'BaseConfig'): | |||
super(BinaryEvaluator, self).__init__(model, data_loader, conf) | |||
# for summarized information | |||
self.tp: int = 0 | |||
self.tn: int = 0 | |||
self.fp: int = 0 | |||
self.fn: int = 0 | |||
self.avg_loss: float = 0 | |||
self.n_received_samples: int = 0 | |||
def reset(self): | |||
self.tp = 0 | |||
self.tn = 0 | |||
self.fp = 0 | |||
self.fn = 0 | |||
self.avg_loss = 0 | |||
self.n_received_samples = 0 | |||
@property | |||
def _prob_key(self) -> str: | |||
return 'positive_class_probability' | |||
@property | |||
def _loss_key(self) -> str: | |||
return 'loss' | |||
def _get_current_batch_gt(self) -> np.ndarray: | |||
return self.data_loader.get_current_batch_samples_labels() | |||
def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None: | |||
assert self._prob_key in model_output, \ | |||
f"model's output must contain {self._prob_key}" | |||
gt = self._get_current_batch_gt() | |||
prediction = (model_output[self._prob_key].cpu().numpy() >= 0.5).astype(int) | |||
ntp = int(np.sum(np.logical_and(gt == 1, prediction == 1))) | |||
ntn = int(np.sum(np.logical_and(gt == 0, prediction == 0))) | |||
nfp = int(np.sum(np.logical_and(gt == 0, prediction == 1))) | |||
nfn = int(np.sum(np.logical_and(gt == 1, prediction == 0))) | |||
self.tp += ntp | |||
self.tn += ntn | |||
self.fp += nfp | |||
self.fn += nfn | |||
new_n = self.n_received_samples + len(gt) | |||
self.avg_loss = self.avg_loss * (float(self.n_received_samples) / new_n) + \ | |||
model_output.get(self._loss_key, 0.0) * (float(len(gt)) / new_n) | |||
self.n_received_samples = new_n | |||
def get_titles_of_evaluation_metrics(self) -> List[str]: | |||
return ['Loss', 'Acc', 'Sens', 'Spec', 'BAcc', 'N'] | |||
def get_values_of_evaluation_metrics(self) -> List[str]: | |||
p = self.tp + self.fn | |||
n = self.tn + self.fp | |||
if p + n > 0: | |||
accuracy = (self.tp + self.tn) * 100.0 / (n + p) | |||
else: | |||
accuracy = -1 | |||
if p > 0: | |||
sensitivity = 100.0 * self.tp / p | |||
else: | |||
sensitivity = -1 | |||
if n > 0: | |||
specificity = 100.0 * self.tn / max(n, 1) | |||
else: | |||
specificity = -1 | |||
if sensitivity > -1 and specificity > -1: | |||
avg_ss = 0.5 * (sensitivity + specificity) | |||
elif sensitivity > -1: | |||
avg_ss = sensitivity | |||
else: | |||
avg_ss = specificity | |||
return ['%.4f' % self.avg_loss, '%.2f' % accuracy, '%.2f' % sensitivity, | |||
'%.2f' % specificity, '%.2f' % avg_ss, str(self.n_received_samples)] | |||
@@ -0,0 +1,109 @@ | |||
from functools import partial | |||
from typing import List, TYPE_CHECKING, Dict, Callable | |||
import numpy as np | |||
from torch import Tensor | |||
from ..data.data_loader import DataLoader | |||
from ..models.model import Model | |||
from ..model_evaluation.evaluator import Evaluator | |||
if TYPE_CHECKING: | |||
from ..configs.base_config import BaseConfig | |||
class BinFaithfulnessEvaluator(Evaluator): | |||
def __init__(self, kw_prefix: str, gt_kw, model: Model, data_loader: DataLoader, conf: 'BaseConfig'): | |||
super(BinFaithfulnessEvaluator, self).__init__(model, data_loader, conf) | |||
self._tp_by_kw: Dict[str, int] = dict() | |||
self._fp_by_kw: Dict[str, int] = dict() | |||
self._tn_by_kw: Dict[str, int] = dict() | |||
self._fn_by_kw: Dict[str, int] = dict() | |||
self._kw_prefix = kw_prefix | |||
self._gt_kw = gt_kw | |||
def reset(self): | |||
self._tp_by_kw: Dict[str, int] = dict() | |||
self._fp_by_kw: Dict[str, int] = dict() | |||
self._tn_by_kw: Dict[str, int] = dict() | |||
self._fn_by_kw: Dict[str, int] = dict() | |||
def update_summaries_based_on_model_output(self, model_output: Dict[str, Tensor]) -> None: | |||
gt = (model_output[self._gt_kw].detach().cpu().numpy() >= 0.5).astype(int) | |||
# looking for prefixes | |||
for k in model_output.keys(): | |||
if k.startswith(self._kw_prefix): | |||
if k not in self._tp_by_kw: | |||
self._tp_by_kw[k] = 0 | |||
self._tn_by_kw[k] = 0 | |||
self._fp_by_kw[k] = 0 | |||
self._fn_by_kw[k] = 0 | |||
pred = (model_output[k].cpu().numpy() >= 0.5).astype(int) | |||
self._tp_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 1)) | |||
self._fp_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 1)) | |||
self._tn_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 0)) | |||
self._fn_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 0)) | |||
def get_titles_of_evaluation_metrics(self) -> List[str]: | |||
return [f'Loyalty_{k}_{metric}' for k in self._tp_by_kw.keys() for metric in ['Acc', 'Sens', 'Spec', 'AvgSS']] | |||
def _get_values_of_evaluation_metrics(self, kw) -> List[str]: | |||
tp = self._tp_by_kw[kw] | |||
tn = self._tn_by_kw[kw] | |||
fp = self._fp_by_kw[kw] | |||
fn = self._fn_by_kw[kw] | |||
p = tp + fn | |||
n = tn + fp | |||
if p + n > 0: | |||
accuracy = (tp + tn) * 100.0 / (n + p) | |||
else: | |||
accuracy = -1 | |||
if p > 0: | |||
sensitivity = 100.0 * tp / p | |||
else: | |||
sensitivity = -1 | |||
if n > 0: | |||
specificity = 100.0 * tn / max(n, 1) | |||
else: | |||
specificity = -1 | |||
if sensitivity > -1 and specificity > -1: | |||
avg_ss = 0.5 * (sensitivity + specificity) | |||
elif sensitivity > -1: | |||
avg_ss = sensitivity | |||
else: | |||
avg_ss = specificity | |||
return ['%.2f' % accuracy, '%.2f' % sensitivity, | |||
'%.2f' % specificity, '%.2f' % avg_ss] | |||
def get_values_of_evaluation_metrics(self) -> List[str]: | |||
return [ | |||
v | |||
for k in self._tp_by_kw.keys() | |||
for v in self._get_values_of_evaluation_metrics(k)] | |||
@classmethod | |||
def standard_creator(cls, prefix_kw: str, pred_kw: str = 'positive_class_probability') -> Callable[[Model, DataLoader, 'BaseConfig'], 'BinFaithfulnessEvaluator']: | |||
return partial(BinFaithfulnessEvaluator, prefix_kw, pred_kw) | |||
def print_evaluation_metrics(self, title: str) -> None: | |||
""" For more readable printing! """ | |||
print(f'{title}:') | |||
for k in self._tp_by_kw.keys(): | |||
print( | |||
f'\t{k}:: ' + | |||
', '.join([f'{m_name}: {m_val}' | |||
for m_name, m_val in zip(['Acc', 'Sens', 'Spec', 'AvgSS'], self._get_values_of_evaluation_metrics(k))])) |
@@ -0,0 +1,108 @@ | |||
from typing import List, TYPE_CHECKING, Dict, Callable | |||
from functools import partial | |||
import numpy as np | |||
from torch import Tensor | |||
from ..data.data_loader import DataLoader | |||
from ..models.model import Model | |||
from ..model_evaluation.evaluator import Evaluator | |||
if TYPE_CHECKING: | |||
from ..configs.base_config import BaseConfig | |||
class BinForetellerEvaluator(Evaluator): | |||
def __init__(self, kw_prefix: str, model: Model, data_loader: DataLoader, conf: 'BaseConfig'): | |||
super(BinForetellerEvaluator, self).__init__(model, data_loader, conf) | |||
self._tp_by_kw: Dict[str, int] = dict() | |||
self._fp_by_kw: Dict[str, int] = dict() | |||
self._tn_by_kw: Dict[str, int] = dict() | |||
self._fn_by_kw: Dict[str, int] = dict() | |||
self._kw_prefix = kw_prefix | |||
def reset(self): | |||
self._tp_by_kw: Dict[str, int] = dict() | |||
self._fp_by_kw: Dict[str, int] = dict() | |||
self._tn_by_kw: Dict[str, int] = dict() | |||
self._fn_by_kw: Dict[str, int] = dict() | |||
def update_summaries_based_on_model_output(self, model_output: Dict[str, Tensor]) -> None: | |||
gt = self.data_loader.get_current_batch_samples_labels() | |||
# looking for prefixes | |||
for k in model_output.keys(): | |||
if k.startswith(self._kw_prefix): | |||
if k not in self._tp_by_kw: | |||
self._tp_by_kw[k] = 0 | |||
self._tn_by_kw[k] = 0 | |||
self._fp_by_kw[k] = 0 | |||
self._fn_by_kw[k] = 0 | |||
pred = (model_output[k].cpu().numpy() >= 0.5).astype(int) | |||
self._tp_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 1)) | |||
self._fp_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 1)) | |||
self._tn_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 0)) | |||
self._fn_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 0)) | |||
def get_titles_of_evaluation_metrics(self) -> List[str]: | |||
return [f'{k}_{metric}' for k in self._tp_by_kw.keys() for metric in ['Acc', 'Sens', 'Spec', 'AvgSS']] | |||
def _get_values_of_evaluation_metrics(self, kw) -> List[str]: | |||
tp = self._tp_by_kw[kw] | |||
tn = self._tn_by_kw[kw] | |||
fp = self._fp_by_kw[kw] | |||
fn = self._fn_by_kw[kw] | |||
p = tp + fn | |||
n = tn + fp | |||
if p + n > 0: | |||
accuracy = (tp + tn) * 100.0 / (n + p) | |||
else: | |||
accuracy = -1 | |||
if p > 0: | |||
sensitivity = 100.0 * tp / p | |||
else: | |||
sensitivity = -1 | |||
if n > 0: | |||
specificity = 100.0 * tn / max(n, 1) | |||
else: | |||
specificity = -1 | |||
if sensitivity > -1 and specificity > -1: | |||
avg_ss = 0.5 * (sensitivity + specificity) | |||
elif sensitivity > -1: | |||
avg_ss = sensitivity | |||
else: | |||
avg_ss = specificity | |||
return ['%.2f' % accuracy, '%.2f' % sensitivity, | |||
'%.2f' % specificity, '%.2f' % avg_ss] | |||
def get_values_of_evaluation_metrics(self) -> List[str]: | |||
return [ | |||
v | |||
for k in self._tp_by_kw.keys() | |||
for v in self._get_values_of_evaluation_metrics(k)] | |||
@classmethod | |||
def standard_creator(cls, prefix_kw: str) -> Callable[[Model, DataLoader, 'BaseConfig'], 'BinForetellerEvaluator']: | |||
return partial(BinForetellerEvaluator, prefix_kw) | |||
def print_evaluation_metrics(self, title: str) -> None: | |||
""" For more readable printing! """ | |||
print(f'{title}:') | |||
for k in self._tp_by_kw.keys(): | |||
print( | |||
f'\t{k}:: ' + | |||
', '.join([f'{m_name}: {m_val}' | |||
for m_name, m_val in zip(['Acc', 'Sens', 'Spec', 'AvgSS'], self._get_values_of_evaluation_metrics(k))])) |
@@ -0,0 +1,35 @@ | |||
from typing import TYPE_CHECKING, Type | |||
import numpy as np | |||
from ..data.data_loader import DataLoader | |||
from .binary_evaluator import BinaryEvaluator | |||
from ..models.model import Model | |||
if TYPE_CHECKING: | |||
from ..configs.base_config import BaseConfig | |||
from ..data.content_loaders.content_loader import ContentLoader | |||
class TagBinaryEvaluator(BinaryEvaluator): | |||
def __init__(self, model: Model, data_loader: DataLoader, conf: 'BaseConfig', tag: str, cl_type: Type[ContentLoader]): | |||
super().__init__(model, data_loader, conf) | |||
self._tag = tag | |||
self._content_loader = data_loader.get_content_loader_of_interest(cl_type) | |||
@property | |||
def _prob_key(self) -> str: | |||
return f'{self._tag}_positive_class_probability' | |||
@property | |||
def _gt_key(self) -> str: | |||
return self._tag | |||
@property | |||
def _loss_key(self) -> str: | |||
return f'{self._tag}_loss' | |||
def _get_current_batch_gt(self) -> np.ndarray: | |||
return self._content_loader.get_placeholder_name_to_fill_function_dict()[self._gt_key]( | |||
self.data_loader.get_current_batch_sample_indices(), None | |||
) |
@@ -0,0 +1,84 @@ | |||
""" The evaluator """ | |||
from abc import abstractmethod, ABC | |||
from typing import List, TYPE_CHECKING, Dict | |||
import torch | |||
from tqdm import tqdm | |||
from ..data.dataflow import DataFlow | |||
if TYPE_CHECKING: | |||
from ..models.model import Model | |||
from ..data.data_loader import DataLoader | |||
from ..configs.base_config import BaseConfig | |||
class Evaluator(ABC): | |||
""" The evaluator """ | |||
def __init__(self, model: 'Model', data_loader: 'DataLoader', conf: 'BaseConfig'): | |||
""" Model is a predictor, data_loader is a subclass of type data_loading.NormalDataLoader which contains information about the samples and how to | |||
iterate over them. """ | |||
self.model = model | |||
self.data_loader = data_loader | |||
self.conf = conf | |||
self.dataflow = DataFlow[Dict[str, torch.Tensor]](model, data_loader, conf.device) | |||
def evaluate(self, max_iters: int = None): | |||
""" CAUTION: Resets the data_loader, | |||
Iterates over the samples (as much as and how data_loader specifies), | |||
calculates the overall evaluation requirements and prints them. | |||
Title specified the title of the string for printing evaluation metrics. | |||
classes_to_use specifies the labels of the samples to do the function on them, | |||
None means all""" | |||
# setting in dataflow | |||
max_iters = float('inf') if max_iters is None else max_iters | |||
with torch.no_grad(): | |||
# Running the model in evaluation mode | |||
self.model.eval() | |||
# initiating variables for running evaluations | |||
self.reset() | |||
with self.dataflow, tqdm(enumerate(self.dataflow.iterate())) as pbar: | |||
for iters, model_output in pbar: | |||
self.update_summaries_based_on_model_output(model_output) | |||
del model_output | |||
pbar.set_description(self.get_evaluation_metrics()) | |||
if iters + 1 >= max_iters: | |||
break | |||
@abstractmethod | |||
def update_summaries_based_on_model_output( | |||
self, model_output: Dict[str, torch.Tensor]) -> None: | |||
""" Updates the inner variables responsible of keeping a summary over data | |||
based on the new outputs of the model, so when needed evaluation metrics | |||
can be calculated based on these summaries. Mostly used in train and validation phase | |||
or eval phase in which we only need evaluation metrics.""" | |||
@abstractmethod | |||
def reset(self): | |||
""" Resets the held information for a new evaluation round!""" | |||
@abstractmethod | |||
def get_titles_of_evaluation_metrics(self) -> List[str]: | |||
""" Returns a list of strings containing the titles of the evaluation metrics""" | |||
@abstractmethod | |||
def get_values_of_evaluation_metrics(self) -> List[str]: | |||
""" Returns a list of values for the calculated evaluation metrics, | |||
converted to string with the desired format!""" | |||
def get_evaluation_metrics(self) -> None: | |||
return ', '.join(["%s: %s" % (eval_title, eval_value) for (eval_title, eval_value) in | |||
zip( | |||
self.get_titles_of_evaluation_metrics(), | |||
self.get_values_of_evaluation_metrics())]) | |||
def print_evaluation_metrics(self, title: str) -> None: | |||
""" Prints the values of the evaluation metrics""" | |||
print("%s: %s" % (title, self.get_evaluation_metrics()), flush=True) |
@@ -0,0 +1,57 @@ | |||
from typing import List, TYPE_CHECKING, OrderedDict as OrdDict, Dict | |||
from collections import OrderedDict | |||
import torch | |||
from ..data.data_loader import DataLoader | |||
from ..models.model import Model | |||
from ..model_evaluation.evaluator import Evaluator | |||
if TYPE_CHECKING: | |||
from ..configs.base_config import BaseConfig | |||
class LossEvaluator(Evaluator): | |||
def __init__(self, model: Model, data_loader: DataLoader, conf: 'BaseConfig'): | |||
super(LossEvaluator, self).__init__(model, data_loader, conf) | |||
self.phase = conf.phase_type | |||
# for summarized information | |||
self.avg_loss: float = 0 | |||
self._avg_other_losses: OrdDict[str, float] = OrderedDict() | |||
self.n_received_samples: int = 0 | |||
self._n_received_samples_other_losses: OrdDict[str, int] = OrderedDict() | |||
def reset(self): | |||
self.avg_loss = 0 | |||
self._avg_other_losses = OrderedDict() | |||
self.n_received_samples = 0 | |||
self._n_received_samples_other_losses = OrderedDict() | |||
def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None: | |||
n_batch = self.data_loader.get_current_batch_size() | |||
new_n = self.n_received_samples + n_batch | |||
self.avg_loss = self.avg_loss * (float(self.n_received_samples) / new_n) + \ | |||
model_output.get('loss', 0.0) * (float(n_batch) / new_n) | |||
for kw in model_output.keys(): | |||
if 'loss' in kw and kw != 'loss': | |||
old_avg = self._avg_other_losses.get(kw, 0.0) | |||
old_n = self._n_received_samples_other_losses.get(kw, 0) | |||
new_n = old_n + n_batch | |||
self._avg_other_losses[kw] = \ | |||
old_avg * (float(old_n) / new_n) + \ | |||
model_output[kw].detach().cpu() * (float(n_batch) / new_n) | |||
self._n_received_samples_other_losses[kw] = new_n | |||
def get_titles_of_evaluation_metrics(self) -> List[str]: | |||
return ['Loss'] + list(self._avg_other_losses.keys()) | |||
def get_values_of_evaluation_metrics(self) -> List[str]: | |||
return ['%.4f' % self.avg_loss] + ['%.4f' % loss for loss in self._avg_other_losses.values()] |
@@ -0,0 +1,122 @@ | |||
from typing import Any, Callable, List, Dict, TYPE_CHECKING, Optional | |||
from functools import partial | |||
import numpy as np | |||
import torch | |||
from sklearn.metrics import confusion_matrix | |||
from ..data.data_loader import DataLoader | |||
from ..models.model import Model | |||
from ..model_evaluation.evaluator import Evaluator | |||
if TYPE_CHECKING: | |||
from ..configs.base_config import BaseConfig | |||
def _precision_score(cm: np.ndarray): | |||
numinator = np.diag(cm) | |||
denominator = cm.sum(axis=0) | |||
return (numinator[denominator != 0] / denominator[denominator != 0]).mean() | |||
def _recall_score(cm: np.ndarray): | |||
numinator = np.diag(cm) | |||
denominator = cm.sum(axis=1) | |||
return (numinator[denominator != 0] / denominator[denominator != 0]).mean() | |||
def _accuracy_score(cm: np.ndarray): | |||
return np.diag(cm).sum() / cm.sum() | |||
class MulticlassEvaluator(Evaluator): | |||
@classmethod | |||
def standard_creator(cls, class_probability_key: str = 'categorical_probability', | |||
include_top5: bool = False) -> Callable[[Model, DataLoader, 'BaseConfig'], 'MulticlassEvaluator']: | |||
return partial(MulticlassEvaluator, class_probability_key=class_probability_key, include_top5=include_top5) | |||
def __init__(self, model: Model, | |||
data_loader: DataLoader, | |||
conf: 'BaseConfig', | |||
class_probability_key: str = 'categorical_probability', | |||
include_top5: bool = False): | |||
super().__init__(model, data_loader, conf) | |||
# for summarized information | |||
self._avg_loss: float = 0 | |||
self._n_received_samples: int = 0 | |||
self._trues = {} | |||
self._falses = {} | |||
self._top5_trues: int = 0 | |||
self._top5_falses: int = 0 | |||
self._num_iters = 0 | |||
self._predictions: Dict[int, np.ndarray] = {} | |||
self._prediction_weights: Dict[int, np.ndarray] = {} | |||
self._sample_details: Dict[str, Dict[str, Any]] = {} | |||
self._cm: Optional[np.ndarray] = None | |||
self._c = None | |||
self._class_probability_key = class_probability_key | |||
self._include_top5 = include_top5 | |||
def reset(self): | |||
# for summarized information | |||
self._avg_loss: float = 0 | |||
self._n_received_samples: int = 0 | |||
self._trues = {} | |||
self._falses = {} | |||
self._top5_trues: int = 0 | |||
self._top5_falses: int = 0 | |||
self._num_iters = 0 | |||
self._cm: Optional[np.ndarray] = None | |||
def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None: | |||
assert self._class_probability_key in model_output, \ | |||
f"model's output dictionary must contain {self._class_probability_key}" | |||
prediction: np.ndarray = model_output[self._class_probability_key]\ | |||
.cpu().numpy() # B C | |||
c = prediction.shape[1] | |||
ground_truth = self.data_loader.get_current_batch_samples_labels() # B | |||
if self._include_top5: | |||
top5_p = prediction.argsort(axis=1)[:, -5:] | |||
trues = (top5_p == ground_truth[:, None]).any(axis=1).sum() | |||
self._top5_trues += trues | |||
self._top5_falses += len(ground_truth) - trues | |||
prediction = prediction.argmax(axis=1) # B | |||
for i in range(c): | |||
nt = int(np.sum(np.logical_and(ground_truth == i, prediction == i))) | |||
nf = int(np.sum(np.logical_and(ground_truth == i, prediction != i))) | |||
if i not in self._trues: | |||
self._trues[i] = 0 | |||
self._falses[i] = 0 | |||
self._trues[i] += nt | |||
self._falses[i] += nf | |||
self._cm = (0 if self._cm is None else self._cm)\ | |||
+ confusion_matrix(ground_truth, prediction, | |||
labels=np.arange(c)).astype(float) | |||
self._num_iters += 1 | |||
new_n = self._n_received_samples + len(ground_truth) | |||
self._avg_loss = self._avg_loss * (float(self._n_received_samples) / new_n) + \ | |||
model_output.get('loss', 0.0) * (float(len(ground_truth)) / new_n) | |||
self._n_received_samples = new_n | |||
def get_titles_of_evaluation_metrics(self) -> List[str]: | |||
return ['Loss', 'Accuracy', 'Precision', 'Recall'] + \ | |||
(['Top5Acc'] if self._include_top5 else []) | |||
def get_values_of_evaluation_metrics(self) -> List[str]: | |||
accuracy = _accuracy_score(self._cm) * 100 | |||
precision = _precision_score(self._cm) * 100 | |||
recall = _recall_score(self._cm) * 100 | |||
top5_acc = self._top5_trues * 100.0 / (self._top5_trues + self._top5_falses)\ | |||
if self._include_top5 else None | |||
return [f'{self._avg_loss:.4e}', f'{accuracy:8.4f}', f'{precision:8.4f}', f'{recall:8.4f}'] + \ | |||
([f'{top5_acc:8.4f}'] if self._include_top5 else []) |
@@ -0,0 +1,63 @@ | |||
from typing import List, TYPE_CHECKING, Dict, OrderedDict, Type | |||
from collections import OrderedDict as ODict | |||
import torch | |||
from ..data.data_loader import DataLoader | |||
from ..models.model import Model | |||
from ..model_evaluation.evaluator import Evaluator | |||
if TYPE_CHECKING: | |||
from ..configs.base_config import BaseConfig | |||
# If you want your model to be analyzed from different points of view by different evaluators, you can use this holder! | |||
class MultiEvaluatorEvaluator(Evaluator): | |||
def __init__( | |||
self, model: Model, data_loader: DataLoader, conf: 'BaseConfig', | |||
evaluators_cls_by_name: OrderedDict[str, Type[Evaluator]]): | |||
""" | |||
evaluators_cls_by_name: The key is an arbitrary name to call the instance, | |||
cls is the constructor | |||
""" | |||
super(MultiEvaluatorEvaluator, self).__init__(model, data_loader, conf) | |||
# Making all evaluators! | |||
self._evaluators_by_name: OrderedDict[str, Evaluator] = ODict() | |||
for eval_name, eval_cls in evaluators_cls_by_name.items(): | |||
self._evaluators_by_name[eval_name] = eval_cls(model, data_loader, conf) | |||
def reset(self): | |||
for evaluator in self._evaluators_by_name.values(): | |||
evaluator.reset() | |||
def get_titles_of_evaluation_metrics(self) -> List[str]: | |||
titles = [] | |||
for eval_name, evaluator in self._evaluators_by_name.items(): | |||
titles += [eval_name + '_' + t for t in evaluator.get_titles_of_evaluation_metrics()] | |||
return titles | |||
def get_values_of_evaluation_metrics(self) -> List[str]: | |||
metrics = [] | |||
for evaluator in self._evaluators_by_name.values(): | |||
metrics += evaluator.get_values_of_evaluation_metrics() | |||
return metrics | |||
def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None: | |||
for evaluator in self._evaluators_by_name.values(): | |||
evaluator.update_summaries_based_on_model_output(model_output) | |||
@staticmethod | |||
def create_standard_multi_evaluator_evaluator_maker(evaluators_cls_by_name: OrderedDict[str, Type[Evaluator]]): | |||
""" For making a constructor, consistent with the known Evaluator""" | |||
def typical_maker(model: Model, data_loader: DataLoader, conf: 'BaseConfig') -> MultiEvaluatorEvaluator: | |||
return MultiEvaluatorEvaluator(model, data_loader, conf, evaluators_cls_by_name) | |||
return typical_maker | |||
def print_evaluation_metrics(self, title: str) -> None: | |||
""" For more readable printing! """ | |||
print(f'{title}:') | |||
for e_name, e_obj in self._evaluators_by_name.items(): | |||
e_obj.print_evaluation_metrics('\t' + e_name) |
@@ -0,0 +1,61 @@ | |||
import typing | |||
from typing import Dict, Iterable | |||
from collections import OrderedDict | |||
import torch | |||
from torch.nn import functional as F | |||
import torchvision | |||
from ..lap_inception import LAPInception | |||
class CelebALAPInception(LAPInception): | |||
def __init__(self, tag:str, aux_weight: float, pool_factory, adaptive_pool_factory): | |||
super().__init__(aux_weight, n_classes=1, pool_factory=pool_factory, adaptive_pool_factory=adaptive_pool_factory) | |||
self._tag = tag | |||
@property | |||
def additional_kwargs(self) -> typing.OrderedDict[str, bool]: | |||
r""" Returns a dictionary from additional `kwargs` names to their optionality """ | |||
return OrderedDict({ | |||
f'{self._tag}': True, | |||
}) | |||
def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]: | |||
# x: B 3 224 224 | |||
if self.training: | |||
out, aux = torchvision.models.Inception3.forward(self, x) # B 1 | |||
out, aux = out.flatten(), aux.flatten() # B | |||
else: | |||
out = torchvision.models.Inception3.forward(self, x).flatten() # B | |||
aux = None | |||
output = dict() | |||
output['positive_class_probability'] = out | |||
if f'{self._tag}' not in gts: | |||
return output | |||
gt = gts[f'{self._tag}'] | |||
r""" Class weighted loss """ | |||
loss = torch.mean(torch.stack(tuple( | |||
F.binary_cross_entropy(out[gt == i], gt[gt == i]) for i in gt.unique() | |||
))) | |||
output['loss'] = loss | |||
return output | |||
""" INTERPRETATION """ | |||
@property | |||
def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]: | |||
""" | |||
:return: input module for interpretation | |||
""" | |||
return ['x'] | |||
@@ -0,0 +1,53 @@ | |||
from collections import OrderedDict | |||
from typing import Dict, List | |||
import typing | |||
import torch | |||
from ...modules import LAP, AdaptiveLAP | |||
from ..lap_resnet import LapBasicBlock, LapResNet, PoolFactory | |||
def pool_factory(channels, sigmoid_scale=1.0): | |||
return LAP(channels, | |||
sigmoid_scale=sigmoid_scale, | |||
n_attention_heads=3, | |||
discriminative_attention=True) | |||
def adaptive_pool_factory(channels, sigmoid_scale=1.0): | |||
return AdaptiveLAP(channels, | |||
sigmoid_scale=sigmoid_scale, | |||
n_attention_heads=3, | |||
discriminative_attention=True) | |||
class CelebALAPResNet18(LapResNet): | |||
def __init__(self, tag: str, | |||
pool_factory: PoolFactory = pool_factory, | |||
adaptive_factory: PoolFactory = adaptive_pool_factory, | |||
sigmoid_scale: float = 1.0): | |||
super().__init__(LapBasicBlock, [2, 2, 2, 2], | |||
pool_factory=pool_factory, | |||
adaptive_factory=adaptive_factory, | |||
sigmoid_scale=sigmoid_scale, | |||
binary=True) | |||
self._tag = tag | |||
@property | |||
def additional_kwargs(self) -> typing.OrderedDict[str, bool]: | |||
r""" Returns a dictionary from additional `kwargs` names to their optionality """ | |||
return OrderedDict({ | |||
f'{self._tag}': True, | |||
}) | |||
def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]: | |||
y = gts[f'{self._tag}'] | |||
return super().forward(x, y) | |||
@property | |||
def attention_layers(self) -> Dict[str, List[LAP]]: | |||
res = super().attention_layers | |||
res.pop('4_overall') | |||
return res |
@@ -0,0 +1,25 @@ | |||
from collections import OrderedDict | |||
from typing import Dict | |||
import typing | |||
import torch | |||
from ..tv_resnet import BasicBlock, ResNet | |||
class CelebAORGResNet18(ResNet): | |||
def __init__(self, tag: str): | |||
super().__init__(BasicBlock, [2, 2, 2, 2], binary=True) | |||
self._tag = tag | |||
@property | |||
def additional_kwargs(self) -> typing.OrderedDict[str, bool]: | |||
r""" Returns a dictionary from additional `kwargs` names to their optionality """ | |||
return OrderedDict({ | |||
f'{self._tag}': True, | |||
}) | |||
def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]: | |||
y = gts[f'{self._tag}'] | |||
return super().forward(x, y) |
@@ -0,0 +1,67 @@ | |||
import typing | |||
from typing import Dict, List, Iterable | |||
from collections import OrderedDict | |||
import torch | |||
from torch import nn | |||
from torch.nn import functional as F | |||
import torchvision | |||
from ..tv_inception import Inception3 | |||
class CelebAORGInception(Inception3): | |||
def __init__(self, tag: str, aux_weight: float): | |||
super().__init__(aux_weight, n_classes=1) | |||
self._tag = tag | |||
@property | |||
def additional_kwargs(self) -> typing.OrderedDict[str, bool]: | |||
r""" Returns a dictionary from additional `kwargs` names to their optionality """ | |||
return OrderedDict({ | |||
f'{self._tag}': True, | |||
}) | |||
def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]: | |||
if self.training: | |||
out, aux = torchvision.models.Inception3.forward(self, x) # B 1 | |||
out, aux = out.flatten(), aux.flatten() # B | |||
else: | |||
out = torchvision.models.Inception3.forward(self, x).flatten() # B | |||
aux = None | |||
output = dict() | |||
output['positive_class_probability'] = out | |||
if f'{self._tag}' not in gts: | |||
return output | |||
gt = gts[f'{self._tag}'] | |||
r""" Class weighted loss """ | |||
loss = torch.mean(torch.stack(tuple( | |||
F.binary_cross_entropy(out[gt == i], gt[gt == i]) for i in gt.unique() | |||
))) | |||
output['loss'] = loss | |||
return output | |||
@property | |||
def target_conv_layers(self) -> List[nn.Module]: | |||
return [ | |||
self.Mixed_7c.branch1x1, | |||
self.Mixed_7c.branch3x3_2a, self.Mixed_7c.branch3x3_2b, | |||
self.Mixed_7c.branch3x3dbl_3a, self.Mixed_7c.branch3x3dbl_3b, | |||
self.Mixed_7c.branch_pool | |||
] | |||
@property | |||
def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]: | |||
return ['x'] | |||
def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor: | |||
p = self.forward(*inputs, **kwargs)['positive_class_probability'] | |||
return torch.stack([1 - p, p], dim=1) |
@@ -0,0 +1,34 @@ | |||
from typing import Dict, List | |||
import torch | |||
from torchvision import transforms | |||
from ...modules import LAP, AdaptiveLAP, ScaledSigmoid | |||
from ..lap_resnet import LapBottleneck, LapResNet, PoolFactory | |||
def pool_factory(channels, sigmoid_scale=1.0): | |||
return LAP(channels, | |||
hidden_channels=[1000], | |||
sigmoid_scale=sigmoid_scale, | |||
hidden_activation=ScaledSigmoid.get_factory(sigmoid_scale)) | |||
class ImagenetLAPResNet50(LapResNet): | |||
def __init__(self, pool_factory: PoolFactory = pool_factory, sigmoid_scale: float = 1.0): | |||
super().__init__(LapBottleneck, [3, 4, 6, 3], | |||
pool_factory=pool_factory, | |||
sigmoid_scale=sigmoid_scale, | |||
lap_positions=[4]) | |||
self.normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> Dict[str, torch.Tensor]: | |||
x = self.normalize(x) | |||
return super().forward(x, y) | |||
@property | |||
def attention_layers(self) -> Dict[str, List[LAP]]: | |||
return { | |||
'2_layer4': super().attention_layers['2_layer4'], | |||
} |
@@ -0,0 +1,556 @@ | |||
from typing import List, Optional, Callable, Any, Iterable, Dict | |||
import torch | |||
from torch import Tensor, nn | |||
import torchvision | |||
from ..modules.lap import LAP | |||
from ..interpreting.attention_interpreter import AttentionInterpretableModel | |||
from ..interpreting.interpretable import CamInterpretableModel | |||
from ..interpreting.relcam.relprop import RPProvider, RelProp | |||
from ..interpreting.relcam import modules as M | |||
class InceptionB(nn.Module): | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
pool_factory, | |||
conv_block: Optional[Callable[..., nn.Module]] = None | |||
) -> None: | |||
super(InceptionB, self).__init__() | |||
if conv_block is None: | |||
conv_block = BasicConv2d | |||
#self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) | |||
self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=1, padding=1) | |||
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) | |||
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) | |||
#self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2) | |||
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=1, padding=1) | |||
self.pool = pool_factory(384 + 96 + in_channels) | |||
self.cat = M.Cat() | |||
def _forward(self, x: Tensor) -> List[Tensor]: | |||
branch3x3 = self.branch3x3(x) | |||
branch3x3dbl = self.branch3x3dbl_1(x) | |||
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |||
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) | |||
outputs = [branch3x3, branch3x3dbl, x] | |||
return outputs | |||
def forward(self, x: Tensor) -> Tensor: | |||
outputs = self._forward(x) | |||
return self.pool(self.cat(outputs, 1)) | |||
@RPProvider.register(InceptionB) | |||
class InceptionBRelProp(RelProp[InceptionB]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
R = RPProvider.get(self.module.pool)(R, alpha=alpha) | |||
branch3x3, branch3x3dbl, x1 = RPProvider.get(self.module.cat)(R, alpha=alpha) | |||
branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_3)(branch3x3dbl, alpha=alpha) | |||
branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha) | |||
x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha) | |||
x3 = RPProvider.get(self.module.branch3x3)(branch3x3, alpha=alpha) | |||
return x1 + x2 + x3 | |||
class InceptionD(nn.Module): | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
pool_factory, | |||
conv_block: Optional[Callable[..., nn.Module]] = None | |||
) -> None: | |||
super(InceptionD, self).__init__() | |||
if conv_block is None: | |||
conv_block = BasicConv2d | |||
self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) | |||
#self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) | |||
self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=1, padding=1) | |||
#self.branch3x3_2_stride = get_pooler(pooler_cls, 320, 2, pooler_hidden_layers) | |||
self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1) | |||
self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3)) | |||
self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0)) | |||
#self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2) | |||
self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=1, padding=1) | |||
#self.branch7x7x3_4_stride = get_pooler(pooler_cls, 192, 2, pooler_hidden_layers) | |||
self.pool = pool_factory(320 + 192 + in_channels) | |||
self.cat = M.Cat() | |||
def _forward(self, x: Tensor) -> List[Tensor]: | |||
branch3x3 = self.branch3x3_1(x) | |||
branch3x3 = self.branch3x3_2(branch3x3) | |||
branch7x7x3 = self.branch7x7x3_1(x) | |||
branch7x7x3 = self.branch7x7x3_2(branch7x7x3) | |||
branch7x7x3 = self.branch7x7x3_3(branch7x7x3) | |||
branch7x7x3 = self.branch7x7x3_4(branch7x7x3) | |||
outputs = [branch3x3, branch7x7x3, x] | |||
return outputs | |||
def forward(self, x: Tensor) -> Tensor: | |||
outputs = self._forward(x) | |||
return self.pool(self.cat(outputs, 1)) | |||
@RPProvider.register(InceptionD) | |||
class InceptionDRelProp(RelProp[InceptionD]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
R = RPProvider.get(self.module.pool)(R, alpha=alpha) | |||
branch3x3, branch7x7x3, x1 = RPProvider.get(self.module.cat)(R, alpha=alpha) | |||
branch7x7x3 = RPProvider.get(self.module.branch7x7x3_4)(branch7x7x3, alpha=alpha) | |||
branch7x7x3 = RPProvider.get(self.module.branch7x7x3_3)(branch7x7x3, alpha=alpha) | |||
branch7x7x3 = RPProvider.get(self.module.branch7x7x3_2)(branch7x7x3, alpha=alpha) | |||
x2 = RPProvider.get(self.module.branch7x7x3_1)(branch7x7x3, alpha=alpha) | |||
branch3x3 = RPProvider.get(self.module.branch3x3_2)(branch3x3, alpha=alpha) | |||
x3 = RPProvider.get(self.module.branch3x3_1)(branch3x3, alpha=alpha) | |||
return x1 + x2 + x3 | |||
def inception_b_maker(pool_factory): | |||
return ( | |||
lambda in_channels, conv_block=None: | |||
InceptionB(in_channels, pool_factory, conv_block)) | |||
def inception_d_maker(pool_factory): | |||
return ( | |||
lambda in_channels, conv_block=None: | |||
InceptionD(in_channels, pool_factory, conv_block)) | |||
class InceptionA(nn.Module): | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
pool_features: int, | |||
conv_block: Optional[Callable[..., nn.Module]] = None, | |||
) -> None: | |||
super(InceptionA, self).__init__() | |||
if conv_block is None: | |||
conv_block = BasicConv2d | |||
self.branch1x1 = conv_block(in_channels, 64, kernel_size=1) | |||
self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1) | |||
self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2) | |||
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) | |||
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) | |||
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1) | |||
self.avg_pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1) | |||
self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1) | |||
self.cat = M.Cat() | |||
def _forward(self, x: Tensor) -> List[Tensor]: | |||
branch1x1 = self.branch1x1(x) | |||
branch5x5 = self.branch5x5_1(x) | |||
branch5x5 = self.branch5x5_2(branch5x5) | |||
branch3x3dbl = self.branch3x3dbl_1(x) | |||
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |||
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) | |||
branch_pool = self.avg_pool(x) | |||
branch_pool = self.branch_pool(branch_pool) | |||
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] | |||
return outputs | |||
def forward(self, x: Tensor) -> Tensor: | |||
outputs = self._forward(x) | |||
return self.cat(outputs, 1) | |||
@RPProvider.register(InceptionA) | |||
class InceptionARelProp(RelProp[InceptionA]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
branch1x1, branch5x5, branch3x3dbl, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha) | |||
branch_pool = RPProvider.get(self.module.branch_pool)(branch_pool, alpha=alpha) | |||
x1 = RPProvider.get(self.module.avg_pool)(branch_pool, alpha=alpha) | |||
branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_3)(branch3x3dbl, alpha=alpha) | |||
branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha) | |||
x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha) | |||
branch5x5 = RPProvider.get(self.module.branch5x5_2)(branch5x5, alpha=alpha) | |||
x3 = RPProvider.get(self.module.branch5x5_1)(branch5x5, alpha=alpha) | |||
x4 = RPProvider.get(self.module.branch1x1)(branch1x1, alpha=alpha) | |||
return x1 + x2 + x3 + x4 | |||
class InceptionC(nn.Module): | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
channels_7x7: int, | |||
conv_block: Optional[Callable[..., nn.Module]] = None | |||
) -> None: | |||
super(InceptionC, self).__init__() | |||
if conv_block is None: | |||
conv_block = BasicConv2d | |||
self.branch1x1 = conv_block(in_channels, 192, kernel_size=1) | |||
c7 = channels_7x7 | |||
self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1) | |||
self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) | |||
self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0)) | |||
self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1) | |||
self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) | |||
self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) | |||
self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) | |||
self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3)) | |||
self.avg_pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1) | |||
self.branch_pool = conv_block(in_channels, 192, kernel_size=1) | |||
self.cat = M.Cat() | |||
def _forward(self, x: Tensor) -> List[Tensor]: | |||
branch1x1 = self.branch1x1(x) | |||
branch7x7 = self.branch7x7_1(x) | |||
branch7x7 = self.branch7x7_2(branch7x7) | |||
branch7x7 = self.branch7x7_3(branch7x7) | |||
branch7x7dbl = self.branch7x7dbl_1(x) | |||
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) | |||
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) | |||
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) | |||
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) | |||
branch_pool = self.avg_pool(x) | |||
branch_pool = self.branch_pool(branch_pool) | |||
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] | |||
return outputs | |||
def forward(self, x: Tensor) -> Tensor: | |||
outputs = self._forward(x) | |||
return self.cat(outputs, 1) | |||
@RPProvider.register(InceptionC) | |||
class InceptionCRelProp(RelProp[InceptionC]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
branch1x1, branch7x7, branch7x7dbl, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha) | |||
branch_pool = RPProvider.get(self.module.branch_pool)(branch_pool, alpha=alpha) | |||
x1 = RPProvider.get(self.module.avg_pool)(branch_pool, alpha=alpha) | |||
branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_5)(branch7x7dbl, alpha=alpha) | |||
branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_4)(branch7x7dbl, alpha=alpha) | |||
branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_3)(branch7x7dbl, alpha=alpha) | |||
branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_2)(branch7x7dbl, alpha=alpha) | |||
x2 = RPProvider.get(self.module.branch7x7dbl_1)(branch7x7dbl, alpha=alpha) | |||
branch7x7 = RPProvider.get(self.module.branch7x7_3)(branch7x7, alpha=alpha) | |||
branch7x7 = RPProvider.get(self.module.branch7x7_2)(branch7x7, alpha=alpha) | |||
x3 = RPProvider.get(self.module.branch7x7_1)(branch7x7, alpha=alpha) | |||
x4 = RPProvider.get(self.module.branch1x1)(branch1x1, alpha=alpha) | |||
return x1 + x2 + x3 + x4 | |||
class InceptionE(nn.Module): | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
conv_block: Optional[Callable[..., nn.Module]] = None | |||
) -> None: | |||
super(InceptionE, self).__init__() | |||
if conv_block is None: | |||
conv_block = BasicConv2d | |||
self.branch1x1 = conv_block(in_channels, 320, kernel_size=1) | |||
self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1) | |||
self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) | |||
self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) | |||
self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1) | |||
self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1) | |||
self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) | |||
self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) | |||
self.avg_pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1) | |||
self.branch_pool = conv_block(in_channels, 192, kernel_size=1) | |||
self.cat1 = M.Cat() | |||
self.cat2 = M.Cat() | |||
self.cat3 = M.Cat() | |||
def _forward(self, x: Tensor) -> List[Tensor]: | |||
branch1x1 = self.branch1x1(x) | |||
branch3x3 = self.branch3x3_1(x) | |||
branch3x3 = [ | |||
self.branch3x3_2a(branch3x3), | |||
self.branch3x3_2b(branch3x3), | |||
] | |||
branch3x3 = self.cat1(branch3x3, 1) | |||
branch3x3dbl = self.branch3x3dbl_1(x) | |||
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |||
branch3x3dbl = [ | |||
self.branch3x3dbl_3a(branch3x3dbl), | |||
self.branch3x3dbl_3b(branch3x3dbl), | |||
] | |||
branch3x3dbl = self.cat2(branch3x3dbl, 1) | |||
branch_pool = self.avg_pool(x) | |||
branch_pool = self.branch_pool(branch_pool) | |||
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] | |||
return outputs | |||
def forward(self, x: Tensor) -> Tensor: | |||
outputs = self._forward(x) | |||
return self.cat3(outputs, 1) | |||
@RPProvider.register(InceptionE) | |||
class InceptionERelProp(RelProp[InceptionE]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
branch1x1, branch3x3, branch3x3dbl, branch_pool = RPProvider.get(self.module.cat3)(R, alpha=alpha) | |||
branch_pool = RPProvider.get(self.module.branch_pool)(branch_pool, alpha=alpha) | |||
x1 = RPProvider.get(self.module.avg_pool)(branch_pool, alpha=alpha) | |||
branch3x3dbl_3a, branch3x3dbl_3b = RPProvider.get(self.module.cat2)(branch3x3dbl, alpha=alpha) | |||
branch3x3dbl_1 = RPProvider.get(self.module.branch3x3dbl_3a)(branch3x3dbl_3a, alpha=alpha) | |||
branch3x3dbl_2 = RPProvider.get(self.module.branch3x3dbl_3b)(branch3x3dbl_3b, alpha=alpha) | |||
branch3x3dbl = branch3x3dbl_1 + branch3x3dbl_2 | |||
branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha) | |||
x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha) | |||
branch3x3_2a, branch3x3_2b = RPProvider.get(self.module.cat1)(branch3x3, alpha=alpha) | |||
branch3x3_1 = RPProvider.get(self.module.branch3x3_2a)(branch3x3_2a, alpha=alpha) | |||
branch3x3_2 = RPProvider.get(self.module.branch3x3_2b)(branch3x3_2b, alpha=alpha) | |||
branch3x3 = branch3x3_1 + branch3x3_2 | |||
x3 = RPProvider.get(self.module.branch3x3_1)(branch3x3, alpha=alpha) | |||
x4 = RPProvider.get(self.module.branch1x1)(branch1x1, alpha=alpha) | |||
return x1 + x2 + x3 + x4 | |||
class InceptionAux(nn.Module): | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
num_classes: int, | |||
conv_block: Optional[Callable[..., nn.Module]] = None | |||
) -> None: | |||
super(InceptionAux, self).__init__() | |||
if conv_block is None: | |||
conv_block = BasicConv2d | |||
self.avgpool1 = nn.AvgPool2d(kernel_size=5, stride=3) | |||
self.conv0 = conv_block(in_channels, 128, kernel_size=1) | |||
self.conv1 = conv_block(128, 768, kernel_size=5, padding=2) | |||
self.conv1.stddev = 0.01 # type: ignore[assignment] | |||
self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1)) | |||
self.flatten = nn.Flatten() | |||
self.fc = nn.Linear(768, num_classes) | |||
self.fc.stddev = 0.001 # type: ignore[assignment] | |||
def forward(self, x: Tensor) -> Tensor: | |||
# N x 768 x 17 x 17 | |||
x = self.avgpool1(x) | |||
# N x 768 x 5 x 5 | |||
x = self.conv0(x) | |||
# N x 128 x 5 x 5 | |||
x = self.conv1(x) | |||
# N x 768 x 1 x 1 | |||
# Adaptive average pooling | |||
x = self.avgpool2(x) | |||
# N x 768 x 1 x 1 | |||
x = self.flatten(x) | |||
# N x 768 | |||
x = self.fc(x) | |||
# N x 1000 | |||
return x | |||
@RPProvider.register(InceptionAux) | |||
class InceptionAuxRelProp(RelProp[InceptionAux]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
R = RPProvider.get(self.module.fc)(R, alpha=alpha) | |||
R = RPProvider.get(self.module.flatten)(R, alpha=alpha) | |||
R = RPProvider.get(self.module.avgpool2)(R, alpha=alpha) | |||
R = RPProvider.get(self.module.conv1)(R, alpha=alpha) | |||
R = RPProvider.get(self.module.conv0)(R, alpha=alpha) | |||
R = RPProvider.get(self.module.avgpool1)(R, alpha=alpha) | |||
return R | |||
class BasicConv2d(nn.Module): | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
out_channels: int, | |||
**kwargs: Any | |||
) -> None: | |||
super(BasicConv2d, self).__init__() | |||
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) | |||
self.bn = nn.BatchNorm2d(out_channels, eps=0.001) | |||
self.act = torch.nn.ReLU() | |||
def forward(self, x: Tensor) -> Tensor: | |||
x = self.conv(x) | |||
x = self.bn(x) | |||
return self.act(x) | |||
@RPProvider.register(BasicConv2d) | |||
class BasicConv2dRelProp(RelProp[BasicConv2d]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
R = RPProvider.get(self.module.act)(R, alpha=alpha) | |||
R = RPProvider.get(self.module.bn)(R, alpha=alpha) | |||
R = RPProvider.get(self.module.conv)(R, alpha=alpha) | |||
return R | |||
class LAPInception(AttentionInterpretableModel, CamInterpretableModel, torchvision.models.Inception3): | |||
def __init__(self, aux_weight: float, n_classes, pool_factory, adaptive_pool_factory): | |||
torchvision.models.Inception3.__init__( | |||
self, | |||
transform_input=False, init_weights=False, | |||
inception_blocks=[ | |||
BasicConv2d, InceptionA, inception_b_maker(pool_factory), | |||
InceptionC, inception_d_maker(pool_factory), InceptionE, InceptionAux | |||
]) | |||
self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2, padding=1) | |||
self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3, padding=1) | |||
self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) | |||
self.maxpool1 = pool_factory(64) | |||
self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1,) | |||
self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3, padding=1) | |||
self.maxpool2 = pool_factory(192) | |||
if adaptive_pool_factory is not None: | |||
self.avgpool = nn.Sequential( | |||
adaptive_pool_factory(2048)) | |||
self.fc = nn.Sequential( | |||
nn.Linear(2048, 1), | |||
nn.Sigmoid() | |||
) | |||
self.AuxLogits.fc = nn.Sequential( | |||
nn.Linear(768, 1), | |||
nn.Sigmoid() | |||
) | |||
self.aux_weight = aux_weight | |||
@property | |||
def target_conv_layers(self) -> List[nn.Module]: | |||
return [ | |||
self.Mixed_7c.branch1x1, | |||
self.Mixed_7c.branch3x3_2a, self.Mixed_7c.branch3x3_2b, | |||
self.Mixed_7c.branch3x3dbl_3a, self.Mixed_7c.branch3x3dbl_3b, | |||
self.Mixed_7c.branch_pool | |||
] | |||
@property | |||
def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]: | |||
return ['x'] | |||
def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor: | |||
p = self.forward(*inputs, **kwargs)['positive_class_probability'] | |||
return torch.stack([1 - p, p], dim=1) | |||
@property | |||
def attention_layers(self) -> Dict[str, List[LAP]]: | |||
attention_groups = { | |||
'0_layer2': [self.maxpool2], | |||
'1_layer6': [self.Mixed_6a.pool], | |||
'2_layer7': [self.Mixed_7a.pool], | |||
'3_avgpool': [self.avgpool[0]], | |||
'4_all': [ | |||
self.maxpool2, | |||
self.Mixed_6a.pool, | |||
self.Mixed_7a.pool, | |||
self.avgpool[0]], | |||
} | |||
return attention_groups | |||
@RPProvider.register(LAPInception) | |||
class Inception3Mo4RelProp(RelProp[LAPInception]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
if RPProvider.get(self.module.fc).Y.shape[1] == 1: | |||
R = R[:, -1:] | |||
R = RPProvider.get(self.module.fc)(R, alpha=alpha) # B 2048 | |||
R = R.reshape_as(RPProvider.get(self.module.dropout).Y) # B 2048 1 1 | |||
R = RPProvider.get(self.module.dropout)(R, alpha=alpha) # B 2048 1 1 | |||
R = RPProvider.get(self.module.avgpool)(R, alpha=alpha) # B 2048 8 8 | |||
R = RPProvider.get(self.module.Mixed_7c)(R, alpha=alpha) # B 2048 8 8 | |||
R = RPProvider.get(self.module.Mixed_7b)(R, alpha=alpha) # B 1280 8 8 | |||
R = RPProvider.get(self.module.Mixed_7a)(R, alpha=alpha) # B 768 17 17 | |||
R = RPProvider.get(self.module.Mixed_6e)(R, alpha=alpha) # B 768 17 17 | |||
R = RPProvider.get(self.module.Mixed_6d)(R, alpha=alpha) # B 768 17 17 | |||
R = RPProvider.get(self.module.Mixed_6c)(R, alpha=alpha) # B 768 17 17 | |||
R = RPProvider.get(self.module.Mixed_6b)(R, alpha=alpha) # B 768 17 17 | |||
R = RPProvider.get(self.module.Mixed_6a)(R, alpha=alpha) # B 288 35 35 | |||
R = RPProvider.get(self.module.Mixed_5d)(R, alpha=alpha) # B 288 35 35 | |||
R = RPProvider.get(self.module.Mixed_5c)(R, alpha=alpha) # B 256 35 35 | |||
R = RPProvider.get(self.module.Mixed_5b)(R, alpha=alpha) # B 192 35 35 | |||
R = RPProvider.get(self.module.maxpool2)(R, alpha=alpha) # B 192 71 71 | |||
R = RPProvider.get(self.module.Conv2d_4a_3x3)(R, alpha=alpha) # B 80 73 73 | |||
R = RPProvider.get(self.module.Conv2d_3b_1x1)(R, alpha=alpha) # B 64 73 73 | |||
R = RPProvider.get(self.module.maxpool1)(R, alpha=alpha) # B 64 147 147 | |||
R = RPProvider.get(self.module.Conv2d_2b_3x3)(R, alpha=alpha) # B 32 147 147 | |||
R = RPProvider.get(self.module.Conv2d_2a_3x3)(R, alpha=alpha) # B 32 149 149 | |||
R = RPProvider.get(self.module.Conv2d_1a_3x3)(R, alpha=alpha) # B 3 299 299 | |||
return R |
@@ -0,0 +1,276 @@ | |||
from typing import Callable, Dict, Generic, List, Optional, Type, TypeVar | |||
from typing_extensions import Protocol | |||
import warnings | |||
import torch | |||
from torch import nn | |||
from ..modules import LAP, AdaptiveLAP | |||
from .tv_resnet import BasicBlock, Bottleneck, ResNet | |||
from ..interpreting.relcam.relprop import RPProvider, RelProp, RelPropSimple | |||
from ..interpreting.relcam import modules as M | |||
from ..interpreting.attention_interpreter import AttentionInterpretableModel | |||
class PoolFactory(Protocol): | |||
def __call__(self, channels: int, sigmoid_scale: float = 1.0) -> LAP: ... | |||
lap_factory: PoolFactory = \ | |||
lambda channels, sigmoid_scale=1.0: \ | |||
LAP(channels, sigmoid_scale=sigmoid_scale) | |||
adaptive_lap_factory: PoolFactory = \ | |||
lambda channels, sigmoid_scale=1.0: \ | |||
AdaptiveLAP(channels, sigmoid_scale=sigmoid_scale) | |||
class LapBasicBlock(BasicBlock): | |||
def __init__(self, pool_factory: PoolFactory, inplanes, planes, stride=1, downsample=None, groups=1, | |||
base_width=64, dilation=1, norm_layer=None, sigmoid_scale: float = 1.0): | |||
super().__init__(inplanes, planes, stride=stride, downsample=downsample, | |||
groups=groups, base_width=base_width, dilation=dilation, | |||
norm_layer=norm_layer) | |||
self.pool = None | |||
if stride != 1: | |||
assert downsample is not None | |||
self.conv1 = nn.Conv2d(inplanes, planes, 3, padding=1, bias=False) | |||
self.downsample[0] = nn.Conv2d(inplanes, planes, 1, bias=False) | |||
self.pool = pool_factory(planes * 2, sigmoid_scale=sigmoid_scale) | |||
self.relu3 = nn.ReLU() | |||
self.cat = M.Cat() | |||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
out = self.conv1(x) # B P H W | |||
out = self.bn1(out) # B P H W | |||
out = self.relu(out) # B P H W | |||
if self.downsample is not None: | |||
x = self.downsample(x) # B P H W | |||
x = self.relu2(x) | |||
if self.pool is not None: | |||
poolin = self.cat([out, x], 1) # B 2P H W | |||
poolout: torch.Tensor = self.pool(poolin) # B 2P H/S W/S | |||
out, x = poolout.chunk(2, dim=1) # B P H/S W/S | |||
out = self.conv2(out) # B P H/S W/S | |||
out = self.bn2(out) # B P H/S W/S | |||
out = self.add([out, x]) # B P H/S W/S | |||
out = self.relu3(out) # B P H/S W/S | |||
return out | |||
@RPProvider.register(LapBasicBlock) | |||
class BlockRelProp(RelProp[LapBasicBlock]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
out = RPProvider.get(self.module.relu3)(R, alpha=alpha) | |||
out, x = RPProvider.get(self.module.add)(out, alpha=alpha) | |||
out = RPProvider.get(self.module.bn2)(out, alpha=alpha) | |||
out = RPProvider.get(self.module.conv2)(out, alpha=alpha) | |||
if self.module.pool is not None: | |||
poolout = torch.cat([out, x], dim=1) | |||
poolin = RPProvider.get(self.module.pool)(poolout, alpha=alpha) | |||
out, x = RPProvider.get(self.module.cat)(poolin, alpha=alpha) | |||
if self.module.downsample is not None: | |||
x = RPProvider.get(self.module.relu2)(x, alpha=alpha) | |||
x = RPProvider.get(self.module.downsample)(x, alpha=alpha) | |||
out = RPProvider.get(self.module.relu)(out, alpha=alpha) | |||
out = RPProvider.get(self.module.bn1)(out, alpha=alpha) | |||
x1 = RPProvider.get(self.module.conv1)(out, alpha=alpha) | |||
return x + x1 | |||
class LapBottleneck(Bottleneck): | |||
def __init__(self, pool_factory: PoolFactory, | |||
inplanes: int, | |||
planes: int, | |||
stride: int = 1, | |||
downsample: Optional[nn.Module] = None, | |||
groups: int = 1, | |||
base_width: int = 64, | |||
dilation: int = 1, | |||
norm_layer: Optional[Callable[..., nn.Module]] = None, | |||
sigmoid_scale: float = 1.0) -> None: | |||
super().__init__(inplanes, planes, stride=stride, downsample=downsample, | |||
groups=groups, base_width=base_width, dilation=dilation, | |||
norm_layer=norm_layer) | |||
self.pool = None | |||
if stride != 1: | |||
assert downsample is not None | |||
width = int(planes * (base_width / 64.)) * groups | |||
self.conv2 = nn.Conv2d(width, width, 3, padding=1, bias=False) | |||
self.downsample[0] = nn.Conv2d(inplanes, planes * self.expansion, 1, bias=False) | |||
self.pool = pool_factory(planes * (self.expansion + 1), sigmoid_scale=sigmoid_scale) | |||
self.cat = M.Cat() | |||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
# x # B I H W | |||
out = self.conv1(x) # B P H W | |||
out = self.bn1(out) # B P H W | |||
out = self.relu(out) # B P H W | |||
out = self.conv2(out) # B P H W | |||
out = self.bn2(out) # B P H W | |||
out = self.relu2(out) # B P H W | |||
if self.downsample is not None: | |||
x = self.downsample(x) # B 4P H W | |||
if self.pool is not None: | |||
poolin = self.cat([out, x], 1) # B 5P H W | |||
poolout: torch.Tensor = self.pool(poolin) # B 5P H/S W/S | |||
out, x = poolout.split( # B P H/S W/S | |||
[out.shape[1], x.shape[1]], # B 4P H/S W/S | |||
dim=1) | |||
out = self.conv3(out) # B 4P H/S W/S | |||
out = self.bn3(out) # B 4P H/S W/S | |||
out = self.add([out, x]) # B P H/S W/S | |||
out = self.relu3(out) # B P H/S W/S | |||
return out | |||
@RPProvider.register(LapBottleneck) | |||
class BlockRP(RelProp[LapBottleneck]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
out = RPProvider.get(self.module.relu3)(R, alpha=alpha) | |||
out, x = RPProvider.get(self.module.add)(out, alpha=alpha) | |||
out = RPProvider.get(self.module.bn3)(out, alpha=alpha) | |||
out = RPProvider.get(self.module.conv3)(out, alpha=alpha) | |||
if self.module.pool is not None: | |||
poolout = torch.cat([out, x], dim=1) | |||
poolin = RPProvider.get(self.module.pool)(poolout, alpha=alpha) | |||
out, x = RPProvider.get(self.module.cat)(poolin, alpha=alpha) | |||
if self.module.downsample is not None: | |||
x = RPProvider.get(self.module.downsample)(x, alpha=alpha) | |||
out = RPProvider.get(self.module.relu2)(out, alpha=alpha) | |||
out = RPProvider.get(self.module.bn2)(out, alpha=alpha) | |||
out = RPProvider.get(self.module.conv2)(out, alpha=alpha) | |||
out = RPProvider.get(self.module.relu)(out, alpha=alpha) | |||
out = RPProvider.get(self.module.bn1)(out, alpha=alpha) | |||
x1 = RPProvider.get(self.module.conv1)(out, alpha=alpha) | |||
return x + x1 | |||
TLapBlock = TypeVar('TLapBlock', LapBasicBlock, LapBottleneck) | |||
class BlockFactory(Generic[TLapBlock]): | |||
def __init__(self, block: Type[TLapBlock], pool_factory: PoolFactory, sigmoid_scale: float = 1.0) -> None: | |||
self.expansion = block.expansion | |||
self._block = block | |||
self._pool_factory = pool_factory | |||
self._sigmoid_scale = sigmoid_scale | |||
def __call__(self, *args, **kwargs) -> TLapBlock: | |||
return self._block(self._pool_factory, *args, **kwargs, sigmoid_scale=self._sigmoid_scale) | |||
class Stride(nn.Module): | |||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
return x[:, :, ::2, ::2] | |||
@RPProvider.register(Stride) | |||
class StrideRelProp(RelPropSimple[Stride]): | |||
pass | |||
class LapResNet(ResNet, AttentionInterpretableModel): | |||
def __init__(self, | |||
block: Type[TLapBlock], | |||
layers: List[int], | |||
pool_factory: PoolFactory = lap_factory, | |||
lap_positions: List[int] = [2, 3, 4], | |||
adaptive_factory: PoolFactory = None, | |||
sigmoid_scale: float = 1.0, | |||
binary: bool = False): | |||
super().__init__(BlockFactory(block, pool_factory, sigmoid_scale=sigmoid_scale), layers, binary=binary) | |||
if adaptive_factory is not None: | |||
self.avgpool = nn.Sequential( | |||
adaptive_factory(512, sigmoid_scale=sigmoid_scale), | |||
) | |||
for i in range(2, 5): | |||
if i not in lap_positions: | |||
warnings.warn(f'Putting stride on layer {i}') | |||
getattr(self, 'layer{}'.format(i))[0].pool = Stride() | |||
@property | |||
def attention_layers(self) -> Dict[str, List[LAP]]: | |||
""" | |||
List of attention groups | |||
""" | |||
attention_groups = { | |||
'0_layer2': [self.layer2[0].pool], | |||
'1_layer3': [self.layer3[0].pool], | |||
'2_layer4': [self.layer4[0].pool], | |||
'3_avgpool': [self.avgpool[0]], | |||
'4_overall': [ | |||
self.layer2[0].pool, | |||
self.layer3[0].pool, | |||
self.layer4[0].pool, | |||
self.avgpool[0], | |||
] | |||
} | |||
assert all( | |||
all( | |||
isinstance(layer, LAP) | |||
for layer in attention_group) | |||
for attention_group in attention_groups.values()), \ | |||
"Only LAP is supported for this interpretation method" | |||
return attention_groups | |||
def lap_resnet18(pool_factory: PoolFactory = lap_factory, | |||
adaptive_factory: PoolFactory = None, | |||
sigmoid_scale: float = 1.0, | |||
lap_positions: List[int] = [2, 3, 4], | |||
binary: bool = False) -> LapResNet: | |||
"""Constructs a LAP-ResNet-18 model. | |||
""" | |||
return LapResNet(LapBasicBlock, [2, 2, 2, 2], pool_factory=pool_factory, | |||
adaptive_factory=adaptive_factory, sigmoid_scale=sigmoid_scale, | |||
lap_positions=lap_positions, binary=binary) | |||
def lap_resnet50(pool_factory: PoolFactory = lap_factory, | |||
adaptive_factory: PoolFactory = None, | |||
sigmoid_scale: float = 1.0, | |||
lap_positions: List[int] = [2, 3, 4], | |||
binary: bool = False) -> LapResNet: | |||
"""Constructs a LAP-ResNet-50 model. | |||
""" | |||
return LapResNet(LapBottleneck, [3, 4, 6, 3], pool_factory=pool_factory, | |||
adaptive_factory=adaptive_factory, sigmoid_scale=sigmoid_scale, | |||
lap_positions=lap_positions, binary=binary) |
@@ -0,0 +1,147 @@ | |||
""" | |||
The Model base class | |||
""" | |||
from collections import OrderedDict | |||
import inspect | |||
import typing | |||
from typing import Any, Dict, List, Tuple | |||
import re | |||
from abc import ABC | |||
import torch | |||
from torch.nn import BatchNorm2d, BatchNorm1d, BatchNorm3d | |||
ModelIO = Dict[str, torch.Tensor] | |||
class Model(torch.nn.Module, ABC): | |||
""" | |||
The Model base class | |||
""" | |||
_frozen_bns_list: List[torch.nn.Module] = [] | |||
def init_weights_from_other_model(self, pretrained_model_dir=None): | |||
""" Initializes the weights from another model with the given address, | |||
Default is to set all the parameters with the same name and shape, | |||
but can be rewritten in subclasses. | |||
The model to load is received via the function get_other_model_to_load_from. | |||
The default value is the same model!""" | |||
if pretrained_model_dir is None: | |||
print('The model was not preinitialized.', flush=True) | |||
return | |||
other_model_dict = torch.load(pretrained_model_dir) | |||
own_state = self.state_dict() | |||
cnt = 0 | |||
skip_cnt = 0 | |||
for name, param in other_model_dict.items(): | |||
if isinstance(param, torch.nn.Parameter): | |||
# backwards compatibility for serialized parameters | |||
param = param.data | |||
if name in own_state and own_state[name].data.shape == param.shape: | |||
own_state[name].copy_(param) | |||
cnt += 1 | |||
else: | |||
skip_cnt += 1 | |||
print(f'{cnt} out of {len(own_state)} parameters were loaded successfully, {skip_cnt} of the given set were skipped.' , flush=True) | |||
def freeze_parameters(self, regex_list: List[str] = None) -> None: | |||
""" | |||
An auxiliary function for freezing the parameters of the model | |||
using an inclusion keywords list and exclusion keywords list. | |||
:param regex_list: All the parameters matching at least one of the given regexes would be freezed. | |||
None means no parameter would be frozen. | |||
:return: None | |||
""" | |||
# One solution for batchnorms is registering a preforward hook | |||
# to call eval everytime before running them | |||
# But that causes a problem, we can't dynamically change the frozen batchnorms during training | |||
# e.g. in layerwise unfreezing. | |||
# So it's better to set a dynamic attribute that keeps their list, rather than having an init | |||
# Init causes other problems in multi-inheritance | |||
if regex_list is None: | |||
print('No parameter was freezeed!', flush=True) | |||
return | |||
regex_list = [re.compile(the_pattern) for the_pattern in regex_list] | |||
def check_freeze_conditions(param_name): | |||
for the_regex in regex_list: | |||
if the_regex.fullmatch(param_name) is not None: | |||
return True | |||
return False | |||
frozen_cnt = 0 | |||
total_cnt = 0 | |||
frozen_modules_names = dict() | |||
for n, p in self.named_parameters(): | |||
total_cnt += 1 | |||
if check_freeze_conditions(n): | |||
p.requires_grad = False | |||
frozen_cnt += 1 | |||
frozen_modules_names[n[:n.rfind('.')]] = True | |||
self._frozen_bns_list = [] | |||
for module_name, module in self.named_modules(): | |||
if (module_name in frozen_modules_names) and \ | |||
(isinstance(module, BatchNorm2d) or | |||
isinstance(module, BatchNorm1d) or | |||
isinstance(module, BatchNorm3d)): | |||
self._frozen_bns_list.append(module) | |||
print('********************* NOTICE *********************') | |||
print('%d out of %d parameters were frozen!' % (frozen_cnt, total_cnt)) | |||
print( | |||
'%d BatchNorm layers will be frozen by being kept in the val mode, IF YOU HAVE NOT OVERRIDEN IT IN YOUR MAIN MODEL!' % len( | |||
self._frozen_bns_list)) | |||
print('***************************************************', flush=True) | |||
def train(self, mode: bool = True): | |||
super(Model, self).train(mode) | |||
for bn in self._frozen_bns_list: | |||
bn.train(False) | |||
return self | |||
@property | |||
def additional_kwargs(self) -> typing.OrderedDict[str, bool]: | |||
r""" Returns a dictionary from additional `kwargs` names to their optionality """ | |||
return OrderedDict({}) | |||
def get_forward_required_kws_kwargs_and_defaults(self) -> Tuple[List[str], List[Any], Dict[str, Any]]: | |||
"""Returns the names of the keyword arguments required in the forward function and the list of the default values for the ones having them. | |||
The second list might be shorter than the first and it represents the default values from the end of the arguments. | |||
Returns: | |||
Tuple[List[str], List[Any], Dict[str, Any]]: The list of args names and default values for the ones having them (from some point to the end) + the dictionary of kwargs | |||
""" | |||
model_argspec = inspect.getfullargspec(self.forward) | |||
model_forward_args = list(model_argspec.args) | |||
# skipping self arg | |||
model_forward_args = model_forward_args[1:] | |||
args_default_values = model_argspec.defaults | |||
if args_default_values is None: | |||
args_default_values = [] | |||
else: | |||
args_default_values = list(args_default_values) | |||
additional_kwargs = self.additional_kwargs if model_argspec.varkw is not None else {} | |||
return model_forward_args, args_default_values, additional_kwargs | |||
@@ -0,0 +1,40 @@ | |||
from typing import Dict | |||
import torch | |||
from torch.nn import functional as F | |||
import torchvision | |||
from ..lap_inception import LAPInception | |||
class RSNALAPInception(LAPInception): | |||
def __init__(self, aux_weight: float, pool_factory, adaptive_pool_factory): | |||
super().__init__(aux_weight, n_classes=1, pool_factory=pool_factory, adaptive_pool_factory=adaptive_pool_factory) | |||
def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]: | |||
x = x.repeat_interleave(3, dim=1) # B 3 224 224 | |||
if self.training: | |||
out, aux = torchvision.models.Inception3.forward(self, x) # B 1 | |||
out, aux = out.flatten(), aux.flatten() # B | |||
else: | |||
out = torchvision.models.Inception3.forward(self, x).flatten() # B | |||
aux = None | |||
if y is not None: | |||
main_loss = F.binary_cross_entropy(out, y) | |||
if aux is not None: | |||
aux_loss = F.binary_cross_entropy(aux, y) | |||
loss = (main_loss + self.aux_weight * aux_loss) / (1 + self.aux_weight) | |||
else: | |||
loss = main_loss | |||
return { | |||
'positive_class_probability': out, | |||
'loss': loss | |||
} | |||
return { | |||
'positive_class_probability': out | |||
} |
@@ -0,0 +1,39 @@ | |||
from typing import Dict | |||
import torch | |||
from ...modules import LAP, AdaptiveLAP | |||
from ..lap_resnet import LapBasicBlock, LapResNet, PoolFactory | |||
def get_pool_factory(discriminative_attention=True): | |||
def pool_factory(channels, sigmoid_scale=1.0): | |||
return LAP(channels, | |||
hidden_channels=[8], | |||
sigmoid_scale=sigmoid_scale, | |||
discriminative_attention=discriminative_attention) | |||
return pool_factory | |||
def get_adaptive_pool_factory(discriminative_attention=True): | |||
def adaptive_pool_factory(channels, sigmoid_scale=1.0): | |||
return AdaptiveLAP(channels, | |||
sigmoid_scale=sigmoid_scale, | |||
discriminative_attention=discriminative_attention) | |||
return adaptive_pool_factory | |||
class RSNALAPResNet18(LapResNet): | |||
def __init__(self, pool_factory: PoolFactory = get_pool_factory(), | |||
adaptive_factory: PoolFactory = get_adaptive_pool_factory(), | |||
sigmoid_scale: float = 1.0): | |||
super().__init__(LapBasicBlock, [2, 2, 2, 2], | |||
pool_factory=pool_factory, | |||
adaptive_factory=adaptive_factory, | |||
sigmoid_scale=sigmoid_scale, | |||
binary=True) | |||
def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]: | |||
x = x.repeat_interleave(3, dim=1) # B 3 224 224 | |||
return super().forward(x, y) |
@@ -0,0 +1,39 @@ | |||
from typing import Dict | |||
import torch | |||
from torch.nn import functional as F | |||
import torchvision | |||
from ..tv_inception import Inception3 | |||
class RSNAORGInception(Inception3): | |||
def __init__(self, aux_weight: float): | |||
super().__init__(aux_weight, n_classes=1) | |||
def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]: | |||
x = x.repeat_interleave(3, dim=1) # B 3 224 224 | |||
if self.training: | |||
out, aux = torchvision.models.Inception3.forward(self, x) # B 1 | |||
out, aux = out.flatten(), aux.flatten() # B | |||
else: | |||
out = torchvision.models.Inception3.forward(self, x).flatten() # B | |||
aux = None | |||
if y is not None: | |||
main_loss = F.binary_cross_entropy(out, y) | |||
if aux is not None: | |||
aux_loss = F.binary_cross_entropy(aux, y) | |||
loss = (main_loss + self.aux_weight * aux_loss) / (1 + self.aux_weight) | |||
else: | |||
loss = main_loss | |||
return { | |||
'positive_class_probability': out, | |||
'loss': loss | |||
} | |||
return { | |||
'positive_class_probability': out | |||
} |
@@ -0,0 +1,15 @@ | |||
from typing import Dict | |||
import torch | |||
from ..tv_resnet import BasicBlock, ResNet | |||
class RSNAORGResNet18(ResNet): | |||
def __init__(self): | |||
super().__init__(BasicBlock, [2, 2, 2, 2], binary=True) | |||
def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]: | |||
x = x.repeat_interleave(3, dim=1) # B 3 224 224 | |||
return super().forward(x, y) |
@@ -0,0 +1,218 @@ | |||
from typing import Iterable, Optional, Callable, List | |||
import torch | |||
from torch import nn, Tensor | |||
import torchvision | |||
from ..interpreting.interpretable import CamInterpretableModel | |||
from ..interpreting.relcam.relprop import RPProvider, RelProp | |||
from ..interpreting.relcam import modules as M | |||
from .lap_inception import ( | |||
InceptionA, | |||
InceptionC, | |||
InceptionE, | |||
BasicConv2d, | |||
InceptionAux as BaseInceptionAux, | |||
) | |||
class InceptionB(nn.Module): | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
conv_block: Optional[Callable[..., nn.Module]] = None | |||
) -> None: | |||
super(InceptionB, self).__init__() | |||
if conv_block is None: | |||
conv_block = BasicConv2d | |||
self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) | |||
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) | |||
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) | |||
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2) | |||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2) | |||
self.cat = M.Cat() | |||
def _forward(self, x: Tensor) -> List[Tensor]: | |||
branch3x3 = self.branch3x3(x) | |||
branch3x3dbl = self.branch3x3dbl_1(x) | |||
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |||
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) | |||
branch_pool = self.maxpool(x) | |||
outputs = [branch3x3, branch3x3dbl, branch_pool] | |||
return outputs | |||
def forward(self, x: Tensor) -> Tensor: | |||
outputs = self._forward(x) | |||
return self.cat(outputs, 1) | |||
@RPProvider.register(InceptionB) | |||
class InceptionBRelProp(RelProp[InceptionB]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
branch3x3, branch3x3dbl, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha) | |||
x1 = RPProvider.get(self.module.maxpool)(branch_pool, alpha=alpha) | |||
branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_3)(branch3x3dbl, alpha=alpha) | |||
branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha) | |||
x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha) | |||
x3 = RPProvider.get(self.module.branch3x3)(branch3x3, alpha=alpha) | |||
return x1 + x2 + x3 | |||
class InceptionD(nn.Module): | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
conv_block: Optional[Callable[..., nn.Module]] = None | |||
) -> None: | |||
super(InceptionD, self).__init__() | |||
if conv_block is None: | |||
conv_block = BasicConv2d | |||
self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) | |||
self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) | |||
self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1) | |||
self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3)) | |||
self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0)) | |||
self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2) | |||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2) | |||
self.cat = M.Cat() | |||
def _forward(self, x: Tensor) -> List[Tensor]: | |||
branch3x3 = self.branch3x3_1(x) | |||
branch3x3 = self.branch3x3_2(branch3x3) | |||
branch7x7x3 = self.branch7x7x3_1(x) | |||
branch7x7x3 = self.branch7x7x3_2(branch7x7x3) | |||
branch7x7x3 = self.branch7x7x3_3(branch7x7x3) | |||
branch7x7x3 = self.branch7x7x3_4(branch7x7x3) | |||
branch_pool = self.maxpool(x) | |||
outputs = [branch3x3, branch7x7x3, branch_pool] | |||
return outputs | |||
def forward(self, x: Tensor) -> Tensor: | |||
outputs = self._forward(x) | |||
return self.cat(outputs, 1) | |||
@RPProvider.register(InceptionD) | |||
class InceptionDRelProp(RelProp[InceptionD]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
branch3x3, branch7x7x3, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha) | |||
x1 = RPProvider.get(self.module.maxpool)(branch_pool, alpha=alpha) | |||
branch7x7x3 = RPProvider.get(self.module.branch7x7x3_4)(branch7x7x3, alpha=alpha) | |||
branch7x7x3 = RPProvider.get(self.module.branch7x7x3_3)(branch7x7x3, alpha=alpha) | |||
branch7x7x3 = RPProvider.get(self.module.branch7x7x3_2)(branch7x7x3, alpha=alpha) | |||
x2 = RPProvider.get(self.module.branch7x7x3_1)(branch7x7x3, alpha=alpha) | |||
branch3x3 = RPProvider.get(self.module.branch3x3_2)(branch3x3, alpha=alpha) | |||
x3 = RPProvider.get(self.module.branch3x3_1)(branch3x3, alpha=alpha) | |||
return x1 + x2 + x3 | |||
class InceptionAux(BaseInceptionAux): | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
num_classes: int, | |||
conv_block: Optional[Callable[..., nn.Module]] = None | |||
) -> None: | |||
super().__init__(in_channels, num_classes, conv_block=conv_block) | |||
if conv_block is None: | |||
conv_block = BasicConv2d | |||
self.conv1 = conv_block(128, 768, kernel_size=5) | |||
self.conv1.stddev = 0.01 # type: ignore[assignment] | |||
class Inception3(CamInterpretableModel, torchvision.models.Inception3): | |||
def __init__(self, aux_weight: float, n_classes=1): | |||
torchvision.models.Inception3.__init__(self, | |||
transform_input=False, init_weights=False, | |||
inception_blocks = [ | |||
BasicConv2d, InceptionA, InceptionB, InceptionC, | |||
InceptionD, InceptionE, InceptionAux | |||
]) | |||
self.fc = nn.Sequential( | |||
nn.Linear(2048, n_classes), | |||
nn.Sigmoid() | |||
) | |||
self.AuxLogits.fc = nn.Sequential( | |||
nn.Linear(768, n_classes), | |||
nn.Sigmoid() | |||
) | |||
self.aux_weight = aux_weight | |||
@property | |||
def target_conv_layers(self) -> List[nn.Module]: | |||
return [ | |||
self.Mixed_7c.branch1x1, | |||
self.Mixed_7c.branch3x3_2a, self.Mixed_7c.branch3x3_2b, | |||
self.Mixed_7c.branch3x3dbl_3a, self.Mixed_7c.branch3x3dbl_3b, | |||
self.Mixed_7c.branch_pool | |||
] | |||
@property | |||
def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]: | |||
return ['x'] | |||
def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor: | |||
p = self.forward(*inputs, **kwargs)['positive_class_probability'] | |||
return torch.stack([1 - p, p], dim=1) | |||
@RPProvider.register(Inception3) | |||
class Inception3Mo4RelProp(RelProp[Inception3]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
if RPProvider.get(self.module.fc).Y.shape[1] == 1: | |||
R = R[:, -1:] | |||
R = RPProvider.get(self.module.fc)(R, alpha=alpha) # B 2048 | |||
R = R.reshape_as(RPProvider.get(self.module.dropout).Y) # B 2048 1 1 | |||
R = RPProvider.get(self.module.dropout)(R, alpha=alpha) # B 2048 1 1 | |||
R = RPProvider.get(self.module.avgpool)(R, alpha=alpha) # B 2048 8 8 | |||
R = RPProvider.get(self.module.Mixed_7c)(R, alpha=alpha) # B 2048 8 8 | |||
R = RPProvider.get(self.module.Mixed_7b)(R, alpha=alpha) # B 1280 8 8 | |||
R = RPProvider.get(self.module.Mixed_7a)(R, alpha=alpha) # B 768 17 17 | |||
R = RPProvider.get(self.module.Mixed_6e)(R, alpha=alpha) # B 768 17 17 | |||
R = RPProvider.get(self.module.Mixed_6d)(R, alpha=alpha) # B 768 17 17 | |||
R = RPProvider.get(self.module.Mixed_6c)(R, alpha=alpha) # B 768 17 17 | |||
R = RPProvider.get(self.module.Mixed_6b)(R, alpha=alpha) # B 768 17 17 | |||
R = RPProvider.get(self.module.Mixed_6a)(R, alpha=alpha) # B 288 35 35 | |||
R = RPProvider.get(self.module.Mixed_5d)(R, alpha=alpha) # B 288 35 35 | |||
R = RPProvider.get(self.module.Mixed_5c)(R, alpha=alpha) # B 256 35 35 | |||
R = RPProvider.get(self.module.Mixed_5b)(R, alpha=alpha) # B 192 35 35 | |||
R = RPProvider.get(self.module.maxpool2)(R, alpha=alpha) # B 192 71 71 | |||
R = RPProvider.get(self.module.Conv2d_4a_3x3)(R, alpha=alpha) # B 80 73 73 | |||
R = RPProvider.get(self.module.Conv2d_3b_1x1)(R, alpha=alpha) # B 64 73 73 | |||
R = RPProvider.get(self.module.maxpool1)(R, alpha=alpha) # B 64 147 147 | |||
R = RPProvider.get(self.module.Conv2d_2b_3x3)(R, alpha=alpha) # B 32 147 147 | |||
R = RPProvider.get(self.module.Conv2d_2a_3x3)(R, alpha=alpha) # B 32 149 149 | |||
R = RPProvider.get(self.module.Conv2d_1a_3x3)(R, alpha=alpha) # B 3 299 299 | |||
return R |
@@ -0,0 +1,512 @@ | |||
from typing import Dict, Iterable, Type, Any, Callable, Union, List, Optional | |||
import torch | |||
import torch.nn as nn | |||
from torch import Tensor | |||
from torch.nn import functional as F | |||
from ..interpreting.interpretable import CamInterpretableModel | |||
from ..interpreting.relcam.relprop import RPProvider, RelProp | |||
from ..interpreting.relcam import modules as M | |||
__all__ = [ | |||
"ResNet", | |||
"resnet18", | |||
"resnet34", | |||
"resnet50", | |||
"resnet101", | |||
"resnet152", | |||
"resnext50_32x4d", | |||
"resnext101_32x8d", | |||
"wide_resnet50_2", | |||
"wide_resnet101_2", | |||
] | |||
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: | |||
"""3x3 convolution with padding""" | |||
return nn.Conv2d( | |||
in_planes, | |||
out_planes, | |||
kernel_size=3, | |||
stride=stride, | |||
padding=dilation, | |||
groups=groups, | |||
bias=False, | |||
dilation=dilation, | |||
) | |||
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: | |||
"""1x1 convolution""" | |||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |||
class BasicBlock(nn.Module): | |||
expansion: int = 1 | |||
def __init__( | |||
self, | |||
inplanes: int, | |||
planes: int, | |||
stride: int = 1, | |||
downsample: Optional[nn.Module] = None, | |||
groups: int = 1, | |||
base_width: int = 64, | |||
dilation: int = 1, | |||
norm_layer: Optional[Callable[..., nn.Module]] = None, | |||
) -> None: | |||
super().__init__() | |||
if norm_layer is None: | |||
norm_layer = nn.BatchNorm2d | |||
if groups != 1 or base_width != 64: | |||
raise ValueError("BasicBlock only supports groups=1 and base_width=64") | |||
if dilation > 1: | |||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | |||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1 | |||
self.conv1 = conv3x3(inplanes, planes, stride) | |||
self.bn1 = norm_layer(planes) | |||
self.relu = nn.ReLU() | |||
self.relu2 = nn.ReLU() | |||
self.conv2 = conv3x3(planes, planes) | |||
self.bn2 = norm_layer(planes) | |||
self.downsample = downsample | |||
self.stride = stride | |||
self.add = M.Add() | |||
def forward(self, x: Tensor) -> Tensor: | |||
out = self.conv1(x) | |||
out = self.bn1(out) | |||
out = self.relu(out) | |||
out = self.conv2(out) | |||
out = self.bn2(out) | |||
if self.downsample is not None: | |||
x = self.downsample(x) | |||
out = self.add([out, x]) | |||
out = self.relu2(out) | |||
return out | |||
@RPProvider.register(BasicBlock) | |||
class BasicBlockRelProp(RelProp[BasicBlock]): | |||
def rel(self, R, alpha): | |||
out = RPProvider.get(self.module.relu2)(R, alpha) | |||
out, x = RPProvider.get(self.module.add)(out, alpha) | |||
if self.module.downsample is not None: | |||
x = RPProvider.get(self.module.downsample)(x, alpha) | |||
out = RPProvider.get(self.module.bn2)(out, alpha) | |||
out = RPProvider.get(self.module.conv2)(out, alpha) | |||
out = RPProvider.get(self.module.relu)(out, alpha) | |||
out = RPProvider.get(self.module.bn1)(out, alpha) | |||
x1 = RPProvider.get(self.module.conv1)(out, alpha) | |||
return x + x1 | |||
class Bottleneck(nn.Module): | |||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) | |||
# while original implementation places the stride at the first 1x1 convolution(self.conv1) | |||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. | |||
# This variant is also known as ResNet V1.5 and improves accuracy according to | |||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. | |||
expansion: int = 4 | |||
def __init__( | |||
self, | |||
inplanes: int, | |||
planes: int, | |||
stride: int = 1, | |||
downsample: Optional[nn.Module] = None, | |||
groups: int = 1, | |||
base_width: int = 64, | |||
dilation: int = 1, | |||
norm_layer: Optional[Callable[..., nn.Module]] = None, | |||
) -> None: | |||
super().__init__() | |||
if norm_layer is None: | |||
norm_layer = nn.BatchNorm2d | |||
width = int(planes * (base_width / 64.0)) * groups | |||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1 | |||
self.conv1 = conv1x1(inplanes, width) | |||
self.bn1 = norm_layer(width) | |||
self.conv2 = conv3x3(width, width, stride, groups, dilation) | |||
self.bn2 = norm_layer(width) | |||
self.conv3 = conv1x1(width, planes * self.expansion) | |||
self.bn3 = norm_layer(planes * self.expansion) | |||
self.relu = nn.ReLU() | |||
self.relu2 = nn.ReLU() | |||
self.relu3 = nn.ReLU() | |||
self.downsample = downsample | |||
self.stride = stride | |||
self.add = M.Add() | |||
def forward(self, x: Tensor) -> Tensor: | |||
out = self.conv1(x) | |||
out = self.bn1(out) | |||
out = self.relu(out) | |||
out = self.conv2(out) | |||
out = self.bn2(out) | |||
out = self.relu2(out) | |||
out = self.conv3(out) | |||
out = self.bn3(out) | |||
if self.downsample is not None: | |||
x = self.downsample(x) | |||
out = self.add([out, x]) | |||
out = self.relu3(out) | |||
return out | |||
@RPProvider.register(Bottleneck) | |||
class BottleneckRelProp(RelProp[Bottleneck]): | |||
def rel(self, R, alpha): | |||
out = RPProvider.get(self.module.relu3)(R, alpha) | |||
out, x = RPProvider.get(self.module.add)(out, alpha) | |||
if self.downsample is not None: | |||
x = RPProvider.get(self.module.downsample)(x, alpha) | |||
out = RPProvider.get(self.module.bn3)(out, alpha) | |||
out = RPProvider.get(self.module.conv3)(out, alpha) | |||
out = RPProvider.get(self.module.relu2)(out, alpha) | |||
out = RPProvider.get(self.module.bn2)(out, alpha) | |||
out = RPProvider.get(self.module.conv2)(out, alpha) | |||
out = RPProvider.get(self.module.relu)(out, alpha) | |||
out = RPProvider.get(self.module.bn1)(out, alpha) | |||
x1 = RPProvider.get(self.module.conv1)(out, alpha) | |||
return x + x1 | |||
class ResNet(CamInterpretableModel): | |||
def __init__( | |||
self, | |||
block: Type[Union[BasicBlock, Bottleneck]], | |||
layers: List[int], | |||
num_classes: int = 1000, | |||
zero_init_residual: bool = False, | |||
groups: int = 1, | |||
width_per_group: int = 64, | |||
replace_stride_with_dilation: Optional[List[bool]] = None, | |||
norm_layer: Optional[Callable[..., nn.Module]] = None, | |||
binary: bool = False, | |||
) -> None: | |||
super().__init__() | |||
if norm_layer is None: | |||
norm_layer = nn.BatchNorm2d | |||
self._norm_layer = norm_layer | |||
self.inplanes = 64 | |||
self.dilation = 1 | |||
if replace_stride_with_dilation is None: | |||
# each element in the tuple indicates if we should replace | |||
# the 2x2 stride with a dilated convolution instead | |||
replace_stride_with_dilation = [False, False, False] | |||
if len(replace_stride_with_dilation) != 3: | |||
raise ValueError( | |||
"replace_stride_with_dilation should be None " | |||
f"or a 3-element tuple, got {replace_stride_with_dilation}" | |||
) | |||
self.groups = groups | |||
self.base_width = width_per_group | |||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) | |||
self.bn1 = norm_layer(self.inplanes) | |||
self.relu = nn.ReLU() | |||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
self.layer1 = self._make_layer(block, 64, layers[0]) | |||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) | |||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) | |||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) | |||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |||
self._binary = binary | |||
if binary: | |||
self.fc = nn.Sequential( | |||
nn.Linear(512 * block.expansion, 1), | |||
nn.Sigmoid(), | |||
nn.Flatten(0) | |||
) | |||
else: | |||
self.fc = nn.Linear(512 * block.expansion, num_classes) | |||
for m in self.modules(): | |||
if isinstance(m, nn.Conv2d): | |||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | |||
nn.init.constant_(m.weight, 1) | |||
nn.init.constant_(m.bias, 0) | |||
# Zero-initialize the last BN in each residual branch, | |||
# so that the residual branch starts with zeros, and each residual block behaves like an identity. | |||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | |||
if zero_init_residual: | |||
for m in self.modules(): | |||
if isinstance(m, Bottleneck): | |||
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] | |||
elif isinstance(m, BasicBlock): | |||
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] | |||
def _make_layer( | |||
self, | |||
block: Type[Union[BasicBlock, Bottleneck]], | |||
planes: int, | |||
blocks: int, | |||
stride: int = 1, | |||
dilate: bool = False, | |||
) -> nn.Sequential: | |||
norm_layer = self._norm_layer | |||
downsample = None | |||
previous_dilation = self.dilation | |||
if dilate: | |||
self.dilation *= stride | |||
stride = 1 | |||
if stride != 1 or self.inplanes != planes * block.expansion: | |||
downsample = nn.Sequential( | |||
conv1x1(self.inplanes, planes * block.expansion, stride), | |||
norm_layer(planes * block.expansion), | |||
) | |||
layers = [] | |||
layers.append( | |||
block( | |||
self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer | |||
) | |||
) | |||
self.inplanes = planes * block.expansion | |||
for _ in range(1, blocks): | |||
layers.append( | |||
block( | |||
self.inplanes, | |||
planes, | |||
groups=self.groups, | |||
base_width=self.base_width, | |||
dilation=self.dilation, | |||
norm_layer=norm_layer, | |||
) | |||
) | |||
return nn.Sequential(*layers) | |||
def _forward_impl(self, x: Tensor) -> Tensor: | |||
# See note [TorchScript super()] | |||
x = self.conv1(x) | |||
x = self.bn1(x) | |||
x = self.relu(x) | |||
x = self.maxpool(x) | |||
x = self.layer1(x) | |||
x = self.layer2(x) | |||
x = self.layer3(x) | |||
x = self.layer4(x) | |||
x = self.avgpool(x) | |||
x = torch.flatten(x, 1) | |||
x = self.fc(x) | |||
return x | |||
def forward(self, x: Tensor, y: Tensor) -> Dict[str, Tensor]: | |||
p = self._forward_impl(x) | |||
if self._binary: | |||
return dict( | |||
positive_class_probability=p, | |||
loss=F.binary_cross_entropy(p, y), | |||
) | |||
return dict( | |||
categorical_probability=p.softmax(dim=1), | |||
loss=F.cross_entropy(p, y.long()), | |||
) | |||
@property | |||
def target_conv_layers(self) -> List[nn.Module]: | |||
""" | |||
Returns: | |||
The convolutional layers to be interpreted. The result of | |||
the interpretation will be sum of the grad-cam of these layers | |||
""" | |||
return [self.layer4] | |||
def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor: | |||
""" | |||
A method to get probabilities assigned to all the classes in the model's forward, | |||
with shape (B, C), in which B is the batch size and C is number of classes | |||
Args: | |||
*inputs: Inputs to the model | |||
**kwargs: Additional arguments | |||
Returns: | |||
Tensor of categorical probabilities | |||
""" | |||
if self._binary: | |||
p = self.forward(*inputs, **kwargs)['positive_class_probability'] | |||
return torch.stack([1 - p, p], dim=1) | |||
return self.forward(*inputs, **kwargs)['categorical_probability'] | |||
@property | |||
def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]: | |||
""" | |||
Returns: | |||
Input module for interpretation | |||
""" | |||
return ['x'] | |||
@RPProvider.register(ResNet) | |||
class ResNetRelProp(RelProp[ResNet]): | |||
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: | |||
if RPProvider.get(self.module.fc).Y.ndim == 1: | |||
R = R[:, -1] | |||
R = RPProvider.get(self.module.fc)(R, alpha=alpha) | |||
R = R.reshape_as(RPProvider.get(self.module.avgpool).Y) | |||
R = RPProvider.get(self.module.avgpool)(R, alpha=alpha) | |||
R = RPProvider.get(self.module.layer4)(R, alpha=alpha) | |||
R = RPProvider.get(self.module.layer3)(R, alpha=alpha) | |||
R = RPProvider.get(self.module.layer2)(R, alpha=alpha) | |||
R = RPProvider.get(self.module.layer1)(R, alpha=alpha) | |||
R = RPProvider.get(self.module.maxpool)(R, alpha=alpha) | |||
R = RPProvider.get(self.module.relu)(R, alpha=alpha) | |||
R = RPProvider.get(self.module.bn1)(R, alpha=alpha) | |||
R = RPProvider.get(self.module.conv1)(R, alpha=alpha) | |||
return R | |||
def _resnet( | |||
arch: str, | |||
block: Type[Union[BasicBlock, Bottleneck]], | |||
layers: List[int], | |||
pretrained: bool, | |||
progress: bool, | |||
**kwargs: Any, | |||
) -> ResNet: | |||
model = ResNet(block, layers, **kwargs) | |||
#if pretrained: | |||
#state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) | |||
#model.load_state_dict(state_dict) | |||
return model | |||
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||
r"""ResNet-18 model from | |||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. | |||
Args: | |||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||
progress (bool): If True, displays a progress bar of the download to stderr | |||
""" | |||
return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) | |||
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||
r"""ResNet-34 model from | |||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. | |||
Args: | |||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||
progress (bool): If True, displays a progress bar of the download to stderr | |||
""" | |||
return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) | |||
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||
r"""ResNet-50 model from | |||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. | |||
Args: | |||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||
progress (bool): If True, displays a progress bar of the download to stderr | |||
""" | |||
return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) | |||
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||
r"""ResNet-101 model from | |||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. | |||
Args: | |||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||
progress (bool): If True, displays a progress bar of the download to stderr | |||
""" | |||
return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) | |||
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||
r"""ResNet-152 model from | |||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. | |||
Args: | |||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||
progress (bool): If True, displays a progress bar of the download to stderr | |||
""" | |||
return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) | |||
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||
r"""ResNeXt-50 32x4d model from | |||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_. | |||
Args: | |||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||
progress (bool): If True, displays a progress bar of the download to stderr | |||
""" | |||
kwargs["groups"] = 32 | |||
kwargs["width_per_group"] = 4 | |||
return _resnet("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) | |||
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||
r"""ResNeXt-101 32x8d model from | |||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_. | |||
Args: | |||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||
progress (bool): If True, displays a progress bar of the download to stderr | |||
""" | |||
kwargs["groups"] = 32 | |||
kwargs["width_per_group"] = 8 | |||
return _resnet("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) | |||
def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||
r"""Wide ResNet-50-2 model from | |||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_. | |||
The model is the same as ResNet except for the bottleneck number of channels | |||
which is twice larger in every block. The number of channels in outer 1x1 | |||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 | |||
channels, and in Wide ResNet-50-2 has 2048-1024-2048. | |||
Args: | |||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||
progress (bool): If True, displays a progress bar of the download to stderr | |||
""" | |||
kwargs["width_per_group"] = 64 * 2 | |||
return _resnet("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) | |||
def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: | |||
r"""Wide ResNet-101-2 model from | |||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_. | |||
The model is the same as ResNet except for the bottleneck number of channels | |||
which is twice larger in every block. The number of channels in outer 1x1 | |||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 | |||
channels, and in Wide ResNet-50-2 has 2048-1024-2048. | |||
Args: | |||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||
progress (bool): If True, displays a progress bar of the download to stderr | |||
""" | |||
kwargs["width_per_group"] = 64 * 2 | |||
return _resnet("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) |
@@ -0,0 +1,5 @@ | |||
from .gaussianmax import GaussianMax2d | |||
from .activated_conv import ActivatedConv2d | |||
from .scaled_sigmoid import ScaledSigmoid | |||
from .lap import LAP | |||
from .adaptive_lap import AdaptiveLAP |