You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataflow.py 1.6KB

1 year ago
12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. from typing import Dict, Union, Tuple
  2. import torch
  3. from . import Interpreter
  4. from ..data.dataflow import DataFlow
  5. from ..models.model import ModelIO
  6. from ..data.data_loader import DataLoader
  7. class InterpretationDataFlow(DataFlow[Tuple[ModelIO, Dict[str, torch.Tensor]]]):
  8. def __init__(self, interpreter: Interpreter,
  9. dataloader: DataLoader,
  10. phase: str,
  11. device: torch.device,
  12. dtype: torch.dtype,
  13. print_debug_info: bool,
  14. label_for_interpreter: Union[int, None],
  15. give_gt_as_label: bool):
  16. super().__init__(interpreter._model, dataloader, phase, device,
  17. dtype, print_debug_info=print_debug_info)
  18. self._interpreter = interpreter
  19. self._label_for_interpreter = label_for_interpreter
  20. self._give_gt_as_label = give_gt_as_label
  21. self._sample_labels = dataloader.get_samples_labels()
  22. @staticmethod
  23. def _detach_output(output: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
  24. return {
  25. name: o.detach() for name, o in output.items()
  26. }
  27. def _final_stage(self, model_input: ModelIO) -> Tuple[ModelIO, Dict[str, torch.Tensor]]:
  28. if self._give_gt_as_label:
  29. return model_input, self._detach_output(self._interpreter.interpret(
  30. torch.Tensor(self._sample_labels[
  31. self._dataloader.get_current_batch_sample_indices()]),
  32. **model_input))
  33. return model_input, self._detach_output(self._interpreter.interpret(
  34. self._label_for_interpreter, **model_input
  35. ))