12345678910111213141516171819202122232425262728293031323334353637383940414243 |
- from typing import Dict, Union, Tuple
-
- import torch
-
- from . import Interpreter
- from ..data.dataflow import DataFlow
- from ..models.model import ModelIO
- from ..data.data_loader import DataLoader
-
-
- class InterpretationDataFlow(DataFlow[Tuple[ModelIO, Dict[str, torch.Tensor]]]):
-
- def __init__(self, interpreter: Interpreter,
- dataloader: DataLoader,
- phase: str,
- device: torch.device,
- dtype: torch.dtype,
- print_debug_info: bool,
- label_for_interpreter: Union[int, None],
- give_gt_as_label: bool):
- super().__init__(interpreter._model, dataloader, phase, device,
- dtype, print_debug_info=print_debug_info)
- self._interpreter = interpreter
- self._label_for_interpreter = label_for_interpreter
- self._give_gt_as_label = give_gt_as_label
- self._sample_labels = dataloader.get_samples_labels()
-
- @staticmethod
- def _detach_output(output: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
- return {
- name: o.detach() for name, o in output.items()
- }
-
- def _final_stage(self, model_input: ModelIO) -> Tuple[ModelIO, Dict[str, torch.Tensor]]:
- if self._give_gt_as_label:
- return model_input, self._detach_output(self._interpreter.interpret(
- torch.Tensor(self._sample_labels[
- self._dataloader.get_current_batch_sample_indices()]),
- **model_input))
-
- return model_input, self._detach_output(self._interpreter.interpret(
- self._label_for_interpreter, **model_input
- ))
|