You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataflow.py 3.4KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. """
  2. Runs flow of data, from loading to forwarding the model.
  3. """
  4. from typing import Iterator, Union, TypeVar, Generic, Dict
  5. import torch
  6. from ..data.data_loader import DataLoader, create_empty_placeholder
  7. from ..data.dataloader_context import DataloaderContext
  8. from ..models.model import Model
  9. TReturn = TypeVar('TReturn')
  10. class DataFlow(Generic[TReturn]):
  11. """
  12. Runs flow of data, from loading to forwarding the model.
  13. """
  14. def __init__(self, model: 'Model', dataloader: 'DataLoader',
  15. device: torch.device, print_debug_info:bool = False):
  16. self._model = model
  17. self._device = device
  18. self._dataloader = dataloader
  19. self._placeholders: Union[None, Dict[str, torch.Tensor]] = None
  20. self._print_debug_info = print_debug_info
  21. def iterate(self) -> Iterator[TReturn]:
  22. """
  23. Iterates on data and forwards the model
  24. """
  25. if self._placeholders is None:
  26. raise Exception("Please use `with` statement before calling iterate method:\n" +
  27. "with dataflow:\n" +
  28. " for model_output in dataflow.iterate():\n" +
  29. " pass\n")
  30. DataloaderContext.instance.dataloader = self._dataloader
  31. while True:
  32. self._dataloader.prepare_next_batch()
  33. if self._print_debug_info:
  34. print('> Next Batch:')
  35. print('\t' + '\n\t'.join(self._dataloader.get_current_batch_samples_names()))
  36. # check if the iteration is done
  37. if self._dataloader.finished_iteration():
  38. break
  39. # Filling the placeholders for running the model
  40. self._dataloader.fill_placeholders(self._placeholders)
  41. # Running the model
  42. yield self._final_stage(self._placeholders)
  43. def __enter__(self):
  44. # self._dataloader.reset()
  45. self._placeholders = self._build_placeholders()
  46. def __exit__(self, exc_type, exc_value, traceback):
  47. for name in self._placeholders:
  48. self._placeholders[name] = None
  49. self._placeholders = None
  50. def _final_stage(self, model_input: Dict[str, torch.Tensor]) -> TReturn:
  51. return self._model(**model_input)
  52. def _build_placeholders(self) -> Dict[str, torch.Tensor]:
  53. """ Creates placeholders for feeding the model.
  54. Returns a dictionary containing placeholders.
  55. The placeholders are created based on the names
  56. of the input variables of the model's forward method.
  57. The variables initialized as None are assumed to be
  58. labels used for training phase only, and won't be
  59. added in phases other than train."""
  60. model_forward_args, args_default_values, additional_kwargs = self._model.get_forward_required_kws_kwargs_and_defaults()
  61. # Checking which args have None as default
  62. args_default_values = ['Mandatory' for _ in
  63. range(len(model_forward_args) - len(args_default_values))] + \
  64. args_default_values
  65. optional_args = [x is None for x in args_default_values]
  66. placeholders = dict()
  67. for argname in model_forward_args:
  68. placeholders[argname] = create_empty_placeholder(self._device)
  69. for argname in additional_kwargs:
  70. placeholders[argname] = create_empty_placeholder(self._device)
  71. return placeholders