from typing import TYPE_CHECKING import traceback from time import time from ..data.data_loader import DataLoader from .interpretable import InterpretableModel from .interpreter_maker import create_interpreter from .interpreter import Interpreter from .interpretation_dataflow import InterpretationDataFlow from .binary_interpretation_evaluator_2d import BinaryInterpretationEvaluator2D if TYPE_CHECKING: from ..configs.base_config import BaseConfig class InterpretingEvalRunner: def __init__(self, conf: 'BaseConfig', model: InterpretableModel): self.conf = conf self.model = model def evaluate(self): interpreter = create_interpreter( self.conf.interpretation_method, self.model) for test_group_info in self.conf.samples_dir.split(','): try: print('>> Evaluating interpretations for %s' % test_group_info, flush=True) t1 = time() test_data_loader = \ DataLoader(self.conf, test_group_info, 'test') evaluator = BinaryInterpretationEvaluator2D( test_data_loader.get_number_of_samples(), self.conf ) self._eval_interpretations(test_data_loader, interpreter, evaluator) evaluator.print_summaries() print('Evaluating Interpretations was done in %.2f secs.' % (time() - t1,), flush=True) except Exception as e: print('Problem in %s' % test_group_info, flush=True) track = traceback.format_exc() print(track, flush=True) def _eval_interpretations(self, data_loader: DataLoader, interpreter: Interpreter, evaluator: BinaryInterpretationEvaluator2D) -> None: """ CAUTION: Resets the data_loader, Iterates over the samples (as much as and how data_loader specifies), finds the interpretations based on the specified interpretation method and saves the results in the received save_dir!""" if not isinstance(self.model, InterpretableModel): raise Exception('Model has not implemented the requirements of the InterpretableModel') # Running the model in evaluation mode self.model.eval() # initiating variables for running evaluations evaluator.reset() label_for_interpreter = self.conf.class_label_for_interpretation give_gt_as_label = label_for_interpreter is None and not self.conf.interpret_predictions_vs_gt dataflow = InterpretationDataFlow(interpreter, data_loader, self.conf.device, False, label_for_interpreter, give_gt_as_label) n_interpreted_samples = 0 max_n_samples = self.conf.n_interpretation_samples with dataflow: for model_input, interpretation_output in dataflow.iterate(): n_batch = model_input[ self.model.ordered_placeholder_names_to_be_interpreted[0]].shape[0] n_samples_to_save = n_batch if max_n_samples is not None: n_samples_to_save = min( max_n_samples - n_interpreted_samples, n_batch) model_outputs = self.model(**model_input) model_preds = model_outputs[ self.conf.prediction_key_in_model_output_dict].detach().cpu().numpy() if self.conf.interpretation_tag_to_evaluate: interpretations = interpretation_output[ self.conf.interpretation_tag_to_evaluate ].detach() else: interpretations = list(interpretation_output.values())[0].detach() interpretations = interpretations.detach().cpu().numpy() evaluator.update_summaries( interpretations[:n_samples_to_save, 0], data_loader.get_current_batch_samples_interpretations()[:n_samples_to_save, 0].cpu().numpy(), model_preds[:n_samples_to_save], data_loader.get_current_batch_samples_labels()[:n_samples_to_save], data_loader.get_current_batch_sample_indices()[:n_samples_to_save], interpreter ) del interpretation_output del model_outputs # if the number of interpreted samples has reached the limit, break n_interpreted_samples += n_batch if max_n_samples is not None and n_interpreted_samples >= max_n_samples: break