Browse Source

Add codes

master
Ahmad Salimi 1 year ago
commit
4241a27ec8
100 changed files with 8466 additions and 0 deletions
  1. 313
    0
      .gitignore
  2. 218
    0
      README.md
  3. 32
    0
      data_preparation/imagenet/extract_train.py
  4. 59
    0
      data_preparation/imagenet/extract_val.py
  5. 41
    0
      data_preparation/prepare_rsna_data.py
  6. 33
    0
      evaluate.py
  7. 21
    0
      evaluate_interpretations.py
  8. 21
    0
      interpret.py
  9. 8
    0
      requirements.txt
  10. 0
    0
      torchlap/__init__.py
  11. 1
    0
      torchlap/__version__.py
  12. 304
    0
      torchlap/configs/base_config.py
  13. 35
    0
      torchlap/configs/celeba_configs.py
  14. 17
    0
      torchlap/configs/imagenet_configs.py
  15. 47
    0
      torchlap/configs/rsna_configs.py
  16. 0
    0
      torchlap/criteria/__init__.py
  17. 86
    0
      torchlap/criteria/aux_loss.py
  18. 123
    0
      torchlap/criteria/bb_supervised.py
  19. 152
    0
      torchlap/criteria/cw_concordance_loss.py
  20. 221
    0
      torchlap/criteria/weakly_supervised.py
  21. 61
    0
      torchlap/data/batch_choosing/batch_chooser.py
  22. 86
    0
      torchlap/data/batch_choosing/class_balanced_shuffled_sequential.py
  23. 38
    0
      torchlap/data/batch_choosing/sequential.py
  24. 154
    0
      torchlap/data/content_loaders/celeba_loader.py
  25. 94
    0
      torchlap/data/content_loaders/content_loader.py
  26. 146
    0
      torchlap/data/content_loaders/imagenet_loader.py
  27. 329
    0
      torchlap/data/content_loaders/rsna_loader.py
  28. 246
    0
      torchlap/data/data_loader.py
  29. 98
    0
      torchlap/data/dataflow.py
  30. 28
    0
      torchlap/data/dataloader_context.py
  31. 0
    0
      torchlap/experiments/__init__.py
  32. 10
    0
      torchlap/experiments/_entry_loader.py
  33. 68
    0
      torchlap/experiments/_model_loading.py
  34. 0
    0
      torchlap/experiments/celeba/__init__.py
  35. 20
    0
      torchlap/experiments/celeba/org_inception.py
  36. 20
    0
      torchlap/experiments/celeba/org_resnet.py
  37. 143
    0
      torchlap/experiments/celeba/ws_lap_inception.py
  38. 131
    0
      torchlap/experiments/celeba/ws_lap_resnet.py
  39. 113
    0
      torchlap/experiments/entrypoint.py
  40. 0
    0
      torchlap/experiments/imagenet/__init__.py
  41. 17
    0
      torchlap/experiments/imagenet/lap_resnet50_ft.py
  42. 17
    0
      torchlap/experiments/imagenet/lap_resnet50_nft.py
  43. 0
    0
      torchlap/experiments/rsna/__init__.py
  44. 77
    0
      torchlap/experiments/rsna/bb_lap_inception.py
  45. 69
    0
      torchlap/experiments/rsna/bb_lap_resnet.py
  46. 14
    0
      torchlap/experiments/rsna/org_inception.py
  47. 14
    0
      torchlap/experiments/rsna/org_resnet.py
  48. 117
    0
      torchlap/experiments/rsna/ws_lap_inception.py
  49. 105
    0
      torchlap/experiments/rsna/ws_lap_resnet.py
  50. 7
    0
      torchlap/interpreting/__init__.py
  51. 148
    0
      torchlap/interpreting/attention_interpreter.py
  52. 48
    0
      torchlap/interpreting/attention_interpreter_smooth_integrator.py
  53. 31
    0
      torchlap/interpreting/attention_sum_interpreter.py
  54. 166
    0
      torchlap/interpreting/binary_interpretation_evaluator_2d.py
  55. 43
    0
      torchlap/interpreting/dataflow.py
  56. 52
    0
      torchlap/interpreting/deep_lift.py
  57. 38
    0
      torchlap/interpreting/gradcam.py
  58. 35
    0
      torchlap/interpreting/guided_backprop.py
  59. 45
    0
      torchlap/interpreting/guided_gradcam.py
  60. 55
    0
      torchlap/interpreting/imagenet_attention_interpreter.py
  61. 53
    0
      torchlap/interpreting/interpretable.py
  62. 69
    0
      torchlap/interpreting/interpretable_wrapper.py
  63. 42
    0
      torchlap/interpreting/interpretation_dataflow.py
  64. 123
    0
      torchlap/interpreting/interpretation_evaluation_runner.py
  65. 67
    0
      torchlap/interpreting/interpreter.py
  66. 63
    0
      torchlap/interpreting/interpreter_maker.py
  67. 162
    0
      torchlap/interpreting/interpreting_runner.py
  68. 1
    0
      torchlap/interpreting/relcam/__init__.py
  69. 44
    0
      torchlap/interpreting/relcam/interpreter.py
  70. 21
    0
      torchlap/interpreting/relcam/modules.py
  71. 339
    0
      torchlap/interpreting/relcam/relprop.py
  72. 102
    0
      torchlap/interpreting/utils.py
  73. 100
    0
      torchlap/model_evaluation/binary_evaluator.py
  74. 109
    0
      torchlap/model_evaluation/binary_faithfulness.py
  75. 108
    0
      torchlap/model_evaluation/binary_fortelling.py
  76. 35
    0
      torchlap/model_evaluation/binary_tag_evaluator.py
  77. 84
    0
      torchlap/model_evaluation/evaluator.py
  78. 57
    0
      torchlap/model_evaluation/loss_evaluator.py
  79. 122
    0
      torchlap/model_evaluation/multiclass_evaluator.py
  80. 63
    0
      torchlap/model_evaluation/multieval_evaluator.py
  81. 0
    0
      torchlap/models/__init__.py
  82. 0
    0
      torchlap/models/celeba/__init__.py
  83. 61
    0
      torchlap/models/celeba/lap_inception.py
  84. 53
    0
      torchlap/models/celeba/lap_resnet.py
  85. 25
    0
      torchlap/models/celeba/org_resnet.py
  86. 67
    0
      torchlap/models/celeba/single_tag_org_inception.py
  87. 0
    0
      torchlap/models/imagenet/__init__.py
  88. 34
    0
      torchlap/models/imagenet/lap_resnet.py
  89. 556
    0
      torchlap/models/lap_inception.py
  90. 276
    0
      torchlap/models/lap_resnet.py
  91. 147
    0
      torchlap/models/model.py
  92. 0
    0
      torchlap/models/rsna/__init__.py
  93. 40
    0
      torchlap/models/rsna/lap_inception.py
  94. 39
    0
      torchlap/models/rsna/lap_resnet.py
  95. 39
    0
      torchlap/models/rsna/org_inception.py
  96. 15
    0
      torchlap/models/rsna/org_resnet.py
  97. 218
    0
      torchlap/models/tv_inception.py
  98. 512
    0
      torchlap/models/tv_resnet.py
  99. 5
    0
      torchlap/modules/__init__.py
  100. 0
    0
      torchlap/modules/_common_types.py

+ 313
- 0
.gitignore View File

@@ -0,0 +1,313 @@
# Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,pycharm+all,macos
# Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode,pycharm+all,macos

### macOS ###
# General
.DS_Store
.AppleDouble
.LSOverride

# Icon must end with two \r
Icon


# Thumbnails
._*

# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent

# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk

### macOS Patch ###
# iCloud generated files
*.icloud

### PyCharm+all ###
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839

# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf

# AWS User-specific
.idea/**/aws.xml

# Generated files
.idea/**/contentModel.xml

# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml

# Gradle
.idea/**/gradle.xml
.idea/**/libraries

# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr

# CMake
cmake-build-*/

# Mongo Explorer plugin
.idea/**/mongoSettings.xml

# File-based project format
*.iws

# IntelliJ
out/

# mpeltonen/sbt-idea plugin
.idea_modules/

# JIRA plugin
atlassian-ide-plugin.xml

# Cursive Clojure plugin
.idea/replstate.xml

# SonarLint plugin
.idea/sonarlint/

# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties

# Editor-based Rest Client
.idea/httpRequests

# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser

### PyCharm+all Patch ###
# Ignore everything but code style settings and run configurations
# that are supposed to be shared within teams.

.idea/*

!.idea/codeStyles
!.idea/runConfigurations

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml

# ruff
.ruff_cache/

### VisualStudioCode ###
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
!.vscode/*.code-snippets

# Local History for Visual Studio Code
.history/

# Built Visual Studio Code Extensions
*.vsix

### VisualStudioCode Patch ###
# Ignore all local history of files
.history
.ionide

# End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,pycharm+all,macos

+ 218
- 0
README.md View File

@@ -0,0 +1,218 @@
# LAP
An Attention-Based Module for Faithful Interpretation and Knowledge Injection in Convolutional Neural Networks

In this supplementary material, we have provided a pdf containing more details for some of the sections mentioned in the paper, and the code of our work.

# Data Preparation

## Metadata

First, download the metadata from [here](https://mega.nz/file/MqhXiLgC#NK1vd9ZLGU-3a182x-BXfvGCSuqHgq9hflI6tddh4Wc)

## RSNA

## CelebA

1. Download the dataset from [here](https://www.kaggle.com/jessicali9530/celeba-dataset)

2. Extract the data

```bash
unzip -q -j celeba-dataset.zip '**/*.jpg' -d data/celeba
```

3. Extract `celeba.tsv` from [this](#metadata) metadata file to `dataset_metadata/celeba.tsv`

## Imagenet

1. Download the ImageNet ILSVRC2012 dataset from [here](https://www.image-net.org/challenges/LSVRC/index.php).

1. Extract `ILSVRC2012_img_train_t3.tar` to `data/imagenet/tars`.

1. Run the following command to extract the train images per class:

```bash
python data_preparation/imagenet/extract_train.py -s data/imagenet/tars/ -n <num-threads>
```

1. Run the following command to extract the validation images per class (replace directory with the directory containing `ILSVRC2012_img_val.tar`, and extract `imagenet_val_maps.pklz` from [this](#metadata) metadata file):

```bash
python data_preparation/imagenet/extract_val.py -s <directory> -m <imagenet_val_maps.pklz>
```

1. download the localizations bounding-boxes from [here](https://www.kaggle.com/c/imagenet-object-localization-challenge/data). Put `LOC_train_solution.csv` and `LOC_val_solution.csv` in `data/imagenet/extracted`.

# Model checkpoints

Download all the checkpoints from [here](https://mega.nz/file/16pWWA5L#b1SwLEv-YO5lL9wLVTcgykn2LmUBHm1exPFNB0JBoZo).

# Run

## Evaluate the checkpoints

### RSNA

#### Vanilla ResNet 18

##### Evaluate the model performance on the test set

```bash
python evaluate.py rsna.org_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.org_resnet.pth
```

#### Vanilla Inception V3

##### Evaluate the model performance on the test set

```bash
python evaluate.py rsna.org_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.org_inception.pth
```

#### WS ResNet 18

##### Evaluate the model performance on the test set

```bash
python evaluate.py rsna.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_resnet.pth
```

##### Generate the LAP interpretations for positive samples

```bash
python interpret.py rsna.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir rsna/ws_resnet_interpretations
```

##### Generate the LAP interpretations for negative samples

```bash
python interpret.py rsna.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir rsna/ws_resnet_interpretations
```

#### WS Inception V3

##### Evaluate the model performance on the test set

```bash
python evaluate.py rsna.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_inception.pth
```

##### Generate the LAP interpretations for positive samples

```bash
python interpret.py rsna.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir rsna/ws_inception_interpretations
```

##### Generate the LAP interpretations for negative samples

```bash
python interpret.py rsna.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir rsna/ws_inception_interpretations
```

#### BB ResNet 18

##### Evaluate the model performance on the test set

```bash
python evaluate.py rsna.bb_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.bb_resnet.pth
```

##### Generate the LAP interpretations for positive samples

```bash
python interpret.py rsna.bb_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.bb_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir rsna/bb_resnet_interpretations
```

##### Generate the LAP interpretations for negative samples

```bash
python interpret.py rsna.bb_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.bb_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir rsna/bb_resnet_interpretations
```

#### BB Inception V3

##### Evaluate the model performance on the test set

```bash
python evaluate.py rsna.bb_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.bb_inception.pth
```

##### Generate the LAP interpretations for positive samples

```bash
python interpret.py rsna.bb_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.bb_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir rsna/bb_inception_interpretations
```

##### Generate the LAP interpretations for negative samples

```bash
python interpret.py rsna.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/rsna.ws_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir rsna/ws_inception_interpretations
```

### CelebA

#### Vanilla ResNet 18

##### Evaluate the model performance on the test set

```bash
python evaluate.py celeba.org_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.org_resnet.pth
```

#### Vanilla Inception V3

##### Evaluate the model performance on the test set

```bash
python evaluate.py celeba.org_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.org_inception.pth
```

#### WS ResNet 18

##### Evaluate the model performance on the test set

```bash
python evaluate.py celeba.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_resnet.pth
```

##### Generate the LAP interpretations for positive samples

```bash
python interpret.py celeba.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir celeba/ws_resnet_interpretations
```

##### Generate the LAP interpretations for negative samples

```bash
python interpret.py celeba.ws_lap_resnet --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_resnet.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir celeba/ws_resnet_interpretations
```

#### WS Inception V3

##### Evaluate the model performance on the test set

```bash
python evaluate.py celeba.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_inception.pth
```

##### Generate the LAP interpretations for positive samples

```bash
python interpret.py celeba.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 1 --save-by-file-name --report-dir celeba/ws_inception_interpretations
```

##### Generate the LAP interpretations for negative samples

```bash
python interpret.py celeba.ws_lap_inception --device cuda:0 --samples-dir test --final-model-dir checkpoints/celeba.ws_inception.pth --skip-raw --interpretation-method Attention --global-threshold --cut-threshold 0 --mapped-labels-to-use 0 --save-by-file-name --report-dir celeba/ws_inception_interpretations
```

### ImageNet

#### WS ResNet 50 (Fine-tuned)

##### Evaluate the model performance on the validation set

```bash
python evaluate.py imagenet.lap_resnet50_ft --device cuda:0 --samples-dir val --final-model-dir checkpoints/imagenet.ws_resnet_ft.pth
```

+ 32
- 0
data_preparation/imagenet/extract_train.py View File

@@ -0,0 +1,32 @@
import argparse
import glob
import os
import tarfile
from multiprocessing import Pool


parser = argparse.ArgumentParser()
parser.add_argument('-s', dest='source', help='Class tars directory', required=True)
parser.add_argument('-t', dest='target', help='train set directory', default='data/imagenet/extracted')
parser.add_argument('-n', dest='num_threads', help='number of threads', default=1, type=int)
args = parser.parse_args()

class_tars = glob.glob(os.path.join(args.source, '*.tar'))
assert len(class_tars) == 1000, f"class_tars length: {len(class_tars)}"


def extract(class_tar):
filename = os.path.basename(class_tar)
class_name = filename.replace('.tar', '')
print('Extract ' + os.path.basename(class_tar))
class_fname = os.path.join(args.target, class_name)
os.makedirs(class_fname, exist_ok=True)
with tarfile.open(class_tar) as f:
f.extractall(class_fname)



pool = Pool(args.num_threads)
pool.map(extract, class_tars)
pool.close()

+ 59
- 0
data_preparation/imagenet/extract_val.py View File

@@ -0,0 +1,59 @@
"""Prepare the ImageNet dataset"""
# import torch
import os
import argparse
import tarfile
import pickle
import gzip
# import subprocess
# from tqdm import tqdm
# from mxnet.gluon.utils import check_sha1
# from gluoncv.utils import download, makedirs

_VAL_TAR = 'ILSVRC2012_img_val.tar'

def parse_args():
parser = argparse.ArgumentParser(
description='Setup the ImageNet dataset.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-s', dest='download_dir', required=True,
help="The directory that contains downloaded tar files")
parser.add_argument('-m', dest='mapping', required=True,
help="The mapping file for validation set")
parser.add_argument('--target-dir', default='data/imagenet/extracted',
help="The directory to store extracted images")
args = parser.parse_args()
return args

def check_file(filename):
if not os.path.exists(filename):
raise ValueError('File not found: '+filename)

def extract_val(tar_fname, target_dir, val_maps_file):
os.makedirs(target_dir)
print('Extracting ' + tar_fname)
with tarfile.open(tar_fname) as tar:
tar.extractall(target_dir)
# build rec file before images are moved into subfolders
# move images to proper subfolders
with gzip.open(val_maps_file, 'rb') as f:
dirs, mappings = pickle.load(f)
for d in dirs:
os.makedirs(os.path.join(target_dir, d))
for m in mappings:
os.rename(os.path.join(target_dir, m[0]), os.path.join(target_dir, m[1], m[0]))

def main():
args = parse_args()

target_dir = os.path.expanduser(args.target_dir)
if os.path.exists(target_dir):
raise ValueError('Target dir ['+target_dir+'] exists. Remove it first')

download_dir = os.path.expanduser(args.download_dir)
val_tar_fname = os.path.join(download_dir, _VAL_TAR)
check_file(val_tar_fname)
extract_val(val_tar_fname, os.path.join(target_dir, 'val'), args.mapping)

if __name__ == '__main__':
main()

+ 41
- 0
data_preparation/prepare_rsna_data.py View File

@@ -0,0 +1,41 @@
from typing import List, Tuple
import argparse
from os import makedirs, path, listdir
import numpy as np
import cv2
from multiprocessing import Pool


def resize_image(img_dir: str, img_save_dir: str, res: int) -> None:
img = cv2.imread(img_dir)
img = cv2.resize(img, dsize=(res, res))
cv2.imwrite(img_save_dir, img)


if __name__ == '__main__':

parser = argparse.ArgumentParser()
parser.add_argument('kaggle_dataset_dir', type=str, help='download the dataset from https://www.kaggle.com/c/rsna-pneumonia-detection-challenge/data, extract it, and pass its path is this argument')
parser.add_argument('resolution', type=int, help='The resolution required for your model, 224 for resnet, 299 for vanilla inception, 256 for modified inception')
parser.add_argument('cores', type=int, help='The number of cores for multiprocessing.')
args = parser.parse_args()

save_dir = f'data/RSNA-Kaggle_R{args.resolution}'
makedirs(save_dir, exist_ok=True)

kaggle_dataset_dir = args.kaggle_dataset_dir
assert path.exists(kaggle_dataset_dir), f'{kaggle_dataset_dir} does not exist!'

# reading rsna images names
rsna_imgs_path = path.join(kaggle_dataset_dir, 'stage_2_train_images')
assert path.exists(rsna_imgs_path), 'Make sure there is a folder named stage_2_train_images in the passed kaggle_directory!'

imgs_names = np.asarray(listdir(rsna_imgs_path))

imgs_src = np.vectorize(lambda x: path.join(rsna_imgs_path, x))(imgs_names)
imgs_dst = np.vectorize(lambda x: path.join(save_dir, x))(imgs_names)

pool = Pool(args.cores)
pool.starmap(resize_image, zip(imgs_src, imgs_dst, np.full((len(imgs_src),), args.resolution)))
pool.close()


+ 33
- 0
evaluate.py View File

@@ -0,0 +1,33 @@
import traceback
from torchlap.experiments._entry_loader import get_entry, PhaseType
from torchlap.data.data_loader import DataLoader, RunType
from torchlap.experiments._model_loading import load_model
from torchlap.utils.hooker import Hooker


def main() -> None:

ep = get_entry(PhaseType.EVAL)
conf = ep.conf
model = ep.model
model = load_model(model, conf)
model.eval()

for test_group_info in conf.samples_dir.split(','):
try:
print('')
print('>> Running evaluations for %s' % test_group_info, flush=True)

test_data_loader = DataLoader(conf, test_group_info, RunType.TEST)
evaluator = conf.evaluator_cls(model, test_data_loader, conf)
with Hooker(*conf.hooks_by_phase[conf.phase_type]):
evaluator.evaluate()
except Exception:
print('Problem in %s' % test_group_info, flush=True)
track = traceback.format_exc()
print(track, flush=True)


if __name__ == '__main__':
main()

+ 21
- 0
evaluate_interpretations.py View File

@@ -0,0 +1,21 @@
from torchlap.experiments._entry_loader import get_entry, PhaseType
from torchlap.experiments._model_loading import load_model
from torchlap.interpreting.interpretation_evaluation_runner import InterpretingEvalRunner
from torchlap.utils.hooker import Hooker


def main() -> None:

ep = get_entry(PhaseType.EVALINTERPRET)
conf = ep.conf
model = ep.model
model = load_model(model, conf)
model.eval()

runner = InterpretingEvalRunner(conf, model)
with Hooker(*conf.hooks_by_phase[conf.phase_type]):
runner.evaluate()


if __name__ == '__main__':
main()

+ 21
- 0
interpret.py View File

@@ -0,0 +1,21 @@
from torchlap.experiments._entry_loader import get_entry, PhaseType
from torchlap.experiments._model_loading import load_model
from torchlap.interpreting.interpreting_runner import InterpretingRunner
from torchlap.utils.hooker import Hooker


def main() -> None:

ep = get_entry(PhaseType.INTERPRET)
conf = ep.conf
model = ep.model
model = load_model(model, conf)
model.eval()

runner = InterpretingRunner(conf, model)
with Hooker(*conf.hooks_by_phase[conf.phase_type]):
runner.interpret()


if __name__ == '__main__':
main()

+ 8
- 0
requirements.txt View File

@@ -0,0 +1,8 @@
numpy
pandas
torch>=1.7.0
torchvision>=0.8.1
scikit-image>=0.14.3
scikit-learn>=0.19.1
captum>=0.3.0
opencv-python>=4.4.0.46

+ 0
- 0
torchlap/__init__.py View File


+ 1
- 0
torchlap/__version__.py View File

@@ -0,0 +1 @@
__version__ = 'master'

+ 304
- 0
torchlap/configs/base_config.py View File

@@ -0,0 +1,304 @@
from typing import TYPE_CHECKING, Union, Dict, Tuple, Type, List, Callable, Iterator, Optional
import os
import random
import numpy as np
from sys import stderr
from enum import Enum

import torch

from ..interpreting.interpreter_maker import InterpretationType
from ..data.batch_choosing.class_balanced_shuffled_sequential import ClassBalancedShuffledSequentialBatchChooser
from ..data.batch_choosing.sequential import SequentialBatchChooser
from ..utils.hooker import HookPair
if TYPE_CHECKING:
from ..model_evaluation.evaluator import Evaluator
from ..data.batch_choosing.batch_chooser import BatchChooser
from ..data.content_loaders.content_loader import ContentLoader


class PhaseType(Enum):
TRAIN = 'train'
EVAL = 'eval'
INTERPRET = 'interpret'
EVALINTERPRET = 'evalinterpret'


class BaseConfig:

def __init__(self,
try_name: str, try_num: int, data_separation: str,
input_size: int,
phase_type: PhaseType,
content_loader_cls: Type['ContentLoader'],
evaluator_cls: Type['Evaluator']) -> None:

# -------------------- ESSENTIAL CONFIG --------------------

self.data_separation = data_separation
self.try_name = try_name
self.try_num = try_num
self.phase_type = phase_type

# classes

self.evaluator_cls = evaluator_cls
self.content_loader_cls = content_loader_cls
self.train_batch_chooser_cls: Type['BatchChooser'] = ClassBalancedShuffledSequentialBatchChooser
self.eval_batch_chooser_cls: Type['BatchChooser'] = SequentialBatchChooser
# Setups

self.input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = input_size
self.batch_size = 64
self.random_seed: int = 17
self.samples_dir: Union[str, None] = None

# Saving and loading directories!

self.save_dir: Union[str, None] = None
self.report_dir: Union[str, None] = None
self.final_model_dir: Union[str, None] = None
self.epoch: Union[str, None] = None

# Device setting

self.dev_name = None
self.device = None

# -------------------- ESSENTIAL CONFIG --------------------

# -------------------- TRAINING CONFIG --------------------

self.pretrained_model_file = None

self.big_batch_size: int = 1
self.max_epochs: int = 10
self.iters_per_epoch: Union[int, None] = None
self.val_iters_per_epoch: Union[int, None] = None

# cleaning log while training
self.keep_best_and_last_epochs_only = False

# auto-stopping the train
self.n_unsuccessful_epochs_to_stop: Union[None, int] = 100

self.different_augmentation_per_batch = True
self.augmentations_dict: Union[Dict[str, torch.nn.Module], None] = None

self.title_of_reference_metric_to_choose_best_epoch = 'Loss'
self.operator_to_decide_on_improvement_of_val_reference_metric = '<='

# for dataloader
self.mapped_labels_to_use: Union[List[int], None] = None # None means all, a list of ints= the ones to consider

# hooking
self.hooks_by_phase: Dict[PhaseType, List[HookPair]] = {
phasetype: [] for phasetype in PhaseType
}

# -------------------- TRAINING CONFIG --------------------
# ------------------ OPTIMIZATION CONFIG ------------------

self.init_lr: float = 1e-4
self.lr_decay: float = 1e-6
self.freezing_regexes: Union[None, List[str]] = None
self.optimizer_creator: Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer] = get_adam_creator(self.init_lr, self.lr_decay)

# ------------------ OPTIMIZATION CONFIG ------------------
# ----------------- INTERPRETATION CONFIG -----------------
self.save_by_file_name = False
self.interpretation_method = InterpretationType.GuidedBackprop
self.class_label_for_interpretation: Union[int, None] = None
self.interpret_predictions_vs_gt: bool = True

# saving interpretations:

self.skip_overlay = False
self.skip_raw = False

# FOR INTERPRETATION EVALUATION

# A threshold to cut all interpretations lower than
self.cut_threshold: float = 0

# whether the threshold is applied to un-normalized interpretation map
self.global_threshold: bool = False

# whether to use dynamic threshold
self.dynamic_threshold: bool = False

self.prediction_key_in_model_output_dict = 'positive_class_probability'

# for when interpretation masks as continuous objects are compared with GT
self.acceptable_min_intersection_threshold: float = 0.5

# the key for interpretation to be evaluated, None = the first one available
self.interpretation_tag_to_evaluate = None

self.n_interpretation_samples: Optional[int] = None

# ----------------- INTERPRETATION CONFIG -----------------

def __repr__(self):
"""
Represent config as string
"""
return str(self.__dict__)

def _update_based_on_args(self, args) -> None:
""" Updates the values based on the received arguments from the input! """
args_dict = args.__dict__
for arg in args_dict:
if arg in self.__dict__ and args_dict[arg] is not None:
setattr(self, arg, args_dict[arg])

def update(self, args):
"""
updates the config from args
"""

# dirty piece of code for repetition
self._update_based_on_args(args)
self._update_based_on_args(args)

self.dev_name, self.device = _get_device(args)
self._set_random_seeds()
self._set_save_dir()

self._update_report_dir()

def _set_random_seeds(self):
# Set the seed for hash based operations in python
os.environ['PYTHONHASHSEED'] = '0'

torch.manual_seed(self.random_seed)
torch.cuda.manual_seed_all(self.random_seed)
np.random.seed(self.random_seed)
random.seed(self.random_seed)

if self.phase_type != PhaseType.TRAIN:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

# Set the numpy seed
np.random.seed(self.random_seed)

def _set_save_dir(self):
if self.save_dir is None:
self.save_dir = f"../Results/{self.try_num}_{self.try_name}"
os.makedirs(self.save_dir, exist_ok=True)

def _update_report_dir(self):
if self.report_dir is None:
self.report_dir = self.save_dir + '/' + self.phase_type.value

def get_sample_group_specific_report_dir(self, samples_specification, extra_subdir=None) -> str:

if extra_subdir is None:
extra_subdir = ''
else:
extra_subdir = '/' + extra_subdir

# Making sure of keeping version information!

if self.final_model_dir is not None:
report_dir = '%s/%s%s' % (
self.report_dir,
self.final_model_dir[self.final_model_dir.rfind('/') + 1:self.final_model_dir.rfind('.')],
extra_subdir)

else:
epoch = self.epoch
if epoch is None:
epoch = 'best'
report_dir = '%s/epoch,%s%s' % (
self.report_dir,
epoch, extra_subdir)

# adding sample info

if samples_specification not in ['train', 'test', 'val'] and not os.path.isfile(samples_specification):
if ':' in samples_specification:
base_info, _ = tuple(samples_specification.split(':'))
else:
base_info = samples_specification

if not os.path.isdir(base_info):
report_dir = report_dir + '/' + \
samples_specification.replace(':', '_').replace('../', '')
else:
report_dir = report_dir + '/' + samples_specification

return self.get_unique_save_dir(report_dir)

@staticmethod
def get_unique_save_dir(sd):
if os.path.exists(sd):

report_num = 2
base_sd = sd

while os.path.exists(sd):
sd = base_sd + '_%d' % report_num
report_num += 1

return sd

def get_save_dir_for_sample(self, save_dir, sample_name):
if self.save_by_file_name and '/' in sample_name:
return save_dir + '/' + sample_name[sample_name.rfind('/') + 1:]
else:
return save_dir + '/' + sample_name


def _get_device(args):
""" cuda num, -1 means parallel, -2 means cpu, no cuda means cpu"""
dev_name = args.device

if 'cuda' in dev_name:
gpu_num = 0
if ':' in dev_name:
gpu_num = int(dev_name.split(':')[1])

if gpu_num >= torch.cuda.device_count():
print('No %s, using CPU instead!' % dev_name, file=stderr)
dev_name = 'cpu'
print('@ Running on CPU @', flush=True)

if dev_name == 'cuda':
dev_name = None
the_device = torch.device('cuda')
print('@ Running on all %d GPUs @' % torch.cuda.device_count(), flush=True)

elif 'cuda:' in dev_name:
the_device = torch.device(dev_name)
print('@ Running on GPU:%d @' % int(dev_name.split(':')[1]), flush=True)

else:
dev_name = 'cpu'
the_device = torch.device(dev_name)
print('@ Running on CPU @', flush=True)

return dev_name, the_device


def get_adam_creator(init_lr: float = 1e-4, lr_decay: float = 1e-6)\
-> Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer]:
def create_adam(params: Iterator[torch.nn.Parameter]) -> torch.optim.Optimizer:
return torch.optim.Adam(params, lr=init_lr, weight_decay=lr_decay)
return create_adam


def get_sgd_creator(init_lr: float = 1e-4, lr_decay: float = 1e-6)\
-> Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer]:
def create_sgd(params: Iterator[torch.nn.Parameter]) -> torch.optim.Optimizer:
return torch.optim.SGD(params, lr=init_lr, weight_decay=lr_decay)
return create_sgd


+ 35
- 0
torchlap/configs/celeba_configs.py View File

@@ -0,0 +1,35 @@
from .base_config import BaseConfig, PhaseType
from typing import Union
from torch import nn
from torchvision import transforms
from ..data.content_loaders.celeba_loader import CelebALoader, CelebATag
from ..model_evaluation.binary_evaluator import BinaryEvaluator

class CelebAConfigs(BaseConfig):

def __init__(self,
try_name: str, try_num: int, input_size, phase_type: PhaseType) -> None:
super().__init__(try_name, try_num, 'default', input_size, phase_type, CelebALoader, BinaryEvaluator)
# replaced configs!
self.batch_size = 64
self.iters_per_epoch = 200
self.tags = [tag.name for tag in CelebATag if tag.name.endswith('Tag')]
self.main_tag: Union[CelebATag, None] = None
self.data_root: str = 'data/celeba'
self.dataset_metadata: str = 'dataset_metadata/celeba.tsv'

self.title_of_reference_metric_to_choose_best_epoch = 'AvgSS'
self.operator_to_decide_on_improvement_of_val_reference_metric = '>='
self.keep_best_and_last_epochs_only = True

# augmentation
self.augmentations_dict = {
'x': nn.Sequential(
transforms.RandomRotation(45),
transforms.RandomAffine(0, shear=0.2, scale=(.8, 1.2)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=.5, contrast=.5, saturation=.5),
transforms.RandomPerspective())
}


+ 17
- 0
torchlap/configs/imagenet_configs.py View File

@@ -0,0 +1,17 @@
from .base_config import BaseConfig, PhaseType
from ..data.content_loaders.imagenet_loader import ImagenetLoader
from ..model_evaluation.multiclass_evaluator import MulticlassEvaluator

class ImagenetConfigs(BaseConfig):
def __init__(self, try_name: str, try_num: int, input_size, phase_type: PhaseType) -> None:
super().__init__(try_name, try_num, 'data/imagenet/extracted', input_size, phase_type, ImagenetLoader, MulticlassEvaluator.standard_creator('categorical_probability', include_top5=True))

# replaced configs!
self.batch_size = 64
self.big_batch_size = 4
self.max_epochs = 10

self.title_of_reference_metric_to_choose_best_epoch = 'Accuracy'
self.operator_to_decide_on_improvement_of_val_reference_metric = '>='
self.keep_best_and_last_epochs_only = True

+ 47
- 0
torchlap/configs/rsna_configs.py View File

@@ -0,0 +1,47 @@
from .base_config import BaseConfig, PhaseType
from typing import Dict, List, Optional
from torch import nn
from torchvision import transforms
from ..utils.random_brightness_augmentation import ClippedBrightnessAugment
from ..data.content_loaders.rsna_loader import RSNALoader
from ..model_evaluation.binary_evaluator import BinaryEvaluator

class RSNAConfigs(BaseConfig):

def __init__(self, try_name: str, try_num: int, input_size, phase_type: PhaseType) -> None:
super().__init__(try_name, try_num, f'dataset_metadata/RSNA/DataSeparation_R{input_size}', input_size, phase_type, RSNALoader, BinaryEvaluator)

# replaced configs!
self.batch_size = 64
self.max_epochs = 200

self.label_map_dict: Dict[str, int] = {
'healthy': 0, 'pneumonia': 1
}

self.augmentations_dict = {
'x': nn.Sequential(
transforms.RandomRotation(45),
transforms.RandomAffine(0, shear=0.4),
transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4)),
ClippedBrightnessAugment(0.5, 1.5, 0, 1)),

'infection': nn.Sequential(
transforms.RandomRotation(45),
transforms.RandomAffine(0, shear=0.4),
transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4))),

'interpretations': nn.Sequential(
transforms.RandomRotation(45),
transforms.RandomAffine(0, shear=0.4),
transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4))),
}

self.title_of_reference_metric_to_choose_best_epoch = 'BAcc'
self.operator_to_decide_on_improvement_of_val_reference_metric = '>='
self.keep_best_and_last_epochs_only = True

self.infection_map_size = self.input_size
self.receptive_field_radii: List[int] = []

self.infections_bb_dir: Optional[str] = 'data/RSNA/infection_bounding_boxs.tsv'

+ 0
- 0
torchlap/criteria/__init__.py View File


+ 86
- 0
torchlap/criteria/aux_loss.py View File

@@ -0,0 +1,86 @@
from abc import ABC, abstractmethod
from functools import partial
from typing import Dict, List, Tuple

import torch
from torch import nn

from ..configs.base_config import BaseConfig, PhaseType


class AuxLoss(ABC):

def __init__(self, title: str, model: nn.Module,
layer_by_name: Dict[str, nn.Module],
loss_weight: float):
"""This class is used for calculating loss based on the middle layers of the network.

Args:
title (str): A unique title, the required inputs would be kept by this title in the output dictionary. Also the loss term would be added as title_loss
model (nn.Module): The base model
named_layers_list (List[Tuple[str, torch.nn.Module]]): A list of pairs, layer name and the layer. The names used for one instance of this class should be unique.
loss_weight (float): The weight of the loss in the final loss. The loss term would be automatically added to the model's loss in model_output dictionary by this factor.
"""

self._title = title
self._model = model
self._loss_weight = loss_weight
self._layers_names = layer_by_name.keys()

self._layers_outputs_by_name: Dict[str, torch.Tensor] = dict()

self.hook_pairs = [
(module, partial(self._save_layer_val_hook, name))
for name, module in layer_by_name.items()
] + [(model, self._add_loss)]
def _save_layer_val_hook(self, layer_name, _, __, layer_out):
self._layers_outputs_by_name[layer_name] = layer_out

def configure(self, config: BaseConfig) -> None:
"""Must be called to add the necessary configs and hooks so the loss will be calculated!

Args:
config (BaseConfig): The base config!
"""

for phase in PhaseType:
config.hooks_by_phase[phase] += self.hook_pairs

def _add_loss(self, _, __, model_out: Dict[str, torch.Tensor]):
""" A hook for calculating and adding loss the the model's output dictionary

Args:
_ ([type]): Model, useless! (just the hook format)
__ ([type]): Model input, also useless! (just the hook format)
model_out (Dict[str, torch.Tensor]): The output dictionary of the model
"""

# popping values of tensors
layers_values = []
for layer_name in self._layers_names:
assert layer_name in self._layers_outputs_by_name, f'The output of {layer_name} was not found. It seems the forward function has not been called.'
layers_values.append((layer_name, self._layers_outputs_by_name.pop(layer_name)))

loss = self._calculate_loss(layers_values, model_out)

# adding loss to output dictionary
assert f'{self._title}_loss' not in model_out, f'Trying to add {self._title}_loss to model\'s output multiple times!'
model_out[f'{self._title}_loss'] = loss.clone()

# adding loss to the main loss
model_out['loss'] = model_out.get('loss', torch.zeros([], dtype=loss.dtype, device=loss.device)) + self._loss_weight * loss

return model_out

@abstractmethod
def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_output: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Calculates loss based on the output of the layers given in the constructor.

Args:
layers_values (List[Tuple[str, torch.Tensor]]): List of layers names and their output value
model_output (Dict[str, torch.Tensor]): The output dictionary of the model, if something else is required for the loss calculation! or it can be used to add stuff into the output!

Returns:
torch.Tensor: The loss
"""

+ 123
- 0
torchlap/criteria/bb_supervised.py View File

@@ -0,0 +1,123 @@
from typing import List, Tuple, Dict, Union
import math
import torch
from torch.nn import functional as F

from ..data.data_loader import DataLoader
from ..data.dataloader_context import DataloaderContext
from .aux_loss import AuxLoss


class BBLoss(AuxLoss):

def __init__(
self, title: str, model: torch.nn.Module,
layer_by_name: Dict[str, torch.nn.Module],
loss_weight: float,
inf_mask_kw: str,
layers_loss_weights: Union[float, List[float]],
bb_fill_ratio: float):
super().__init__(title, model, layer_by_name, loss_weight)

self._inf_mask_kw = inf_mask_kw
self._bb_fill_ratio = bb_fill_ratio

if isinstance(layers_loss_weights, list):
assert len(layers_loss_weights) == len(layer_by_name), f'There should be as many weights as layers, one per layer. Expected {len(layer_by_name)} found {len(layers_loss_weights)}'
self._layers_loss_weights = layers_loss_weights
else:
self._layers_loss_weights = [layers_loss_weights for _ in range(len(layer_by_name))]

def _calculate_loss(
self,
layers_values: List[Tuple[str, torch.Tensor]],
model_output: Dict[str, torch.Tensor]) -> torch.Tensor:
# gathering extra information
dl: DataLoader = DataloaderContext.instance.dataloader
xray_y = torch.from_numpy(dl.get_current_batch_samples_labels()).to(dl._device)
infection_mask = dl.get_current_batch_data(keyword=self._inf_mask_kw).to(dl._device)

xray_y = xray_y.bool()
infection_mask = infection_mask[xray_y]

PB, _, H, W = infection_mask.shape

loss = torch.zeros([], dtype=torch.float32,
device=xray_y.device, requires_grad=True)

for li, (ln, lo) in enumerate(layers_values):
out = F.interpolate(lo,
size=(H, W),
mode='bilinear',
align_corners=True)
out = out.flatten(start_dim=1)

neg_out = out[~xray_y]
pos_out = out[xray_y]
NB = neg_out.shape[0]

infection_mask = infection_mask.flatten(start_dim=1)

pos_out_infection = torch.where(infection_mask < 1,
torch.zeros_like(pos_out, requires_grad=True),
pos_out)

neg_losses = []
pos_losses = []

if PB > 0:
bb_area = infection_mask.sum(dim=-1) # B
k = (self._bb_fill_ratio * bb_area.quantile(q=0.5))\
.ceil()\
.long()

# Top positive in bb pixels must be positive
pos_infection_topk, pos_infection_indices = pos_out_infection\
.topk(k, dim=-1, sorted=False)
pos_infection_batch_index = torch.arange(PB)\
.to(pos_infection_indices.device)\
.unsqueeze(1)\
.repeat_interleave(k, dim=1)
pos_weight = infection_mask[pos_infection_batch_index, pos_infection_indices].floor() # make non ones to be zero
pos_losses.append(F.binary_cross_entropy(
pos_infection_topk,
torch.ones_like(pos_infection_topk),
pos_weight
))

if (infection_mask == 0).any():
# All positive out bb pixels must be negative
pos_out_non_infection = pos_out[infection_mask == 0]
neg_losses.append(F.binary_cross_entropy(
pos_out_non_infection,
torch.zeros_like(pos_out_non_infection)
))

if NB > 0:
if PB > 0:
# Top negative pixels must be negative
neg_k = int(math.ceil(PB * k * 1.0 / NB))
neg_out_topk = neg_out.topk(neg_k, dim=-1, sorted=False)[0]
neg_losses.append(F.binary_cross_entropy(
neg_out_topk,
torch.zeros_like(neg_out_topk)
))
else:
# All negative pixels must be negative
neg_losses.append(F.binary_cross_entropy(
neg_out,
torch.zeros_like(neg_out)
))

losses = []
if len(neg_losses) > 0:
losses.append(torch.stack(neg_losses).mean())
if len(pos_losses) > 0:
losses.append(torch.stack(pos_losses).mean())
l_loss = torch.stack(losses).mean()
model_output[f'{self._title}_{ln}_loss'] = l_loss
loss = loss + self._layers_loss_weights[li] * l_loss
return loss

+ 152
- 0
torchlap/criteria/cw_concordance_loss.py View File

@@ -0,0 +1,152 @@
from typing import Dict, Union, List, Tuple

import numpy as np
import torch
from torch.nn import functional as F

from .aux_loss import AuxLoss
from ..data.dataloader_context import DataloaderContext
from ..data.data_loader import DataLoader



class PoolConcordanceLossCalculator(AuxLoss):

def __init__(self,
title: str, model: torch.nn.Module,
layer_by_name: Dict[str, torch.nn.Module],
loss_weight: float, weights: Union[float, List[float]],
diff_thresholds: Union[float, List[float]],
labels_by_channel: Dict[int, List[int]]):
super().__init__(title, model, layer_by_name, loss_weight)

# at least two pools are needed
assert len(layer_by_name) > 1, 'At least two pool layers are required to calculate this loss!'

self._title = title
self._model = model
self._loss_prefix = f'{title}_JS_loss_'

self._labels_by_channel = labels_by_channel

self._pool_names = list(layer_by_name.keys())
self._weights = weights if isinstance(weights, list) else \
[weights for _ in range(len(layer_by_name) - 1)]
if isinstance(weights, list):
assert len(weights) == len(layer_by_name) - 1, 'Weights must have a length of pool_layers -1'

self._diff_thresholds = diff_thresholds if isinstance(diff_thresholds, list) else \
[diff_thresholds for _ in range(len(layer_by_name) - 1)]
if isinstance(diff_thresholds, list):
assert len(diff_thresholds) == len(layer_by_name) - 1, 'Diff thresholds must have a length of pool_layers -1'

def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_output: Dict[str, torch.Tensor]) -> torch.Tensor:

# reading pools from the output and deleting them
pool_vals = [x[1] for x in layers_values]

# getting samples labels
dl: DataLoader = DataloaderContext.instance.dataloader
labels = dl.get_current_batch_samples_labels()
labels = torch.from_numpy(labels).to(dl._device)
# sorting by shape from biggest to smallest
p_shapes = np.asarray([p.shape[-1] for p in pool_vals])
sorted_inds = np.argsort(-1 * p_shapes)
pool_vals = [pool_vals[i] for i in sorted_inds]
pool_names = [self._pool_names[i] for i in sorted_inds]
loss = torch.zeros([], dtype=torch.float32, device=labels.device, requires_grad=True)

# for each pair of pools, calculate loss!
for i in range(len(pool_names) - 1):
p_loss = self._cal_pool_pair_loss(pool_vals[i], pool_vals[i + 1], self._diff_thresholds[i], labels)

assert (f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}') not in model_output, 'Trying to add ' + (f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}') + ' to model output multiple times'

model_output[f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}'] = p_loss.clone()
loss = loss + self._weights[i] * p_loss

return loss

def _cal_pool_pair_loss(self, p1: torch.Tensor, p2: torch.Tensor, diff_threshold: float, labels: torch.Tensor) -> torch.Tensor:

# down-sampling by max-pool till reaching the same shape as p2!
if p1.shape[-1] > p2.shape[-1]:
p1 = F.adaptive_max_pool2d(p1, p2.shape[-2:])

# jensen shannon loss, for each channel -> in the related class!
loss = torch.tensor(0.0, requires_grad=True, device=p1.device)
for channel, r_labels in self._labels_by_channel.items():
on_mask = self._get_inclusion_mask(labels, r_labels)

ip1 = p1[on_mask, channel, ...]
ip2 = p2[on_mask, channel, ...]

if torch.numel(ip1) > 0:

if ip1.shape != ip2.shape:
print(f'Problem in shape of concordance loss in {self._title}')

loss = loss + jensen_shannon_divergence(
ip1[:, None, ...], ip2[:, None, ...], diff_threshold)

return loss

def _get_inclusion_mask(
self,
samples_labels: torch.Tensor,
desired_labels: List[int]) -> torch.Tensor:
with torch.no_grad():
inclusion_mask = torch.stack([samples_labels == l for l in desired_labels], dim=0)
aggregation = torch.sum(inclusion_mask.float(), dim=0)
return torch.greater(aggregation, 0)


def jensen_shannon_divergence(p1: torch.Tensor, p2: torch.Tensor, diff_threshold: float) -> torch.Tensor:
"""
Calculates the jensen shannon loss between two distributions p1 and p2

Args:
p1 (torch.Tensor): A tensor of shape B C ... in range [0, 1] that is the probabilities for each neuron in the 1st distribution.
p2 (torch.Tensor): A tensor of the same shape as src, in range [0, 1] that is the probabilities for each neuron in the 2nd distribution.
diff_threshold (float): Threshold between p1 and p2 to decide whether to consider one pixel in loss
Returns:
torch.Tensor: The calculated loss.
"""

assert p1.shape == p2.shape, 'The tensors must have the same shape'
assert 0 <= diff_threshold < 1, 'The difference threshold should be in range [0, 1)'

# Reshaping tensors
p1 = torch.transpose(p1, 0, 1).flatten(1).transpose(0, 1)
p2 = torch.transpose(p2, 0, 1).flatten(1).transpose(0, 1)

# if binary, append class 0!
if p1.shape[1] == 1:
p1 = torch.cat([1 - p1, p1], dim=1)
p2 = torch.cat([1 - p2, p2], dim=1)

with torch.no_grad():
mask = torch.abs(p1 - p2).detach() >= diff_threshold
mask = mask.max(dim=-1)[0]

# to make sure error does not result in log(negative)!
lp1 = torch.log(torch.maximum(p1 + 1e-4, 1e-6 + torch.zeros_like(p1)))
lp2 = torch.log(torch.maximum(p2 + 1e-4, 1e-6 + torch.zeros_like(p2)))

loss = 0.5 * (
_smean(mask, torch.sum(p1 * (lp1 - lp2), dim=-1)) +
_smean(mask, torch.sum(p2 * (lp2 - lp1), dim=-1))
)

return loss


def _smean(mask: torch.Tensor, val: torch.Tensor) -> torch.Tensor:
if torch.any(mask.bool()):
return torch.mean(val[mask])
else:
return torch.tensor(0.0, requires_grad=True, device=mask.device)

+ 221
- 0
torchlap/criteria/weakly_supervised.py View File

@@ -0,0 +1,221 @@
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

from ..data.data_loader import DataLoader
from ..data.dataloader_context import DataloaderContext
from .aux_loss import AuxLoss


class DiscriminativeWeaklySupervisedLoss(AuxLoss):

def __init__(
self, title: str, model: nn.Module,
att_score_layer: nn.Module,
loss_weight: float,
neg_ratio: float, pos_ratio_range: Tuple[float, float],
on_labels_by_channel: Dict[int, List[int]],
discr_score_layer: nn.Module = None,
w_attention_in_ordering: float = 1,
w_discr_in_ordering: float = 1):
"""
Calculates binary weakly supervised score for an attention layer
with extra discrimination head.

Args:
title (str): The title of the loss, must be unique!
model (Model): the base model, so the output can be modified
att_score_layer (torch.nn.Module): A layer that gives attention score (B C ...)
loss_weight (float): The weight of the loss
neg_ratio (float, optional): top ratio to apply loss for negative samples. Defaults to 0.1.
pos_ratio_range (Tuple[float, float], optional): low and top ratios to apply loss for positive samples. Defaults to (0.033, 0.278). Calculated by distribution of positive bounding boxes.
on_labels_by_channel (Dict[int, List[int]]): The dictionary that specifies the samples related to which labels should be on in each channel.
w_attention_in_ordering (float): The weight of the attention score used in ordering the pixels.
w_discr_in_ordering (float): The weight of the reference score used in ordering the pixels.
discr_score_layer (torch.nn.Module): A layer that gives discriminative score (B C ...)
"""
layers = dict(
att=att_score_layer,
)
if discr_score_layer is not None:
layers['discr'] = discr_score_layer
super().__init__(title, model, layers, loss_weight)
self._has_discr = discr_score_layer is not None

self._neg_ratio = neg_ratio
self._pos_ratio_range = pos_ratio_range

self._on_labels_by_channel = on_labels_by_channel

self._w_attention_in_ordering = w_attention_in_ordering
self._w_discr_in_ordering = w_discr_in_ordering

def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_out: Dict[str, torch.Tensor]) -> torch.Tensor:

discriminative_scores = None
probabilities = None

for ln, lv in layers_values:
if ln == 'att':
probabilities = lv
else:
discriminative_scores = lv

dl: DataLoader = DataloaderContext.instance.dataloader
labels = dl.get_current_batch_samples_labels()

if discriminative_scores is not None:
discrimination_loss = self._calculate_discrimination_loss(
labels, discriminative_scores)
assert (self._title + '_discr_loss') not in model_out, 'Trying to add ' + (self._title + '_discr_loss') + ' to model output multiple times'
model_out[self._title + '_discr_loss'] = discrimination_loss.clone()
else:
discrimination_loss = torch.zeros([], requires_grad=True, device=labels.device)

attention_loss = self._calculate_attention_loss(
labels, probabilities, (discriminative_scores if discriminative_scores is not None else probabilities))

assert (self._title + '_ws_loss') not in model_out, 'Trying to add ' + (self._title + '_ws_loss') + ' to model output multiple times'
model_out[self._title + '_ws_loss'] = attention_loss.clone()

loss = self._loss_weight * (discrimination_loss + attention_loss)

return loss

def _calculate_discrimination_loss(
self,
samples_labels: torch.Tensor,
discrimination_scores: torch.Tensor) -> torch.Tensor:
losses = []

for channel, labels in self._on_labels_by_channel.items():
on_mask = self._get_inclusion_mask(samples_labels, labels, discrimination_scores.device)
on_ps = discrimination_scores[on_mask, channel, ...]
off_ps = discrimination_scores[torch.logical_not(on_mask), channel, ...]

if torch.numel(on_ps) > 0:
losses.append(self._cal_loss(1, True, on_ps))
if torch.numel(off_ps) > 0:
losses.append(self._cal_loss(1, False, off_ps))

return torch.mean(torch.stack(losses))

def _calculate_attention_loss(
self,
samples_labels: torch.Tensor,
attention_scores: torch.Tensor,
discrimination_scores: torch.Tensor) -> torch.Tensor:
losses = []

for channel, labels in self._on_labels_by_channel.items():

on_mask = self._get_inclusion_mask(samples_labels, labels, discrimination_scores.device)
on_atts = attention_scores[on_mask, channel, ...]
on_discr = discrimination_scores[on_mask, channel, ...].detach()

off_atts = attention_scores[torch.logical_not(on_mask), channel, ...]
off_discr = discrimination_scores[torch.logical_not(on_mask), channel, ...].detach()

neg_losses = []
pos_losses = []

# loss injection to the model

if torch.numel(off_atts) > 0 and self._neg_ratio > 0:
neg_losses.append(self._cal_loss(
self._neg_ratio, False, off_atts, off_discr, largest=True
))
if torch.numel(on_atts) > 0:

# Calculate positive top k to be positive

if self._pos_ratio_range[0] > 0:
pos_losses.append(self._cal_loss(
self._pos_ratio_range[0], True, on_atts, on_discr, True
))

# Calculate positive bottom k to be negative
if self._pos_ratio_range[1] < 1:

neg_losses.append(self._cal_loss(
1 - self._pos_ratio_range[1], False, on_atts, on_discr, False
))

if len(neg_losses) > 0:
losses.append(torch.stack(neg_losses).mean())

if len(pos_losses) > 0:
losses.append(torch.stack(pos_losses).mean())

return torch.stack(losses).mean()

def _get_inclusion_mask(
self,
samples_labels: np.ndarray,
desired_labels: List[int], device: torch.device) -> torch.Tensor:
with torch.no_grad():
samples_labels = torch.from_numpy(samples_labels).to(device)
inclusion_mask = torch.stack([samples_labels == l for l in desired_labels], dim=0)
aggregation = torch.sum(inclusion_mask.float(), dim=0)
return torch.greater(aggregation, 0)

def _cal_loss(
self,
ratio: float, positive_label: bool,
att_scores: torch.Tensor,
discr_scores: Optional[torch.Tensor] = None,
largest: bool = True):

if ratio == 1:
ps = att_scores

else:

k = np.ceil(
ratio * att_scores.shape[-1] * att_scores.shape[-2]).astype(int)
ps = self._get_topk(att_scores, discr_scores, k, largest=largest)

ps = ps.flatten()

if positive_label:
gt = torch.ones_like(ps)
else:
gt = torch.zeros_like(ps)

return F.binary_cross_entropy(ps, gt)

def _get_topk(self, att_scores: torch.Tensor, discr_scores: torch.Tensor,
k: int, dim=-1, largest=True, return_indices=False) -> torch.Tensor:

scores = self._pixels_scores(att_scores, discr_scores)
b = att_scores.shape[0]
top_inds = (scores.flatten(1)).topk(k, dim=dim, largest=largest, sorted=False).indices
# B K

ret_val = att_scores.flatten(1)[
torch.repeat_interleave(
torch.arange(b, device=att_scores.device), k).reshape(b, k),
top_inds] # B K

if not return_indices:
return ret_val
else:
return ret_val, top_inds

def _pixels_scores(self, attention_scores: torch.Tensor, discr_scores: torch.Tensor) -> torch.Tensor:
return self._w_attention_in_ordering * attention_scores + self._w_discr_in_ordering * discr_scores

+ 61
- 0
torchlap/data/batch_choosing/batch_chooser.py View File

@@ -0,0 +1,61 @@
"""
The BatchChooser
"""
from abc import abstractmethod
from typing import List, TYPE_CHECKING
import numpy as np
if TYPE_CHECKING:
from ..data_loader import DataLoader



class BatchChooser:
"""
The BatchChooser
"""

def __init__(self, data_loader: 'DataLoader', batch_size: int):
""" Receives as input a data_loader which contains the information about the samples and the desired batch size. """

self.data_loader = data_loader
self.batch_size = batch_size
self.current_batch_sample_indices: np.ndarray = np.asarray([], dtype=int)
self.completed_iteration = False

self.class_samples_indices = self.data_loader.get_class_sample_indices()
self._classes_labels: List[int] = list(self.class_samples_indices.keys())

def prepare_next_batch(self):
"""
Prepares the next batch based on its strategy
"""

if self.finished_iteration():
self.reset()

self.current_batch_sample_indices = \
self.get_next_batch_sample_indices().astype(int)

if len(self.current_batch_sample_indices) == 0:
self.completed_iteration = True
return

@abstractmethod
def get_next_batch_sample_indices(self) -> np.ndarray:
pass

def get_current_batch_sample_indices(self) -> np.ndarray:
""" Returns a list of indices of the samples chosen for the current batch. """
return self.current_batch_sample_indices

def finished_iteration(self):
""" Returns True if iteration is finished over all the slices of all the samples, False otherwise"""
return self.completed_iteration

def reset(self):
""" Resets the sample iterator. """
self.completed_iteration = False
self.current_batch_sample_indices = np.asarray([], dtype=int)

def get_current_batch_size(self) -> int:
return len(self.current_batch_sample_indices)

+ 86
- 0
torchlap/data/batch_choosing/class_balanced_shuffled_sequential.py View File

@@ -0,0 +1,86 @@
from typing import TYPE_CHECKING
from .batch_chooser import BatchChooser
import numpy as np
if TYPE_CHECKING:
from ..data_loader import DataLoader


class ClassBalancedShuffledSequentialBatchChooser(BatchChooser):

def __init__(self, data_loader: 'DataLoader', batch_size: int):
super().__init__(data_loader, batch_size)
self._cursor = 0

# just copying to prevent any dependency
for k, v in self.class_samples_indices.items():
self.class_samples_indices[k] = np.copy(v)
np.random.shuffle(self.class_samples_indices[k])
self.classes_cursors = np.zeros(len(self.class_samples_indices), dtype=int)

def get_next_batch_sample_indices(self):
""" Returns a list containing the indices of the samples chosen for the next batch."""

n_samples_per_class = self._get_n_samples_per_class()

def sample_for_each_class(class_index):

class_samples = self.class_samples_indices[
self._classes_labels[class_index]]

if self.classes_cursors[class_index] + \
n_samples_per_class[class_index] <= \
len(class_samples):

ret_val = np.copy(class_samples[
self.classes_cursors[class_index]:
self.classes_cursors[class_index] + n_samples_per_class[class_index]
])
self.classes_cursors[class_index] += n_samples_per_class[class_index]

else:
ret_val = np.copy(class_samples)[
self.classes_cursors[class_index]:]
np.random.shuffle(class_samples)
self.classes_cursors[class_index] = \
n_samples_per_class[class_index] - len(ret_val)
ret_val = np.concatenate((
ret_val,
np.copy(class_samples
[: n_samples_per_class[class_index] - len(ret_val)]
)), axis=0)

if self.classes_cursors[class_index] == len(class_samples):
np.random.shuffle(class_samples)
self.classes_cursors[class_index] = 0

return ret_val

chosen_sample_indices = np.concatenate(tuple([
sample_for_each_class(ci) for ci in range(len(self._classes_labels))
]), axis=0)

return chosen_sample_indices

def _get_n_samples_per_class(self):

# determining the number of samples per class

n_samples_per_class = np.full(len(self.class_samples_indices),
int(self.batch_size // len(self.class_samples_indices)))

remaining_samles_cnt = self.batch_size - np.sum(n_samples_per_class)

if remaining_samles_cnt:

if self._cursor + remaining_samles_cnt >= len(self._classes_labels):
n_samples_per_class[self._cursor: len(self._classes_labels)] += 1
remaining_samles_cnt -= (len(self._classes_labels) - self._cursor)
self._cursor = 0

if remaining_samles_cnt > 0:
n_samples_per_class[self._cursor: self._cursor + remaining_samles_cnt] += 1
self._cursor += remaining_samles_cnt

return n_samples_per_class

+ 38
- 0
torchlap/data/batch_choosing/sequential.py View File

@@ -0,0 +1,38 @@
from typing import TYPE_CHECKING
from .batch_chooser import BatchChooser
import numpy as np
if TYPE_CHECKING:
from ..data_loader import DataLoader


class SequentialBatchChooser(BatchChooser):

def __init__(self, data_loader: 'DataLoader', batch_size: int):
super().__init__(data_loader, batch_size)

self.cursor = 0
self.desired_samples_inds = None

def reset(self):
super(SequentialBatchChooser, self).reset()
self.cursor = 0
self.desired_samples_inds = None

def get_next_batch_sample_indices(self):
""" Returns a list containing the indices of the samples chosen for the next batch."""

if self.desired_samples_inds is None:
self.desired_samples_inds = \
np.arange(self.data_loader.get_number_of_samples())

next_cursor = min(
len(self.desired_samples_inds),
self.cursor + self.batch_size)

next_sample_inds = self.desired_samples_inds[
self.cursor:next_cursor]
self.cursor = next_cursor

return next_sample_inds

+ 154
- 0
torchlap/data/content_loaders/celeba_loader.py View File

@@ -0,0 +1,154 @@
from enum import Enum
import os
import glob
from typing import Callable, TYPE_CHECKING, Union
import imageio

import cv2
import numpy as np
import pandas as pd
from .content_loader import ContentLoader

if TYPE_CHECKING:
from ...configs.celeba_configs import CelebAConfigs


class CelebATag(Enum):
FiveOClockShadowTag = '5_o_Clock_Shadow'
ArchedEyebrowsTag = 'Arched_Eyebrows'
AttractiveTag = 'Attractive'
BagsUnderEyesTag = 'Bags_Under_Eyes'
BaldTag = 'Bald'
BangsTag = 'Bangs'
BigLipsTag = 'Big_Lips'
BigNoseTag = 'Big_Nose'
BlackHairTag = 'Black_Hair'
BlondHairTag = 'Blond_Hair'
BlurryTag = 'Blurry'
BrownHairTag = 'Brown_Hair'
BushyEyebrowsTag = 'Bushy_Eyebrows'
ChubbyTag = 'Chubby'
DoubleChinTag = 'Double_Chin'
EyeglassesTag = 'Eyeglasses'
GoateeTag = 'Goatee'
GrayHairTag = 'Gray_Hair'
HighCheekbonesTag = 'High_Cheekbones'
MaleTag = 'Male'
MouthSlightlyOpenTag = 'Mouth_Slightly_Open'
MustacheTag = 'Mustache'
NarrowEyesTag = 'Narrow_Eyes'
NoBeardTag = 'No_Beard'
OvalFaceTag = 'Oval_Face'
PaleSkinTag = 'Pale_Skin'
PointyNoseTag = 'Pointy_Nose'
RecedingHairlineTag = 'Receding_Hairline'
RosyCheeksTag = 'Rosy_Cheeks'
SideburnsTag = 'Sideburns'
SmilingTag = 'Smiling'
StraightHairTag = 'Straight_Hair'
WavyHairTag = 'Wavy_Hair'
WearingEarringsTag = 'Wearing_Earrings'
WearingHatTag = 'Wearing_Hat'
WearingLipstickTag = 'Wearing_Lipstick'
WearingNecklaceTag = 'Wearing_Necklace'
WearingNecktieTag = 'Wearing_Necktie'
YoungTag = 'Young'

BoundingBoxX = 'x_1'
BoundingBoxY = 'y_1'
BoundingBoxW = 'width'
BoundingBoxH = 'height'

Partition = 'partition'

LeftEyeX = 'lefteye_x'
LeftEyeY = 'lefteye_y'
RightEyeX = 'righteye_x'
RightEyeY = 'righteye_y'

NoseX = 'nose_x'
NoseY = 'nose_y'

LeftMouthX = 'leftmouth_x'
LeftMouthY = 'leftmouth_y'
RightMouthX = 'rightmouth_x'
RightMouthY = 'rightmouth_y'


class SampleField(Enum):
ImageId = 'image_id'
Specification = 'specification'


class CelebALoader(ContentLoader):
def __init__(self, conf: 'CelebAConfigs', data_specification: str):
""" read all directories for scans and annotations, which are split by DataSplitter.py
And then, keeps only those samples which are used for 'usage' """

super().__init__(conf, data_specification)
self.conf = conf
self._datasep = True
self._samples_metadata = self._load_metadata(data_specification)
self._data_root = conf.data_root
self._warning_count = 0

def _load_metadata(self, data_specification: str) -> pd.DataFrame:
if '/' in data_specification:
self._datasep = False
return pd.DataFrame({
SampleField.ImageId.value: glob.glob(os.path.join(data_specification, '*.jpg'))
})
metadata: pd.DataFrame = pd.read_csv(self.conf.dataset_metadata, sep='\t')
metadata = metadata[metadata[SampleField.Specification.value] == data_specification]
metadata = metadata.drop(SampleField.Specification.value, axis=1)
metadata = metadata.reset_index().drop('index', axis=1)
return metadata

def get_samples_names(self):
return self._samples_metadata[SampleField.ImageId.value].values

def get_samples_labels(self):
r""" Dummy. Because we have multiple labels """
return np.ones((len(self._samples_metadata),))\
if self.conf.main_tag is None\
else self._samples_metadata[self.conf.main_tag.value].values

def drop_samples(self, drop_mask: np.ndarray) -> None:
self._samples_metadata = self._samples_metadata[np.logical_not(drop_mask)]

def get_placeholder_name_to_fill_function_dict(self):
""" Returns a dictionary of the placeholders' names (the ones this content loader supports)
to the functions used for filling them. The functions must receive as input data_loader,
which is an object of class data_loader that contains information about the current batch
(e.g. the indices of the samples, or if the sample has many elements the indices of the chosen
elements) and return an array per placeholder name according to the receives batch information.
IMPORTANT: Better to use a fixed prefix in the names of the placeholders to become clear which content loader
they belong to! Some sort of having a mark :))!"""
return {
'x': self._get_x,
**{
tag.name: self._generate_tag_getter(tag) for tag in CelebATag
}
}

def _get_x(self, samples_inds: np.ndarray)\
-> np.ndarray:
images = tuple(self._read_image(index) for index in samples_inds)
images = np.stack(images, axis=0) # B I I 3
images = images.transpose((0, 3, 1, 2)) # B 3 I I
images = images.astype(float) / 255.
return images

def _read_image(self, sample_index: int) -> np.ndarray:
filepath = self._samples_metadata.iloc[sample_index][SampleField.ImageId.value]
if self._datasep:
filepath = os.path.join(self._data_root, filepath)

image = imageio.imread(filepath) # H W 3
image = cv2.resize(image, (self.conf.input_size, self.conf.input_size)) # I I 3
return image

def _generate_tag_getter(self, tag: CelebATag) -> Callable[[np.ndarray, Union[None, np.ndarray]], np.ndarray]:
def get_tag(samples_inds: np.ndarray) -> np.ndarray:
return self._samples_metadata.iloc[samples_inds][tag.value].values if self._datasep else np.zeros(len(samples_inds))
return get_tag

+ 94
- 0
torchlap/data/content_loaders/content_loader.py View File

@@ -0,0 +1,94 @@
"""
A class for loading one content type needed for the run.
"""
from abc import abstractmethod, ABC
from typing import TYPE_CHECKING, Dict, Callable, Union, List

import numpy as np

if TYPE_CHECKING:
from .. import BaseConfig

LoadFunction = Callable[[np.ndarray, Union[None, np.ndarray]], np.ndarray]


class ContentLoader(ABC):
""" A class for loading one content type needed for the run. """

def __init__(self, conf: 'BaseConfig', data_specification: str):
""" Receives as input conf which is a dictionary containing configurations.
This dictionary can also be used to pass fixed addresses for easing the usage!
prefix_name is the str which all the variables that must be filled with this class have
this prefix in their names so it would be clear that this class should fill it.
data_specification is the string that specifies data and where it is! e.g. train, test, val"""

self.conf = conf
self.data_specification = data_specification
self._fill_func_by_var_name: Union[None, Dict[str, LoadFunction]] = None

@abstractmethod
def get_samples_names(self):
""" Returns a list containing names of all the samples of the content loader,
each sample must owns a unique ID, and this function returns all this IDs.
The order of the list must always be the same during one run.
For example, this function can return an ID column of a table for TableLoader
or the dir of images as ID for ImageLoader"""

@abstractmethod
def get_samples_labels(self):
""" Returns list of labels of the whole samples.
The order of the list must always be the same during one run."""

@abstractmethod
def get_placeholder_name_to_fill_function_dict(self) -> Dict[str, LoadFunction]:
""" Returns a dictionary of the placeholders' names (the ones this content loader supports)
to the functions used for filling them. The functions must receive as input
batch_samples_inds and batch_samples_elements_inds which defines the current batch,
and return an array per placeholder name according to the receives batch information.
IMPORTANT: Better to use a fixed prefix in the names of the placeholders to become clear which content loader
they belong to! Some sort of having a mark :))!"""

def fill_placeholders(self,
keys: List[str],
samples_inds: np.ndarray)\
-> Dict[str, Union[None, np.ndarray]]:
""" Receives as input placeholders which is a dictionary of the placeholders'
names to a torch tensor for filling it to feed the model with and
samples_inds and samples_elements_inds, which contain
information about the current batch. Fills placeholders based on the
function dictionary received in get_placeholder_name_to_fill_function_dict."""

if self._fill_func_by_var_name is None:
self._fill_func_by_var_name = self.get_placeholder_name_to_fill_function_dict()

# Filling all the placeholders in the received dictionary!
placeholders = dict()

for placeholder_name in keys:
if placeholder_name in self._fill_func_by_var_name:
placeholder_v = self._fill_func_by_var_name[placeholder_name](
samples_inds)
if placeholder_v is None:
raise Exception('None value for key %s' % placeholder_name)

placeholders[placeholder_name] = placeholder_v
else:
raise Exception(f'Unknown key for content loader: {placeholder_name}')

return placeholders

def get_batch_true_interpretations(self, samples_inds: np.ndarray,
samples_elements_inds: Union[None, np.ndarray, List[np.ndarray]])\
-> np.ndarray:
"""
Receives the indices of samples and their elements (if elemented) in the current batch,
returns the true interpretations as expected (Bounding box or a boolean mask)
"""
raise NotImplementedError('If you want to evaluate interpretations, you must implement this function in your content loader.')

def drop_samples(self, drop_mask: np.ndarray) -> None:
"""
Receives a boolean drop mask and drops the samples whose mask are true,
From now on the indices of samples are based on the new samples set (after elimination)
"""
raise NotImplementedError('If you want to filter samples by mapped labels, you should implement this function in your content loader.')

+ 146
- 0
torchlap/data/content_loaders/imagenet_loader.py View File

@@ -0,0 +1,146 @@
import os
from typing import TYPE_CHECKING, Dict, List, Tuple
import pickle
from PIL import Image
from enum import Enum

import pandas as pd
import numpy as np
import torch
from torchvision import transforms
from torchvision.datasets.folder import default_loader, find_classes, make_dataset, IMG_EXTENSIONS

from .content_loader import ContentLoader
from ...utils.bb_generator import generate_bb_map

if TYPE_CHECKING:
from ...configs.imagenet_configs import ImagenetConfigs


def _get_image_transform(data_specification: str, input_size: int) -> transforms.Compose:
if data_specification == 'train':
return transforms.Compose([
transforms.RandomResizedCrop(input_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
if data_specification in ['test', 'val']:
return transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
])
raise ValueError('Unknown data specification: {}'.format(data_specification))


def load_samples_cache(cache_name: str, imagenet_dir: str) -> List[Tuple[str, int]]:
cache_path = os.path.join('.cache/', cache_name)
print(cache_path)
if os.path.isfile(cache_path):
print('Loading cached samples from {}'.format(cache_path))
with open(cache_path, 'rb') as f:
return pickle.load(f)
print('Creating cache for {}'.format(cache_name))
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
_, class_to_idx = find_classes(imagenet_dir)
samples = make_dataset(imagenet_dir, class_to_idx, IMG_EXTENSIONS)
with open(cache_path, 'wb') as f:
pickle.dump(samples, f)
return samples


class BBoxField(Enum):
ImageId = 'ImageId'
PredictionString = 'PredictionString'


def load_bbox_df(imagenet_root: str, data_specification: str) -> pd.DataFrame:
return pd.read_csv(os.path.join(imagenet_root, f'LOC_{data_specification}_solution.csv'))


class ImagenetLoader(ContentLoader):

def __init__(self, conf: 'ImagenetConfigs', data_specification: str):
super().__init__(conf, data_specification)
imagenet_dir = os.path.join(conf.data_separation, data_specification)
self.__samples = load_samples_cache(f'imagenet.{data_specification}', imagenet_dir)
self.__transform = _get_image_transform(data_specification, conf.input_size)
self.__bboxes = load_bbox_df(conf.data_separation, data_specification)
self.__rng_states: Dict[int, Tuple[str, torch.Tensor]] = {}

def get_samples_names(self):
''' sample names must be unique, they can be either scan_names or scan_dirs.
Decided to put scan_names. No difference'''
return np.array([path for path, _ in self.__samples])

def get_samples_labels(self):
return np.array([label for _, label in self.__samples])

def drop_samples(self, drop_mask: np.ndarray) -> None:
self.__samples = [sample for sample, drop in zip(self.__samples, drop_mask) if not drop]

def get_placeholder_name_to_fill_function_dict(self):
""" Returns a dictionary of the placeholders' names (the ones this content loader supports)
to the functions used for filling them. The functions must receive as input data_loader,
which is an object of class data_loader that contains information about the current batch
(e.g. the indices of the samples, or if the sample has many elements the indices of the chosen
elements) and return an array per placeholder name according to the receives batch information.
IMPORTANT: Better to use a fixed prefix in the names of the placeholders to become clear which content loader
they belong to! Some sort of having a mark :))!"""
return {
'x': self.__get_x,
'y': self.__get_y,
'bbox': self.__get_bbox,
'interpretations': self.__get_bbox,
}

def __get_x(self, samples_inds: np.ndarray)\
-> np.ndarray:
def get_image(index) -> torch.Tensor:
path, _ = self.__samples[index]
sample = default_loader(path)
self.__handle_random_state(index, 'x')
return self.__transform(sample)

return torch.stack(
tuple([get_image(index) for index in samples_inds]),
dim=0)

def __get_y(self, samples_inds: np.ndarray)\
-> np.ndarray:
return self.get_samples_labels()[samples_inds]

def __handle_random_state(self, idx: int, label: str) -> None:
if idx not in self.__rng_states or self.__rng_states[idx][0] == label:
self.__rng_states[idx] = (label, torch.get_rng_state())
else:
torch.set_rng_state(self.__rng_states[idx][1])

def __get_bbox(self, sample_inds: np.ndarray)\
-> np.ndarray:
def extract_bb(prediction_string: str) -> np.ndarray:
splitted = prediction_string.split()
n = len(splitted) // 5
return np.array([[float(b) for b in splitted[5 * i + 1: 5 * i + 5]] for i in range(n)])\
.reshape((-1, 2, 2))

def make_bb(index) -> torch.Tensor:
path, _ = self.__samples[index]
image_id = os.path.basename(path).split('.')[0]
image_size = np.array(Image.open(path).size)[::-1] # 2
bboxes = self.__bboxes[self.__bboxes[BBoxField.ImageId.value] == image_id][BBoxField.PredictionString.value]
if len(bboxes) == 0:
return torch.zeros(1, self.conf.inp_size, self.conf.inp_size) * np.nan
bboxes = bboxes.values[0]
bboxes = extract_bb(bboxes)[..., ::-1] / image_size # N 2 2
start_points = bboxes[:, 0] # N 2
end_points = bboxes[:, 1] # N 2
bb_map = generate_bb_map(start_points, end_points, tuple(image_size))
bb_map = Image.fromarray(bb_map)
self.__handle_random_state(index, 'bb')
return self.__transform(bb_map)

return torch.stack([make_bb(i) for i in sample_inds], dim=0)

+ 329
- 0
torchlap/data/content_loaders/rsna_loader.py View File

@@ -0,0 +1,329 @@
from typing import Dict, List, TYPE_CHECKING, Optional, Tuple
from os import path, listdir
from sys import stderr
from functools import partial

import cv2
import numpy as np
import pandas as pd
from .content_loader import ContentLoader

from ...utils.bb_generator import generate_bb_map
if TYPE_CHECKING:
from ...configs.rsna_configs import RSNAConfigs


class RSNALoader(ContentLoader):
def __init__(self, conf: 'RSNAConfigs', data_specification: str):
""" read all directories for scans and annotations, which are split by DataSplitter.py
And then, keeps only those samples which are used for 'usage' """

super().__init__(conf, data_specification)
self.conf = conf
self.img_size = conf.input_size
self.infection_map_size = conf.infection_map_size
self.samples_info = None

self.load_samples(data_specification)
self.loaded_images = None

self._infections_bbs: Optional[Dict[str, Tuple[List[Tuple[float, float], List[Tuple[float, float]]]]]] = None
self._load_infections()

def load_samples(self, data_specification):

if ':' in data_specification:
data_specification, filter_str = data_specification.split(':')
else:
filter_str = None

if data_specification in ['train', 'val', 'test']:
self.samples_info = \
pd.read_csv('%s/%s.txt' %
(
self.conf.data_separation, data_specification), sep='\t', header=0)
self.samples_info['sample'] = self.samples_info['sample'].apply(lambda s: s[len('../'):])
self.samples_info['Interpretation_dir'] = self.samples_info['Interpretation_dir'].apply(lambda s: s[len('../'):])

elif path.isfile(data_specification) and path.exists(data_specification):
self.samples_info = pd.read_csv(data_specification, sep='\t', header=0)
self.samples_info['sample'] = self.samples_info['sample'].apply(lambda s: s[len('../'):])
self.samples_info['Interpretation_dir'] = self.samples_info['Interpretation_dir'].apply(lambda s: s[len('../'):])

elif path.isdir(data_specification) and path.exists(data_specification):

samples_names = np.asarray(
[data_specification + '/' + x for x in listdir(data_specification)])
print('%d samples discovered in %s' % (len(samples_names), data_specification))

if filter_str is not None:
pass_filter_str = np.vectorize(lambda x: '/%s/' % filter_str in x)
samples_names = samples_names[pass_filter_str(samples_names)]
print('%d samples remained after filtering' % len(samples_names))

self.samples_info = pd.DataFrame({
'label': np.full(len(samples_names), -1, dtype=int),
'sample': samples_names,
'view': np.full(len(samples_names), 'Unknown', dtype=np.object),
'dataset': np.full(len(samples_names), 'unknown', dtype=np.object),
'Interpretation_dir': np.full(len(samples_names), '', dtype=np.object),

})
return

else:
print('Please implement this part!', flush=True)

org_cnt = len(self.samples_info)

# filtering by filter string!
if filter_str is not None:
pass_filter_str = np.vectorize(lambda x: '/%s/' % filter_str in x)
self.samples_info = self.samples_info[pass_filter_str(self.samples_info['sample'].values)]

print(
'%d of %d samples remained after filtering for %s!' % (len(self.samples_info), org_cnt, data_specification),
flush=True)
org_cnt = len(self.samples_info)

# label mapping/ordering!
label_mapping_dict = self.conf.label_map_dict

v_map_label = np.vectorize(lambda l: label_mapping_dict.get(l.lower(), -1))
self.samples_info['label'] = v_map_label(self.samples_info['label'].values)

# filtering unmapped labels
self.samples_info = self.samples_info[self.samples_info['label'].values != -1]
print('%d out of %d samples remained after label mapping!' %
(len(self.samples_info), org_cnt), flush=True)

# counting the labels!
flattened_labels = self.samples_info['label'].values
u_labels = np.unique(flattened_labels)
labels_cnt = np.zeros(len(u_labels))
np.add.at(labels_cnt, np.searchsorted(u_labels, flattened_labels, side='left'), 1)
print('[%s]' % ', '.join(['class-%s: %d' % (str(u_labels[i]), labels_cnt[i]) for i in range(len(labels_cnt))]),
flush=True)

def _load_infections(self) -> None:
"""
If a file address is given as infections_bb_dir in configs
(containing imgs and their bounding boxes as a str)
the information related to bounding boxes will be taken,
otherwise it is expected from the data separation to have a column
indicating the address of the infection mask
"""
if self.conf.infections_bb_dir is not None:
bbs_info = pd.read_csv(self.conf.infections_bb_dir, sep='\t', header=0)

imgs_dirs = np.vectorize(self._get_id_by_path)(bbs_info['img_dir'].values)
imgs_bbs = [
[] if '-' not in str(im_bbs) else [
tuple([np.clip(float(x), 0, 1) for x in im_bb.split('-')])
for im_bb in im_bbs.split(',')
]
for im_bbs in bbs_info['img_bb'].values]

# reformatting to tuple of list of starts and list of ends
imgs_bbs_starts = [[
(r1, c1) for r1, c1, _, _ in img_bbs]
for img_bbs in imgs_bbs]

imgs_bbs_ends = [[
(r2, c2) for _, _, r2, c2 in img_bbs]
for img_bbs in imgs_bbs]

imgs_bbs = [
(imgs_bbs_starts[i], imgs_bbs_ends[i])
for i in range(len(imgs_bbs_starts))]

self._infections_bbs = dict(zip(imgs_dirs, imgs_bbs))

def get_samples_names(self):
return self.samples_info['sample'].values

def get_samples_labels(self):
return self.samples_info['label'].values

def drop_samples(self, drop_mask: np.asarray) -> None:
""" Keep_mask is a bool array for keeping samples """
self.samples_info = self.samples_info[np.logical_not(drop_mask)]

def get_placeholder_name_to_fill_function_dict(self):
ret_val = {
'x': self.get_batch_scaled_images,

'y': self.get_batch_label,

'infection': partial(self.get_batch_extended_bbs_map_from_bbs_file, 0) if
self._infections_bbs is not None else
self.get_batch_bbs_map_from_mask_file,

'interpretations': partial(self.get_batch_extended_bbs_map_from_bbs_file, 0) if
self._infections_bbs is not None else
self.get_batch_bbs_map_from_mask_file,

'has_bb': self.get_batch_has_bb_mask_from_bb_file if
self._infections_bbs is not None else
self.get_batch_has_bb_mask_from_mask_file
}

# if infection radii is not empty, it is expected to have a dictionary
if len(self.conf.receptive_field_radii) > 0:
assert self._infections_bbs is not None, \
"When having receptive field radii, " \
"you should provide bounding boxes as a file in config.infections_bb_dir"

ret_val.update({
f'ex_infection_{r}': partial(self.get_batch_extended_bbs_map_from_bbs_file, r)
for r in self.conf.receptive_field_radii
})

return ret_val

def get_batch_has_bb_mask_from_mask_file(self, samples_inds: np.ndarray) \
-> np.ndarray:

assert 'Interpretation_dir' in self.samples_info.columns, \
'If you have not specified infections_bb_dirin your config, ' \
'you need to have Interpretation_dir column in your data separation'

def has_inf(si):
if self.samples_info.iloc[si]['label'] == 0:
return True
int_dir = self.samples_info.iloc[si]['Interpretation_dir']
return str(int_dir) != 'nan'

return np.vectorize(has_inf)(samples_inds)

def get_batch_has_bb_mask_from_bb_file(self, samples_inds: np.ndarray) \
-> np.ndarray:

def has_inf(si):
if self.samples_info.iloc[si]['label'] == 0:
return True
sample_key = self.samples_info.iloc[si]['sample']
sample_key = self._get_id_by_path(sample_key)
return sample_key in self._infections_bbs and \
len(self._infections_bbs[sample_key][0]) > 0

return np.vectorize(has_inf)(samples_inds)

def get_batch_bbs_map_from_mask_file(self, samples_inds: np.ndarray)\
-> np.ndarray:

def read_interpretation(im_ind):

# for healthy, return full 0
if self.samples_info.iloc[im_ind]['label'] == 0:
return np.full((1, self.infection_map_size, self.infection_map_size), 0, dtype=np.float)

int_dir = self.samples_info.iloc[im_ind]['Interpretation_dir']
# if it does not exist, return full -1
if str(int_dir) == 'nan':
return np.full((1, self.infection_map_size, self.infection_map_size), -1, dtype=np.float)

if 'npy' in int_dir:
interpretation = np.load(int_dir)
else:
interpretation = np.load(int_dir)['arr_0']

interpretation = interpretation.astype(float)

if interpretation.shape != (self.infection_map_size, self.infection_map_size):
interpretation = (cv2.resize(np.round(255 * interpretation, 0),
dsize=(self.infection_map_size, self.infection_map_size)) >= 128).astype(float)
return interpretation[np.newaxis, :, :]

batch_interpretations = np.stack(tuple([read_interpretation(si)
for si in samples_inds]), axis=0)

return batch_interpretations

@staticmethod
def _get_id_by_path(path: str) -> str:
return path[path.index('png_images/'):]

def get_batch_extended_bbs_map_from_bbs_file(self, radius, samples_inds: np.ndarray) -> np.ndarray:

def make_map(im_ind):

if self.samples_info.iloc[im_ind]['label'] == 0:
return np.full((1, self.infection_map_size, self.infection_map_size), 0, dtype=np.float)

sample_key = self.samples_info.iloc[im_ind]['sample']
sample_key = self._get_id_by_path(sample_key)

if sample_key in self._infections_bbs and \
len(self._infections_bbs[sample_key][0]) > 0:
bbs_info = self._infections_bbs[sample_key]
start_points, end_points = bbs_info
mask = generate_bb_map(start_points, end_points,
(self.infection_map_size, self.infection_map_size), radius)
return mask[np.newaxis, :, :]
else:
return np.full((1, self.infection_map_size, self.infection_map_size), -1, dtype=np.float)

batch_interpretations = np.stack(tuple([make_map(si)
for si in samples_inds]), axis=0)

return batch_interpretations

def get_batch_scaled_images(self,
samples_inds: np.ndarray) \
-> np.ndarray:

def read_img(im_ind):

im = cv2.imread(self.samples_info.iloc[im_ind]['sample'])
if im is None:
print(self.samples_info.iloc[im_ind]['sample'] + ' is missing!', file=stderr)
raise Exception('Missing image')

if len(list(im.shape)) == 3:
ret_val = im[:, :, 0]
else:
ret_val = im

if ret_val.shape != (self.img_size, self.img_size):
ret_val = cv2.resize(ret_val, dsize=(self.img_size, self.img_size))
return ret_val[np.newaxis, :, :]

batch_imgs = np.stack(tuple([read_img(si)
for si in samples_inds]), axis=0)
batch_imgs = batch_imgs.astype(np.float32) / 255
return batch_imgs

def get_batch_label(self,
samples_inds: np.ndarray,
) \
-> np.ndarray:
return self.samples_info['label'].iloc[samples_inds].values.astype(np.float32)

def get_batch_true_interpretations(self, samples_inds: np.ndarray) \
-> np.ndarray:

def read_interpretation(im_ind):

# for healthy, return full 0
if self.samples_info.iloc[im_ind]['label'] == 0:
return np.full((1, self.img_size, self.img_size), 0, dtype=np.float)

int_dir = self.samples_info.iloc[im_ind]['Interpretation_dir']
if str(int_dir) == 'nan':
return np.full((1, self.img_size, self.img_size), np.nan, dtype=np.float)

if 'npy' in int_dir:
interpretation = np.load(int_dir)
else:
interpretation = np.load(int_dir)['arr_0'].astype(int)

if interpretation.shape != (self.img_size, self.img_size):
interpretation = (cv2.resize(np.round(255 * interpretation, 0).astype(np.uint8),
dsize=(self.img_size, self.img_size)) >= 128).astype(float)
return interpretation[np.newaxis, :, :]

batch_interpretations = np.stack(tuple([read_interpretation(si)
for si in samples_inds]), axis=0)

return batch_interpretations

+ 246
- 0
torchlap/data/data_loader.py View File

@@ -0,0 +1,246 @@
""" A class for loading the whole data needed for the run! """
from typing import Dict, List, Union, TYPE_CHECKING
import random
import numpy as np
import torch

from .batch_choosing.batch_chooser import BatchChooser
from .content_loaders.content_loader import ContentLoader
if TYPE_CHECKING:
from ..configs.base_config import BaseConfig


class RunType:
TRAIN = 'train'
VAL = 'val'
TEST = 'test'


class DataLoader():
""" A class for loading the whole data needed for the run! """

def __init__(self, conf: 'BaseConfig', data_specification: str, run_type: RunType):
""" Conf is a dictionary containing configurations,
sample specification is a string specifying the samples e.g. address of the CTs
run type is one of the strings train/val/test specifying the mode that the data_loader
will be used."""

self.conf = conf
self._device = conf.device
self.data_specification = data_specification
self.run_type = run_type

self._content_loader: ContentLoader = self.conf.content_loader_cls(self.conf, data_specification)

self.samples_names: np.ndarray = self._content_loader.get_samples_names()
self.samples_labels = self._content_loader.get_samples_labels()

keep_mask = self.get_samples_keep_mask()
if keep_mask is not None:
drop_mask = np.logical_not(keep_mask)
self._content_loader.drop_samples(drop_mask)
self.samples_names = self.samples_names[keep_mask]
self.samples_labels = self.samples_labels[keep_mask]

self.class_samples_indices = dict()
print('%d samples' % len(self.samples_names), flush=True)
for i in range(len(self.samples_names)):
if self.samples_labels[i] not in self.class_samples_indices:
self.class_samples_indices[self.samples_labels[i]] = []
self.class_samples_indices[self.samples_labels[i]].append(i)

for c in self.class_samples_indices.keys():
self.class_samples_indices[c] = np.asarray(self.class_samples_indices[c])

# for augmentationa
# Only do augmentations in the training phase
self._different_augmentation_per_batch = conf.different_augmentation_per_batch
self._augmentations_dict: Union[None, Dict[str, torch.nn.Module]] = None
if run_type == RunType.TRAIN:
self._augmentations_dict = conf.augmentations_dict

# for preventing reprocessing
self._processed_data_dictionary: Dict[str, torch.Tensor] = dict()

# for preserving same augmentations in one batch
self._transformations_seeds_dict: Dict[torch.nn.Module, Union[int, np.ndarray]] = dict()

if run_type == RunType.TRAIN:
self._batch_chooser: BatchChooser = self.conf.train_batch_chooser_cls(self, conf.batch_size)
else:
self._batch_chooser: BatchChooser = self.conf.eval_batch_chooser_cls(self, conf.batch_size)

def get_samples_names(self) -> np.ndarray:
""" Returns the names of the samples based on the first content loader.
IMPORTANT: These all must be the same for all of the content loaders."""
return self.samples_names

def get_samples_labels(self):
""" Returns the labels of the samples in a numpy array. """
return self.samples_labels

def get_number_of_samples(self):
""" Returns the number of samples loaded. """
return len(self.samples_names)

def get_class_sample_indices(self):
""" Returns a dictionary, containing lists of samples indices belonging to each class label."""
return self.class_samples_indices

def prepare_next_batch(self):
# Resetting information
self._processed_data_dictionary = dict()
self._transformations_seeds_dict = dict()
self._batch_chooser.prepare_next_batch()

def finished_iteration(self) -> bool:
return self._batch_chooser.finished_iteration()

def fill_placeholders(self, placeholders_dict: Dict[str, torch.Tensor]) -> None:
""" Receives as input a dictionary of placeholders and fills them using all the content loaders."""

missed_placeholders_names: List[str]
if len(self._processed_data_dictionary) == 0:
missed_placeholders_names = list(placeholders_dict.keys())
else:
missed_placeholders_names = []
for k, v in placeholders_dict.items():
if k in self._processed_data_dictionary:
placeholders_dict[k] = self._processed_data_dictionary[k]
else:
missed_placeholders_names.append(k)

if len(missed_placeholders_names) == 0:
return

new_batch_info = self._content_loader.fill_placeholders(
missed_placeholders_names, self._batch_chooser.get_current_batch_sample_indices())

# filling the unfilled ones
if len(missed_placeholders_names) > 0:

# filling all missed keys!
with torch.no_grad():
for k in missed_placeholders_names:

self._fill_placeholder(placeholders_dict[k], new_batch_info[k])

if self._augmentations_dict is not None and k in self._augmentations_dict:
placeholders_dict[k] = self._apply_augmentation(k, placeholders_dict[k])

# Keeping a copy of data
self._processed_data_dictionary[k] = placeholders_dict[k]

def get_current_batch_data(self, keyword: str) -> torch.Tensor:
""" Builds placeholder and retrieves the requirement from loaders and returns it!

Args:
keyword (str): The keyword to load for the current batch.

Returns:
torch.Tensor: The information related to the keyword for the current batch.
"""
if keyword in self._processed_data_dictionary:
return self._processed_data_dictionary[keyword]
else:
placeholders_dict = {keyword: create_empty_placeholder(self._device)}
self.fill_placeholders(placeholders_dict)
return placeholders_dict[keyword]

def get_current_batch_sample_indices(self) -> np.ndarray:
""" Returns a list of indices of the samples chosen for the current batch. """
return self._batch_chooser.get_current_batch_sample_indices()

def get_max_class_samples_num(self):
return max([len(x) for x in self.class_samples_indices.values()])

def get_classes_num(self):
return len(self.class_samples_indices)

def get_samples_keep_mask(self) -> Union[None, np.ndarray]:
if self.conf.mapped_labels_to_use is None:
return None

existing_dict = dict(zip(self.conf.mapped_labels_to_use, self.conf.mapped_labels_to_use))
keep_mask = np.vectorize(lambda x: x in existing_dict)(self.samples_labels)

print(f'Kept {np.sum(keep_mask)} out of {len(keep_mask)} samples after considering the labels of interest.')

return keep_mask

def get_current_batch_size(self) -> int:
return self._batch_chooser.get_current_batch_size()

def get_current_batch_samples_names(self) -> np.ndarray:
return self.samples_names[self._batch_chooser.get_current_batch_sample_indices()]

def get_current_batch_samples_labels(self) -> np.ndarray:
return self.samples_labels[self._batch_chooser.get_current_batch_sample_indices()]

def get_current_batch_samples_interpretations(self) -> np.ndarray:
return self.get_current_batch_data('interpretations')

@staticmethod
def _fill_placeholder(placeholder: torch.Tensor, val: np.ndarray):
""" Fills the torch placeholder with the given numpy value,
If shape mismatches, resizes the placeholder so the data would fit. """

# Resize if the shape mismatches
if list(placeholder.shape) != list(val.shape):
placeholder.resize_(*tuple(val.shape))

# feeding the value
placeholder.copy_(torch.Tensor(val))

def _apply_augmentation(self, var_name: str, var_val: torch.Tensor) -> torch.Tensor:
""" This function would be called for the variables we are sure they need augmentation
(they are presented in augmentation dictionary), applies the specified augmentation,
makes sure all augmentations be the same on the same batch elements,
and returns the augmented data"""

def run_single_aug(aug, inp):
if type(aug) in self._transformations_seeds_dict:
seed = self._transformations_seeds_dict[type(aug)]
else:
# setting a seed
seed = np.random.randint(2147483647) # make a seed with numpy generator
self._transformations_seeds_dict[type(aug)] = seed

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

if self._different_augmentation_per_batch:
ret_val = torch.stack([aug(inp[bi]) for bi in range(inp.shape[0])], dim=0)
else:
ret_val = aug(inp)
return ret_val

field_transformation = self._augmentations_dict[var_name]

if not isinstance(field_transformation, torch.nn.Sequential):
return run_single_aug(field_transformation, var_val)

else:
aug_val = var_val
for single_transform in field_transformation.children():
aug_val = run_single_aug(
single_transform,
aug_val)

return aug_val

def create_empty_placeholder(device: torch.device) -> torch.Tensor:
""" Create empty placeholder with shape (1) in the given device.

Args:
device (torch.device): Device to create the placeholder on.

Returns:
torch.Tensor: Empty placeholder.
"""
return torch.zeros(1, dtype=torch.float32,
device=device,
requires_grad=False)

+ 98
- 0
torchlap/data/dataflow.py View File

@@ -0,0 +1,98 @@
"""
Runs flow of data, from loading to forwarding the model.
"""
from typing import Iterator, Union, TypeVar, Generic, Dict

import torch

from ..data.data_loader import DataLoader, create_empty_placeholder
from ..data.dataloader_context import DataloaderContext
from ..models.model import Model


TReturn = TypeVar('TReturn')


class DataFlow(Generic[TReturn]):
"""
Runs flow of data, from loading to forwarding the model.
"""

def __init__(self, model: 'Model', dataloader: 'DataLoader',
device: torch.device, print_debug_info:bool = False):
self._model = model
self._device = device
self._dataloader = dataloader
self._placeholders: Union[None, Dict[str, torch.Tensor]] = None
self._print_debug_info = print_debug_info

def iterate(self) -> Iterator[TReturn]:
"""
Iterates on data and forwards the model
"""

if self._placeholders is None:
raise Exception("Please use `with` statement before calling iterate method:\n" +
"with dataflow:\n" +
" for model_output in dataflow.iterate():\n" +
" pass\n")

DataloaderContext.instance.dataloader = self._dataloader

while True:
self._dataloader.prepare_next_batch()

if self._print_debug_info:
print('> Next Batch:')
print('\t' + '\n\t'.join(self._dataloader.get_current_batch_samples_names()))

# check if the iteration is done
if self._dataloader.finished_iteration():
break

# Filling the placeholders for running the model
self._dataloader.fill_placeholders(self._placeholders)

# Running the model
yield self._final_stage(self._placeholders)

def __enter__(self):
# self._dataloader.reset()
self._placeholders = self._build_placeholders()

def __exit__(self, exc_type, exc_value, traceback):
for name in self._placeholders:
self._placeholders[name] = None
self._placeholders = None

def _final_stage(self, model_input: Dict[str, torch.Tensor]) -> TReturn:
return self._model(**model_input)

def _build_placeholders(self) -> Dict[str, torch.Tensor]:
""" Creates placeholders for feeding the model.
Returns a dictionary containing placeholders.
The placeholders are created based on the names
of the input variables of the model's forward method.
The variables initialized as None are assumed to be
labels used for training phase only, and won't be
added in phases other than train."""

model_forward_args, args_default_values, additional_kwargs = self._model.get_forward_required_kws_kwargs_and_defaults()

# Checking which args have None as default
args_default_values = ['Mandatory' for _ in
range(len(model_forward_args) - len(args_default_values))] + \
args_default_values

optional_args = [x is None for x in args_default_values]

placeholders = dict()

for argname in model_forward_args:
placeholders[argname] = create_empty_placeholder(self._device)

for argname in additional_kwargs:
placeholders[argname] = create_empty_placeholder(self._device)

return placeholders

+ 28
- 0
torchlap/data/dataloader_context.py View File

@@ -0,0 +1,28 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .data_loader import DataLoader


class classproperty(property):
def __get__(self, obj, objtype=None):
return super(classproperty, self).__get__(objtype)

def __set__(self, obj, value):
super(classproperty, self).__set__(type(obj), value)

def __delete__(self, obj):
super(classproperty, self).__delete__(type(obj))


class DataloaderContext:
_instance: 'DataloaderContext' = None

def __init__(self) -> None:
self.dataloader: 'DataLoader' = None

@classproperty
def instance(cls) -> 'DataloaderContext':
if cls._instance is None:
cls._instance = DataloaderContext()
return cls._instance

+ 0
- 0
torchlap/experiments/__init__.py View File


+ 10
- 0
torchlap/experiments/_entry_loader.py View File

@@ -0,0 +1,10 @@
from typing import Type
from .entrypoint import BaseEntrypoint
from ..configs.base_config import PhaseType
from importlib import import_module
import sys


def get_entry(phase_type: PhaseType):
EntryPoint: Type[BaseEntrypoint] = import_module(f'torchlap.experiments.{sys.argv[1]}').EntryPoint
return EntryPoint(phase_type)

+ 68
- 0
torchlap/experiments/_model_loading.py View File

@@ -0,0 +1,68 @@
from os import path
from typing import TYPE_CHECKING
import torch
from ..models.model import Model
if TYPE_CHECKING:
from ..configs.base_config import BaseConfig


def load_model(the_model: Model, conf: 'BaseConfig') -> Model:

dev_name = conf.dev_name

if conf.final_model_dir is not None:
load_dir = conf.final_model_dir
if not path.exists(load_dir):
raise Exception('No path %s' % load_dir)
else:
load_dir = conf.save_dir

print('>>> loading from: ' + load_dir, flush=True)
if not path.exists(load_dir):
raise Exception('Problem in loading the model. %s does not exist!' % load_dir)

map_location = None
if dev_name == 'cpu':
map_location = 'cpu'
elif dev_name is not None and ':' in dev_name:
map_location = dev_name

the_model.to(conf.device)

if path.isfile(load_dir):

if not path.exists(load_dir):
raise Exception('Problem in loading the model. %s does not exist!' % load_dir)

the_model.load_state_dict(torch.load(load_dir, map_location=map_location))
print('Loaded the model at %s' % load_dir, flush=True)

else:
if conf.epoch is not None:
epoch = int(conf.epoch)
elif path.exists(load_dir + '/GeneralInfo'):
checkpoint = torch.load(load_dir + '/GeneralInfo', map_location=map_location)
epoch = checkpoint['best_val_epoch']
print(f'Epoch has not been specified in the config, '
f'using {epoch} which is best_val_epoch instead')
else:
raise Exception('Either epoch or pt dir (final_model) must be given as GeneralInfo was not found')

if not path.exists('%s/%d' % (load_dir, epoch)):
raise Exception('Problem in loading the model. %s/%d does not exist!' %
(load_dir, epoch))

checkpoint = torch.load('%s/%d' % (load_dir, epoch),
map_location=map_location)

# backward compatibility
if 'model_state_dict' in checkpoint:
checkpoint = checkpoint['model_state_dict']

the_model.load_state_dict(checkpoint)

print(f'Loaded the model from epoch {epoch} of {load_dir}', flush=True)

the_model.to(conf.device)

return the_model

+ 0
- 0
torchlap/experiments/celeba/__init__.py View File


+ 20
- 0
torchlap/experiments/celeba/org_inception.py View File

@@ -0,0 +1,20 @@
from typing import Tuple
from ...configs.celeba_configs import CelebAConfigs, CelebATag
from ...models.model import Model
from ..entrypoint import BaseEntrypoint
from ...models.celeba.single_tag_org_inception import CelebAORGInception


class EntryPoint(BaseEntrypoint):
def _get_conf_model(self) -> Tuple[CelebAConfigs, Model]:
conf = CelebAConfigs('CelebA_Inception_Org', 8, 299, self.phase_type)
conf.max_epochs = 12
conf.main_tag = CelebATag.SmilingTag
conf.tags = [conf.main_tag.name]

model = CelebAORGInception(conf.main_tag.name, 0.4)

return conf, model


+ 20
- 0
torchlap/experiments/celeba/org_resnet.py View File

@@ -0,0 +1,20 @@
from typing import Tuple
from ...configs.celeba_configs import CelebAConfigs, CelebATag
from ...models.model import Model
from ..entrypoint import BaseEntrypoint
from ...models.celeba.org_resnet import CelebAORGResNet18


class EntryPoint(BaseEntrypoint):

def _get_conf_model(self) -> Tuple[CelebAConfigs, Model]:
conf = CelebAConfigs('CelebA_ResNet_Org', 9, 224, self.phase_type)
conf.max_epochs = 12
conf.main_tag = CelebATag.SmilingTag
conf.tags = [conf.main_tag.name]

model = CelebAORGResNet18(conf.main_tag.name)

return conf, model


+ 143
- 0
torchlap/experiments/celeba/ws_lap_inception.py View File

@@ -0,0 +1,143 @@
from typing import Tuple
from collections import OrderedDict
from ...configs.celeba_configs import CelebAConfigs, CelebATag
from ...models.model import Model
from ..entrypoint import BaseEntrypoint
from ...models.celeba.lap_inception import CelebALAPInception
from ...modules.lap import LAP
from ...modules.adaptive_lap import AdaptiveLAP
from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator
from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss
from ...utils.aux_output import AuxOutput
from ...utils.output_modifier import OutputModifier
from ...model_evaluation.binary_evaluator import BinaryEvaluator
from ...model_evaluation.binary_fortelling import BinForetellerEvaluator
from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator
from ...model_evaluation.loss_evaluator import LossEvaluator
from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator


def lap_factory(channel):
return LAP(
channel, 2, 2, hidden_channels=[8], n_attention_heads=3,
sigmoid_scale=0.1, discriminative_attention=True)


def adaptive_lap_factory(channel):
return AdaptiveLAP(channel, [], 0.1,
n_attention_heads=3, discriminative_attention=True)


class EntryPoint(BaseEntrypoint):

def __init__(self, phase_type) -> None:
self.active_min_ratio: float = 0.02
self.active_max_ratio: float = 0.4
self.inactive_ratio: float = 0.01
self.common_max_ratio = 0.02

super().__init__(phase_type)
def _get_conf_model(self) -> Tuple[CelebAConfigs, Model]:
conf = CelebAConfigs('CelebA_Inception_WS', 10, 256, self.phase_type)
conf.max_epochs = 12
conf.main_tag = CelebATag.SmilingTag
conf.tags = [conf.main_tag.name]

model = CelebALAPInception(conf.main_tag.name, 0.4, lap_factory, adaptive_lap_factory)

# aux loss for free head

fws_losses = [
DiscriminativeWeaklySupervisedLoss(
title, model, att_score_layer, 0.025 / 3, 0,
(0, self.common_max_ratio),
{2: [0, 1]}, discr_score_layer,
w_attention_in_ordering=1, w_discr_in_ordering=0)
for title, att_score_layer, discr_score_layer in [
('FPool2', model.maxpool2.attention_layer, model.maxpool2.discrimination_layer),
('FPool6', model.Mixed_6a.pool.attention_layer, model.Mixed_6a.pool.discrimination_layer),
('FPool7', model.Mixed_7a.pool.attention_layer, model.Mixed_7a.pool.discrimination_layer),
('FPoolG', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer),
]
]
for fws_loss in fws_losses:
fws_loss.configure(conf)

# weakly supervised losses based on discriminative head and attention head

ws_losses = [
DiscriminativeWeaklySupervisedLoss(
title, model, att_score_layer, 0.025 * 2 / 3,
self.inactive_ratio,
(self.active_min_ratio, self.active_max_ratio), {0: [0], 1: [1]},
discr_score_layer,
w_attention_in_ordering=0.2, w_discr_in_ordering=1)
for title, att_score_layer, discr_score_layer in [
('Pool2', model.maxpool2.attention_layer, model.maxpool2.discrimination_layer),
('Pool6', model.Mixed_6a.pool.attention_layer, model.Mixed_6a.pool.discrimination_layer),
('Pool7', model.Mixed_7a.pool.attention_layer, model.Mixed_7a.pool.discrimination_layer),
('PoolG', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer),
]
]
for ws_loss in ws_losses:
ws_loss.configure(conf)

# concordance loss for attention

concordance_loss = PoolConcordanceLossCalculator(
'AC', model, OrderedDict([
('att2', model.maxpool2.attention_layer),
('att6', model.Mixed_6a.pool.attention_layer),
('att7', model.Mixed_7a.pool.attention_layer),
('attG', model.avgpool[0].attention_layer),
]), loss_weight=1, weights=0.1 / 4, diff_thresholds=0,
labels_by_channel={0: [0], 1: [1]})
concordance_loss.configure(conf)

# concordance loss for discrimination head

concordance_loss2 = PoolConcordanceLossCalculator(
'DC', model, OrderedDict([
('D-att2', model.maxpool2.discrimination_layer),
('D-att6', model.Mixed_6a.pool.discrimination_layer),
('D-att7', model.Mixed_7a.pool.discrimination_layer),
('D-attG', model.avgpool[0].discrimination_layer),
]), loss_weight=1, weights=0.05 / 4, diff_thresholds=0,
labels_by_channel={0: [0], 1: [1]})
concordance_loss2.configure(conf)


conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({
'b': BinaryEvaluator,
'l': LossEvaluator,
'f': BinForetellerEvaluator.standard_creator('foretell'),
'bf': BinFaithfulnessEvaluator.standard_creator('foretell'),
}))
conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc'

###################################
########### Foreteller ############
###################################

aux = AuxOutput(model, dict(
foretell_pool2=model.maxpool2.attention_layer,
foretell_pool6=model.Mixed_6a.pool.attention_layer,
foretell_pool7=model.Mixed_7a.pool.attention_layer,
foretell_avgpool=model.avgpool[0].attention_layer,
))
aux.configure(conf)


output_modifier = OutputModifier(model,
lambda x: (x - 0.5).relu().flatten(2).sum(dim=2)[:, :2].argmax(dim=1),
'foretell_pool2',
'foretell_pool6',
'foretell_pool7',
'foretell_avgpool',
)
output_modifier.configure(conf)

return conf, model


+ 131
- 0
torchlap/experiments/celeba/ws_lap_resnet.py View File

@@ -0,0 +1,131 @@
from typing import Tuple
from collections import OrderedDict

from ...configs.celeba_configs import CelebAConfigs, CelebATag
from ...models.model import Model
from ..entrypoint import BaseEntrypoint
from ...models.celeba.lap_resnet import CelebALAPResNet18
from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator
from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss
from ...utils.aux_output import AuxOutput
from ...utils.output_modifier import OutputModifier
from ...model_evaluation.binary_evaluator import BinaryEvaluator
from ...model_evaluation.binary_fortelling import BinForetellerEvaluator
from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator
from ...model_evaluation.loss_evaluator import LossEvaluator
from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator


class EntryPoint(BaseEntrypoint):

def __init__(self, phase_type) -> None:
self.active_min_ratio: float = 0.02
self.active_max_ratio: float = 0.4
self.inactive_ratio: float = 0.01
self.common_max_ratio = 0.02

super().__init__(phase_type)
def _get_conf_model(self) -> Tuple[CelebAConfigs, Model]:
conf = CelebAConfigs('CelebA_ResNet_WS', 11, 224, self.phase_type)
conf.max_epochs = 12
conf.main_tag = CelebATag.SmilingTag
conf.tags = [conf.main_tag.name]

model = CelebALAPResNet18(conf.main_tag.name, sigmoid_scale=0.1)

# aux loss for free head

fws_losses = [
DiscriminativeWeaklySupervisedLoss(
title, model, att_score_layer, 0.025 / 3, 0,
(0, self.common_max_ratio),
{2: [0, 1]}, discr_score_layer,
w_attention_in_ordering=1, w_discr_in_ordering=0)
for title, att_score_layer, discr_score_layer in [
('Fatt2', model.layer2[0].pool.attention_layer, model.layer2[0].pool.discrimination_layer),
('Fatt3', model.layer3[0].pool.attention_layer, model.layer3[0].pool.discrimination_layer),
('Fatt4', model.layer4[0].pool.attention_layer, model.layer4[0].pool.discrimination_layer),
('Fatt5', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer),
]
]
for fws_loss in fws_losses:
fws_loss.configure(conf)

# weakly supervised losses based on discriminative head and attention head

ws_losses = [
DiscriminativeWeaklySupervisedLoss(
title, model, att_score_layer, 0.025 * 2 / 3,
self.inactive_ratio,
(self.active_min_ratio, self.active_max_ratio), {0: [0], 1: [1]},
discr_score_layer,
w_attention_in_ordering=0.2, w_discr_in_ordering=1)
for title, att_score_layer, discr_score_layer in [
('att2', model.layer2[0].pool.attention_layer, model.layer2[0].pool.discrimination_layer),
('att3', model.layer3[0].pool.attention_layer, model.layer3[0].pool.discrimination_layer),
('att4', model.layer4[0].pool.attention_layer, model.layer4[0].pool.discrimination_layer),
('att5', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer),
]
]
for ws_loss in ws_losses:
ws_loss.configure(conf)

# concordance loss for attention

concordance_loss = PoolConcordanceLossCalculator(
'AC', model, OrderedDict([
('att2', model.layer2[0].pool.attention_layer),
('att3', model.layer3[0].pool.attention_layer),
('att4', model.layer4[0].pool.attention_layer),
('att5', model.avgpool[0].attention_layer),
]), loss_weight=1, weights=0.1 / 4, diff_thresholds=0,
labels_by_channel={0: [0], 1: [1]})
concordance_loss.configure(conf)

# concordance loss for discrimination head

concordance_loss2 = PoolConcordanceLossCalculator(
'DC', model, OrderedDict([
('att2', model.layer2[0].pool.discrimination_layer),
('att3', model.layer3[0].pool.discrimination_layer),
('att4', model.layer4[0].pool.discrimination_layer),
('att5', model.avgpool[0].discrimination_layer),
]), loss_weight=1, weights=0.05 / 4, diff_thresholds=0,
labels_by_channel={0: [0], 1: [1]})
concordance_loss2.configure(conf)


conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({
'b': BinaryEvaluator,
'l': LossEvaluator,
'f': BinForetellerEvaluator.standard_creator('foretell'),
'bf': BinFaithfulnessEvaluator.standard_creator('foretell'),
}))
conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc'

###################################
########### Foreteller ############
###################################

aux = AuxOutput(model, dict(
foretell_pool2=model.layer2[0].pool.attention_layer,
foretell_pool3=model.layer3[0].pool.attention_layer,
foretell_pool4=model.layer4[0].pool.attention_layer,
foretell_avgpool=model.avgpool[0].attention_layer,
))
aux.configure(conf)


output_modifier = OutputModifier(model,
lambda x: (x - 0.5).relu().flatten(2).sum(dim=2)[:, :2].argmax(dim=1),
'foretell_pool2',
'foretell_pool3',
'foretell_pool4',
'foretell_avgpool',
)
output_modifier.configure(conf)

return conf, model


+ 113
- 0
torchlap/experiments/entrypoint.py View File

@@ -0,0 +1,113 @@
import sys
from typing import Tuple
import argparse
import os
from abc import ABC, abstractmethod
from ..models.model import Model
from ..configs.base_config import BaseConfig, PhaseType


class BaseEntrypoint(ABC):
"""Base class for all entrypoints.
"""
description = ''

def __init__(self, phase_type: PhaseType) -> None:
super().__init__()
self.phase_type = phase_type

self.conf, self.model = self._get_conf_model()
self.parser = self._create_parser(sys.argv[0], sys.argv[1])
self.conf.update(self.parser.parse_args(sys.argv[2:]))

def _create_parser(self, command: str, entrypoint_name: str) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog=f'{os.path.basename(command)} {entrypoint_name}',
description=self.description or None
)
parser.add_argument('--device', default='cpu', type=str,
help='Analysis will run over device, use cpu, cuda:#, cuda (for all gpus)')

parser.add_argument('--samples-dir', default=None, type=str,
help='The directory/dataGroup to be evaluated. In the case of dataGroup can be train/test/val. In the case of directory must contain 0 and 1 subdirectories. Use:FilterName to do over samples containing /FilterName/ in their path')

parser.add_argument('--final-model-dir', default=None, type=str,
help='The directory to load the model from, when not given will be calculated!')

parser.add_argument('--save-dir', default=None, type=str,
help='The directory to save the model from, when not given will be calculated!')

parser.add_argument('--report-dir', default=None, type=str,
help='The dir to save reports per slice per sample in.')

parser.add_argument('--epoch', default=None, type=str,
help='The epoch to load.')

parser.add_argument('--try-name', default=None, type=str,
help='The run name specifying what run is doing')

parser.add_argument('--try-num', default=None, type=int,
help='The try number to load')

parser.add_argument('--data-separation', default=None, type=str,
help='The data_separation to be used.')

parser.add_argument('--batch-size', default=None, type=int,
help='The batch size to be used.')

parser.add_argument('--pretrained-model-file', default=None, type=str,
help='Address of .pt pretrained model')

parser.add_argument('--max-epochs', default=None, type=int,
help='The maximum epochs of training!')

parser.add_argument('--big-batch-size', default=None, type=int,
help='The big batch size (iteration per optimization)!')

parser.add_argument('--iters-per-epoch', default=None, type=int,
help='The number of big batches per epoch!')

parser.add_argument('--interpretation-method', default=None, type=str,
help='The method used for interpreting the results!')

parser.add_argument('--cut-threshold', default=None, type=float,
help='The threshold for cutting interpretations!')

parser.add_argument('--global-threshold', action='store_true',
help='Whether the given cut threshold must be applied global or to the relative values!')

parser.add_argument('--dynamic-threshold', action='store_true',
help='Whether to use dynamic threshold in interpretation!')

parser.add_argument('--class-label-for-interpretation', default=None, type=int,
help='The class label we want to explain why it has been chosen. None means the decision of the model!')

parser.add_argument('--interpret-predictions-vs-gt', default='1', type=(lambda x: x == '1'),
help='If the class_label for_interpretation is None this will be considered. If 1, interpretations would be done for the predicted label, otherwise for the ground truth.')

parser.add_argument('--mapped-labels-to-use', default=None, type=(lambda x: [int(y) for y in x.split(',')]),
help='The labels to do the analysis on them only (comma separated), default is all the labels.')

parser.add_argument('--skip-overlay', action='store_true', default=None,
help='Passing this flag prevents the interpretation phase from storing overlay images.')

parser.add_argument('--skip-raw', action='store_true', default=None,
help='Passing this flag prevents the interpretation phase from storing the raw interpretation values.')

parser.add_argument('--overlay-only', action='store_true',
help='Passing this flag makes the interpretation phase to store just overlay images '
'and not the `.npy` files.')

parser.add_argument('--save-by-file-name', action='store_true',
help='Saves sample-specific files by their file names only, not the whole path.')

parser.add_argument('--n-interpretation-samples', default=None, type=int,
help='The max number of samples to be interpreted.')
parser.add_argument('--interpretation-tag-to-evaluate', default=None, type=str,
help='The tag to be used as interpretation for evaluation phase, otherwise the first one will be used.')
return parser

@abstractmethod
def _get_conf_model(self) -> Tuple[BaseConfig, Model]:
pass

+ 0
- 0
torchlap/experiments/imagenet/__init__.py View File


+ 17
- 0
torchlap/experiments/imagenet/lap_resnet50_ft.py View File

@@ -0,0 +1,17 @@
from typing import Tuple

from torchlap.configs.imagenet_configs import ImagenetConfigs
from ...models.model import Model
from ..entrypoint import BaseEntrypoint
from ...models.imagenet.lap_resnet import ImagenetLAPResNet50


class EntryPoint(BaseEntrypoint):
def _get_conf_model(self) -> Tuple[ImagenetConfigs, Model]:
config = ImagenetConfigs('ImagenetFT', 2, 224, self.phase_type)
model = ImagenetLAPResNet50(sigmoid_scale=0.1)
config.freezing_regexes = [
r'(?!^layer4\..*$)(?!^fc\..*$)(^.*$)'
]
return config, model

+ 17
- 0
torchlap/experiments/imagenet/lap_resnet50_nft.py View File

@@ -0,0 +1,17 @@
from typing import Tuple

from torchlap.configs.imagenet_configs import ImagenetConfigs
from ...models.model import Model
from ..entrypoint import BaseEntrypoint
from ...models.imagenet.lap_resnet import ImagenetLAPResNet50


class EntryPoint(BaseEntrypoint):
def _get_conf_model(self) -> Tuple[ImagenetConfigs, Model]:
config = ImagenetConfigs('ImagenetNFT', 1, 224, self.phase_type)
model = ImagenetLAPResNet50(sigmoid_scale=0.1)
config.freezing_regexes = [
r'(?!^.*\.pool\..*$)(^.*$)'
]
return config, model

+ 0
- 0
torchlap/experiments/rsna/__init__.py View File


+ 77
- 0
torchlap/experiments/rsna/bb_lap_inception.py View File

@@ -0,0 +1,77 @@
from collections import OrderedDict
from typing import Tuple
from ...configs.rsna_configs import RSNAConfigs
from ...models.model import Model
from ..entrypoint import BaseEntrypoint
from ...models.rsna.lap_inception import RSNALAPInception
from ...modules.lap import LAP
from ...modules.adaptive_lap import AdaptiveLAP
from ...criteria.bb_supervised import BBLoss
from ...utils.aux_output import AuxOutput
from ...utils.output_modifier import OutputModifier
from ...model_evaluation.binary_evaluator import BinaryEvaluator
from ...model_evaluation.binary_fortelling import BinForetellerEvaluator
from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator
from ...model_evaluation.loss_evaluator import LossEvaluator
from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator


def lap_factory(channel):
return LAP(
channel, 2, 2, hidden_channels=[8],
sigmoid_scale=0.1, discriminative_attention=False)


def adaptive_lap_factory(channel):
return AdaptiveLAP(channel, [], 0.1, discriminative_attention=False)


class EntryPoint(BaseEntrypoint):
def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]:
conf = RSNAConfigs('RSNA_Inception_BB', 6, 256, self.phase_type)
model = RSNALAPInception(0.4, lap_factory, adaptive_lap_factory)

bb_losses = BBLoss(
'BB', model, {
'Pool2': model.maxpool2.attention_layer,
'Pool6': model.Mixed_6a.pool.attention_layer,
'Pool7': model.Mixed_7a.pool.attention_layer,
'PoolG': model.avgpool[0].attention_layer,
}
, 1,
'interpretations', 1, 0.5)
bb_losses.configure(conf)

conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({
'b': BinaryEvaluator,
'l': LossEvaluator,
'f': BinForetellerEvaluator.standard_creator('foretell'),
'bf': BinFaithfulnessEvaluator.standard_creator('foretell'),
}))
conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc'

###################################
########### Foreteller ############
###################################

aux = AuxOutput(model, dict(
foretell_pool2=model.maxpool2.attention_layer,
foretell_pool6=model.Mixed_6a.pool.attention_layer,
foretell_pool7=model.Mixed_7a.pool.attention_layer,
foretell_avgpool=model.avgpool[0].attention_layer,
))
aux.configure(conf)


output_modifier = OutputModifier(model,
lambda x: x.flatten(1).max(dim=-1).values,
'foretell_pool2',
'foretell_pool6',
'foretell_pool7',
'foretell_avgpool',
)
output_modifier.configure(conf)

return conf, model


+ 69
- 0
torchlap/experiments/rsna/bb_lap_resnet.py View File

@@ -0,0 +1,69 @@
from collections import OrderedDict
from typing import Tuple

from ...utils.output_modifier import OutputModifier
from ...utils.aux_output import AuxOutput
from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator
from ...model_evaluation.binary_evaluator import BinaryEvaluator
from ...model_evaluation.binary_fortelling import BinForetellerEvaluator
from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator
from ...model_evaluation.loss_evaluator import LossEvaluator
from ...criteria.bb_supervised import BBLoss
from ...configs.rsna_configs import RSNAConfigs
from ...models.model import Model
from ..entrypoint import BaseEntrypoint
from ...models.rsna.lap_resnet import RSNALAPResNet18, get_adaptive_pool_factory, get_pool_factory


class EntryPoint(BaseEntrypoint):
def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]:
config = RSNAConfigs('RSNA_ResNet_BB', 1, 224, self.phase_type)
model = RSNALAPResNet18(
pool_factory=get_pool_factory(discriminative_attention=False),
adaptive_factory=get_adaptive_pool_factory(discriminative_attention=False),
sigmoid_scale=0.1,
)

bb_losses = BBLoss(
'BB', model, {
'att2': model.layer2[0].pool.attention_layer,
'att3': model.layer3[0].pool.attention_layer,
'att4': model.layer4[0].pool.attention_layer,
'att5': model.avgpool[0].attention_layer,
}
, 1,
'interpretations', 1, 0.5)
bb_losses.configure(config)

config.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({
'b': BinaryEvaluator,
'l': LossEvaluator,
'f': BinForetellerEvaluator.standard_creator('foretell'),
'bf': BinFaithfulnessEvaluator.standard_creator('foretell'),
}))
config.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc'

###################################
########### Foreteller ############
###################################

aux = AuxOutput(model, dict(
foretell_pool2=model.layer2[0].pool.attention_layer,
foretell_pool3=model.layer3[0].pool.attention_layer,
foretell_pool4=model.layer4[0].pool.attention_layer,
foretell_avgpool=model.avgpool[0].attention_layer,
))
aux.configure(config)

output_modifier = OutputModifier(model,
lambda x: x.flatten(1).max(dim=-1).values,
'foretell_pool2',
'foretell_pool3',
'foretell_pool4',
'foretell_avgpool',
)
output_modifier.configure(config)


return config, model

+ 14
- 0
torchlap/experiments/rsna/org_inception.py View File

@@ -0,0 +1,14 @@
from typing import Tuple
from ...configs.rsna_configs import RSNAConfigs
from ...models.model import Model
from ..entrypoint import BaseEntrypoint
from ...models.rsna.org_inception import RSNAORGInception


class EntryPoint(BaseEntrypoint):
def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]:
config = RSNAConfigs('RSNA_Inception_Org', 2, 299, self.phase_type)
model = RSNAORGInception(0.4)
return config, model


+ 14
- 0
torchlap/experiments/rsna/org_resnet.py View File

@@ -0,0 +1,14 @@
from typing import Tuple
from ...configs.rsna_configs import RSNAConfigs
from ...models.model import Model
from ..entrypoint import BaseEntrypoint
from ...models.rsna.org_resnet import RSNAORGResNet18


class EntryPoint(BaseEntrypoint):
def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]:
config = RSNAConfigs('RSNA_ResNet_Org', 1, 224, self.phase_type)
model = RSNAORGResNet18()

return config, model

+ 117
- 0
torchlap/experiments/rsna/ws_lap_inception.py View File

@@ -0,0 +1,117 @@
from typing import Tuple
from collections import OrderedDict
from ...configs.rsna_configs import RSNAConfigs
from ...models.model import Model
from ..entrypoint import BaseEntrypoint
from ...models.rsna.lap_inception import RSNALAPInception
from ...modules.lap import LAP
from ...modules.adaptive_lap import AdaptiveLAP
from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator
from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss
from ...utils.aux_output import AuxOutput
from ...utils.output_modifier import OutputModifier
from ...model_evaluation.binary_evaluator import BinaryEvaluator
from ...model_evaluation.binary_fortelling import BinForetellerEvaluator
from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator
from ...model_evaluation.loss_evaluator import LossEvaluator
from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator


def lap_factory(channel):
return LAP(
channel, 2, 2, hidden_channels=[8],
sigmoid_scale=0.1, discriminative_attention=True)


def adaptive_lap_factory(channel):
return AdaptiveLAP(channel, sigmoid_scale=0.1, discriminative_attention=True)

class EntryPoint(BaseEntrypoint):

def __init__(self, phase_type) -> None:
self.active_min_ratio: float = 0.1
self.active_max_ratio: float = 0.5
self.inactive_ratio: float = 0.1

super().__init__(phase_type)
def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]:
conf = RSNAConfigs('RSNA_Inception_WS', 4, 256, self.phase_type)
model = RSNALAPInception(0.4, lap_factory, adaptive_lap_factory)

# weakly supervised losses based on discriminative head and attention head

ws_losses = [
DiscriminativeWeaklySupervisedLoss(
title, model, att_score_layer, 0.25,
self.inactive_ratio,
(self.active_min_ratio, self.active_max_ratio), {0: [1]},
discr_score_layer,
w_attention_in_ordering=0.2, w_discr_in_ordering=1)
for title, att_score_layer, discr_score_layer in [
('Pool2', model.maxpool2.attention_layer, model.maxpool2.discrimination_layer),
('Pool6', model.Mixed_6a.pool.attention_layer, model.Mixed_6a.pool.discrimination_layer),
('Pool7', model.Mixed_7a.pool.attention_layer, model.Mixed_7a.pool.discrimination_layer),
('PoolG', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer),
]
]
for ws_loss in ws_losses:
ws_loss.configure(conf)

# concordance loss for attention

concordance_loss = PoolConcordanceLossCalculator(
'AC', model, OrderedDict([
('att2', model.maxpool2.attention_layer),
('att6', model.Mixed_6a.pool.attention_layer),
('att7', model.Mixed_7a.pool.attention_layer),
('attG', model.avgpool[0].attention_layer),
]), loss_weight=1, weights=1 / 4, diff_thresholds=0.1,
labels_by_channel={0: [1]})
concordance_loss.configure(conf)

# concordance loss for discrimination head

concordance_loss2 = PoolConcordanceLossCalculator(
'DC', model, OrderedDict([
('D-att2', model.maxpool2.discrimination_layer),
('D-att6', model.Mixed_6a.pool.discrimination_layer),
('D-att7', model.Mixed_7a.pool.discrimination_layer),
('D-attG', model.avgpool[0].discrimination_layer),
]), loss_weight=1, weights=0.5 / 4, diff_thresholds=0,
labels_by_channel={0: [1]})
concordance_loss2.configure(conf)


conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({
'b': BinaryEvaluator,
'l': LossEvaluator,
'f': BinForetellerEvaluator.standard_creator('foretell'),
'bf': BinFaithfulnessEvaluator.standard_creator('foretell'),
}))
conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc'

###################################
########### Foreteller ############
###################################

aux = AuxOutput(model, dict(
foretell_pool2=model.maxpool2.attention_layer,
foretell_pool6=model.Mixed_6a.pool.attention_layer,
foretell_pool7=model.Mixed_7a.pool.attention_layer,
foretell_avgpool=model.avgpool[0].attention_layer,
))
aux.configure(conf)


output_modifier = OutputModifier(model,
lambda x: x.flatten(1).max(dim=-1).values,
'foretell_pool2',
'foretell_pool6',
'foretell_pool7',
'foretell_avgpool',
)
output_modifier.configure(conf)

return conf, model


+ 105
- 0
torchlap/experiments/rsna/ws_lap_resnet.py View File

@@ -0,0 +1,105 @@
from collections import OrderedDict
from typing import Tuple

from ...utils.output_modifier import OutputModifier
from ...utils.aux_output import AuxOutput
from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator
from ...model_evaluation.binary_evaluator import BinaryEvaluator
from ...model_evaluation.binary_fortelling import BinForetellerEvaluator
from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator
from ...model_evaluation.loss_evaluator import LossEvaluator
from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator
from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss
from ...configs.rsna_configs import RSNAConfigs
from ...models.model import Model
from ..entrypoint import BaseEntrypoint
from ...models.rsna.lap_resnet import RSNALAPResNet18


class EntryPoint(BaseEntrypoint):

def __init__(self, phase_type) -> None:
self.active_min_ratio: float = 0.1
self.active_max_ratio: float = 0.5
self.inactive_ratio: float = 0.1

super().__init__(phase_type)
def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]:
config = RSNAConfigs('RSNA_ResNet_WS', 1, 224, self.phase_type)
model = RSNALAPResNet18()

# weakly supervised losses based on discriminative head and attention head

ws_losses = [
DiscriminativeWeaklySupervisedLoss(
title, model, att_score_layer, 0.25,
self.inactive_ratio,
(self.active_min_ratio, self.active_max_ratio), {0: [1]},
discr_score_layer,
w_attention_in_ordering=0.2, w_discr_in_ordering=1)
for title, att_score_layer, discr_score_layer in [
('att2', model.layer2[0].pool.attention_layer, model.layer2[0].pool.discrimination_layer),
('att3', model.layer3[0].pool.attention_layer, model.layer3[0].pool.discrimination_layer),
('att4', model.layer4[0].pool.attention_layer, model.layer4[0].pool.discrimination_layer),
('att5', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer),
]
]
for ws_loss in ws_losses:
ws_loss.configure(config)

# concordance loss for attention

concordance_loss = PoolConcordanceLossCalculator(
'AC', model, OrderedDict([
('att2', model.layer2[0].pool.attention_layer),
('att3', model.layer3[0].pool.attention_layer),
('att4', model.layer4[0].pool.attention_layer),
('att5', model.avgpool[0].attention_layer),
]), loss_weight=1, weights=1 / 4, diff_thresholds=0.1,
labels_by_channel={0: [1]})
concordance_loss.configure(config)

# concordance loss for discrimination head

concordance_loss2 = PoolConcordanceLossCalculator(
'DC', model, OrderedDict([
('att2', model.layer2[0].pool.discrimination_layer),
('att3', model.layer3[0].pool.discrimination_layer),
('att4', model.layer4[0].pool.discrimination_layer),
('att5', model.avgpool[0].discrimination_layer),
]), loss_weight=1, weights=0.5 / 4, diff_thresholds=0,
labels_by_channel={0: [1]})
concordance_loss2.configure(config)


config.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({
'b': BinaryEvaluator,
'l': LossEvaluator,
'f': BinForetellerEvaluator.standard_creator('foretell'),
'bf': BinFaithfulnessEvaluator.standard_creator('foretell'),
}))
config.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc'

###################################
########### Foreteller ############
###################################

aux = AuxOutput(model, dict(
foretell_pool2=model.layer2[0].pool.attention_layer,
foretell_pool3=model.layer3[0].pool.attention_layer,
foretell_pool4=model.layer4[0].pool.attention_layer,
foretell_avgpool=model.avgpool[0].attention_layer,
))
aux.configure(config)

output_modifier = OutputModifier(model,
lambda x: x.flatten(1).max(dim=-1).values,
'foretell_pool2',
'foretell_pool3',
'foretell_pool4',
'foretell_avgpool',
)
output_modifier.configure(config)

return config, model

+ 7
- 0
torchlap/interpreting/__init__.py View File

@@ -0,0 +1,7 @@
"""
Interpretation Modules
"""
from .interpretable import InterpretableModel, CamInterpretableModel
from .interpreter import Interpreter
from .interpretable_wrapper import InterpretableWrapper
from .interpreter_maker import create_interpreter, InterpretationType

+ 148
- 0
torchlap/interpreting/attention_interpreter.py View File

@@ -0,0 +1,148 @@
from abc import ABC, abstractmethod
from collections import OrderedDict as OrdDict
from typing import Dict, List, Tuple, Union

import torch
from torch import nn
import torch.nn.functional as F
from .interpreter import Interpreter
from .interpretable import InterpretableModel
from ..modules import LAP
from ..utils.hooker import Hooker, Hook


AttentionDict = Dict[torch.Size, torch.Tensor]


class AttentionInterpretableModel(InterpretableModel, ABC):

@property
@abstractmethod
def attention_layers(self) -> Dict[str, List[LAP]]:
"""
List of attention groups
"""


class AttentionInterpreter(Interpreter):

def __init__(self, model: AttentionInterpretableModel):
assert isinstance(model, AttentionInterpretableModel),\
f"Expected AttentionInterpretableModel, got {type(model)}"
super().__init__(model)
layer_hooks = sum([
[
(layer.attention_layer, self._generate_forward_hook(group_name, layer))
for layer in group_layers
] for group_name, group_layers in model.attention_layers.items()
], [])
self._hooker = Hooker(*layer_hooks)
self._attention_groups: Dict[str, AttentionDict] = {
group_name: {} for group_name in model.attention_layers}
self._impact = 1.
self._impact_decay = .8
self._input_size: torch.Size = None
self._device = None

def _generate_forward_hook(self, group_name: str, layer: LAP) -> Hook:
def forward_hook(_: nn.Module, __: Tuple[torch.Tensor, ...], out: torch.Tensor) -> None:
processed = out.clone()
shape = processed.shape[2:]
if shape not in self._attention_groups[group_name]:
self._attention_groups[group_name][shape] = torch.zeros_like(
processed)
self._attention_groups[group_name][shape] += processed

return forward_hook

def _reset(self) -> None:
self._attention_groups = {
group_name: {} for group_name in self._model.attention_layers}
self._impact = 1.

def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor:
""" Process current attention level.

Args:
result (torch.Tensor): Attention map till the previous attention level. (B, 1, H1, W1)
attention (torch.Tensor): Current level's attention map. (B, 1, H2, W2)

Returns:
torch.Tensor: New attention map. (B, 1, H2, W2)
"""
self._impact *= self._impact_decay

# smallest attention layer
if result is None:
return attention

# attention is larger than result
scale = torch.ceil(
torch.tensor(attention.shape[2:]) / torch.tensor(result.shape[2:])
).int().tolist()

# find maximum of attention in each kernel
max_map = F.max_pool2d(attention, kernel_size=scale, stride=scale)
max_map = F.interpolate(max_map, attention.shape[2:], mode='nearest')
is_any_active = max_map > 0

# interpolate result to attention size
result = F.interpolate(result, attention.shape[2:], mode='nearest')

# # maximum between result and attention
# impacted = torch.max(result, max_map * self._impact)
impacted = torch.max(result, attention * self._impact)
# when attention is zero, we use zero
impacted[attention == 0] = 0
# impacted = torch.where(attention > 0, impacted, 0)
# when result is non-zero, we will use impacted
impacted[result == 0] = 0
# impacted = torch.where(result > 0, impacted, 0)
# if max_map is zero, we use result, else we use impacted
result[is_any_active] = impacted[is_any_active]
# result = torch.where(is_any_active, impacted, result)

# mask = (max_map > 0) & (result > 0)
# result[mask] = impacted[mask]
# result[mask & (attention == 0)] = 0
return result

def _process_attentions(self) -> Dict[str, torch.Tensor]:
sorted_attentions = {
group_name: OrdDict(sorted(group_attention.items(),
key=lambda pair: pair[0]))
for group_name, group_attention in self._attention_groups.items()
}

results: Dict[str, torch.Tensor] = {group_name: None for group_name in sorted_attentions}
for group_name, group_attention in sorted_attentions.items():
self._impact = 1.0
# from smallest to largest
for attention in group_attention.values():
attention = (attention - 0.5).relu()
results[group_name] = self._process_attention(results[group_name], attention)

interpretations = {}
for group_name, group_result in results.items():
group_result = F.interpolate(group_result, self._input_size, mode='bilinear', align_corners=True)
interpretations[group_name] = group_result
'''
if group_result.shape[1] == 1:
interpretations[group_name] = group_result
else:
interpretations.update({
f"{group_name}_{i}": group_result[:, i:i+1]
for i in range(group_result.shape[1])
})
'''
return interpretations

def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
self._reset()
self._batch_size = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[0]
self._input_size = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[2:]
self._device = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].device
with self._hooker:
self._model(**inputs)

return self._process_attentions()

+ 48
- 0
torchlap/interpreting/attention_interpreter_smooth_integrator.py View File

@@ -0,0 +1,48 @@
from typing import Dict
from collections import OrderedDict as OrdDict

import torch
import torch.nn.functional as F

from .attention_interpreter import AttentionInterpreter


AttentionDict = Dict[torch.Size, torch.Tensor]


class AttentionInterpreterSmoothIntegrator(AttentionInterpreter):

def _process_attentions(self) -> Dict[str, torch.Tensor]:
sorted_attentions = {
group_name: OrdDict(sorted(group_attention.items(),
key=lambda pair: pair[0]))
for group_name, group_attention in self._attention_groups.items()
}

results: Dict[str, torch.Tensor] = {group_name: None for group_name in sorted_attentions}
for group_name, group_attention in sorted_attentions.items():
self._impact = 1.0
# from smallest to largest
sum_ = 0
for attention in group_attention.values():
if torch.is_tensor(sum_):
sum_ = F.interpolate(sum_, size=attention.shape[-2:])
sum_ += attention
attention = (attention - 0.5).relu()
results[group_name] = self._process_attention(results[group_name], attention)
results[group_name] = torch.where(results[group_name] > 0, sum_, sum_ / (2 * len(group_attention)))

interpretations = {}
for group_name, group_result in results.items():
group_result = F.interpolate(group_result, self._input_size, mode='bilinear', align_corners=True)
interpretations[group_name] = group_result
'''
if group_result.shape[1] == 1:
interpretations[group_name] = group_result
else:
interpretations.update({
f"{group_name}_{i}": group_result[:, i:i+1]
for i in range(group_result.shape[1])
})
'''
return interpretations

+ 31
- 0
torchlap/interpreting/attention_sum_interpreter.py View File

@@ -0,0 +1,31 @@
import torch
import torch.nn.functional as F
from .attention_interpreter import AttentionInterpreter, AttentionInterpretableModel

class AttentionSumInterpreter(AttentionInterpreter):

def __init__(self, model: AttentionInterpretableModel):
assert isinstance(model, AttentionInterpretableModel),\
f"Expected AttentionInterpretableModel, got {type(model)}"
super().__init__(model)

def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor:
""" Process current attention level.

Args:
result (torch.Tensor): Attention map till the previous attention level. (B, 1, H1, W1)
attention (torch.Tensor): Current level's attention map. (B, 1, H2, W2)

Returns:
torch.Tensor: New attention map. (B, 1, H2, W2)
"""
self._impact *= self._impact_decay

# smallest attention layer
if result is None:
return attention

# attention is larger than result

result = F.interpolate(result, attention.shape[2:], mode='bilinear', align_corners=True)
return result + attention

+ 166
- 0
torchlap/interpreting/binary_interpretation_evaluator_2d.py View File

@@ -0,0 +1,166 @@
from typing import List, TYPE_CHECKING, Tuple

import numpy as np
from skimage import measure

from .utils import process_interpretations
if TYPE_CHECKING:
from ..configs.base_config import BaseConfig
from . import Interpreter


class BinaryInterpretationEvaluator2D:

def __init__(self, n_samples: int, config: 'BaseConfig'):

self._config = config

self._n_samples = n_samples
self._min_intersection_threshold = config.acceptable_min_intersection_threshold

self._normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
self._normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
self._tk_normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
self._tk_normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)

def reset(self):

self._normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
self._normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
self._tk_normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
self._tk_normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
def update_summaries(
self,
m_interpretations: np.ndarray, ground_truth_interpretations: np.ndarray,
net_preds: np.ndarray, ground_truth_labels: np.ndarray,
batch_inds: np.ndarray, interpreter: 'Interpreter'
) -> None:

assert len(ground_truth_interpretations.shape) == 3, f'GT interpretations must have a shape of BxWxH but it is {ground_truth_interpretations.shape}'
assert len(m_interpretations.shape) == 3, f'Model interpretations must have a shape of BxWxH but it is {m_interpretations.shape}'

# skipping samples without interpretations!
has_interpretation_mask = np.logical_not(np.any(np.isnan(ground_truth_interpretations), axis=(1, 2)))
m_interpretations = m_interpretations[has_interpretation_mask]
ground_truths = ground_truth_interpretations[has_interpretation_mask]
net_preds[has_interpretation_mask]

# finding class labels

if len(net_preds.shape) == 1:
net_preds = (net_preds >= 0.5).astype(int)
elif net_preds.shape[1] == 1:
net_preds = (net_preds[:, 0] >= 0.5).astype(int)
else:
net_preds = net_preds.argmax(axis=-1)

ground_truth_labels = ground_truth_labels[has_interpretation_mask]
batch_inds = batch_inds[has_interpretation_mask]

# Checking shapes
if net_preds.shape == ground_truth_labels.shape:
net_preds = np.round(net_preds, 0).astype(int)
else:
net_preds = np.argmax(net_preds, axis=1)

# calculating soundness
soundnesses = (net_preds == ground_truth_labels).astype(int)

c_interpretations = np.clip(m_interpretations, 0, np.amax(m_interpretations))

b_interpretations = np.stack(tuple([
process_interpretations(m_interpretations[ind][None, ...], self._config, interpreter)
for ind in range(len(m_interpretations))
]), axis=0)[:, 0, ...]
b_interpretations = (b_interpretations > 0).astype(bool)
ground_truths = (ground_truths >= 0.5).astype(bool) #making sure values are 0 and 1 even if resize has been applied

assert ground_truths.shape[-2:] == b_interpretations.shape[-2:], f'Ground truth and model interpretations must have the same shape, found {ground_truths.shape[-2:]} and {b_interpretations.shape[-2:]}'

norm_factor = 1.0 * b_interpretations.shape[1] * b_interpretations.shape[2]

np.add.at(self._normed_intersection_per_soundness, soundnesses,
np.sum(b_interpretations & ground_truths, axis=(1, 2)) * 1.0 / norm_factor)
np.add.at(self._normed_union_per_soundness, soundnesses,
np.sum(b_interpretations | ground_truths, axis=(1, 2)) * 1.0 / norm_factor)
for i in range(len(b_interpretations)):

has_nonzero_captured_bbs = False
has_nonzero_captured_bbs_by_topk = False
s = soundnesses[i]
org_labels = measure.label(ground_truths[i, :, :])
check_labels = measure.label(b_interpretations[i, :, :])

# keeping topK interpretations with k = n_GT! = finding a threshold by quantile! calculating quantile by GT
n_on_gt = np.sum(ground_truths[i])
q = (1 + n_on_gt) * 1.0 / (ground_truths.shape[-1] * ground_truths.shape[-2])
# 1 is added because we have > in thresholding not >=
if q < 1:
tints = c_interpretations[i]
th = max(0, np.quantile(tints.reshape(-1), 1 - q))
tints = (tints > th)
else:
tints = (c_interpretations[i] > 0)

# TOPK METRICS

tk_intersection = np.sum(tints & ground_truths[i])
tk_union = np.sum(tints | ground_truths[i])

self._tk_normed_intersection_per_soundness[s] += tk_intersection * 1.0 / norm_factor
self._tk_normed_union_per_soundness[s] += tk_union * 1.0 / norm_factor

@staticmethod
def get_titles_of_evaluation_metrics() -> List[str]:
return ['S-IOU', 'S-TK-IOU',
'M-IOU', 'M-TK-IOU',
'A-IOU', 'A-TK-IOU']

@staticmethod
def _get_eval_metrics(normed_intersection, normed_union, title,
tk_normed_intersection, tk_normed_union) -> \
Tuple[str, str, str, str, str]:
iou = (1e-6 + normed_intersection) / (1e-6 + normed_union)
tk_iou = (1e-6 + tk_normed_intersection) / (1e-6 + tk_normed_union)

return '%.4f' % (iou * 100,), '%.4f' % (tk_iou * 100,)

def get_values_of_evaluation_metrics(self) -> List[str]:
return \
list(self._get_eval_metrics(
self._normed_intersection_per_soundness[1],
self._normed_union_per_soundness[1],
'Sounds',
self._tk_normed_intersection_per_soundness[1],
self._tk_normed_union_per_soundness[1],
)) + \
list(self._get_eval_metrics(
self._normed_intersection_per_soundness[0],
self._normed_union_per_soundness[0],
'Mistakes',
self._tk_normed_intersection_per_soundness[0],
self._tk_normed_union_per_soundness[0],
)) + \
list(self._get_eval_metrics(
sum(self._normed_intersection_per_soundness),
sum(self._normed_union_per_soundness),
'All',
sum(self._tk_normed_intersection_per_soundness),
sum(self._tk_normed_union_per_soundness),
))

def print_summaries(self):
titles = self.get_titles_of_evaluation_metrics()
vals = self.get_values_of_evaluation_metrics()

nc = 2
for r in range(3):
print(', '.join(['%s: %s' % (titles[nc * r + i], vals[nc * r + i]) for i in range(nc)]))

+ 43
- 0
torchlap/interpreting/dataflow.py View File

@@ -0,0 +1,43 @@
from typing import Dict, Union, Tuple

import torch

from . import Interpreter
from ..data.dataflow import DataFlow
from ..models.model import ModelIO
from ..data.data_loader import DataLoader


class InterpretationDataFlow(DataFlow[Tuple[ModelIO, Dict[str, torch.Tensor]]]):

def __init__(self, interpreter: Interpreter,
dataloader: DataLoader,
phase: str,
device: torch.device,
dtype: torch.dtype,
print_debug_info: bool,
label_for_interpreter: Union[int, None],
give_gt_as_label: bool):
super().__init__(interpreter._model, dataloader, phase, device,
dtype, print_debug_info=print_debug_info)
self._interpreter = interpreter
self._label_for_interpreter = label_for_interpreter
self._give_gt_as_label = give_gt_as_label
self._sample_labels = dataloader.get_samples_labels()
@staticmethod
def _detach_output(output: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return {
name: o.detach() for name, o in output.items()
}

def _final_stage(self, model_input: ModelIO) -> Tuple[ModelIO, Dict[str, torch.Tensor]]:
if self._give_gt_as_label:
return model_input, self._detach_output(self._interpreter.interpret(
torch.Tensor(self._sample_labels[
self._dataloader.get_current_batch_sample_indices()]),
**model_input))

return model_input, self._detach_output(self._interpreter.interpret(
self._label_for_interpreter, **model_input
))

+ 52
- 0
torchlap/interpreting/deep_lift.py View File

@@ -0,0 +1,52 @@
"""
Guided Grad Cam Interpretation
"""
from typing import Dict, Union

import torch
import torch.nn.functional as F
from captum import attr

from . import Interpreter, InterpretableModel, InterpretableWrapper


class DeepLift(Interpreter): # pylint: disable=too-few-public-methods
"""
Produces class activation map
"""

def __init__(self, model: InterpretableModel):
super().__init__(model)
self._model_wrapper = InterpretableWrapper(model)
self._interpreter = attr.DeepLift(self._model_wrapper)
self.__B = None

def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
'''
interprets given input and target class
'''
B = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[0]
if self.__B is None:
self.__B = B
if B < self.__B:
padding = self.__B - B
inputs = {
k: F.pad(v, (*([0] * (v.ndim * 2 - 1)), padding))
for k, v in inputs.items()
if torch.is_tensor(v) and v.shape[0] == B
}
if torch.is_tensor(labels) and labels.shape[0] == B:
labels = F.pad(labels, (0, padding))

labels = self._get_target(labels, **inputs)
separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs)
result = self._interpreter.attribute(
separated_inputs.inputs,
target=labels,
additional_forward_args=separated_inputs.additional_inputs)[0]
result = result[:B]

return {
'default': result
} # TODO: must return and support tuple

+ 38
- 0
torchlap/interpreting/gradcam.py View File

@@ -0,0 +1,38 @@
from typing import Dict, Union

import torch
import torch.nn.functional as F
from captum import attr

from . import Interpreter, CamInterpretableModel, InterpretableWrapper


class GradCam(Interpreter):

def __init__(self, model: CamInterpretableModel):
super().__init__(model)
self.__model_wrapper = InterpretableWrapper(model)
self.__interpreters = [
attr.LayerGradCam(self.__model_wrapper, conv)
for conv in model.target_conv_layers
]

def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
inp_shape = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[2:]
labels = self._get_target(labels, **inputs)
separated_inputs = self.__model_wrapper.convert_inputs_to_args(**inputs)
gradcams = [interpreter.attribute(
separated_inputs.inputs,
target=labels,
additional_forward_args=separated_inputs.additional_inputs)
for interpreter in self.__interpreters]
gradcams = [
cam if torch.is_tensor(cam) else cam[0]
for cam in gradcams
]
gradcams = torch.stack(gradcams).sum(dim=0)
gradcams = F.interpolate(gradcams, size=inp_shape, mode='bilinear', align_corners=True)

return {
'default': gradcams,
} # TODO: must return and support tuple

+ 35
- 0
torchlap/interpreting/guided_backprop.py View File

@@ -0,0 +1,35 @@
"""
Guided Backprop Interpretation
"""
from typing import Dict, Union

import torch
from captum import attr

from . import Interpreter, InterpretableModel, InterpretableWrapper


class GuidedBackprop(Interpreter): # pylint: disable=too-few-public-methods
"""
Produces class activation map
"""

def __init__(self, model: InterpretableModel):
super().__init__(model)
self._model_wrapper = InterpretableWrapper(model)
self._interpreter = attr.GuidedBackprop(self._model_wrapper)


def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
'''
interprets given input and target class
'''

labels = self._get_target(labels, **inputs)
separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs)
return {
'default': self._interpreter.attribute(
separated_inputs.inputs,
target=labels,
additional_forward_args=separated_inputs.additional_inputs)[0]
} # TODO: must return and support tuple

+ 45
- 0
torchlap/interpreting/guided_gradcam.py View File

@@ -0,0 +1,45 @@
"""
Guided Grad Cam Interpretation
"""
from typing import Dict, Union

import torch
from captum import attr

from . import Interpreter, CamInterpretableModel, InterpretableWrapper


class GuidedGradCam(Interpreter):
""" Produces class activation map """

def __init__(self, model: CamInterpretableModel):
super().__init__(model)
self._model_wrapper = InterpretableWrapper(model)
self._interpreters = [
attr.GuidedGradCam(self._model_wrapper, conv)
for conv in model.target_conv_layers
]

def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
""" Interprets given input and target class

Args:
y (Union[int, torch.Tensor]): target class
**inputs (torch.Tensor): model inputs

Returns:
Dict[str, torch.Tensor]: Interpretation results
"""

labels = self._get_target(labels, **inputs)
separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs)
gradcams = [interpreter.attribute(
separated_inputs.inputs,
target=labels,
additional_forward_args=separated_inputs.additional_inputs)[0]
for interpreter in self._interpreters]
gradcams = torch.stack(gradcams).sum(dim=0)

return {
'default': gradcams,
} # TODO: must return and support tuple

+ 55
- 0
torchlap/interpreting/imagenet_attention_interpreter.py View File

@@ -0,0 +1,55 @@
from typing import Callable, Type

import torch

from ..models.model import ModelIO
from ..utils.hooker import Hook, Hooker
from ..data.data_loader import DataLoader
from ..data.dataloader_context import DataloaderContext
from .interpreter import Interpreter
from .attention_interpreter import AttentionInterpretableModel, AttentionInterpreter

class ImagenetPredictionInterpreter(Interpreter):

def __init__(self, model: AttentionInterpretableModel, k: int, base_interpreter: Type[AttentionInterpreter]):
super().__init__(model)
self.__base_interpreter = base_interpreter(model)
self.__k = k
self.__topk_classes: torch.Tensor = None
self.__base_interpreter._hooker = Hooker(
(model, self._generate_prediction_hook()),
*[(attention_layer[-2], hook)
for attention_layer, hook in self.__base_interpreter._hooker._layer_hook_pairs])

@staticmethod
def standard_factory(k: int, base_interpreter: Type[AttentionInterpreter] = AttentionInterpreter) -> Callable[[AttentionInterpretableModel], 'ImagenetPredictionInterpreter']:
return lambda model: ImagenetPredictionInterpreter(model, k, base_interpreter)

def _generate_prediction_hook(self) -> Hook:
def hook(_, __, output: ModelIO):
topk = output['categorical_probability'].detach()\
.topk(self.__k, dim=-1)

dataloader: DataLoader = DataloaderContext.instance.dataloader
sample_names = dataloader.get_current_batch_samples_names()

for name, top_class, top_prob in zip(sample_names, topk.indices, topk.values):
print(f'Top classes of '
f'{name}: '
f'{top_class.detach().cpu().numpy().tolist()} - '
f'{top_prob.cpu().numpy().tolist()}', flush=True)

self.__topk_classes = topk\
.indices\
.flatten()\
.cpu()
return hook

def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor:
batch_indices = torch.arange(attention.shape[0]).repeat_interleave(self.__k)
attention = attention[batch_indices, self.__topk_classes]
attention = attention.reshape(-1, self.__k, *attention.shape[-2:])
return self.__base_interpreter._process_attention(result, attention)

def interpret(self, labels, **inputs):
return self.__base_interpreter.interpret(labels, **inputs)

+ 53
- 0
torchlap/interpreting/interpretable.py View File

@@ -0,0 +1,53 @@
"""
An Interpretable Pytorch Model
"""
from abc import abstractmethod, ABC
from typing import Iterable, List

import torch
from torch import nn

from ..models.model import Model


class InterpretableModel(Model, ABC):
"""
An Interpretable Pytorch Model
"""

@property
@abstractmethod
def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]:
"""
Returns:
Input module for interpretation
"""

@abstractmethod
def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor:
"""
A method to get probabilities assigned to all the classes in the model's forward,
with shape (B, C), in which B is the batch size and C is number of classes

Args:
*inputs: Inputs to the model
**kwargs: Additional arguments

Returns:
Tensor of categorical probabilities
"""


class CamInterpretableModel(InterpretableModel, ABC):
"""
A model interpretable by gradcam
"""

@property
@abstractmethod
def target_conv_layers(self) -> List[nn.Module]:
"""
Returns:
The convolutional layers to be interpreted. The result of
the interpretation will be sum of the grad-cam of these layers
"""

+ 69
- 0
torchlap/interpreting/interpretable_wrapper.py View File

@@ -0,0 +1,69 @@
"""
An adapter wrapper to adapt our models to Captum interpreters
"""
import inspect
from typing import Iterable, Dict, Tuple
from dataclasses import dataclass
from itertools import chain

import torch
from torch import nn

from . import InterpretableModel


@dataclass
class InterpreterArgs:
inputs: Tuple[torch.Tensor, ...]
additional_inputs: Tuple[torch.Tensor, ...]

class InterpretableWrapper(nn.Module):
"""
An adapter wrapper to adapt our models to Captum interpreters
"""

def __init__(self, model: InterpretableModel):
super().__init__()

self._model = model

@property
def _additional_names(self) -> Iterable[str]:
to_be_interpreted_names = self._model.ordered_placeholder_names_to_be_interpreted

signature = inspect.signature(self._model.forward)

return [name for name in signature.parameters
if name not in to_be_interpreted_names
and signature.parameters[name].default is not None]

def convert_inputs_to_kwargs(self, *args: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Converts an ordered *args to **kwargs
"""
to_be_interpreted_names = self._model.ordered_placeholder_names_to_be_interpreted
additional_names = self._additional_names

inputs = {}
for i, name in enumerate(chain(to_be_interpreted_names, additional_names)):
inputs[name] = args[i]

return inputs

def convert_inputs_to_args(self, **kwargs: torch.Tensor) -> InterpreterArgs:
"""
Converts a **kwargs to ordered *args
"""
return InterpreterArgs(
tuple(kwargs[name] for name in self._model.ordered_placeholder_names_to_be_interpreted
if name in kwargs),
tuple(kwargs[name] for name in self._additional_names
if name in kwargs)
)

def forward(self, *args: torch.Tensor) -> torch.Tensor:
"""
Forwards the model
"""
inputs = self.convert_inputs_to_kwargs(*args)
return self._model.get_categorical_probabilities(**inputs)

+ 42
- 0
torchlap/interpreting/interpretation_dataflow.py View File

@@ -0,0 +1,42 @@
from typing import Dict, Union, Tuple

import torch

from .interpreter import Interpreter
from ..data.dataflow import DataFlow
from ..data.data_loader import DataLoader


class InterpretationDataFlow(DataFlow[Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]]):

def __init__(self, interpreter: Interpreter,
dataloader: DataLoader,
device: torch.device,
print_debug_info: bool,
label_for_interpreter: Union[int, None],
give_gt_as_label: bool):
super().__init__(interpreter._model, dataloader, device, print_debug_info=print_debug_info)
self._interpreter = interpreter
self._label_for_interpreter = label_for_interpreter
self._give_gt_as_label = give_gt_as_label
self._sample_labels = dataloader.get_samples_labels()
@staticmethod
def _detach_output(output: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
if output is None:
return None

return {
name: o.detach() for name, o in output.items()
}

def _final_stage(self, model_input: Dict[str, torch.Tensor]) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
if self._give_gt_as_label:
return model_input, self._detach_output(self._interpreter.interpret(
torch.Tensor(self._sample_labels[
self._dataloader.get_current_batch_sample_indices()]),
**model_input))

return model_input, self._detach_output(self._interpreter.interpret(
self._label_for_interpreter, **model_input
))

+ 123
- 0
torchlap/interpreting/interpretation_evaluation_runner.py View File

@@ -0,0 +1,123 @@
from typing import TYPE_CHECKING
import traceback
from time import time

from ..data.data_loader import DataLoader
from .interpretable import InterpretableModel
from .interpreter_maker import create_interpreter
from .interpreter import Interpreter
from .interpretation_dataflow import InterpretationDataFlow
from .binary_interpretation_evaluator_2d import BinaryInterpretationEvaluator2D
if TYPE_CHECKING:
from ..configs.base_config import BaseConfig


class InterpretingEvalRunner:

def __init__(self, conf: 'BaseConfig', model: InterpretableModel):
self.conf = conf
self.model = model

def evaluate(self):

interpreter = create_interpreter(
self.conf.interpretation_method,
self.model)

for test_group_info in self.conf.samples_dir.split(','):

try:

print('>> Evaluating interpretations for %s' % test_group_info, flush=True)

t1 = time()

test_data_loader = \
DataLoader(self.conf, test_group_info, 'test')

evaluator = BinaryInterpretationEvaluator2D(
test_data_loader.get_number_of_samples(),
self.conf
)

self._eval_interpretations(test_data_loader, interpreter, evaluator)

evaluator.print_summaries()

print('Evaluating Interpretations was done in %.2f secs.' % (time() - t1,), flush=True)

except Exception as e:
print('Problem in %s' % test_group_info, flush=True)
track = traceback.format_exc()
print(track, flush=True)

def _eval_interpretations(self,
data_loader: DataLoader, interpreter: Interpreter,
evaluator: BinaryInterpretationEvaluator2D) -> None:
""" CAUTION: Resets the data_loader,
Iterates over the samples (as much as and how data_loader specifies),
finds the interpretations based on the specified interpretation method
and saves the results in the received save_dir!"""

if not isinstance(self.model, InterpretableModel):
raise Exception('Model has not implemented the requirements of the InterpretableModel')

# Running the model in evaluation mode
self.model.eval()

# initiating variables for running evaluations
evaluator.reset()

label_for_interpreter = self.conf.class_label_for_interpretation
give_gt_as_label = label_for_interpreter is None and not self.conf.interpret_predictions_vs_gt

dataflow = InterpretationDataFlow(interpreter,
data_loader,
self.conf.device,
False,
label_for_interpreter,
give_gt_as_label)

n_interpreted_samples = 0
max_n_samples = self.conf.n_interpretation_samples

with dataflow:
for model_input, interpretation_output in dataflow.iterate():

n_batch = model_input[
self.model.ordered_placeholder_names_to_be_interpreted[0]].shape[0]
n_samples_to_save = n_batch
if max_n_samples is not None:
n_samples_to_save = min(
max_n_samples - n_interpreted_samples,
n_batch)

model_outputs = self.model(**model_input)
model_preds = model_outputs[
self.conf.prediction_key_in_model_output_dict].detach().cpu().numpy()

if self.conf.interpretation_tag_to_evaluate:
interpretations = interpretation_output[
self.conf.interpretation_tag_to_evaluate
].detach()
else:
interpretations = list(interpretation_output.values())[0].detach()

interpretations = interpretations.detach().cpu().numpy()

evaluator.update_summaries(
interpretations[:n_samples_to_save, 0],
data_loader.get_current_batch_samples_interpretations()[:n_samples_to_save, 0].cpu().numpy(),
model_preds[:n_samples_to_save],
data_loader.get_current_batch_samples_labels()[:n_samples_to_save],
data_loader.get_current_batch_sample_indices()[:n_samples_to_save],
interpreter
)

del interpretation_output
del model_outputs

# if the number of interpreted samples has reached the limit, break
n_interpreted_samples += n_batch
if max_n_samples is not None and n_interpreted_samples >= max_n_samples:
break

+ 67
- 0
torchlap/interpreting/interpreter.py View File

@@ -0,0 +1,67 @@
"""
An abstract class to apply different interpretation algorithms on torch models.
"""

from abc import abstractmethod, ABC
from typing import Dict, List, Union
import torch
from torch import nn
import numpy as np
from . import InterpretableModel


class Interpreter(ABC): # pylint: disable=too-few-public-methods
"""
An abstract class to apply different interpretation algorithms on torch models.
"""

def __init__(self, model: InterpretableModel):
self._model = model

def _get_all_leaf_children(self, module: nn.Module = None) -> List[nn.Module]:
"""
Gets all leaf modules recursively
"""
if module is None:
module = self._model

children: List[nn.Module] = []

for child in module.children():
if len(list(child.children())) == 0:
children.append(child)
else:
children += self._get_all_leaf_children(child)

return children

@staticmethod
def _get_one_hot_output(output: torch.Tensor,
target: Union[int, torch.Tensor, np.ndarray] = None) -> torch.Tensor:
batch_size = output.shape[0]

if target is None:
target = output.argmax(dim=1)

one_hot_output = torch.zeros_like(output)
one_hot_output[torch.arange(batch_size), target] = 1
return one_hot_output

def _get_target(self, target: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Union[int, torch.Tensor]:
if target is not None:
return target

return self._model.get_categorical_probabilities(**inputs).argmax(dim=1)

@abstractmethod
def interpret(self, labels: Union[int, torch.Tensor, np.ndarray], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
'''
An abstract method to interpret given input and target class
Returns a dictionary of interpretation results.
'''

def dynamic_threshold(self, x: np.ndarray) -> np.ndarray:
"""
A function to dynamically threshold one sample.
"""
return (x - (x.mean() + x.std())).clip(min=0)

+ 63
- 0
torchlap/interpreting/interpreter_maker.py View File

@@ -0,0 +1,63 @@
"""
Creates an interpreter object and returns it
"""
from typing import Callable
from . import InterpretableModel, Interpreter


class InterpretationType:
GuidedBackprop = 'GuidedBackprop'
GuidedGradCam = 'GuidedGradCam'
GradCam = 'GradCam'
RelCam = 'RelCam'
DeepLift = 'DeepLift'
Attention = 'Attention'
AttentionSmooth = 'AttentionSmooth'
AttentionSum = 'AttentionSum'
ImagenetAttention = 'ImagenetAttention'
GT = 'GT'


InterpreterMaker = Callable[[InterpretableModel], Interpreter]


def create_interpreter(interpretation_method: str, model: InterpretableModel) -> Interpreter:
"""
Creates an interpreter object and returns it
:param interpretation_method: The method name
:param model: The model to be interpreted
:return: an interpreter object that can run the model and interpret its output
"""

if interpretation_method == InterpretationType.GuidedBackprop:
from .guided_backprop import GuidedBackprop as interpreter_maker

elif interpretation_method == InterpretationType.GuidedGradCam:
from .guided_gradcam import GuidedGradCam as interpreter_maker

elif interpretation_method == InterpretationType.GradCam:
from .gradcam import GradCam as interpreter_maker

elif interpretation_method == InterpretationType.RelCam:
from .relcam import RelCamInterpreter as interpreter_maker

elif interpretation_method == InterpretationType.DeepLift:
from .deep_lift import DeepLift as interpreter_maker

elif interpretation_method == InterpretationType.Attention:
from .attention_interpreter import AttentionInterpreter as interpreter_maker

elif interpretation_method == InterpretationType.AttentionSmooth:
from .attention_interpreter_smooth_integrator import AttentionInterpreterSmoothIntegrator as interpreter_maker

elif interpretation_method == InterpretationType.AttentionSum:
from .attention_sum_interpreter import AttentionSumInterpreter as interpreter_maker

elif interpretation_method == InterpretationType.ImagenetAttention:
from .imagenet_attention_interpreter import ImagenetPredictionInterpreter as interpreter_maker

else:
raise Exception('Unknown interpretation method ', interpretation_method)

return interpreter_maker(model)

+ 162
- 0
torchlap/interpreting/interpreting_runner.py View File

@@ -0,0 +1,162 @@
from typing import Dict, TYPE_CHECKING
from os import makedirs, path
import traceback
from time import time
import warnings
import imageio

import torch
import numpy as np

from ..data.data_loader import DataLoader
from .interpretable import InterpretableModel
from .interpreter_maker import create_interpreter
from .interpreter import Interpreter
from .interpretation_dataflow import InterpretationDataFlow
from .utils import overlay_interpretation
if TYPE_CHECKING:
from ..configs.base_config import BaseConfig


class InterpretingRunner:

def __init__(self, conf: 'BaseConfig', model: InterpretableModel):
self.conf = conf
self.model = model
def interpret(self):
interpreter = create_interpreter(
self.conf.interpretation_method,
self.model)

for test_group_info in self.conf.samples_dir.split(','):
try:

print('>> Finding interpretations for %s' % test_group_info, flush=True)

t1 = time()

labels_to_use = self.conf.mapped_labels_to_use
if labels_to_use is None:
labels_to_use = 'All'
else:
labels_to_use = ','.join([str(x) for x in self.conf.mapped_labels_to_use])

report_dir = self.conf.get_sample_group_specific_report_dir(
test_group_info,
extra_subdir=f'{self.conf.interpretation_method}-C,{labels_to_use}-cut,{self.conf.cut_threshold}-glob,{self.conf.global_threshold}')

makedirs(report_dir, exist_ok=True)
# writing the whole config
f = open(report_dir + '/conf_info.txt', 'w')
f.write(str(self.conf) + '\n')
f.close()

test_data_loader = \
DataLoader(self.conf, test_group_info, 'test')
self._interpret(report_dir, test_data_loader, interpreter)

print('Interpretations were saved in %s' % report_dir)
print('Interpreting was done in %.2f secs.' % (time() - t1,), flush=True)

except Exception as e:
print('Problem in %s' % test_group_info, flush=True)
track = traceback.format_exc()
print(track, flush=True)

def _interpret(self, save_dir: str, data_loader: DataLoader, interpreter: Interpreter) -> None:
""" Iterates over the samples (as much as and how data_loader specifies),
finds the interpretations based on the specified interpretation method
and saves the results in the received save_dir!"""

makedirs(save_dir, exist_ok=True)

if not isinstance(self.model, InterpretableModel):
raise Exception('Model has not implemented the requirements of the InterpretableModel')

# Running the model in evaluation mode
self.model.eval()

# initiating variables for running evaluations

label_for_interpreter = self.conf.class_label_for_interpretation
give_gt_as_label = label_for_interpreter is None and not self.conf.interpret_predictions_vs_gt

dataflow = InterpretationDataFlow(interpreter,
data_loader,
self.conf.device,
False,
label_for_interpreter,
give_gt_as_label)

n_interpreted_samples = 0
max_n_samples = self.conf.n_interpretation_samples

with dataflow:
for model_input, interpretation_output in dataflow.iterate():
n_batch = model_input[
self.model.ordered_placeholder_names_to_be_interpreted[0]].shape[0]
n_samples_to_save = n_batch
if max_n_samples is not None:
n_samples_to_save = min(
max_n_samples - n_interpreted_samples,
n_batch)

self._save_interpretations_of_batch(
model_input[self.model.ordered_placeholder_names_to_be_interpreted[0]],
interpretation_output, save_dir, data_loader, n_samples_to_save,
interpreter)
del interpretation_output

# if the number of interpreted samples has reached the limit, break
n_interpreted_samples += n_batch
if max_n_samples is not None and n_interpreted_samples >= max_n_samples:
break

def _save_interpretations_of_batch(self, model_input: torch.Tensor, interpreter_output: Dict[str, torch.Tensor],
save_dir: str, data_loader: DataLoader,
n_samples_to_save: int, interpreter: Interpreter) -> None:
"""
Receives the output of the interpreter and saves the interpretations in the received directory
in a file named as the sample name.
The behaviour can be overwritten in children if extra stuff are required.
:param interpreter_output:
:param save_dir:
:return: None
"""

batch_samples_names = data_loader.get_current_batch_samples_names()

save_dirs = [self.conf.get_save_dir_for_sample(
save_dir, batch_samples_names[bi].replace('../', ''))
for bi in range(len(batch_samples_names))]
for sd in save_dirs:
makedirs(path.dirname(sd), exist_ok=True)

interpreter_output = {name: output.cpu().numpy() for name, output in interpreter_output.items()}
""" Make inputs grayscale """
model_input = model_input.mean(dim=1).cpu().numpy()

def save(bis):
for bi in bis:
for name, output in interpreter_output.items():
output = output[bi]
filename = f'{save_dirs[bi]}_{name}'

if self.conf.dynamic_threshold:
output = interpreter.dynamic_threshold(output)

if not self.conf.skip_raw:
np.save(filename, output)
else:
warnings.warn('Skipping raw interpretations')

if not self.conf.skip_overlay:
if output.shape[0] > 3:
output = output[:3]
overlayed = overlay_interpretation(model_input[bi][np.newaxis, ...], output, self.conf)
imageio.imwrite(f'{filename}_overlay.png', overlayed)
else:
warnings.warn('Skipping overlayed interpretations')

save(np.arange(min(len(model_input), n_samples_to_save)))

+ 1
- 0
torchlap/interpreting/relcam/__init__.py View File

@@ -0,0 +1 @@
from .interpreter import RelCamInterpreter

+ 44
- 0
torchlap/interpreting/relcam/interpreter.py View File

@@ -0,0 +1,44 @@
from typing import Dict, Union
import warnings

import numpy as np
import torch.nn.functional as F
import torch
from ..interpreter import Interpreter
from ..interpretable import CamInterpretableModel
from .relprop import RPProvider


class RelCamInterpreter(Interpreter):

def __init__(self, model: CamInterpretableModel):
super().__init__(model)
self.__targets = model.target_conv_layers
for name, module in model.named_modules():
if not RPProvider.propable(module):
warnings.warn(f"Module {name} of type {type(module)} is not propable! Hope you know what you are doing!")
continue
RPProvider.create(module)

@staticmethod
def __normalize(x: torch.Tensor) -> torch.Tensor:
fx = x.flatten(1)
minx = fx.min(dim=1).values[:, None, None, None]
maxx = fx.max(dim=1).values[:, None, None, None]
return (x - minx) / (1e-6 + maxx - minx)

def interpret(self, labels: Union[int, torch.Tensor, np.ndarray], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
x_shape = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape
with RPProvider.capture(self.__targets) as Rs:
z = self._model.get_categorical_probabilities(**inputs)
one_hot_y = self._get_one_hot_output(z, labels)
RPProvider.get(self._model)(one_hot_y)
result = list(Rs.values())
relcam = torch.stack(result).sum(dim=0)\
if len(result) > 1\
else result[0]
relcam = self.__normalize(relcam)
relcam = F.interpolate(relcam, size=x_shape[2:], mode='bilinear', align_corners=False)
return {
'default': relcam,
}

+ 21
- 0
torchlap/interpreting/relcam/modules.py View File

@@ -0,0 +1,21 @@
from torch import nn
import torch


class Add(nn.Module):

def forward(self, inputs):
return torch.add(*inputs)


class Multiply(nn.Module):

def forward(self, inputs):
return torch.mul(*inputs)


class Cat(nn.Module):

def forward(self, inputs, dim):
self.__setattr__('dim', dim)
return torch.cat(inputs, dim)

+ 339
- 0
torchlap/interpreting/relcam/relprop.py View File

@@ -0,0 +1,339 @@
from contextlib import ExitStack, contextmanager
from typing import Callable, Dict, Generic, Iterator, List, Tuple, Type, TypeVar
from uuid import UUID, uuid4

import torch
import torch.nn.functional as F
from torch import nn
import torchvision

from . import modules as M


TModule = TypeVar('TModule', bound=nn.Module)
TType = TypeVar('TType', bound=type)


def safe_divide(a, b):
return a / (b + b.eq(0).type(b.type()) * 1e-9) * b.ne(0).type(b.type())


class classdict(Dict[type, TType], Generic[TType]):

def __nearest_available_resolution(self, __k: Type[TModule]) -> Type[TModule]:
return next((r for r in __k.mro() if dict.__contains__(self, r)), None)

def __contains__(self, __k: Type[TModule]) -> bool:
return self.__nearest_available_resolution(__k) is not None

def __getitem__(self, __k: Type[TModule]) -> TType:
r = self.__nearest_available_resolution(__k)
if r is None:
raise KeyError(f"{__k} not found!")
return super().__getitem__(r)



class RPProvider:
__props: Dict[Type[TModule], Type['RelProp[TModule]']] = classdict()
__instances: Dict[TModule, 'RelProp[TModule]'] = {}
__target_results: Dict[nn.Module, torch.Tensor] = {}

@classmethod
def register(cls, *module_types: Type[TModule]):
def decorator(prop_cls):
cls.__props.update({
module_type: prop_cls
for module_type in module_types
})
return prop_cls
return decorator

@classmethod
def create(cls, module: TModule):
cls.__instances[module] = prop = cls.__props[type(module)](module)
return prop

@classmethod
def __hook(cls, prop: 'RelProp[TModule]', R: torch.Tensor) -> None:
r_weight = torch.mean(R, dim=(2, 3), keepdim=True)
r_cam = prop.X * r_weight
r_cam = torch.sum(r_cam, dim=1, keepdim=True)
cls.__target_results[prop.module] = r_cam

@classmethod
@contextmanager
def capture(cls, target_layers: List[nn.Module]) -> Iterator[Dict[nn.Module, torch.Tensor]]:
with ExitStack() as stack:
cls.__target_results.clear()
[stack.enter_context(prop.hook_module())
for prop in cls.__instances.values()]
[stack.enter_context(cls.get(target).register_hook(cls.__hook))
for target in target_layers]
yield cls.__target_results

@classmethod
def propable(cls, module: TModule) -> bool:
return type(module) in cls.__props

@classmethod
def get(cls, module: TModule) -> 'RelProp[TModule]':
return cls.__instances[module]


@RPProvider.register(nn.Identity, nn.ReLU, nn.LeakyReLU, nn.Dropout, nn.Sigmoid)
class RelProp(Generic[TModule]):
Hook = Callable[['RelProp', torch.Tensor], None]

def __forward_hook(self, _, input: Tuple[torch.Tensor, ...], output: torch.Tensor):
if type(input[0]) in (list, tuple):
self.X = []
for i in input[0]:
x = i.detach()
x.requires_grad = True
self.X.append(x)
else:
self.X = input[0].detach()
self.X.requires_grad = True
self.Y = output

def __init__(self, module: TModule) -> None:
self.module = module
self.__hooks: Dict[UUID, RelProp.Hook] = {}

@contextmanager
def hook_module(self):
handle = self.module.register_forward_hook(self.__forward_hook)
try:
yield
finally:
handle.remove()

@contextmanager
def register_hook(self, hook: 'RelProp.Hook'):
uuid = uuid4()
self.__hooks[uuid] = hook
try:
yield
finally:
self.__hooks.pop(uuid)

def grad(self, Z, X, S):
C = torch.autograd.grad(Z, X, S, retain_graph=True)
return C

def __call__(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
R = self.rel(R, alpha=alpha)
[hook(self, R) for hook in self.__hooks.values()]
return R

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
return R


@RPProvider.register(nn.MaxPool2d, nn.AvgPool2d, M.Add, torchvision.transforms.transforms.Normalize)
class RelPropSimple(RelProp[TModule]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
Z = self.module.forward(self.X)
S = safe_divide(R, Z)
C = self.grad(Z, self.X, S)

if torch.is_tensor(self.X) == False:
outputs = []
outputs.append(self.X[0] * C[0])
outputs.append(self.X[1] * C[1])
else:
outputs = self.X * C[0]
return outputs


@RPProvider.register(nn.Flatten)
class RelPropFlatten(RelProp[TModule]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
return R.reshape_as(self.X)


@RPProvider.register(nn.AdaptiveAvgPool2d)
class AdaptiveAvgPool2dRelProp(RelProp[nn.AdaptiveAvgPool2d]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
px = torch.clamp(self.X, min=0)

def f(x1):
Z1 = F.adaptive_avg_pool2d(x1, self.module.output_size)
S1 = safe_divide(R, Z1)
C1 = x1 * self.grad(Z1, x1, S1)[0]
return C1

activator_relevances = f(px)
out = activator_relevances
return out


@RPProvider.register(nn.ZeroPad2d)
class ZeroPad2dRelProp(RelProp[nn.ZeroPad2d]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
Z = self.module.forward(self.X)
S = safe_divide(R, Z)
C = self.grad(Z, self.X, S)
outputs = self.X * C[0]
return outputs


@RPProvider.register(M.Multiply)
class MultiplyRelProp(RelProp[M.Multiply]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
x0 = torch.clamp(self.X[0], min=0)
x1 = torch.clamp(self.X[1], min=0)
x = [x0, x1]
Z = self.module.forward(x)
S = safe_divide(R, Z)
C = self.grad(Z, x, S)

outputs = []
outputs.append(x[0] * C[0])
outputs.append(x[1] * C[1])

return outputs


@RPProvider.register(M.Cat)
class CatRelProp(RelProp[M.Cat]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
Z = self.module.forward(self.X, self.module.dim)
S = safe_divide(R, Z)
C = self.grad(Z, self.X, S)

outputs = []
for x, c in zip(self.X, C):
outputs.append(x * c)

return outputs


@RPProvider.register(nn.Sequential)
class SequentialRelProp(RelProp[nn.Sequential]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
for m in reversed(self.module):
R = RPProvider.get(m)(R, alpha=alpha)
return R


@RPProvider.register(nn.BatchNorm2d)
class BatchNorm2dRelProp(RelProp[nn.BatchNorm2d]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
X = self.X
beta = 1 - alpha
weight = self.module.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
(self.module.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.module.eps).pow(0.5))
Z = X * weight + 1e-9
S = R / Z
Ca = S * weight
R = self.X * (Ca)
return R


@RPProvider.register(nn.Linear)
class LinearRelProp(RelProp[nn.Linear]):
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
beta = alpha - 1
pw = torch.clamp(self.module.weight, min=0)
nw = torch.clamp(self.module.weight, max=0)
px = torch.clamp(self.X, min=0)
nx = torch.clamp(self.X, max=0)

def f(w1, w2, x1, x2):
Z1 = F.linear(x1, w1)
Z2 = F.linear(x2, w2)
Z = Z1 + Z2
S = safe_divide(R, Z)
C1 = x1 * self.grad(Z1, x1, S)[0]
C2 = x2 * self.grad(Z2, x2, S)[0]
return C1 + C2

activator_relevances = f(pw, nw, px, nx)
inhibitor_relevances = f(nw, pw, px, nx)
out = alpha * activator_relevances - beta*inhibitor_relevances
return out


@RPProvider.register(nn.Conv2d)
class Conv2dRelProp(RelProp[nn.Conv2d]):

def gradprop2(self, DY, weight):
Z = self.module.forward(self.X)

output_padding = self.X.size()[2] - (
(Z.size()[2] - 1) * self.module.stride[0] - 2 * self.module.padding[0] + self.module.kernel_size[0])

return F.conv_transpose2d(DY, weight, stride=self.module.stride, padding=self.module.padding, output_padding=output_padding)

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
if self.X.shape[1] == 3:
pw = torch.clamp(self.module.weight, min=0)
nw = torch.clamp(self.module.weight, max=0)
X = self.X
L = self.X * 0 + \
torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
keepdim=True)[0]
H = self.X * 0 + \
torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
keepdim=True)[0]
Za = torch.conv2d(X, self.module.weight, bias=None, stride=self.module.stride, padding=self.module.padding) - \
torch.conv2d(L, pw, bias=None, stride=self.module.stride, padding=self.module.padding) - \
torch.conv2d(H, nw, bias=None, stride=self.module.stride,
padding=self.module.padding) + 1e-9

S = R / Za
C = X * self.gradprop2(S, self.module.weight) - L * \
self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
R = C
else:
beta = alpha - 1
pw = torch.clamp(self.module.weight, min=0)
nw = torch.clamp(self.module.weight, max=0)
px = torch.clamp(self.X, min=0)
nx = torch.clamp(self.X, max=0)

def f(w1, w2, x1, x2):
Z1 = F.conv2d(x1, w1, bias=self.module.bias, stride=self.module.stride,
padding=self.module.padding, groups=self.module.groups)
Z2 = F.conv2d(x2, w2, bias=self.module.bias, stride=self.module.stride,
padding=self.module.padding, groups=self.module.groups)
Z = Z1 + Z2
S = safe_divide(R, Z)
C1 = x1 * self.grad(Z1, x1, S)[0]
C2 = x2 * self.grad(Z2, x2, S)[0]
return C1 + C2

activator_relevances = f(pw, nw, px, nx)
inhibitor_relevances = f(nw, pw, px, nx)

R = alpha * activator_relevances - beta * inhibitor_relevances
return R


@RPProvider.register(nn.ConvTranspose2d)
class ConvTranspose2dRelProp(RelProp[nn.ConvTranspose2d]):
def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:

pw = torch.clamp(self.module.weight, min=0)
px = torch.clamp(self.X, min=0)

def f(w1, x1):
Z1 = F.conv_transpose2d(x1, w1, bias=None, stride=self.module.stride,
padding=self.module.padding, output_padding=self.module.output_padding)
S1 = safe_divide(R, Z1)
C1 = x1 * self.grad(Z1, x1, S1)[0]
return C1

activator_relevances = f(pw, px)
R = activator_relevances
return R

+ 102
- 0
torchlap/interpreting/utils.py View File

@@ -0,0 +1,102 @@
from typing import TYPE_CHECKING, Optional

import numpy as np
import cv2
from skimage.morphology import dilation, erosion

if TYPE_CHECKING:
from ..configs.base_config import BaseConfig
from . import Interpreter


def overlay_interpretation(image: np.ndarray, int_map: np.ndarray, config: 'BaseConfig') -> np.ndarray:
"""
image (np.ndarray): The image to overlay interpretations over. Of shape C W H. The numbers are assumed to be in range [0, 1]
int_map (np.ndarray): The map containing interpretation scores.
config ('Config'): An object containing required information about the configurations.
Returns:
(np.ndarray): The overlayed image
"""

int_map = process_interpretations(int_map, config) # C W H
int_map = np.moveaxis(int_map, 0, -1) # W H C

# if there are more than 3 channels, use the first 3!
if int_map.shape[-1] > 3:
int_map = int_map[..., :3]

# norm by max to make sure!
int_map = int_map * 1.0 / (1e-6 + np.amax(int_map))

alpha = 0.5

if np.amin(image) < 0:
image += 0.5

image = np.moveaxis(image, 0, -1) # W H C
image = image * 255

o_color_by_c = {
1: [255, 0, 0],
2: [255, 255, 0],
3: [255, 255, 255],
}
o_color = np.asarray(o_color_by_c.get(int_map.shape[-1], None))
if o_color is None:
raise Exception("Trying to overlay more than 3 maps on the image.")

# if int map has two channels, append one zeros to the end!
if int_map.shape[-1] == 2:
int_map = np.concatenate((int_map, np.zeros((int_map.shape[-3], int_map.shape[-2], 1))), axis=-1)

overlayed = (1 - int_map) * image + \
int_map * (1 - alpha) * image + \
int_map * alpha * o_color[np.newaxis, np.newaxis, :]

overlayed = np.round(overlayed).astype(np.uint8)
return overlayed


def process_interpretations(int_map: np.ndarray, config: 'BaseConfig', interpreter: Optional['Interpreter'] = None) -> np.ndarray:
"""
int_map (np.ndarray): The map of interpretation scores. Is assumed to have a shape of C ...
config ('Config'): An object containing the information about the required configurations

Returns:
np.ndarray: Processed interpretation maps, with numbers in the range of [0, 1]
"""

if interpreter and config.dynamic_threshold:
int_map = interpreter.dynamic_threshold(int_map)

if not config.global_threshold:

# Treating map as negative=reverse effect; discarding negatives

int_map[int_map < 0] = 0

qval = np.amax(int_map)
if qval > 0:
int_map[int_map > qval] = qval
int_map = int_map / qval

# applying cut threshold!
int_map[int_map <= config.cut_threshold] = config.cut_threshold
int_map -= config.cut_threshold

kw = int(np.round(0.1 * int_map.shape[-1]))

mv = np.amax(int_map)
int_map /= (mv + 1e-6)

# dilation and erosion to make continuous objects and erase noises
for c in range(int_map.shape[0]):
int_map[c] = dilation(int_map[c], cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kw + 3, kw + 3)))
int_map[c] = erosion(int_map[c], cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kw, kw)))
int_map[c] = erosion(int_map[c], cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)))

int_map[c] = cv2.blur(np.round(255 * int_map[c]).astype(np.uint8), (5, 5)) * 1.0 / 255.0

return int_map # [0, 1]



+ 100
- 0
torchlap/model_evaluation/binary_evaluator.py View File

@@ -0,0 +1,100 @@
from typing import List, TYPE_CHECKING, Dict

import numpy as np
import torch

from ..data.data_loader import DataLoader
from ..models.model import Model
from ..model_evaluation.evaluator import Evaluator
if TYPE_CHECKING:
from ..configs.base_config import BaseConfig


class BinaryEvaluator(Evaluator):

def __init__(self, model: Model, data_loader: DataLoader, conf: 'BaseConfig'):
super(BinaryEvaluator, self).__init__(model, data_loader, conf)

# for summarized information
self.tp: int = 0
self.tn: int = 0
self.fp: int = 0
self.fn: int = 0
self.avg_loss: float = 0
self.n_received_samples: int = 0

def reset(self):
self.tp = 0
self.tn = 0
self.fp = 0
self.fn = 0
self.avg_loss = 0
self.n_received_samples = 0

@property
def _prob_key(self) -> str:
return 'positive_class_probability'

@property
def _loss_key(self) -> str:
return 'loss'

def _get_current_batch_gt(self) -> np.ndarray:
return self.data_loader.get_current_batch_samples_labels()

def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None:
assert self._prob_key in model_output, \
f"model's output must contain {self._prob_key}"

gt = self._get_current_batch_gt()
prediction = (model_output[self._prob_key].cpu().numpy() >= 0.5).astype(int)

ntp = int(np.sum(np.logical_and(gt == 1, prediction == 1)))
ntn = int(np.sum(np.logical_and(gt == 0, prediction == 0)))
nfp = int(np.sum(np.logical_and(gt == 0, prediction == 1)))
nfn = int(np.sum(np.logical_and(gt == 1, prediction == 0)))

self.tp += ntp
self.tn += ntn
self.fp += nfp
self.fn += nfn

new_n = self.n_received_samples + len(gt)
self.avg_loss = self.avg_loss * (float(self.n_received_samples) / new_n) + \
model_output.get(self._loss_key, 0.0) * (float(len(gt)) / new_n)

self.n_received_samples = new_n

def get_titles_of_evaluation_metrics(self) -> List[str]:
return ['Loss', 'Acc', 'Sens', 'Spec', 'BAcc', 'N']

def get_values_of_evaluation_metrics(self) -> List[str]:

p = self.tp + self.fn
n = self.tn + self.fp

if p + n > 0:
accuracy = (self.tp + self.tn) * 100.0 / (n + p)
else:
accuracy = -1

if p > 0:
sensitivity = 100.0 * self.tp / p
else:
sensitivity = -1

if n > 0:
specificity = 100.0 * self.tn / max(n, 1)
else:
specificity = -1

if sensitivity > -1 and specificity > -1:
avg_ss = 0.5 * (sensitivity + specificity)
elif sensitivity > -1:
avg_ss = sensitivity
else:
avg_ss = specificity

return ['%.4f' % self.avg_loss, '%.2f' % accuracy, '%.2f' % sensitivity,
'%.2f' % specificity, '%.2f' % avg_ss, str(self.n_received_samples)]


+ 109
- 0
torchlap/model_evaluation/binary_faithfulness.py View File

@@ -0,0 +1,109 @@
from functools import partial
from typing import List, TYPE_CHECKING, Dict, Callable

import numpy as np
from torch import Tensor

from ..data.data_loader import DataLoader
from ..models.model import Model
from ..model_evaluation.evaluator import Evaluator
if TYPE_CHECKING:
from ..configs.base_config import BaseConfig


class BinFaithfulnessEvaluator(Evaluator):

def __init__(self, kw_prefix: str, gt_kw, model: Model, data_loader: DataLoader, conf: 'BaseConfig'):
super(BinFaithfulnessEvaluator, self).__init__(model, data_loader, conf)

self._tp_by_kw: Dict[str, int] = dict()
self._fp_by_kw: Dict[str, int] = dict()
self._tn_by_kw: Dict[str, int] = dict()
self._fn_by_kw: Dict[str, int] = dict()

self._kw_prefix = kw_prefix
self._gt_kw = gt_kw

def reset(self):
self._tp_by_kw: Dict[str, int] = dict()
self._fp_by_kw: Dict[str, int] = dict()
self._tn_by_kw: Dict[str, int] = dict()
self._fn_by_kw: Dict[str, int] = dict()

def update_summaries_based_on_model_output(self, model_output: Dict[str, Tensor]) -> None:

gt = (model_output[self._gt_kw].detach().cpu().numpy() >= 0.5).astype(int)

# looking for prefixes
for k in model_output.keys():
if k.startswith(self._kw_prefix):

if k not in self._tp_by_kw:
self._tp_by_kw[k] = 0
self._tn_by_kw[k] = 0
self._fp_by_kw[k] = 0
self._fn_by_kw[k] = 0

pred = (model_output[k].cpu().numpy() >= 0.5).astype(int)
self._tp_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 1))
self._fp_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 1))
self._tn_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 0))
self._fn_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 0))

def get_titles_of_evaluation_metrics(self) -> List[str]:
return [f'Loyalty_{k}_{metric}' for k in self._tp_by_kw.keys() for metric in ['Acc', 'Sens', 'Spec', 'AvgSS']]

def _get_values_of_evaluation_metrics(self, kw) -> List[str]:

tp = self._tp_by_kw[kw]
tn = self._tn_by_kw[kw]
fp = self._fp_by_kw[kw]
fn = self._fn_by_kw[kw]

p = tp + fn
n = tn + fp

if p + n > 0:
accuracy = (tp + tn) * 100.0 / (n + p)
else:
accuracy = -1

if p > 0:
sensitivity = 100.0 * tp / p
else:
sensitivity = -1

if n > 0:
specificity = 100.0 * tn / max(n, 1)
else:
specificity = -1

if sensitivity > -1 and specificity > -1:
avg_ss = 0.5 * (sensitivity + specificity)
elif sensitivity > -1:
avg_ss = sensitivity
else:
avg_ss = specificity

return ['%.2f' % accuracy, '%.2f' % sensitivity,
'%.2f' % specificity, '%.2f' % avg_ss]

def get_values_of_evaluation_metrics(self) -> List[str]:
return [
v
for k in self._tp_by_kw.keys()
for v in self._get_values_of_evaluation_metrics(k)]

@classmethod
def standard_creator(cls, prefix_kw: str, pred_kw: str = 'positive_class_probability') -> Callable[[Model, DataLoader, 'BaseConfig'], 'BinFaithfulnessEvaluator']:
return partial(BinFaithfulnessEvaluator, prefix_kw, pred_kw)

def print_evaluation_metrics(self, title: str) -> None:
""" For more readable printing! """

print(f'{title}:')
for k in self._tp_by_kw.keys():
print(
f'\t{k}:: ' +
', '.join([f'{m_name}: {m_val}'
for m_name, m_val in zip(['Acc', 'Sens', 'Spec', 'AvgSS'], self._get_values_of_evaluation_metrics(k))]))

+ 108
- 0
torchlap/model_evaluation/binary_fortelling.py View File

@@ -0,0 +1,108 @@
from typing import List, TYPE_CHECKING, Dict, Callable
from functools import partial

import numpy as np
from torch import Tensor

from ..data.data_loader import DataLoader
from ..models.model import Model
from ..model_evaluation.evaluator import Evaluator
if TYPE_CHECKING:
from ..configs.base_config import BaseConfig


class BinForetellerEvaluator(Evaluator):

def __init__(self, kw_prefix: str, model: Model, data_loader: DataLoader, conf: 'BaseConfig'):
super(BinForetellerEvaluator, self).__init__(model, data_loader, conf)

self._tp_by_kw: Dict[str, int] = dict()
self._fp_by_kw: Dict[str, int] = dict()
self._tn_by_kw: Dict[str, int] = dict()
self._fn_by_kw: Dict[str, int] = dict()

self._kw_prefix = kw_prefix

def reset(self):
self._tp_by_kw: Dict[str, int] = dict()
self._fp_by_kw: Dict[str, int] = dict()
self._tn_by_kw: Dict[str, int] = dict()
self._fn_by_kw: Dict[str, int] = dict()

def update_summaries_based_on_model_output(self, model_output: Dict[str, Tensor]) -> None:

gt = self.data_loader.get_current_batch_samples_labels()

# looking for prefixes
for k in model_output.keys():
if k.startswith(self._kw_prefix):

if k not in self._tp_by_kw:
self._tp_by_kw[k] = 0
self._tn_by_kw[k] = 0
self._fp_by_kw[k] = 0
self._fn_by_kw[k] = 0

pred = (model_output[k].cpu().numpy() >= 0.5).astype(int)
self._tp_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 1))
self._fp_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 1))
self._tn_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 0))
self._fn_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 0))

def get_titles_of_evaluation_metrics(self) -> List[str]:
return [f'{k}_{metric}' for k in self._tp_by_kw.keys() for metric in ['Acc', 'Sens', 'Spec', 'AvgSS']]

def _get_values_of_evaluation_metrics(self, kw) -> List[str]:

tp = self._tp_by_kw[kw]
tn = self._tn_by_kw[kw]
fp = self._fp_by_kw[kw]
fn = self._fn_by_kw[kw]

p = tp + fn
n = tn + fp

if p + n > 0:
accuracy = (tp + tn) * 100.0 / (n + p)
else:
accuracy = -1

if p > 0:
sensitivity = 100.0 * tp / p
else:
sensitivity = -1

if n > 0:
specificity = 100.0 * tn / max(n, 1)
else:
specificity = -1

if sensitivity > -1 and specificity > -1:
avg_ss = 0.5 * (sensitivity + specificity)
elif sensitivity > -1:
avg_ss = sensitivity
else:
avg_ss = specificity

return ['%.2f' % accuracy, '%.2f' % sensitivity,
'%.2f' % specificity, '%.2f' % avg_ss]

def get_values_of_evaluation_metrics(self) -> List[str]:
return [
v
for k in self._tp_by_kw.keys()
for v in self._get_values_of_evaluation_metrics(k)]

@classmethod
def standard_creator(cls, prefix_kw: str) -> Callable[[Model, DataLoader, 'BaseConfig'], 'BinForetellerEvaluator']:
return partial(BinForetellerEvaluator, prefix_kw)

def print_evaluation_metrics(self, title: str) -> None:
""" For more readable printing! """

print(f'{title}:')
for k in self._tp_by_kw.keys():
print(
f'\t{k}:: ' +
', '.join([f'{m_name}: {m_val}'
for m_name, m_val in zip(['Acc', 'Sens', 'Spec', 'AvgSS'], self._get_values_of_evaluation_metrics(k))]))

+ 35
- 0
torchlap/model_evaluation/binary_tag_evaluator.py View File

@@ -0,0 +1,35 @@
from typing import TYPE_CHECKING, Type

import numpy as np

from ..data.data_loader import DataLoader
from .binary_evaluator import BinaryEvaluator
from ..models.model import Model
if TYPE_CHECKING:
from ..configs.base_config import BaseConfig
from ..data.content_loaders.content_loader import ContentLoader


class TagBinaryEvaluator(BinaryEvaluator):

def __init__(self, model: Model, data_loader: DataLoader, conf: 'BaseConfig', tag: str, cl_type: Type[ContentLoader]):
super().__init__(model, data_loader, conf)
self._tag = tag
self._content_loader = data_loader.get_content_loader_of_interest(cl_type)

@property
def _prob_key(self) -> str:
return f'{self._tag}_positive_class_probability'

@property
def _gt_key(self) -> str:
return self._tag

@property
def _loss_key(self) -> str:
return f'{self._tag}_loss'

def _get_current_batch_gt(self) -> np.ndarray:
return self._content_loader.get_placeholder_name_to_fill_function_dict()[self._gt_key](
self.data_loader.get_current_batch_sample_indices(), None
)

+ 84
- 0
torchlap/model_evaluation/evaluator.py View File

@@ -0,0 +1,84 @@
""" The evaluator """
from abc import abstractmethod, ABC
from typing import List, TYPE_CHECKING, Dict

import torch
from tqdm import tqdm

from ..data.dataflow import DataFlow
if TYPE_CHECKING:
from ..models.model import Model
from ..data.data_loader import DataLoader
from ..configs.base_config import BaseConfig


class Evaluator(ABC):
""" The evaluator """

def __init__(self, model: 'Model', data_loader: 'DataLoader', conf: 'BaseConfig'):
""" Model is a predictor, data_loader is a subclass of type data_loading.NormalDataLoader which contains information about the samples and how to
iterate over them. """
self.model = model
self.data_loader = data_loader
self.conf = conf
self.dataflow = DataFlow[Dict[str, torch.Tensor]](model, data_loader, conf.device)

def evaluate(self, max_iters: int = None):
""" CAUTION: Resets the data_loader,
Iterates over the samples (as much as and how data_loader specifies),
calculates the overall evaluation requirements and prints them.
Title specified the title of the string for printing evaluation metrics.
classes_to_use specifies the labels of the samples to do the function on them,
None means all"""

# setting in dataflow
max_iters = float('inf') if max_iters is None else max_iters

with torch.no_grad():

# Running the model in evaluation mode
self.model.eval()

# initiating variables for running evaluations
self.reset()

with self.dataflow, tqdm(enumerate(self.dataflow.iterate())) as pbar:
for iters, model_output in pbar:
self.update_summaries_based_on_model_output(model_output)
del model_output

pbar.set_description(self.get_evaluation_metrics())

if iters + 1 >= max_iters:
break

@abstractmethod
def update_summaries_based_on_model_output(
self, model_output: Dict[str, torch.Tensor]) -> None:
""" Updates the inner variables responsible of keeping a summary over data
based on the new outputs of the model, so when needed evaluation metrics
can be calculated based on these summaries. Mostly used in train and validation phase
or eval phase in which we only need evaluation metrics."""

@abstractmethod
def reset(self):
""" Resets the held information for a new evaluation round!"""

@abstractmethod
def get_titles_of_evaluation_metrics(self) -> List[str]:
""" Returns a list of strings containing the titles of the evaluation metrics"""

@abstractmethod
def get_values_of_evaluation_metrics(self) -> List[str]:
""" Returns a list of values for the calculated evaluation metrics,
converted to string with the desired format!"""

def get_evaluation_metrics(self) -> None:
return ', '.join(["%s: %s" % (eval_title, eval_value) for (eval_title, eval_value) in
zip(
self.get_titles_of_evaluation_metrics(),
self.get_values_of_evaluation_metrics())])

def print_evaluation_metrics(self, title: str) -> None:
""" Prints the values of the evaluation metrics"""
print("%s: %s" % (title, self.get_evaluation_metrics()), flush=True)

+ 57
- 0
torchlap/model_evaluation/loss_evaluator.py View File

@@ -0,0 +1,57 @@
from typing import List, TYPE_CHECKING, OrderedDict as OrdDict, Dict
from collections import OrderedDict

import torch

from ..data.data_loader import DataLoader
from ..models.model import Model
from ..model_evaluation.evaluator import Evaluator
if TYPE_CHECKING:
from ..configs.base_config import BaseConfig


class LossEvaluator(Evaluator):

def __init__(self, model: Model, data_loader: DataLoader, conf: 'BaseConfig'):
super(LossEvaluator, self).__init__(model, data_loader, conf)

self.phase = conf.phase_type

# for summarized information
self.avg_loss: float = 0
self._avg_other_losses: OrdDict[str, float] = OrderedDict()
self.n_received_samples: int = 0
self._n_received_samples_other_losses: OrdDict[str, int] = OrderedDict()

def reset(self):
self.avg_loss = 0
self._avg_other_losses = OrderedDict()
self.n_received_samples = 0
self._n_received_samples_other_losses = OrderedDict()

def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None:

n_batch = self.data_loader.get_current_batch_size()

new_n = self.n_received_samples + n_batch
self.avg_loss = self.avg_loss * (float(self.n_received_samples) / new_n) + \
model_output.get('loss', 0.0) * (float(n_batch) / new_n)

for kw in model_output.keys():
if 'loss' in kw and kw != 'loss':

old_avg = self._avg_other_losses.get(kw, 0.0)
old_n = self._n_received_samples_other_losses.get(kw, 0)

new_n = old_n + n_batch

self._avg_other_losses[kw] = \
old_avg * (float(old_n) / new_n) + \
model_output[kw].detach().cpu() * (float(n_batch) / new_n)
self._n_received_samples_other_losses[kw] = new_n

def get_titles_of_evaluation_metrics(self) -> List[str]:
return ['Loss'] + list(self._avg_other_losses.keys())

def get_values_of_evaluation_metrics(self) -> List[str]:
return ['%.4f' % self.avg_loss] + ['%.4f' % loss for loss in self._avg_other_losses.values()]

+ 122
- 0
torchlap/model_evaluation/multiclass_evaluator.py View File

@@ -0,0 +1,122 @@
from typing import Any, Callable, List, Dict, TYPE_CHECKING, Optional
from functools import partial

import numpy as np
import torch
from sklearn.metrics import confusion_matrix

from ..data.data_loader import DataLoader
from ..models.model import Model
from ..model_evaluation.evaluator import Evaluator
if TYPE_CHECKING:
from ..configs.base_config import BaseConfig


def _precision_score(cm: np.ndarray):
numinator = np.diag(cm)
denominator = cm.sum(axis=0)
return (numinator[denominator != 0] / denominator[denominator != 0]).mean()


def _recall_score(cm: np.ndarray):
numinator = np.diag(cm)
denominator = cm.sum(axis=1)
return (numinator[denominator != 0] / denominator[denominator != 0]).mean()


def _accuracy_score(cm: np.ndarray):
return np.diag(cm).sum() / cm.sum()


class MulticlassEvaluator(Evaluator):

@classmethod
def standard_creator(cls, class_probability_key: str = 'categorical_probability',
include_top5: bool = False) -> Callable[[Model, DataLoader, 'BaseConfig'], 'MulticlassEvaluator']:
return partial(MulticlassEvaluator, class_probability_key=class_probability_key, include_top5=include_top5)

def __init__(self, model: Model,
data_loader: DataLoader,
conf: 'BaseConfig',
class_probability_key: str = 'categorical_probability',
include_top5: bool = False):
super().__init__(model, data_loader, conf)

# for summarized information
self._avg_loss: float = 0
self._n_received_samples: int = 0
self._trues = {}
self._falses = {}
self._top5_trues: int = 0
self._top5_falses: int = 0
self._num_iters = 0
self._predictions: Dict[int, np.ndarray] = {}
self._prediction_weights: Dict[int, np.ndarray] = {}
self._sample_details: Dict[str, Dict[str, Any]] = {}
self._cm: Optional[np.ndarray] = None
self._c = None
self._class_probability_key = class_probability_key
self._include_top5 = include_top5

def reset(self):
# for summarized information
self._avg_loss: float = 0
self._n_received_samples: int = 0
self._trues = {}
self._falses = {}
self._top5_trues: int = 0
self._top5_falses: int = 0
self._num_iters = 0
self._cm: Optional[np.ndarray] = None

def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None:

assert self._class_probability_key in model_output, \
f"model's output dictionary must contain {self._class_probability_key}"

prediction: np.ndarray = model_output[self._class_probability_key]\
.cpu().numpy() # B C
c = prediction.shape[1]
ground_truth = self.data_loader.get_current_batch_samples_labels() # B

if self._include_top5:
top5_p = prediction.argsort(axis=1)[:, -5:]
trues = (top5_p == ground_truth[:, None]).any(axis=1).sum()
self._top5_trues += trues
self._top5_falses += len(ground_truth) - trues

prediction = prediction.argmax(axis=1) # B

for i in range(c):
nt = int(np.sum(np.logical_and(ground_truth == i, prediction == i)))
nf = int(np.sum(np.logical_and(ground_truth == i, prediction != i)))
if i not in self._trues:
self._trues[i] = 0
self._falses[i] = 0
self._trues[i] += nt
self._falses[i] += nf

self._cm = (0 if self._cm is None else self._cm)\
+ confusion_matrix(ground_truth, prediction,
labels=np.arange(c)).astype(float)
self._num_iters += 1

new_n = self._n_received_samples + len(ground_truth)
self._avg_loss = self._avg_loss * (float(self._n_received_samples) / new_n) + \
model_output.get('loss', 0.0) * (float(len(ground_truth)) / new_n)

self._n_received_samples = new_n
def get_titles_of_evaluation_metrics(self) -> List[str]:
return ['Loss', 'Accuracy', 'Precision', 'Recall'] + \
(['Top5Acc'] if self._include_top5 else [])

def get_values_of_evaluation_metrics(self) -> List[str]:
accuracy = _accuracy_score(self._cm) * 100
precision = _precision_score(self._cm) * 100
recall = _recall_score(self._cm) * 100
top5_acc = self._top5_trues * 100.0 / (self._top5_trues + self._top5_falses)\
if self._include_top5 else None

return [f'{self._avg_loss:.4e}', f'{accuracy:8.4f}', f'{precision:8.4f}', f'{recall:8.4f}'] + \
([f'{top5_acc:8.4f}'] if self._include_top5 else [])

+ 63
- 0
torchlap/model_evaluation/multieval_evaluator.py View File

@@ -0,0 +1,63 @@
from typing import List, TYPE_CHECKING, Dict, OrderedDict, Type
from collections import OrderedDict as ODict

import torch

from ..data.data_loader import DataLoader
from ..models.model import Model
from ..model_evaluation.evaluator import Evaluator
if TYPE_CHECKING:
from ..configs.base_config import BaseConfig

# If you want your model to be analyzed from different points of view by different evaluators, you can use this holder!


class MultiEvaluatorEvaluator(Evaluator):

def __init__(
self, model: Model, data_loader: DataLoader, conf: 'BaseConfig',
evaluators_cls_by_name: OrderedDict[str, Type[Evaluator]]):
"""
evaluators_cls_by_name: The key is an arbitrary name to call the instance,
cls is the constructor
"""
super(MultiEvaluatorEvaluator, self).__init__(model, data_loader, conf)

# Making all evaluators!
self._evaluators_by_name: OrderedDict[str, Evaluator] = ODict()
for eval_name, eval_cls in evaluators_cls_by_name.items():
self._evaluators_by_name[eval_name] = eval_cls(model, data_loader, conf)

def reset(self):
for evaluator in self._evaluators_by_name.values():
evaluator.reset()

def get_titles_of_evaluation_metrics(self) -> List[str]:
titles = []
for eval_name, evaluator in self._evaluators_by_name.items():
titles += [eval_name + '_' + t for t in evaluator.get_titles_of_evaluation_metrics()]
return titles

def get_values_of_evaluation_metrics(self) -> List[str]:
metrics = []
for evaluator in self._evaluators_by_name.values():
metrics += evaluator.get_values_of_evaluation_metrics()
return metrics

def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None:
for evaluator in self._evaluators_by_name.values():
evaluator.update_summaries_based_on_model_output(model_output)

@staticmethod
def create_standard_multi_evaluator_evaluator_maker(evaluators_cls_by_name: OrderedDict[str, Type[Evaluator]]):
""" For making a constructor, consistent with the known Evaluator"""
def typical_maker(model: Model, data_loader: DataLoader, conf: 'BaseConfig') -> MultiEvaluatorEvaluator:
return MultiEvaluatorEvaluator(model, data_loader, conf, evaluators_cls_by_name)
return typical_maker

def print_evaluation_metrics(self, title: str) -> None:
""" For more readable printing! """

print(f'{title}:')
for e_name, e_obj in self._evaluators_by_name.items():
e_obj.print_evaluation_metrics('\t' + e_name)

+ 0
- 0
torchlap/models/__init__.py View File


+ 0
- 0
torchlap/models/celeba/__init__.py View File


+ 61
- 0
torchlap/models/celeba/lap_inception.py View File

@@ -0,0 +1,61 @@
import typing
from typing import Dict, Iterable
from collections import OrderedDict

import torch
from torch.nn import functional as F
import torchvision

from ..lap_inception import LAPInception


class CelebALAPInception(LAPInception):

def __init__(self, tag:str, aux_weight: float, pool_factory, adaptive_pool_factory):
super().__init__(aux_weight, n_classes=1, pool_factory=pool_factory, adaptive_pool_factory=adaptive_pool_factory)
self._tag = tag

@property
def additional_kwargs(self) -> typing.OrderedDict[str, bool]:
r""" Returns a dictionary from additional `kwargs` names to their optionality """
return OrderedDict({
f'{self._tag}': True,
})

def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]:
# x: B 3 224 224

if self.training:
out, aux = torchvision.models.Inception3.forward(self, x) # B 1
out, aux = out.flatten(), aux.flatten() # B
else:
out = torchvision.models.Inception3.forward(self, x).flatten() # B
aux = None

output = dict()
output['positive_class_probability'] = out

if f'{self._tag}' not in gts:
return output

gt = gts[f'{self._tag}']

r""" Class weighted loss """
loss = torch.mean(torch.stack(tuple(
F.binary_cross_entropy(out[gt == i], gt[gt == i]) for i in gt.unique()
)))

output['loss'] = loss

return output

""" INTERPRETATION """

@property
def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]:
"""
:return: input module for interpretation
"""
return ['x']


+ 53
- 0
torchlap/models/celeba/lap_resnet.py View File

@@ -0,0 +1,53 @@
from collections import OrderedDict
from typing import Dict, List
import typing

import torch

from ...modules import LAP, AdaptiveLAP
from ..lap_resnet import LapBasicBlock, LapResNet, PoolFactory


def pool_factory(channels, sigmoid_scale=1.0):
return LAP(channels,
sigmoid_scale=sigmoid_scale,
n_attention_heads=3,
discriminative_attention=True)


def adaptive_pool_factory(channels, sigmoid_scale=1.0):
return AdaptiveLAP(channels,
sigmoid_scale=sigmoid_scale,
n_attention_heads=3,
discriminative_attention=True)


class CelebALAPResNet18(LapResNet):

def __init__(self, tag: str,
pool_factory: PoolFactory = pool_factory,
adaptive_factory: PoolFactory = adaptive_pool_factory,
sigmoid_scale: float = 1.0):
super().__init__(LapBasicBlock, [2, 2, 2, 2],
pool_factory=pool_factory,
adaptive_factory=adaptive_factory,
sigmoid_scale=sigmoid_scale,
binary=True)
self._tag = tag

@property
def additional_kwargs(self) -> typing.OrderedDict[str, bool]:
r""" Returns a dictionary from additional `kwargs` names to their optionality """
return OrderedDict({
f'{self._tag}': True,
})

def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]:
y = gts[f'{self._tag}']
return super().forward(x, y)

@property
def attention_layers(self) -> Dict[str, List[LAP]]:
res = super().attention_layers
res.pop('4_overall')
return res

+ 25
- 0
torchlap/models/celeba/org_resnet.py View File

@@ -0,0 +1,25 @@
from collections import OrderedDict
from typing import Dict
import typing

import torch

from ..tv_resnet import BasicBlock, ResNet


class CelebAORGResNet18(ResNet):

def __init__(self, tag: str):
super().__init__(BasicBlock, [2, 2, 2, 2], binary=True)
self._tag = tag

@property
def additional_kwargs(self) -> typing.OrderedDict[str, bool]:
r""" Returns a dictionary from additional `kwargs` names to their optionality """
return OrderedDict({
f'{self._tag}': True,
})

def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]:
y = gts[f'{self._tag}']
return super().forward(x, y)

+ 67
- 0
torchlap/models/celeba/single_tag_org_inception.py View File

@@ -0,0 +1,67 @@
import typing
from typing import Dict, List, Iterable
from collections import OrderedDict

import torch
from torch import nn
from torch.nn import functional as F
import torchvision

from ..tv_inception import Inception3


class CelebAORGInception(Inception3):

def __init__(self, tag: str, aux_weight: float):
super().__init__(aux_weight, n_classes=1)
self._tag = tag
@property
def additional_kwargs(self) -> typing.OrderedDict[str, bool]:
r""" Returns a dictionary from additional `kwargs` names to their optionality """
return OrderedDict({
f'{self._tag}': True,
})

def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]:

if self.training:
out, aux = torchvision.models.Inception3.forward(self, x) # B 1
out, aux = out.flatten(), aux.flatten() # B
else:
out = torchvision.models.Inception3.forward(self, x).flatten() # B
aux = None

output = dict()
output['positive_class_probability'] = out

if f'{self._tag}' not in gts:
return output

gt = gts[f'{self._tag}']

r""" Class weighted loss """
loss = torch.mean(torch.stack(tuple(
F.binary_cross_entropy(out[gt == i], gt[gt == i]) for i in gt.unique()
)))

output['loss'] = loss

return output

@property
def target_conv_layers(self) -> List[nn.Module]:
return [
self.Mixed_7c.branch1x1,
self.Mixed_7c.branch3x3_2a, self.Mixed_7c.branch3x3_2b,
self.Mixed_7c.branch3x3dbl_3a, self.Mixed_7c.branch3x3dbl_3b,
self.Mixed_7c.branch_pool
]

@property
def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]:
return ['x']

def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor:
p = self.forward(*inputs, **kwargs)['positive_class_probability']
return torch.stack([1 - p, p], dim=1)

+ 0
- 0
torchlap/models/imagenet/__init__.py View File


+ 34
- 0
torchlap/models/imagenet/lap_resnet.py View File

@@ -0,0 +1,34 @@
from typing import Dict, List
import torch

from torchvision import transforms

from ...modules import LAP, AdaptiveLAP, ScaledSigmoid
from ..lap_resnet import LapBottleneck, LapResNet, PoolFactory


def pool_factory(channels, sigmoid_scale=1.0):
return LAP(channels,
hidden_channels=[1000],
sigmoid_scale=sigmoid_scale,
hidden_activation=ScaledSigmoid.get_factory(sigmoid_scale))


class ImagenetLAPResNet50(LapResNet):

def __init__(self, pool_factory: PoolFactory = pool_factory, sigmoid_scale: float = 1.0):
super().__init__(LapBottleneck, [3, 4, 6, 3],
pool_factory=pool_factory,
sigmoid_scale=sigmoid_scale,
lap_positions=[4])
self.normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

def forward(self, x: torch.Tensor, y: torch.Tensor) -> Dict[str, torch.Tensor]:
x = self.normalize(x)
return super().forward(x, y)

@property
def attention_layers(self) -> Dict[str, List[LAP]]:
return {
'2_layer4': super().attention_layers['2_layer4'],
}

+ 556
- 0
torchlap/models/lap_inception.py View File

@@ -0,0 +1,556 @@
from typing import List, Optional, Callable, Any, Iterable, Dict

import torch
from torch import Tensor, nn
import torchvision

from ..modules.lap import LAP
from ..interpreting.attention_interpreter import AttentionInterpretableModel
from ..interpreting.interpretable import CamInterpretableModel
from ..interpreting.relcam.relprop import RPProvider, RelProp
from ..interpreting.relcam import modules as M


class InceptionB(nn.Module):

def __init__(
self,
in_channels: int,
pool_factory,
conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
super(InceptionB, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
#self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=1, padding=1)

self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
#self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=1, padding=1)

self.pool = pool_factory(384 + 96 + in_channels)

self.cat = M.Cat()

def _forward(self, x: Tensor) -> List[Tensor]:

branch3x3 = self.branch3x3(x)

branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

outputs = [branch3x3, branch3x3dbl, x]
return outputs

def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x)
return self.pool(self.cat(outputs, 1))


@RPProvider.register(InceptionB)
class InceptionBRelProp(RelProp[InceptionB]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
R = RPProvider.get(self.module.pool)(R, alpha=alpha)
branch3x3, branch3x3dbl, x1 = RPProvider.get(self.module.cat)(R, alpha=alpha)

branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_3)(branch3x3dbl, alpha=alpha)
branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha)
x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha)

x3 = RPProvider.get(self.module.branch3x3)(branch3x3, alpha=alpha)

return x1 + x2 + x3


class InceptionD(nn.Module):

def __init__(
self,
in_channels: int,
pool_factory,
conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
super(InceptionD, self).__init__()
if conv_block is None:
conv_block = BasicConv2d

self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
#self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=1, padding=1)
#self.branch3x3_2_stride = get_pooler(pooler_cls, 320, 2, pooler_hidden_layers)

self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
#self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=1, padding=1)
#self.branch7x7x3_4_stride = get_pooler(pooler_cls, 192, 2, pooler_hidden_layers)

self.pool = pool_factory(320 + 192 + in_channels)

self.cat = M.Cat()

def _forward(self, x: Tensor) -> List[Tensor]:

branch3x3 = self.branch3x3_1(x)
branch3x3 = self.branch3x3_2(branch3x3)

branch7x7x3 = self.branch7x7x3_1(x)
branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

outputs = [branch3x3, branch7x7x3, x]
return outputs

def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x)
return self.pool(self.cat(outputs, 1))


@RPProvider.register(InceptionD)
class InceptionDRelProp(RelProp[InceptionD]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
R = RPProvider.get(self.module.pool)(R, alpha=alpha)
branch3x3, branch7x7x3, x1 = RPProvider.get(self.module.cat)(R, alpha=alpha)

branch7x7x3 = RPProvider.get(self.module.branch7x7x3_4)(branch7x7x3, alpha=alpha)
branch7x7x3 = RPProvider.get(self.module.branch7x7x3_3)(branch7x7x3, alpha=alpha)
branch7x7x3 = RPProvider.get(self.module.branch7x7x3_2)(branch7x7x3, alpha=alpha)
x2 = RPProvider.get(self.module.branch7x7x3_1)(branch7x7x3, alpha=alpha)

branch3x3 = RPProvider.get(self.module.branch3x3_2)(branch3x3, alpha=alpha)
x3 = RPProvider.get(self.module.branch3x3_1)(branch3x3, alpha=alpha)

return x1 + x2 + x3


def inception_b_maker(pool_factory):
return (
lambda in_channels, conv_block=None:
InceptionB(in_channels, pool_factory, conv_block))


def inception_d_maker(pool_factory):
return (
lambda in_channels, conv_block=None:
InceptionD(in_channels, pool_factory, conv_block))


class InceptionA(nn.Module):

def __init__(
self,
in_channels: int,
pool_features: int,
conv_block: Optional[Callable[..., nn.Module]] = None,
) -> None:
super(InceptionA, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)

self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)

self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)

self.avg_pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
self.cat = M.Cat()

def _forward(self, x: Tensor) -> List[Tensor]:

branch1x1 = self.branch1x1(x)

branch5x5 = self.branch5x5_1(x)
branch5x5 = self.branch5x5_2(branch5x5)

branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

branch_pool = self.avg_pool(x)
branch_pool = self.branch_pool(branch_pool)

outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
return outputs

def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x)
return self.cat(outputs, 1)


@RPProvider.register(InceptionA)
class InceptionARelProp(RelProp[InceptionA]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
branch1x1, branch5x5, branch3x3dbl, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha)

branch_pool = RPProvider.get(self.module.branch_pool)(branch_pool, alpha=alpha)
x1 = RPProvider.get(self.module.avg_pool)(branch_pool, alpha=alpha)

branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_3)(branch3x3dbl, alpha=alpha)
branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha)
x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha)

branch5x5 = RPProvider.get(self.module.branch5x5_2)(branch5x5, alpha=alpha)
x3 = RPProvider.get(self.module.branch5x5_1)(branch5x5, alpha=alpha)

x4 = RPProvider.get(self.module.branch1x1)(branch1x1, alpha=alpha)

return x1 + x2 + x3 + x4


class InceptionC(nn.Module):

def __init__(
self,
in_channels: int,
channels_7x7: int,
conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
super(InceptionC, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)

c7 = channels_7x7
self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))

self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))

self.avg_pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)

self.cat = M.Cat()

def _forward(self, x: Tensor) -> List[Tensor]:

branch1x1 = self.branch1x1(x)

branch7x7 = self.branch7x7_1(x)
branch7x7 = self.branch7x7_2(branch7x7)
branch7x7 = self.branch7x7_3(branch7x7)

branch7x7dbl = self.branch7x7dbl_1(x)
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

branch_pool = self.avg_pool(x)
branch_pool = self.branch_pool(branch_pool)

outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
return outputs

def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x)
return self.cat(outputs, 1)


@RPProvider.register(InceptionC)
class InceptionCRelProp(RelProp[InceptionC]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
branch1x1, branch7x7, branch7x7dbl, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha)

branch_pool = RPProvider.get(self.module.branch_pool)(branch_pool, alpha=alpha)
x1 = RPProvider.get(self.module.avg_pool)(branch_pool, alpha=alpha)

branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_5)(branch7x7dbl, alpha=alpha)
branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_4)(branch7x7dbl, alpha=alpha)
branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_3)(branch7x7dbl, alpha=alpha)
branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_2)(branch7x7dbl, alpha=alpha)
x2 = RPProvider.get(self.module.branch7x7dbl_1)(branch7x7dbl, alpha=alpha)

branch7x7 = RPProvider.get(self.module.branch7x7_3)(branch7x7, alpha=alpha)
branch7x7 = RPProvider.get(self.module.branch7x7_2)(branch7x7, alpha=alpha)
x3 = RPProvider.get(self.module.branch7x7_1)(branch7x7, alpha=alpha)

x4 = RPProvider.get(self.module.branch1x1)(branch1x1, alpha=alpha)

return x1 + x2 + x3 + x4


class InceptionE(nn.Module):

def __init__(
self,
in_channels: int,
conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
super(InceptionE, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)

self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))

self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))

self.avg_pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)

self.cat1 = M.Cat()
self.cat2 = M.Cat()
self.cat3 = M.Cat()

def _forward(self, x: Tensor) -> List[Tensor]:

branch1x1 = self.branch1x1(x)

branch3x3 = self.branch3x3_1(x)
branch3x3 = [
self.branch3x3_2a(branch3x3),
self.branch3x3_2b(branch3x3),
]
branch3x3 = self.cat1(branch3x3, 1)

branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = [
self.branch3x3dbl_3a(branch3x3dbl),
self.branch3x3dbl_3b(branch3x3dbl),
]
branch3x3dbl = self.cat2(branch3x3dbl, 1)

branch_pool = self.avg_pool(x)
branch_pool = self.branch_pool(branch_pool)

outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
return outputs

def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x)
return self.cat3(outputs, 1)


@RPProvider.register(InceptionE)
class InceptionERelProp(RelProp[InceptionE]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
branch1x1, branch3x3, branch3x3dbl, branch_pool = RPProvider.get(self.module.cat3)(R, alpha=alpha)

branch_pool = RPProvider.get(self.module.branch_pool)(branch_pool, alpha=alpha)
x1 = RPProvider.get(self.module.avg_pool)(branch_pool, alpha=alpha)

branch3x3dbl_3a, branch3x3dbl_3b = RPProvider.get(self.module.cat2)(branch3x3dbl, alpha=alpha)
branch3x3dbl_1 = RPProvider.get(self.module.branch3x3dbl_3a)(branch3x3dbl_3a, alpha=alpha)
branch3x3dbl_2 = RPProvider.get(self.module.branch3x3dbl_3b)(branch3x3dbl_3b, alpha=alpha)
branch3x3dbl = branch3x3dbl_1 + branch3x3dbl_2
branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha)
x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha)

branch3x3_2a, branch3x3_2b = RPProvider.get(self.module.cat1)(branch3x3, alpha=alpha)
branch3x3_1 = RPProvider.get(self.module.branch3x3_2a)(branch3x3_2a, alpha=alpha)
branch3x3_2 = RPProvider.get(self.module.branch3x3_2b)(branch3x3_2b, alpha=alpha)
branch3x3 = branch3x3_1 + branch3x3_2
x3 = RPProvider.get(self.module.branch3x3_1)(branch3x3, alpha=alpha)

x4 = RPProvider.get(self.module.branch1x1)(branch1x1, alpha=alpha)

return x1 + x2 + x3 + x4


class InceptionAux(nn.Module):

def __init__(
self,
in_channels: int,
num_classes: int,
conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
super(InceptionAux, self).__init__()
if conv_block is None:
conv_block = BasicConv2d

self.avgpool1 = nn.AvgPool2d(kernel_size=5, stride=3)
self.conv0 = conv_block(in_channels, 128, kernel_size=1)
self.conv1 = conv_block(128, 768, kernel_size=5, padding=2)
self.conv1.stddev = 0.01 # type: ignore[assignment]
self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1))
self.flatten = nn.Flatten()
self.fc = nn.Linear(768, num_classes)
self.fc.stddev = 0.001 # type: ignore[assignment]

def forward(self, x: Tensor) -> Tensor:
# N x 768 x 17 x 17
x = self.avgpool1(x)
# N x 768 x 5 x 5
x = self.conv0(x)
# N x 128 x 5 x 5
x = self.conv1(x)
# N x 768 x 1 x 1
# Adaptive average pooling
x = self.avgpool2(x)
# N x 768 x 1 x 1
x = self.flatten(x)
# N x 768
x = self.fc(x)
# N x 1000
return x


@RPProvider.register(InceptionAux)
class InceptionAuxRelProp(RelProp[InceptionAux]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
R = RPProvider.get(self.module.fc)(R, alpha=alpha)
R = RPProvider.get(self.module.flatten)(R, alpha=alpha)
R = RPProvider.get(self.module.avgpool2)(R, alpha=alpha)
R = RPProvider.get(self.module.conv1)(R, alpha=alpha)
R = RPProvider.get(self.module.conv0)(R, alpha=alpha)
R = RPProvider.get(self.module.avgpool1)(R, alpha=alpha)
return R


class BasicConv2d(nn.Module):

def __init__(
self,
in_channels: int,
out_channels: int,
**kwargs: Any
) -> None:
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
self.act = torch.nn.ReLU()

def forward(self, x: Tensor) -> Tensor:
x = self.conv(x)
x = self.bn(x)
return self.act(x)


@RPProvider.register(BasicConv2d)
class BasicConv2dRelProp(RelProp[BasicConv2d]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
R = RPProvider.get(self.module.act)(R, alpha=alpha)
R = RPProvider.get(self.module.bn)(R, alpha=alpha)
R = RPProvider.get(self.module.conv)(R, alpha=alpha)
return R


class LAPInception(AttentionInterpretableModel, CamInterpretableModel, torchvision.models.Inception3):
def __init__(self, aux_weight: float, n_classes, pool_factory, adaptive_pool_factory):
torchvision.models.Inception3.__init__(
self,
transform_input=False, init_weights=False,
inception_blocks=[
BasicConv2d, InceptionA, inception_b_maker(pool_factory),
InceptionC, inception_d_maker(pool_factory), InceptionE, InceptionAux
])

self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2, padding=1)
self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3, padding=1)
self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
self.maxpool1 = pool_factory(64)
self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1,)
self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3, padding=1)
self.maxpool2 = pool_factory(192)

if adaptive_pool_factory is not None:
self.avgpool = nn.Sequential(
adaptive_pool_factory(2048))

self.fc = nn.Sequential(
nn.Linear(2048, 1),
nn.Sigmoid()
)
self.AuxLogits.fc = nn.Sequential(
nn.Linear(768, 1),
nn.Sigmoid()
)
self.aux_weight = aux_weight
@property
def target_conv_layers(self) -> List[nn.Module]:
return [
self.Mixed_7c.branch1x1,
self.Mixed_7c.branch3x3_2a, self.Mixed_7c.branch3x3_2b,
self.Mixed_7c.branch3x3dbl_3a, self.Mixed_7c.branch3x3dbl_3b,
self.Mixed_7c.branch_pool
]

@property
def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]:
return ['x']

def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor:
p = self.forward(*inputs, **kwargs)['positive_class_probability']
return torch.stack([1 - p, p], dim=1)

@property
def attention_layers(self) -> Dict[str, List[LAP]]:
attention_groups = {
'0_layer2': [self.maxpool2],
'1_layer6': [self.Mixed_6a.pool],
'2_layer7': [self.Mixed_7a.pool],
'3_avgpool': [self.avgpool[0]],
'4_all': [
self.maxpool2,
self.Mixed_6a.pool,
self.Mixed_7a.pool,
self.avgpool[0]],
}

return attention_groups


@RPProvider.register(LAPInception)
class Inception3Mo4RelProp(RelProp[LAPInception]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
if RPProvider.get(self.module.fc).Y.shape[1] == 1:
R = R[:, -1:]
R = RPProvider.get(self.module.fc)(R, alpha=alpha) # B 2048
R = R.reshape_as(RPProvider.get(self.module.dropout).Y) # B 2048 1 1
R = RPProvider.get(self.module.dropout)(R, alpha=alpha) # B 2048 1 1
R = RPProvider.get(self.module.avgpool)(R, alpha=alpha) # B 2048 8 8

R = RPProvider.get(self.module.Mixed_7c)(R, alpha=alpha) # B 2048 8 8
R = RPProvider.get(self.module.Mixed_7b)(R, alpha=alpha) # B 1280 8 8
R = RPProvider.get(self.module.Mixed_7a)(R, alpha=alpha) # B 768 17 17

R = RPProvider.get(self.module.Mixed_6e)(R, alpha=alpha) # B 768 17 17
R = RPProvider.get(self.module.Mixed_6d)(R, alpha=alpha) # B 768 17 17
R = RPProvider.get(self.module.Mixed_6c)(R, alpha=alpha) # B 768 17 17
R = RPProvider.get(self.module.Mixed_6b)(R, alpha=alpha) # B 768 17 17
R = RPProvider.get(self.module.Mixed_6a)(R, alpha=alpha) # B 288 35 35

R = RPProvider.get(self.module.Mixed_5d)(R, alpha=alpha) # B 288 35 35
R = RPProvider.get(self.module.Mixed_5c)(R, alpha=alpha) # B 256 35 35
R = RPProvider.get(self.module.Mixed_5b)(R, alpha=alpha) # B 192 35 35

R = RPProvider.get(self.module.maxpool2)(R, alpha=alpha) # B 192 71 71
R = RPProvider.get(self.module.Conv2d_4a_3x3)(R, alpha=alpha) # B 80 73 73
R = RPProvider.get(self.module.Conv2d_3b_1x1)(R, alpha=alpha) # B 64 73 73

R = RPProvider.get(self.module.maxpool1)(R, alpha=alpha) # B 64 147 147
R = RPProvider.get(self.module.Conv2d_2b_3x3)(R, alpha=alpha) # B 32 147 147
R = RPProvider.get(self.module.Conv2d_2a_3x3)(R, alpha=alpha) # B 32 149 149
R = RPProvider.get(self.module.Conv2d_1a_3x3)(R, alpha=alpha) # B 3 299 299
return R

+ 276
- 0
torchlap/models/lap_resnet.py View File

@@ -0,0 +1,276 @@
from typing import Callable, Dict, Generic, List, Optional, Type, TypeVar
from typing_extensions import Protocol
import warnings

import torch
from torch import nn

from ..modules import LAP, AdaptiveLAP
from .tv_resnet import BasicBlock, Bottleneck, ResNet
from ..interpreting.relcam.relprop import RPProvider, RelProp, RelPropSimple
from ..interpreting.relcam import modules as M
from ..interpreting.attention_interpreter import AttentionInterpretableModel


class PoolFactory(Protocol):
def __call__(self, channels: int, sigmoid_scale: float = 1.0) -> LAP: ...


lap_factory: PoolFactory = \
lambda channels, sigmoid_scale=1.0: \
LAP(channels, sigmoid_scale=sigmoid_scale)


adaptive_lap_factory: PoolFactory = \
lambda channels, sigmoid_scale=1.0: \
AdaptiveLAP(channels, sigmoid_scale=sigmoid_scale)


class LapBasicBlock(BasicBlock):

def __init__(self, pool_factory: PoolFactory, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None, sigmoid_scale: float = 1.0):

super().__init__(inplanes, planes, stride=stride, downsample=downsample,
groups=groups, base_width=base_width, dilation=dilation,
norm_layer=norm_layer)

self.pool = None
if stride != 1:
assert downsample is not None

self.conv1 = nn.Conv2d(inplanes, planes, 3, padding=1, bias=False)
self.downsample[0] = nn.Conv2d(inplanes, planes, 1, bias=False)
self.pool = pool_factory(planes * 2, sigmoid_scale=sigmoid_scale)
self.relu3 = nn.ReLU()
self.cat = M.Cat()

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.conv1(x) # B P H W
out = self.bn1(out) # B P H W
out = self.relu(out) # B P H W

if self.downsample is not None:
x = self.downsample(x) # B P H W
x = self.relu2(x)

if self.pool is not None:
poolin = self.cat([out, x], 1) # B 2P H W
poolout: torch.Tensor = self.pool(poolin) # B 2P H/S W/S
out, x = poolout.chunk(2, dim=1) # B P H/S W/S

out = self.conv2(out) # B P H/S W/S
out = self.bn2(out) # B P H/S W/S

out = self.add([out, x]) # B P H/S W/S
out = self.relu3(out) # B P H/S W/S

return out


@RPProvider.register(LapBasicBlock)
class BlockRelProp(RelProp[LapBasicBlock]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
out = RPProvider.get(self.module.relu3)(R, alpha=alpha)
out, x = RPProvider.get(self.module.add)(out, alpha=alpha)

out = RPProvider.get(self.module.bn2)(out, alpha=alpha)
out = RPProvider.get(self.module.conv2)(out, alpha=alpha)

if self.module.pool is not None:
poolout = torch.cat([out, x], dim=1)
poolin = RPProvider.get(self.module.pool)(poolout, alpha=alpha)
out, x = RPProvider.get(self.module.cat)(poolin, alpha=alpha)

if self.module.downsample is not None:
x = RPProvider.get(self.module.relu2)(x, alpha=alpha)
x = RPProvider.get(self.module.downsample)(x, alpha=alpha)
out = RPProvider.get(self.module.relu)(out, alpha=alpha)
out = RPProvider.get(self.module.bn1)(out, alpha=alpha)
x1 = RPProvider.get(self.module.conv1)(out, alpha=alpha)

return x + x1


class LapBottleneck(Bottleneck):

def __init__(self, pool_factory: PoolFactory,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
sigmoid_scale: float = 1.0) -> None:
super().__init__(inplanes, planes, stride=stride, downsample=downsample,
groups=groups, base_width=base_width, dilation=dilation,
norm_layer=norm_layer)

self.pool = None
if stride != 1:
assert downsample is not None
width = int(planes * (base_width / 64.)) * groups
self.conv2 = nn.Conv2d(width, width, 3, padding=1, bias=False)
self.downsample[0] = nn.Conv2d(inplanes, planes * self.expansion, 1, bias=False)
self.pool = pool_factory(planes * (self.expansion + 1), sigmoid_scale=sigmoid_scale)
self.cat = M.Cat()

def forward(self, x: torch.Tensor) -> torch.Tensor:
# x # B I H W

out = self.conv1(x) # B P H W
out = self.bn1(out) # B P H W
out = self.relu(out) # B P H W

out = self.conv2(out) # B P H W
out = self.bn2(out) # B P H W
out = self.relu2(out) # B P H W

if self.downsample is not None:
x = self.downsample(x) # B 4P H W

if self.pool is not None:
poolin = self.cat([out, x], 1) # B 5P H W
poolout: torch.Tensor = self.pool(poolin) # B 5P H/S W/S
out, x = poolout.split( # B P H/S W/S
[out.shape[1], x.shape[1]], # B 4P H/S W/S
dim=1)

out = self.conv3(out) # B 4P H/S W/S
out = self.bn3(out) # B 4P H/S W/S

out = self.add([out, x]) # B P H/S W/S
out = self.relu3(out) # B P H/S W/S

return out


@RPProvider.register(LapBottleneck)
class BlockRP(RelProp[LapBottleneck]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
out = RPProvider.get(self.module.relu3)(R, alpha=alpha)
out, x = RPProvider.get(self.module.add)(out, alpha=alpha)

out = RPProvider.get(self.module.bn3)(out, alpha=alpha)
out = RPProvider.get(self.module.conv3)(out, alpha=alpha)

if self.module.pool is not None:
poolout = torch.cat([out, x], dim=1)
poolin = RPProvider.get(self.module.pool)(poolout, alpha=alpha)
out, x = RPProvider.get(self.module.cat)(poolin, alpha=alpha)

if self.module.downsample is not None:
x = RPProvider.get(self.module.downsample)(x, alpha=alpha)

out = RPProvider.get(self.module.relu2)(out, alpha=alpha)
out = RPProvider.get(self.module.bn2)(out, alpha=alpha)
out = RPProvider.get(self.module.conv2)(out, alpha=alpha)

out = RPProvider.get(self.module.relu)(out, alpha=alpha)
out = RPProvider.get(self.module.bn1)(out, alpha=alpha)
x1 = RPProvider.get(self.module.conv1)(out, alpha=alpha)

return x + x1


TLapBlock = TypeVar('TLapBlock', LapBasicBlock, LapBottleneck)


class BlockFactory(Generic[TLapBlock]):

def __init__(self, block: Type[TLapBlock], pool_factory: PoolFactory, sigmoid_scale: float = 1.0) -> None:
self.expansion = block.expansion
self._block = block
self._pool_factory = pool_factory
self._sigmoid_scale = sigmoid_scale

def __call__(self, *args, **kwargs) -> TLapBlock:
return self._block(self._pool_factory, *args, **kwargs, sigmoid_scale=self._sigmoid_scale)


class Stride(nn.Module):

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x[:, :, ::2, ::2]


@RPProvider.register(Stride)
class StrideRelProp(RelPropSimple[Stride]):
pass


class LapResNet(ResNet, AttentionInterpretableModel):

def __init__(self,
block: Type[TLapBlock],
layers: List[int],
pool_factory: PoolFactory = lap_factory,
lap_positions: List[int] = [2, 3, 4],
adaptive_factory: PoolFactory = None,
sigmoid_scale: float = 1.0,
binary: bool = False):
super().__init__(BlockFactory(block, pool_factory, sigmoid_scale=sigmoid_scale), layers, binary=binary)
if adaptive_factory is not None:
self.avgpool = nn.Sequential(
adaptive_factory(512, sigmoid_scale=sigmoid_scale),
)
for i in range(2, 5):
if i not in lap_positions:
warnings.warn(f'Putting stride on layer {i}')
getattr(self, 'layer{}'.format(i))[0].pool = Stride()

@property
def attention_layers(self) -> Dict[str, List[LAP]]:
"""
List of attention groups
"""
attention_groups = {
'0_layer2': [self.layer2[0].pool],
'1_layer3': [self.layer3[0].pool],
'2_layer4': [self.layer4[0].pool],
'3_avgpool': [self.avgpool[0]],
'4_overall': [
self.layer2[0].pool,
self.layer3[0].pool,
self.layer4[0].pool,
self.avgpool[0],
]
}

assert all(
all(
isinstance(layer, LAP)
for layer in attention_group)
for attention_group in attention_groups.values()), \
"Only LAP is supported for this interpretation method"

return attention_groups


def lap_resnet18(pool_factory: PoolFactory = lap_factory,
adaptive_factory: PoolFactory = None,
sigmoid_scale: float = 1.0,
lap_positions: List[int] = [2, 3, 4],
binary: bool = False) -> LapResNet:
"""Constructs a LAP-ResNet-18 model.
"""
return LapResNet(LapBasicBlock, [2, 2, 2, 2], pool_factory=pool_factory,
adaptive_factory=adaptive_factory, sigmoid_scale=sigmoid_scale,
lap_positions=lap_positions, binary=binary)


def lap_resnet50(pool_factory: PoolFactory = lap_factory,
adaptive_factory: PoolFactory = None,
sigmoid_scale: float = 1.0,
lap_positions: List[int] = [2, 3, 4],
binary: bool = False) -> LapResNet:
"""Constructs a LAP-ResNet-50 model.
"""
return LapResNet(LapBottleneck, [3, 4, 6, 3], pool_factory=pool_factory,
adaptive_factory=adaptive_factory, sigmoid_scale=sigmoid_scale,
lap_positions=lap_positions, binary=binary)

+ 147
- 0
torchlap/models/model.py View File

@@ -0,0 +1,147 @@
"""
The Model base class
"""
from collections import OrderedDict
import inspect
import typing
from typing import Any, Dict, List, Tuple
import re
from abc import ABC

import torch
from torch.nn import BatchNorm2d, BatchNorm1d, BatchNorm3d


ModelIO = Dict[str, torch.Tensor]


class Model(torch.nn.Module, ABC):
"""
The Model base class
"""
_frozen_bns_list: List[torch.nn.Module] = []

def init_weights_from_other_model(self, pretrained_model_dir=None):
""" Initializes the weights from another model with the given address,
Default is to set all the parameters with the same name and shape,
but can be rewritten in subclasses.
The model to load is received via the function get_other_model_to_load_from.
The default value is the same model!"""

if pretrained_model_dir is None:
print('The model was not preinitialized.', flush=True)
return
other_model_dict = torch.load(pretrained_model_dir)
own_state = self.state_dict()

cnt = 0
skip_cnt = 0

for name, param in other_model_dict.items():

if isinstance(param, torch.nn.Parameter):
# backwards compatibility for serialized parameters
param = param.data

if name in own_state and own_state[name].data.shape == param.shape:
own_state[name].copy_(param)
cnt += 1

else:
skip_cnt += 1

print(f'{cnt} out of {len(own_state)} parameters were loaded successfully, {skip_cnt} of the given set were skipped.' , flush=True)

def freeze_parameters(self, regex_list: List[str] = None) -> None:
"""
An auxiliary function for freezing the parameters of the model
using an inclusion keywords list and exclusion keywords list.
:param regex_list: All the parameters matching at least one of the given regexes would be freezed.
None means no parameter would be frozen.
:return: None
"""

# One solution for batchnorms is registering a preforward hook
# to call eval everytime before running them
# But that causes a problem, we can't dynamically change the frozen batchnorms during training
# e.g. in layerwise unfreezing.
# So it's better to set a dynamic attribute that keeps their list, rather than having an init
# Init causes other problems in multi-inheritance

if regex_list is None:
print('No parameter was freezeed!', flush=True)
return

regex_list = [re.compile(the_pattern) for the_pattern in regex_list]

def check_freeze_conditions(param_name):
for the_regex in regex_list:
if the_regex.fullmatch(param_name) is not None:
return True
return False

frozen_cnt = 0
total_cnt = 0

frozen_modules_names = dict()

for n, p in self.named_parameters():
total_cnt += 1
if check_freeze_conditions(n):
p.requires_grad = False
frozen_cnt += 1
frozen_modules_names[n[:n.rfind('.')]] = True

self._frozen_bns_list = []

for module_name, module in self.named_modules():
if (module_name in frozen_modules_names) and \
(isinstance(module, BatchNorm2d) or
isinstance(module, BatchNorm1d) or
isinstance(module, BatchNorm3d)):
self._frozen_bns_list.append(module)

print('********************* NOTICE *********************')
print('%d out of %d parameters were frozen!' % (frozen_cnt, total_cnt))
print(
'%d BatchNorm layers will be frozen by being kept in the val mode, IF YOU HAVE NOT OVERRIDEN IT IN YOUR MAIN MODEL!' % len(
self._frozen_bns_list))
print('***************************************************', flush=True)

def train(self, mode: bool = True):
super(Model, self).train(mode)
for bn in self._frozen_bns_list:
bn.train(False)
return self

@property
def additional_kwargs(self) -> typing.OrderedDict[str, bool]:
r""" Returns a dictionary from additional `kwargs` names to their optionality """
return OrderedDict({})

def get_forward_required_kws_kwargs_and_defaults(self) -> Tuple[List[str], List[Any], Dict[str, Any]]:
"""Returns the names of the keyword arguments required in the forward function and the list of the default values for the ones having them.
The second list might be shorter than the first and it represents the default values from the end of the arguments.

Returns:
Tuple[List[str], List[Any], Dict[str, Any]]: The list of args names and default values for the ones having them (from some point to the end) + the dictionary of kwargs
"""

model_argspec = inspect.getfullargspec(self.forward)

model_forward_args = list(model_argspec.args)
# skipping self arg
model_forward_args = model_forward_args[1:]

args_default_values = model_argspec.defaults
if args_default_values is None:
args_default_values = []
else:
args_default_values = list(args_default_values)

additional_kwargs = self.additional_kwargs if model_argspec.varkw is not None else {}

return model_forward_args, args_default_values, additional_kwargs

+ 0
- 0
torchlap/models/rsna/__init__.py View File


+ 40
- 0
torchlap/models/rsna/lap_inception.py View File

@@ -0,0 +1,40 @@
from typing import Dict

import torch
from torch.nn import functional as F
import torchvision

from ..lap_inception import LAPInception


class RSNALAPInception(LAPInception):

def __init__(self, aux_weight: float, pool_factory, adaptive_pool_factory):
super().__init__(aux_weight, n_classes=1, pool_factory=pool_factory, adaptive_pool_factory=adaptive_pool_factory)
def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]:

x = x.repeat_interleave(3, dim=1) # B 3 224 224
if self.training:
out, aux = torchvision.models.Inception3.forward(self, x) # B 1
out, aux = out.flatten(), aux.flatten() # B
else:
out = torchvision.models.Inception3.forward(self, x).flatten() # B
aux = None

if y is not None:
main_loss = F.binary_cross_entropy(out, y)
if aux is not None:
aux_loss = F.binary_cross_entropy(aux, y)
loss = (main_loss + self.aux_weight * aux_loss) / (1 + self.aux_weight)
else:
loss = main_loss

return {
'positive_class_probability': out,
'loss': loss
}

return {
'positive_class_probability': out
}

+ 39
- 0
torchlap/models/rsna/lap_resnet.py View File

@@ -0,0 +1,39 @@
from typing import Dict

import torch

from ...modules import LAP, AdaptiveLAP
from ..lap_resnet import LapBasicBlock, LapResNet, PoolFactory


def get_pool_factory(discriminative_attention=True):
def pool_factory(channels, sigmoid_scale=1.0):
return LAP(channels,
hidden_channels=[8],
sigmoid_scale=sigmoid_scale,
discriminative_attention=discriminative_attention)
return pool_factory


def get_adaptive_pool_factory(discriminative_attention=True):
def adaptive_pool_factory(channels, sigmoid_scale=1.0):
return AdaptiveLAP(channels,
sigmoid_scale=sigmoid_scale,
discriminative_attention=discriminative_attention)
return adaptive_pool_factory


class RSNALAPResNet18(LapResNet):

def __init__(self, pool_factory: PoolFactory = get_pool_factory(),
adaptive_factory: PoolFactory = get_adaptive_pool_factory(),
sigmoid_scale: float = 1.0):
super().__init__(LapBasicBlock, [2, 2, 2, 2],
pool_factory=pool_factory,
adaptive_factory=adaptive_factory,
sigmoid_scale=sigmoid_scale,
binary=True)
def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]:
x = x.repeat_interleave(3, dim=1) # B 3 224 224
return super().forward(x, y)

+ 39
- 0
torchlap/models/rsna/org_inception.py View File

@@ -0,0 +1,39 @@
from typing import Dict

import torch
from torch.nn import functional as F
import torchvision

from ..tv_inception import Inception3


class RSNAORGInception(Inception3):

def __init__(self, aux_weight: float):
super().__init__(aux_weight, n_classes=1)
def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]:
x = x.repeat_interleave(3, dim=1) # B 3 224 224
if self.training:
out, aux = torchvision.models.Inception3.forward(self, x) # B 1
out, aux = out.flatten(), aux.flatten() # B
else:
out = torchvision.models.Inception3.forward(self, x).flatten() # B
aux = None

if y is not None:
main_loss = F.binary_cross_entropy(out, y)
if aux is not None:
aux_loss = F.binary_cross_entropy(aux, y)
loss = (main_loss + self.aux_weight * aux_loss) / (1 + self.aux_weight)
else:
loss = main_loss

return {
'positive_class_probability': out,
'loss': loss
}

return {
'positive_class_probability': out
}

+ 15
- 0
torchlap/models/rsna/org_resnet.py View File

@@ -0,0 +1,15 @@
from typing import Dict

import torch

from ..tv_resnet import BasicBlock, ResNet


class RSNAORGResNet18(ResNet):

def __init__(self):
super().__init__(BasicBlock, [2, 2, 2, 2], binary=True)
def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]:
x = x.repeat_interleave(3, dim=1) # B 3 224 224
return super().forward(x, y)

+ 218
- 0
torchlap/models/tv_inception.py View File

@@ -0,0 +1,218 @@
from typing import Iterable, Optional, Callable, List

import torch
from torch import nn, Tensor
import torchvision

from ..interpreting.interpretable import CamInterpretableModel
from ..interpreting.relcam.relprop import RPProvider, RelProp
from ..interpreting.relcam import modules as M
from .lap_inception import (
InceptionA,
InceptionC,
InceptionE,
BasicConv2d,
InceptionAux as BaseInceptionAux,
)


class InceptionB(nn.Module):

def __init__(
self,
in_channels: int,
conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
super(InceptionB, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)

self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)

self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
self.cat = M.Cat()

def _forward(self, x: Tensor) -> List[Tensor]:
branch3x3 = self.branch3x3(x)

branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

branch_pool = self.maxpool(x)

outputs = [branch3x3, branch3x3dbl, branch_pool]
return outputs

def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x)
return self.cat(outputs, 1)


@RPProvider.register(InceptionB)
class InceptionBRelProp(RelProp[InceptionB]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
branch3x3, branch3x3dbl, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha)

x1 = RPProvider.get(self.module.maxpool)(branch_pool, alpha=alpha)

branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_3)(branch3x3dbl, alpha=alpha)
branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha)
x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha)

x3 = RPProvider.get(self.module.branch3x3)(branch3x3, alpha=alpha)

return x1 + x2 + x3


class InceptionD(nn.Module):

def __init__(
self,
in_channels: int,
conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
super(InceptionD, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)

self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)

self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
self.cat = M.Cat()

def _forward(self, x: Tensor) -> List[Tensor]:
branch3x3 = self.branch3x3_1(x)
branch3x3 = self.branch3x3_2(branch3x3)

branch7x7x3 = self.branch7x7x3_1(x)
branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

branch_pool = self.maxpool(x)
outputs = [branch3x3, branch7x7x3, branch_pool]
return outputs

def forward(self, x: Tensor) -> Tensor:
outputs = self._forward(x)
return self.cat(outputs, 1)


@RPProvider.register(InceptionD)
class InceptionDRelProp(RelProp[InceptionD]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
branch3x3, branch7x7x3, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha)

x1 = RPProvider.get(self.module.maxpool)(branch_pool, alpha=alpha)

branch7x7x3 = RPProvider.get(self.module.branch7x7x3_4)(branch7x7x3, alpha=alpha)
branch7x7x3 = RPProvider.get(self.module.branch7x7x3_3)(branch7x7x3, alpha=alpha)
branch7x7x3 = RPProvider.get(self.module.branch7x7x3_2)(branch7x7x3, alpha=alpha)
x2 = RPProvider.get(self.module.branch7x7x3_1)(branch7x7x3, alpha=alpha)

branch3x3 = RPProvider.get(self.module.branch3x3_2)(branch3x3, alpha=alpha)
x3 = RPProvider.get(self.module.branch3x3_1)(branch3x3, alpha=alpha)

return x1 + x2 + x3


class InceptionAux(BaseInceptionAux):

def __init__(
self,
in_channels: int,
num_classes: int,
conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
super().__init__(in_channels, num_classes, conv_block=conv_block)
if conv_block is None:
conv_block = BasicConv2d
self.conv1 = conv_block(128, 768, kernel_size=5)
self.conv1.stddev = 0.01 # type: ignore[assignment]


class Inception3(CamInterpretableModel, torchvision.models.Inception3):

def __init__(self, aux_weight: float, n_classes=1):
torchvision.models.Inception3.__init__(self,
transform_input=False, init_weights=False,
inception_blocks = [
BasicConv2d, InceptionA, InceptionB, InceptionC,
InceptionD, InceptionE, InceptionAux
])

self.fc = nn.Sequential(
nn.Linear(2048, n_classes),
nn.Sigmoid()
)

self.AuxLogits.fc = nn.Sequential(
nn.Linear(768, n_classes),
nn.Sigmoid()
)
self.aux_weight = aux_weight

@property
def target_conv_layers(self) -> List[nn.Module]:
return [
self.Mixed_7c.branch1x1,
self.Mixed_7c.branch3x3_2a, self.Mixed_7c.branch3x3_2b,
self.Mixed_7c.branch3x3dbl_3a, self.Mixed_7c.branch3x3dbl_3b,
self.Mixed_7c.branch_pool
]

@property
def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]:
return ['x']

def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor:
p = self.forward(*inputs, **kwargs)['positive_class_probability']
return torch.stack([1 - p, p], dim=1)



@RPProvider.register(Inception3)
class Inception3Mo4RelProp(RelProp[Inception3]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
if RPProvider.get(self.module.fc).Y.shape[1] == 1:
R = R[:, -1:]
R = RPProvider.get(self.module.fc)(R, alpha=alpha) # B 2048
R = R.reshape_as(RPProvider.get(self.module.dropout).Y) # B 2048 1 1
R = RPProvider.get(self.module.dropout)(R, alpha=alpha) # B 2048 1 1
R = RPProvider.get(self.module.avgpool)(R, alpha=alpha) # B 2048 8 8

R = RPProvider.get(self.module.Mixed_7c)(R, alpha=alpha) # B 2048 8 8
R = RPProvider.get(self.module.Mixed_7b)(R, alpha=alpha) # B 1280 8 8
R = RPProvider.get(self.module.Mixed_7a)(R, alpha=alpha) # B 768 17 17

R = RPProvider.get(self.module.Mixed_6e)(R, alpha=alpha) # B 768 17 17
R = RPProvider.get(self.module.Mixed_6d)(R, alpha=alpha) # B 768 17 17
R = RPProvider.get(self.module.Mixed_6c)(R, alpha=alpha) # B 768 17 17
R = RPProvider.get(self.module.Mixed_6b)(R, alpha=alpha) # B 768 17 17
R = RPProvider.get(self.module.Mixed_6a)(R, alpha=alpha) # B 288 35 35

R = RPProvider.get(self.module.Mixed_5d)(R, alpha=alpha) # B 288 35 35
R = RPProvider.get(self.module.Mixed_5c)(R, alpha=alpha) # B 256 35 35
R = RPProvider.get(self.module.Mixed_5b)(R, alpha=alpha) # B 192 35 35

R = RPProvider.get(self.module.maxpool2)(R, alpha=alpha) # B 192 71 71
R = RPProvider.get(self.module.Conv2d_4a_3x3)(R, alpha=alpha) # B 80 73 73
R = RPProvider.get(self.module.Conv2d_3b_1x1)(R, alpha=alpha) # B 64 73 73

R = RPProvider.get(self.module.maxpool1)(R, alpha=alpha) # B 64 147 147
R = RPProvider.get(self.module.Conv2d_2b_3x3)(R, alpha=alpha) # B 32 147 147
R = RPProvider.get(self.module.Conv2d_2a_3x3)(R, alpha=alpha) # B 32 149 149
R = RPProvider.get(self.module.Conv2d_1a_3x3)(R, alpha=alpha) # B 3 299 299
return R

+ 512
- 0
torchlap/models/tv_resnet.py View File

@@ -0,0 +1,512 @@
from typing import Dict, Iterable, Type, Any, Callable, Union, List, Optional

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F

from ..interpreting.interpretable import CamInterpretableModel
from ..interpreting.relcam.relprop import RPProvider, RelProp
from ..interpreting.relcam import modules as M


__all__ = [
"ResNet",
"resnet18",
"resnet34",
"resnet50",
"resnet101",
"resnet152",
"resnext50_32x4d",
"resnext101_32x8d",
"wide_resnet50_2",
"wide_resnet101_2",
]


def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation,
)


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
expansion: int = 1

def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU()
self.relu2 = nn.ReLU()
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride

self.add = M.Add()

def forward(self, x: Tensor) -> Tensor:
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)

if self.downsample is not None:
x = self.downsample(x)

out = self.add([out, x])
out = self.relu2(out)

return out


@RPProvider.register(BasicBlock)
class BasicBlockRelProp(RelProp[BasicBlock]):

def rel(self, R, alpha):
out = RPProvider.get(self.module.relu2)(R, alpha)
out, x = RPProvider.get(self.module.add)(out, alpha)

if self.module.downsample is not None:
x = RPProvider.get(self.module.downsample)(x, alpha)

out = RPProvider.get(self.module.bn2)(out, alpha)
out = RPProvider.get(self.module.conv2)(out, alpha)

out = RPProvider.get(self.module.relu)(out, alpha)
out = RPProvider.get(self.module.bn1)(out, alpha)
x1 = RPProvider.get(self.module.conv1)(out, alpha)

return x + x1


class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

expansion: int = 4

def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.0)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU()
self.relu2 = nn.ReLU()
self.relu3 = nn.ReLU()
self.downsample = downsample
self.stride = stride

self.add = M.Add()

def forward(self, x: Tensor) -> Tensor:
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu2(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
x = self.downsample(x)

out = self.add([out, x])
out = self.relu3(out)

return out


@RPProvider.register(Bottleneck)
class BottleneckRelProp(RelProp[Bottleneck]):

def rel(self, R, alpha):
out = RPProvider.get(self.module.relu3)(R, alpha)
out, x = RPProvider.get(self.module.add)(out, alpha)

if self.downsample is not None:
x = RPProvider.get(self.module.downsample)(x, alpha)

out = RPProvider.get(self.module.bn3)(out, alpha)
out = RPProvider.get(self.module.conv3)(out, alpha)

out = RPProvider.get(self.module.relu2)(out, alpha)
out = RPProvider.get(self.module.bn2)(out, alpha)
out = RPProvider.get(self.module.conv2)(out, alpha)

out = RPProvider.get(self.module.relu)(out, alpha)
out = RPProvider.get(self.module.bn1)(out, alpha)
x1 = RPProvider.get(self.module.conv1)(out, alpha)

return x + x1


class ResNet(CamInterpretableModel):
def __init__(
self,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
num_classes: int = 1000,
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
binary: bool = False,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer

self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
"replace_stride_with_dilation should be None "
f"or a 3-element tuple, got {replace_stride_with_dilation}"
)
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

self._binary = binary
if binary:
self.fc = nn.Sequential(
nn.Linear(512 * block.expansion, 1),
nn.Sigmoid(),
nn.Flatten(0)
)
else:
self.fc = nn.Linear(512 * block.expansion, num_classes)

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]

def _make_layer(
self,
block: Type[Union[BasicBlock, Bottleneck]],
planes: int,
blocks: int,
stride: int = 1,
dilate: bool = False,
) -> nn.Sequential:
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)

layers = []
layers.append(
block(
self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
)
)
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer,
)
)

return nn.Sequential(*layers)

def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)

return x

def forward(self, x: Tensor, y: Tensor) -> Dict[str, Tensor]:
p = self._forward_impl(x)
if self._binary:
return dict(
positive_class_probability=p,
loss=F.binary_cross_entropy(p, y),
)
return dict(
categorical_probability=p.softmax(dim=1),
loss=F.cross_entropy(p, y.long()),
)

@property
def target_conv_layers(self) -> List[nn.Module]:
"""
Returns:
The convolutional layers to be interpreted. The result of
the interpretation will be sum of the grad-cam of these layers
"""
return [self.layer4]

def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor:
"""
A method to get probabilities assigned to all the classes in the model's forward,
with shape (B, C), in which B is the batch size and C is number of classes

Args:
*inputs: Inputs to the model
**kwargs: Additional arguments

Returns:
Tensor of categorical probabilities
"""
if self._binary:
p = self.forward(*inputs, **kwargs)['positive_class_probability']
return torch.stack([1 - p, p], dim=1)
return self.forward(*inputs, **kwargs)['categorical_probability']

@property
def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]:
"""
Returns:
Input module for interpretation
"""
return ['x']


@RPProvider.register(ResNet)
class ResNetRelProp(RelProp[ResNet]):

def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
if RPProvider.get(self.module.fc).Y.ndim == 1:
R = R[:, -1]
R = RPProvider.get(self.module.fc)(R, alpha=alpha)
R = R.reshape_as(RPProvider.get(self.module.avgpool).Y)
R = RPProvider.get(self.module.avgpool)(R, alpha=alpha)

R = RPProvider.get(self.module.layer4)(R, alpha=alpha)
R = RPProvider.get(self.module.layer3)(R, alpha=alpha)
R = RPProvider.get(self.module.layer2)(R, alpha=alpha)
R = RPProvider.get(self.module.layer1)(R, alpha=alpha)

R = RPProvider.get(self.module.maxpool)(R, alpha=alpha)
R = RPProvider.get(self.module.relu)(R, alpha=alpha)
R = RPProvider.get(self.module.bn1)(R, alpha=alpha)
R = RPProvider.get(self.module.conv1)(R, alpha=alpha)
return R


def _resnet(
arch: str,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
pretrained: bool,
progress: bool,
**kwargs: Any,
) -> ResNet:
model = ResNet(block, layers, **kwargs)
#if pretrained:
#state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
#model.load_state_dict(state_dict)
return model


def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)


def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)


def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)


def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)


def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs)


def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs["groups"] = 32
kwargs["width_per_group"] = 4
return _resnet("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)


def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs["groups"] = 32
kwargs["width_per_group"] = 8
return _resnet("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)


def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs["width_per_group"] = 64 * 2
return _resnet("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)


def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs["width_per_group"] = 64 * 2
return _resnet("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)

+ 5
- 0
torchlap/modules/__init__.py View File

@@ -0,0 +1,5 @@
from .gaussianmax import GaussianMax2d
from .activated_conv import ActivatedConv2d
from .scaled_sigmoid import ScaledSigmoid
from .lap import LAP
from .adaptive_lap import AdaptiveLAP

+ 0
- 0
torchlap/modules/_common_types.py View File


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save