""" Runs flow of data, from loading to forwarding the model. """ from typing import Iterator, Union, TypeVar, Generic, Dict import torch from ..data.data_loader import DataLoader, create_empty_placeholder from ..data.dataloader_context import DataloaderContext from ..models.model import Model TReturn = TypeVar('TReturn') class DataFlow(Generic[TReturn]): """ Runs flow of data, from loading to forwarding the model. """ def __init__(self, model: 'Model', dataloader: 'DataLoader', device: torch.device, print_debug_info:bool = False): self._model = model self._device = device self._dataloader = dataloader self._placeholders: Union[None, Dict[str, torch.Tensor]] = None self._print_debug_info = print_debug_info def iterate(self) -> Iterator[TReturn]: """ Iterates on data and forwards the model """ if self._placeholders is None: raise Exception("Please use `with` statement before calling iterate method:\n" + "with dataflow:\n" + " for model_output in dataflow.iterate():\n" + " pass\n") DataloaderContext.instance.dataloader = self._dataloader while True: self._dataloader.prepare_next_batch() if self._print_debug_info: print('> Next Batch:') print('\t' + '\n\t'.join(self._dataloader.get_current_batch_samples_names())) # check if the iteration is done if self._dataloader.finished_iteration(): break # Filling the placeholders for running the model self._dataloader.fill_placeholders(self._placeholders) # Running the model yield self._final_stage(self._placeholders) def __enter__(self): # self._dataloader.reset() self._placeholders = self._build_placeholders() def __exit__(self, exc_type, exc_value, traceback): for name in self._placeholders: self._placeholders[name] = None self._placeholders = None def _final_stage(self, model_input: Dict[str, torch.Tensor]) -> TReturn: return self._model(**model_input) def _build_placeholders(self) -> Dict[str, torch.Tensor]: """ Creates placeholders for feeding the model. Returns a dictionary containing placeholders. The placeholders are created based on the names of the input variables of the model's forward method. The variables initialized as None are assumed to be labels used for training phase only, and won't be added in phases other than train.""" model_forward_args, args_default_values, additional_kwargs = self._model.get_forward_required_kws_kwargs_and_defaults() # Checking which args have None as default args_default_values = ['Mandatory' for _ in range(len(model_forward_args) - len(args_default_values))] + \ args_default_values optional_args = [x is None for x in args_default_values] placeholders = dict() for argname in model_forward_args: placeholders[argname] = create_empty_placeholder(self._device) for argname in additional_kwargs: placeholders[argname] = create_empty_placeholder(self._device) return placeholders