|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- """
- Runs flow of data, from loading to forwarding the model.
- """
- from typing import Iterator, Union, TypeVar, Generic, Dict
-
- import torch
-
- from ..data.data_loader import DataLoader, create_empty_placeholder
- from ..data.dataloader_context import DataloaderContext
- from ..models.model import Model
-
-
- TReturn = TypeVar('TReturn')
-
-
- class DataFlow(Generic[TReturn]):
- """
- Runs flow of data, from loading to forwarding the model.
- """
-
- def __init__(self, model: 'Model', dataloader: 'DataLoader',
- device: torch.device, print_debug_info:bool = False):
- self._model = model
- self._device = device
- self._dataloader = dataloader
- self._placeholders: Union[None, Dict[str, torch.Tensor]] = None
- self._print_debug_info = print_debug_info
-
- def iterate(self) -> Iterator[TReturn]:
- """
- Iterates on data and forwards the model
- """
-
- if self._placeholders is None:
- raise Exception("Please use `with` statement before calling iterate method:\n" +
- "with dataflow:\n" +
- " for model_output in dataflow.iterate():\n" +
- " pass\n")
-
- DataloaderContext.instance.dataloader = self._dataloader
-
- while True:
- self._dataloader.prepare_next_batch()
-
- if self._print_debug_info:
- print('> Next Batch:')
- print('\t' + '\n\t'.join(self._dataloader.get_current_batch_samples_names()))
-
- # check if the iteration is done
- if self._dataloader.finished_iteration():
- break
-
- # Filling the placeholders for running the model
- self._dataloader.fill_placeholders(self._placeholders)
-
- # Running the model
- yield self._final_stage(self._placeholders)
-
- def __enter__(self):
- # self._dataloader.reset()
- self._placeholders = self._build_placeholders()
-
- def __exit__(self, exc_type, exc_value, traceback):
- for name in self._placeholders:
- self._placeholders[name] = None
- self._placeholders = None
-
- def _final_stage(self, model_input: Dict[str, torch.Tensor]) -> TReturn:
- return self._model(**model_input)
-
- def _build_placeholders(self) -> Dict[str, torch.Tensor]:
- """ Creates placeholders for feeding the model.
- Returns a dictionary containing placeholders.
- The placeholders are created based on the names
- of the input variables of the model's forward method.
- The variables initialized as None are assumed to be
- labels used for training phase only, and won't be
- added in phases other than train."""
-
- model_forward_args, args_default_values, additional_kwargs = self._model.get_forward_required_kws_kwargs_and_defaults()
-
- # Checking which args have None as default
- args_default_values = ['Mandatory' for _ in
- range(len(model_forward_args) - len(args_default_values))] + \
- args_default_values
-
- optional_args = [x is None for x in args_default_values]
-
- placeholders = dict()
-
- for argname in model_forward_args:
- placeholders[argname] = create_empty_placeholder(self._device)
-
- for argname in additional_kwargs:
- placeholders[argname] = create_empty_placeholder(self._device)
-
- return placeholders
-
|