|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123 |
- from typing import TYPE_CHECKING
- import traceback
- from time import time
-
- from ..data.data_loader import DataLoader
- from .interpretable import InterpretableModel
- from .interpreter_maker import create_interpreter
- from .interpreter import Interpreter
- from .interpretation_dataflow import InterpretationDataFlow
- from .binary_interpretation_evaluator_2d import BinaryInterpretationEvaluator2D
- if TYPE_CHECKING:
- from ..configs.base_config import BaseConfig
-
-
- class InterpretingEvalRunner:
-
- def __init__(self, conf: 'BaseConfig', model: InterpretableModel):
- self.conf = conf
- self.model = model
-
- def evaluate(self):
-
- interpreter = create_interpreter(
- self.conf.interpretation_method,
- self.model)
-
- for test_group_info in self.conf.samples_dir.split(','):
-
- try:
-
- print('>> Evaluating interpretations for %s' % test_group_info, flush=True)
-
- t1 = time()
-
- test_data_loader = \
- DataLoader(self.conf, test_group_info, 'test')
-
- evaluator = BinaryInterpretationEvaluator2D(
- test_data_loader.get_number_of_samples(),
- self.conf
- )
-
- self._eval_interpretations(test_data_loader, interpreter, evaluator)
-
- evaluator.print_summaries()
-
- print('Evaluating Interpretations was done in %.2f secs.' % (time() - t1,), flush=True)
-
- except Exception as e:
- print('Problem in %s' % test_group_info, flush=True)
- track = traceback.format_exc()
- print(track, flush=True)
-
- def _eval_interpretations(self,
- data_loader: DataLoader, interpreter: Interpreter,
- evaluator: BinaryInterpretationEvaluator2D) -> None:
- """ CAUTION: Resets the data_loader,
- Iterates over the samples (as much as and how data_loader specifies),
- finds the interpretations based on the specified interpretation method
- and saves the results in the received save_dir!"""
-
- if not isinstance(self.model, InterpretableModel):
- raise Exception('Model has not implemented the requirements of the InterpretableModel')
-
- # Running the model in evaluation mode
- self.model.eval()
-
- # initiating variables for running evaluations
- evaluator.reset()
-
- label_for_interpreter = self.conf.class_label_for_interpretation
- give_gt_as_label = label_for_interpreter is None and not self.conf.interpret_predictions_vs_gt
-
- dataflow = InterpretationDataFlow(interpreter,
- data_loader,
- self.conf.device,
- False,
- label_for_interpreter,
- give_gt_as_label)
-
- n_interpreted_samples = 0
- max_n_samples = self.conf.n_interpretation_samples
-
- with dataflow:
- for model_input, interpretation_output in dataflow.iterate():
-
- n_batch = model_input[
- self.model.ordered_placeholder_names_to_be_interpreted[0]].shape[0]
- n_samples_to_save = n_batch
- if max_n_samples is not None:
- n_samples_to_save = min(
- max_n_samples - n_interpreted_samples,
- n_batch)
-
- model_outputs = self.model(**model_input)
- model_preds = model_outputs[
- self.conf.prediction_key_in_model_output_dict].detach().cpu().numpy()
-
- if self.conf.interpretation_tag_to_evaluate:
- interpretations = interpretation_output[
- self.conf.interpretation_tag_to_evaluate
- ].detach()
- else:
- interpretations = list(interpretation_output.values())[0].detach()
-
- interpretations = interpretations.detach().cpu().numpy()
-
- evaluator.update_summaries(
- interpretations[:n_samples_to_save, 0],
- data_loader.get_current_batch_samples_interpretations()[:n_samples_to_save, 0].cpu().numpy(),
- model_preds[:n_samples_to_save],
- data_loader.get_current_batch_samples_labels()[:n_samples_to_save],
- data_loader.get_current_batch_sample_indices()[:n_samples_to_save],
- interpreter
- )
-
- del interpretation_output
- del model_outputs
-
- # if the number of interpreted samples has reached the limit, break
- n_interpreted_samples += n_batch
- if max_n_samples is not None and n_interpreted_samples >= max_n_samples:
- break
|