You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

interpretation_dataflow.py 1.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. from typing import Dict, Union, Tuple
  2. import torch
  3. from .interpreter import Interpreter
  4. from ..data.dataflow import DataFlow
  5. from ..data.data_loader import DataLoader
  6. class InterpretationDataFlow(DataFlow[Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]]):
  7. def __init__(self, interpreter: Interpreter,
  8. dataloader: DataLoader,
  9. device: torch.device,
  10. print_debug_info: bool,
  11. label_for_interpreter: Union[int, None],
  12. give_gt_as_label: bool):
  13. super().__init__(interpreter._model, dataloader, device, print_debug_info=print_debug_info)
  14. self._interpreter = interpreter
  15. self._label_for_interpreter = label_for_interpreter
  16. self._give_gt_as_label = give_gt_as_label
  17. self._sample_labels = dataloader.get_samples_labels()
  18. @staticmethod
  19. def _detach_output(output: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
  20. if output is None:
  21. return None
  22. return {
  23. name: o.detach() for name, o in output.items()
  24. }
  25. def _final_stage(self, model_input: Dict[str, torch.Tensor]) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
  26. if self._give_gt_as_label:
  27. return model_input, self._detach_output(self._interpreter.interpret(
  28. torch.Tensor(self._sample_labels[
  29. self._dataloader.get_current_batch_sample_indices()]),
  30. **model_input))
  31. return model_input, self._detach_output(self._interpreter.interpret(
  32. self._label_for_interpreter, **model_input
  33. ))