from typing import Dict, TYPE_CHECKING from os import makedirs, path import traceback from time import time import warnings import imageio import torch import numpy as np from ..data.data_loader import DataLoader from .interpretable import InterpretableModel from .interpreter_maker import create_interpreter from .interpreter import Interpreter from .interpretation_dataflow import InterpretationDataFlow from .utils import overlay_interpretation if TYPE_CHECKING: from ..configs.base_config import BaseConfig class InterpretingRunner: def __init__(self, conf: 'BaseConfig', model: InterpretableModel): self.conf = conf self.model = model def interpret(self): interpreter = create_interpreter( self.conf.interpretation_method, self.model) for test_group_info in self.conf.samples_dir.split(','): try: print('>> Finding interpretations for %s' % test_group_info, flush=True) t1 = time() labels_to_use = self.conf.mapped_labels_to_use if labels_to_use is None: labels_to_use = 'All' else: labels_to_use = ','.join([str(x) for x in self.conf.mapped_labels_to_use]) report_dir = self.conf.get_sample_group_specific_report_dir( test_group_info, extra_subdir=f'{self.conf.interpretation_method}-C,{labels_to_use}-cut,{self.conf.cut_threshold}-glob,{self.conf.global_threshold}') makedirs(report_dir, exist_ok=True) # writing the whole config f = open(report_dir + '/conf_info.txt', 'w') f.write(str(self.conf) + '\n') f.close() test_data_loader = \ DataLoader(self.conf, test_group_info, 'test') self._interpret(report_dir, test_data_loader, interpreter) print('Interpretations were saved in %s' % report_dir) print('Interpreting was done in %.2f secs.' % (time() - t1,), flush=True) except Exception as e: print('Problem in %s' % test_group_info, flush=True) track = traceback.format_exc() print(track, flush=True) def _interpret(self, save_dir: str, data_loader: DataLoader, interpreter: Interpreter) -> None: """ Iterates over the samples (as much as and how data_loader specifies), finds the interpretations based on the specified interpretation method and saves the results in the received save_dir!""" makedirs(save_dir, exist_ok=True) if not isinstance(self.model, InterpretableModel): raise Exception('Model has not implemented the requirements of the InterpretableModel') # Running the model in evaluation mode self.model.eval() # initiating variables for running evaluations label_for_interpreter = self.conf.class_label_for_interpretation give_gt_as_label = label_for_interpreter is None and not self.conf.interpret_predictions_vs_gt dataflow = InterpretationDataFlow(interpreter, data_loader, self.conf.device, False, label_for_interpreter, give_gt_as_label) n_interpreted_samples = 0 max_n_samples = self.conf.n_interpretation_samples with dataflow: for model_input, interpretation_output in dataflow.iterate(): n_batch = model_input[ self.model.ordered_placeholder_names_to_be_interpreted[0]].shape[0] n_samples_to_save = n_batch if max_n_samples is not None: n_samples_to_save = min( max_n_samples - n_interpreted_samples, n_batch) self._save_interpretations_of_batch( model_input[self.model.ordered_placeholder_names_to_be_interpreted[0]], interpretation_output, save_dir, data_loader, n_samples_to_save, interpreter) del interpretation_output # if the number of interpreted samples has reached the limit, break n_interpreted_samples += n_batch if max_n_samples is not None and n_interpreted_samples >= max_n_samples: break def _save_interpretations_of_batch(self, model_input: torch.Tensor, interpreter_output: Dict[str, torch.Tensor], save_dir: str, data_loader: DataLoader, n_samples_to_save: int, interpreter: Interpreter) -> None: """ Receives the output of the interpreter and saves the interpretations in the received directory in a file named as the sample name. The behaviour can be overwritten in children if extra stuff are required. :param interpreter_output: :param save_dir: :return: None """ batch_samples_names = data_loader.get_current_batch_samples_names() save_dirs = [self.conf.get_save_dir_for_sample( save_dir, batch_samples_names[bi].replace('../', '')) for bi in range(len(batch_samples_names))] for sd in save_dirs: makedirs(path.dirname(sd), exist_ok=True) interpreter_output = {name: output.cpu().numpy() for name, output in interpreter_output.items()} """ Make inputs grayscale """ model_input = model_input.mean(dim=1).cpu().numpy() def save(bis): for bi in bis: for name, output in interpreter_output.items(): output = output[bi] filename = f'{save_dirs[bi]}_{name}' if self.conf.dynamic_threshold: output = interpreter.dynamic_threshold(output) if not self.conf.skip_raw: np.save(filename, output) else: warnings.warn('Skipping raw interpretations') if not self.conf.skip_overlay: if output.shape[0] > 3: output = output[:3] overlayed = overlay_interpretation(model_input[bi][np.newaxis, ...], output, self.conf) imageio.imwrite(f'{filename}_overlay.png', overlayed) else: warnings.warn('Skipping overlayed interpretations') save(np.arange(min(len(model_input), n_samples_to_save)))