You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

interpreting_runner.py 6.7KB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. from typing import Dict, TYPE_CHECKING
  2. from os import makedirs, path
  3. import traceback
  4. from time import time
  5. import warnings
  6. import imageio
  7. import torch
  8. import numpy as np
  9. from ..data.data_loader import DataLoader
  10. from .interpretable import InterpretableModel
  11. from .interpreter_maker import create_interpreter
  12. from .interpreter import Interpreter
  13. from .interpretation_dataflow import InterpretationDataFlow
  14. from .utils import overlay_interpretation
  15. if TYPE_CHECKING:
  16. from ..configs.base_config import BaseConfig
  17. class InterpretingRunner:
  18. def __init__(self, conf: 'BaseConfig', model: InterpretableModel):
  19. self.conf = conf
  20. self.model = model
  21. def interpret(self):
  22. interpreter = create_interpreter(
  23. self.conf.interpretation_method,
  24. self.model)
  25. for test_group_info in self.conf.samples_dir.split(','):
  26. try:
  27. print('>> Finding interpretations for %s' % test_group_info, flush=True)
  28. t1 = time()
  29. labels_to_use = self.conf.mapped_labels_to_use
  30. if labels_to_use is None:
  31. labels_to_use = 'All'
  32. else:
  33. labels_to_use = ','.join([str(x) for x in self.conf.mapped_labels_to_use])
  34. report_dir = self.conf.get_sample_group_specific_report_dir(
  35. test_group_info,
  36. extra_subdir=f'{self.conf.interpretation_method}-C,{labels_to_use}-cut,{self.conf.cut_threshold}-glob,{self.conf.global_threshold}')
  37. makedirs(report_dir, exist_ok=True)
  38. # writing the whole config
  39. f = open(report_dir + '/conf_info.txt', 'w')
  40. f.write(str(self.conf) + '\n')
  41. f.close()
  42. test_data_loader = \
  43. DataLoader(self.conf, test_group_info, 'test')
  44. self._interpret(report_dir, test_data_loader, interpreter)
  45. print('Interpretations were saved in %s' % report_dir)
  46. print('Interpreting was done in %.2f secs.' % (time() - t1,), flush=True)
  47. except Exception as e:
  48. print('Problem in %s' % test_group_info, flush=True)
  49. track = traceback.format_exc()
  50. print(track, flush=True)
  51. def _interpret(self, save_dir: str, data_loader: DataLoader, interpreter: Interpreter) -> None:
  52. """ Iterates over the samples (as much as and how data_loader specifies),
  53. finds the interpretations based on the specified interpretation method
  54. and saves the results in the received save_dir!"""
  55. makedirs(save_dir, exist_ok=True)
  56. if not isinstance(self.model, InterpretableModel):
  57. raise Exception('Model has not implemented the requirements of the InterpretableModel')
  58. # Running the model in evaluation mode
  59. self.model.eval()
  60. # initiating variables for running evaluations
  61. label_for_interpreter = self.conf.class_label_for_interpretation
  62. give_gt_as_label = label_for_interpreter is None and not self.conf.interpret_predictions_vs_gt
  63. dataflow = InterpretationDataFlow(interpreter,
  64. data_loader,
  65. self.conf.device,
  66. False,
  67. label_for_interpreter,
  68. give_gt_as_label)
  69. n_interpreted_samples = 0
  70. max_n_samples = self.conf.n_interpretation_samples
  71. with dataflow:
  72. for model_input, interpretation_output in dataflow.iterate():
  73. n_batch = model_input[
  74. self.model.ordered_placeholder_names_to_be_interpreted[0]].shape[0]
  75. n_samples_to_save = n_batch
  76. if max_n_samples is not None:
  77. n_samples_to_save = min(
  78. max_n_samples - n_interpreted_samples,
  79. n_batch)
  80. self._save_interpretations_of_batch(
  81. model_input[self.model.ordered_placeholder_names_to_be_interpreted[0]],
  82. interpretation_output, save_dir, data_loader, n_samples_to_save,
  83. interpreter)
  84. del interpretation_output
  85. # if the number of interpreted samples has reached the limit, break
  86. n_interpreted_samples += n_batch
  87. if max_n_samples is not None and n_interpreted_samples >= max_n_samples:
  88. break
  89. def _save_interpretations_of_batch(self, model_input: torch.Tensor, interpreter_output: Dict[str, torch.Tensor],
  90. save_dir: str, data_loader: DataLoader,
  91. n_samples_to_save: int, interpreter: Interpreter) -> None:
  92. """
  93. Receives the output of the interpreter and saves the interpretations in the received directory
  94. in a file named as the sample name.
  95. The behaviour can be overwritten in children if extra stuff are required.
  96. :param interpreter_output:
  97. :param save_dir:
  98. :return: None
  99. """
  100. batch_samples_names = data_loader.get_current_batch_samples_names()
  101. save_dirs = [self.conf.get_save_dir_for_sample(
  102. save_dir, batch_samples_names[bi].replace('../', ''))
  103. for bi in range(len(batch_samples_names))]
  104. for sd in save_dirs:
  105. makedirs(path.dirname(sd), exist_ok=True)
  106. interpreter_output = {name: output.cpu().numpy() for name, output in interpreter_output.items()}
  107. """ Make inputs grayscale """
  108. model_input = model_input.mean(dim=1).cpu().numpy()
  109. def save(bis):
  110. for bi in bis:
  111. for name, output in interpreter_output.items():
  112. output = output[bi]
  113. filename = f'{save_dirs[bi]}_{name}'
  114. if self.conf.dynamic_threshold:
  115. output = interpreter.dynamic_threshold(output)
  116. if not self.conf.skip_raw:
  117. np.save(filename, output)
  118. else:
  119. warnings.warn('Skipping raw interpretations')
  120. if not self.conf.skip_overlay:
  121. if output.shape[0] > 3:
  122. output = output[:3]
  123. overlayed = overlay_interpretation(model_input[bi][np.newaxis, ...], output, self.conf)
  124. imageio.imwrite(f'{filename}_overlay.png', overlayed)
  125. else:
  126. warnings.warn('Skipping overlayed interpretations')
  127. save(np.arange(min(len(model_input), n_samples_to_save)))