123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- from typing import Dict, TYPE_CHECKING
- from os import makedirs, path
- import traceback
- from time import time
- import warnings
- import imageio
-
- import torch
- import numpy as np
-
- from ..data.data_loader import DataLoader
- from .interpretable import InterpretableModel
- from .interpreter_maker import create_interpreter
- from .interpreter import Interpreter
- from .interpretation_dataflow import InterpretationDataFlow
- from .utils import overlay_interpretation
- if TYPE_CHECKING:
- from ..configs.base_config import BaseConfig
-
-
- class InterpretingRunner:
-
- def __init__(self, conf: 'BaseConfig', model: InterpretableModel):
- self.conf = conf
- self.model = model
-
- def interpret(self):
- interpreter = create_interpreter(
- self.conf.interpretation_method,
- self.model)
-
- for test_group_info in self.conf.samples_dir.split(','):
- try:
-
- print('>> Finding interpretations for %s' % test_group_info, flush=True)
-
- t1 = time()
-
- labels_to_use = self.conf.mapped_labels_to_use
- if labels_to_use is None:
- labels_to_use = 'All'
- else:
- labels_to_use = ','.join([str(x) for x in self.conf.mapped_labels_to_use])
-
- report_dir = self.conf.get_sample_group_specific_report_dir(
- test_group_info,
- extra_subdir=f'{self.conf.interpretation_method}-C,{labels_to_use}-cut,{self.conf.cut_threshold}-glob,{self.conf.global_threshold}')
-
- makedirs(report_dir, exist_ok=True)
- # writing the whole config
- f = open(report_dir + '/conf_info.txt', 'w')
- f.write(str(self.conf) + '\n')
- f.close()
-
- test_data_loader = \
- DataLoader(self.conf, test_group_info, 'test')
- self._interpret(report_dir, test_data_loader, interpreter)
-
- print('Interpretations were saved in %s' % report_dir)
- print('Interpreting was done in %.2f secs.' % (time() - t1,), flush=True)
-
- except Exception as e:
- print('Problem in %s' % test_group_info, flush=True)
- track = traceback.format_exc()
- print(track, flush=True)
-
- def _interpret(self, save_dir: str, data_loader: DataLoader, interpreter: Interpreter) -> None:
- """ Iterates over the samples (as much as and how data_loader specifies),
- finds the interpretations based on the specified interpretation method
- and saves the results in the received save_dir!"""
-
- makedirs(save_dir, exist_ok=True)
-
- if not isinstance(self.model, InterpretableModel):
- raise Exception('Model has not implemented the requirements of the InterpretableModel')
-
- # Running the model in evaluation mode
- self.model.eval()
-
- # initiating variables for running evaluations
-
- label_for_interpreter = self.conf.class_label_for_interpretation
- give_gt_as_label = label_for_interpreter is None and not self.conf.interpret_predictions_vs_gt
-
- dataflow = InterpretationDataFlow(interpreter,
- data_loader,
- self.conf.device,
- False,
- label_for_interpreter,
- give_gt_as_label)
-
- n_interpreted_samples = 0
- max_n_samples = self.conf.n_interpretation_samples
-
- with dataflow:
- for model_input, interpretation_output in dataflow.iterate():
- n_batch = model_input[
- self.model.ordered_placeholder_names_to_be_interpreted[0]].shape[0]
- n_samples_to_save = n_batch
- if max_n_samples is not None:
- n_samples_to_save = min(
- max_n_samples - n_interpreted_samples,
- n_batch)
-
- self._save_interpretations_of_batch(
- model_input[self.model.ordered_placeholder_names_to_be_interpreted[0]],
- interpretation_output, save_dir, data_loader, n_samples_to_save,
- interpreter)
- del interpretation_output
-
- # if the number of interpreted samples has reached the limit, break
- n_interpreted_samples += n_batch
- if max_n_samples is not None and n_interpreted_samples >= max_n_samples:
- break
-
- def _save_interpretations_of_batch(self, model_input: torch.Tensor, interpreter_output: Dict[str, torch.Tensor],
- save_dir: str, data_loader: DataLoader,
- n_samples_to_save: int, interpreter: Interpreter) -> None:
- """
- Receives the output of the interpreter and saves the interpretations in the received directory
- in a file named as the sample name.
- The behaviour can be overwritten in children if extra stuff are required.
- :param interpreter_output:
- :param save_dir:
- :return: None
- """
-
- batch_samples_names = data_loader.get_current_batch_samples_names()
-
- save_dirs = [self.conf.get_save_dir_for_sample(
- save_dir, batch_samples_names[bi].replace('../', ''))
- for bi in range(len(batch_samples_names))]
- for sd in save_dirs:
- makedirs(path.dirname(sd), exist_ok=True)
-
- interpreter_output = {name: output.cpu().numpy() for name, output in interpreter_output.items()}
- """ Make inputs grayscale """
- model_input = model_input.mean(dim=1).cpu().numpy()
-
- def save(bis):
- for bi in bis:
- for name, output in interpreter_output.items():
- output = output[bi]
- filename = f'{save_dirs[bi]}_{name}'
-
- if self.conf.dynamic_threshold:
- output = interpreter.dynamic_threshold(output)
-
- if not self.conf.skip_raw:
- np.save(filename, output)
- else:
- warnings.warn('Skipping raw interpretations')
-
- if not self.conf.skip_overlay:
- if output.shape[0] > 3:
- output = output[:3]
- overlayed = overlay_interpretation(model_input[bi][np.newaxis, ...], output, self.conf)
- imageio.imwrite(f'{filename}_overlay.png', overlayed)
- else:
- warnings.warn('Skipping overlayed interpretations')
-
- save(np.arange(min(len(model_input), n_samples_to_save)))
|