""" A class for loading one content type needed for the run. """ from abc import abstractmethod, ABC from typing import TYPE_CHECKING, Dict, Callable, Union, List import numpy as np if TYPE_CHECKING: from .. import BaseConfig LoadFunction = Callable[[np.ndarray, Union[None, np.ndarray]], np.ndarray] class ContentLoader(ABC): """ A class for loading one content type needed for the run. """ def __init__(self, conf: 'BaseConfig', data_specification: str): """ Receives as input conf which is a dictionary containing configurations. This dictionary can also be used to pass fixed addresses for easing the usage! prefix_name is the str which all the variables that must be filled with this class have this prefix in their names so it would be clear that this class should fill it. data_specification is the string that specifies data and where it is! e.g. train, test, val""" self.conf = conf self.data_specification = data_specification self._fill_func_by_var_name: Union[None, Dict[str, LoadFunction]] = None @abstractmethod def get_samples_names(self): """ Returns a list containing names of all the samples of the content loader, each sample must owns a unique ID, and this function returns all this IDs. The order of the list must always be the same during one run. For example, this function can return an ID column of a table for TableLoader or the dir of images as ID for ImageLoader""" @abstractmethod def get_samples_labels(self): """ Returns list of labels of the whole samples. The order of the list must always be the same during one run.""" @abstractmethod def get_placeholder_name_to_fill_function_dict(self) -> Dict[str, LoadFunction]: """ Returns a dictionary of the placeholders' names (the ones this content loader supports) to the functions used for filling them. The functions must receive as input batch_samples_inds and batch_samples_elements_inds which defines the current batch, and return an array per placeholder name according to the receives batch information. IMPORTANT: Better to use a fixed prefix in the names of the placeholders to become clear which content loader they belong to! Some sort of having a mark :))!""" def fill_placeholders(self, keys: List[str], samples_inds: np.ndarray)\ -> Dict[str, Union[None, np.ndarray]]: """ Receives as input placeholders which is a dictionary of the placeholders' names to a torch tensor for filling it to feed the model with and samples_inds and samples_elements_inds, which contain information about the current batch. Fills placeholders based on the function dictionary received in get_placeholder_name_to_fill_function_dict.""" if self._fill_func_by_var_name is None: self._fill_func_by_var_name = self.get_placeholder_name_to_fill_function_dict() # Filling all the placeholders in the received dictionary! placeholders = dict() for placeholder_name in keys: if placeholder_name in self._fill_func_by_var_name: placeholder_v = self._fill_func_by_var_name[placeholder_name]( samples_inds) if placeholder_v is None: raise Exception('None value for key %s' % placeholder_name) placeholders[placeholder_name] = placeholder_v else: raise Exception(f'Unknown key for content loader: {placeholder_name}') return placeholders def get_batch_true_interpretations(self, samples_inds: np.ndarray, samples_elements_inds: Union[None, np.ndarray, List[np.ndarray]])\ -> np.ndarray: """ Receives the indices of samples and their elements (if elemented) in the current batch, returns the true interpretations as expected (Bounding box or a boolean mask) """ raise NotImplementedError('If you want to evaluate interpretations, you must implement this function in your content loader.') def drop_samples(self, drop_mask: np.ndarray) -> None: """ Receives a boolean drop mask and drops the samples whose mask are true, From now on the indices of samples are based on the new samples set (after elimination) """ raise NotImplementedError('If you want to filter samples by mapped labels, you should implement this function in your content loader.')