from typing import Dict, Union, Tuple import torch from .interpreter import Interpreter from ..data.dataflow import DataFlow from ..data.data_loader import DataLoader class InterpretationDataFlow(DataFlow[Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]]): def __init__(self, interpreter: Interpreter, dataloader: DataLoader, device: torch.device, print_debug_info: bool, label_for_interpreter: Union[int, None], give_gt_as_label: bool): super().__init__(interpreter._model, dataloader, device, print_debug_info=print_debug_info) self._interpreter = interpreter self._label_for_interpreter = label_for_interpreter self._give_gt_as_label = give_gt_as_label self._sample_labels = dataloader.get_samples_labels() @staticmethod def _detach_output(output: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: if output is None: return None return { name: o.detach() for name, o in output.items() } def _final_stage(self, model_input: Dict[str, torch.Tensor]) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: if self._give_gt_as_label: return model_input, self._detach_output(self._interpreter.interpret( torch.Tensor(self._sample_labels[ self._dataloader.get_current_batch_sample_indices()]), **model_input)) return model_input, self._detach_output(self._interpreter.interpret( self._label_for_interpreter, **model_input ))