You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

interpretation_evaluation_runner.py 4.8KB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from typing import TYPE_CHECKING
  2. import traceback
  3. from time import time
  4. from ..data.data_loader import DataLoader
  5. from .interpretable import InterpretableModel
  6. from .interpreter_maker import create_interpreter
  7. from .interpreter import Interpreter
  8. from .interpretation_dataflow import InterpretationDataFlow
  9. from .binary_interpretation_evaluator_2d import BinaryInterpretationEvaluator2D
  10. if TYPE_CHECKING:
  11. from ..configs.base_config import BaseConfig
  12. class InterpretingEvalRunner:
  13. def __init__(self, conf: 'BaseConfig', model: InterpretableModel):
  14. self.conf = conf
  15. self.model = model
  16. def evaluate(self):
  17. interpreter = create_interpreter(
  18. self.conf.interpretation_method,
  19. self.model)
  20. for test_group_info in self.conf.samples_dir.split(','):
  21. try:
  22. print('>> Evaluating interpretations for %s' % test_group_info, flush=True)
  23. t1 = time()
  24. test_data_loader = \
  25. DataLoader(self.conf, test_group_info, 'test')
  26. evaluator = BinaryInterpretationEvaluator2D(
  27. test_data_loader.get_number_of_samples(),
  28. self.conf
  29. )
  30. self._eval_interpretations(test_data_loader, interpreter, evaluator)
  31. evaluator.print_summaries()
  32. print('Evaluating Interpretations was done in %.2f secs.' % (time() - t1,), flush=True)
  33. except Exception as e:
  34. print('Problem in %s' % test_group_info, flush=True)
  35. track = traceback.format_exc()
  36. print(track, flush=True)
  37. def _eval_interpretations(self,
  38. data_loader: DataLoader, interpreter: Interpreter,
  39. evaluator: BinaryInterpretationEvaluator2D) -> None:
  40. """ CAUTION: Resets the data_loader,
  41. Iterates over the samples (as much as and how data_loader specifies),
  42. finds the interpretations based on the specified interpretation method
  43. and saves the results in the received save_dir!"""
  44. if not isinstance(self.model, InterpretableModel):
  45. raise Exception('Model has not implemented the requirements of the InterpretableModel')
  46. # Running the model in evaluation mode
  47. self.model.eval()
  48. # initiating variables for running evaluations
  49. evaluator.reset()
  50. label_for_interpreter = self.conf.class_label_for_interpretation
  51. give_gt_as_label = label_for_interpreter is None and not self.conf.interpret_predictions_vs_gt
  52. dataflow = InterpretationDataFlow(interpreter,
  53. data_loader,
  54. self.conf.device,
  55. False,
  56. label_for_interpreter,
  57. give_gt_as_label)
  58. n_interpreted_samples = 0
  59. max_n_samples = self.conf.n_interpretation_samples
  60. with dataflow:
  61. for model_input, interpretation_output in dataflow.iterate():
  62. n_batch = model_input[
  63. self.model.ordered_placeholder_names_to_be_interpreted[0]].shape[0]
  64. n_samples_to_save = n_batch
  65. if max_n_samples is not None:
  66. n_samples_to_save = min(
  67. max_n_samples - n_interpreted_samples,
  68. n_batch)
  69. model_outputs = self.model(**model_input)
  70. model_preds = model_outputs[
  71. self.conf.prediction_key_in_model_output_dict].detach().cpu().numpy()
  72. if self.conf.interpretation_tag_to_evaluate:
  73. interpretations = interpretation_output[
  74. self.conf.interpretation_tag_to_evaluate
  75. ].detach()
  76. else:
  77. interpretations = list(interpretation_output.values())[0].detach()
  78. interpretations = interpretations.detach().cpu().numpy()
  79. evaluator.update_summaries(
  80. interpretations[:n_samples_to_save, 0],
  81. data_loader.get_current_batch_samples_interpretations()[:n_samples_to_save, 0].cpu().numpy(),
  82. model_preds[:n_samples_to_save],
  83. data_loader.get_current_batch_samples_labels()[:n_samples_to_save],
  84. data_loader.get_current_batch_sample_indices()[:n_samples_to_save],
  85. interpreter
  86. )
  87. del interpretation_output
  88. del model_outputs
  89. # if the number of interpreted samples has reached the limit, break
  90. n_interpreted_samples += n_batch
  91. if max_n_samples is not None and n_interpreted_samples >= max_n_samples:
  92. break