|
123456789101112131415161718192021222324252627282930313233343536373839404142 |
- from typing import Dict, Union, Tuple
-
- import torch
-
- from .interpreter import Interpreter
- from ..data.dataflow import DataFlow
- from ..data.data_loader import DataLoader
-
-
- class InterpretationDataFlow(DataFlow[Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]]):
-
- def __init__(self, interpreter: Interpreter,
- dataloader: DataLoader,
- device: torch.device,
- print_debug_info: bool,
- label_for_interpreter: Union[int, None],
- give_gt_as_label: bool):
- super().__init__(interpreter._model, dataloader, device, print_debug_info=print_debug_info)
- self._interpreter = interpreter
- self._label_for_interpreter = label_for_interpreter
- self._give_gt_as_label = give_gt_as_label
- self._sample_labels = dataloader.get_samples_labels()
-
- @staticmethod
- def _detach_output(output: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
- if output is None:
- return None
-
- return {
- name: o.detach() for name, o in output.items()
- }
-
- def _final_stage(self, model_input: Dict[str, torch.Tensor]) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
- if self._give_gt_as_label:
- return model_input, self._detach_output(self._interpreter.interpret(
- torch.Tensor(self._sample_labels[
- self._dataloader.get_current_batch_sample_indices()]),
- **model_input))
-
- return model_input, self._detach_output(self._interpreter.interpret(
- self._label_for_interpreter, **model_input
- ))
|