|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394 |
- """
- A class for loading one content type needed for the run.
- """
- from abc import abstractmethod, ABC
- from typing import TYPE_CHECKING, Dict, Callable, Union, List
-
- import numpy as np
-
- if TYPE_CHECKING:
- from .. import BaseConfig
-
- LoadFunction = Callable[[np.ndarray, Union[None, np.ndarray]], np.ndarray]
-
-
- class ContentLoader(ABC):
- """ A class for loading one content type needed for the run. """
-
- def __init__(self, conf: 'BaseConfig', data_specification: str):
- """ Receives as input conf which is a dictionary containing configurations.
- This dictionary can also be used to pass fixed addresses for easing the usage!
- prefix_name is the str which all the variables that must be filled with this class have
- this prefix in their names so it would be clear that this class should fill it.
- data_specification is the string that specifies data and where it is! e.g. train, test, val"""
-
- self.conf = conf
- self.data_specification = data_specification
- self._fill_func_by_var_name: Union[None, Dict[str, LoadFunction]] = None
-
- @abstractmethod
- def get_samples_names(self):
- """ Returns a list containing names of all the samples of the content loader,
- each sample must owns a unique ID, and this function returns all this IDs.
- The order of the list must always be the same during one run.
- For example, this function can return an ID column of a table for TableLoader
- or the dir of images as ID for ImageLoader"""
-
- @abstractmethod
- def get_samples_labels(self):
- """ Returns list of labels of the whole samples.
- The order of the list must always be the same during one run."""
-
- @abstractmethod
- def get_placeholder_name_to_fill_function_dict(self) -> Dict[str, LoadFunction]:
- """ Returns a dictionary of the placeholders' names (the ones this content loader supports)
- to the functions used for filling them. The functions must receive as input
- batch_samples_inds and batch_samples_elements_inds which defines the current batch,
- and return an array per placeholder name according to the receives batch information.
- IMPORTANT: Better to use a fixed prefix in the names of the placeholders to become clear which content loader
- they belong to! Some sort of having a mark :))!"""
-
- def fill_placeholders(self,
- keys: List[str],
- samples_inds: np.ndarray)\
- -> Dict[str, Union[None, np.ndarray]]:
- """ Receives as input placeholders which is a dictionary of the placeholders'
- names to a torch tensor for filling it to feed the model with and
- samples_inds and samples_elements_inds, which contain
- information about the current batch. Fills placeholders based on the
- function dictionary received in get_placeholder_name_to_fill_function_dict."""
-
- if self._fill_func_by_var_name is None:
- self._fill_func_by_var_name = self.get_placeholder_name_to_fill_function_dict()
-
- # Filling all the placeholders in the received dictionary!
- placeholders = dict()
-
- for placeholder_name in keys:
- if placeholder_name in self._fill_func_by_var_name:
- placeholder_v = self._fill_func_by_var_name[placeholder_name](
- samples_inds)
- if placeholder_v is None:
- raise Exception('None value for key %s' % placeholder_name)
-
- placeholders[placeholder_name] = placeholder_v
- else:
- raise Exception(f'Unknown key for content loader: {placeholder_name}')
-
- return placeholders
-
- def get_batch_true_interpretations(self, samples_inds: np.ndarray,
- samples_elements_inds: Union[None, np.ndarray, List[np.ndarray]])\
- -> np.ndarray:
- """
- Receives the indices of samples and their elements (if elemented) in the current batch,
- returns the true interpretations as expected (Bounding box or a boolean mask)
- """
- raise NotImplementedError('If you want to evaluate interpretations, you must implement this function in your content loader.')
-
- def drop_samples(self, drop_mask: np.ndarray) -> None:
- """
- Receives a boolean drop mask and drops the samples whose mask are true,
- From now on the indices of samples are based on the new samples set (after elimination)
- """
- raise NotImplementedError('If you want to filter samples by mapped labels, you should implement this function in your content loader.')
|