You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

content_loader.py 4.5KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. """
  2. A class for loading one content type needed for the run.
  3. """
  4. from abc import abstractmethod, ABC
  5. from typing import TYPE_CHECKING, Dict, Callable, Union, List
  6. import numpy as np
  7. if TYPE_CHECKING:
  8. from .. import BaseConfig
  9. LoadFunction = Callable[[np.ndarray, Union[None, np.ndarray]], np.ndarray]
  10. class ContentLoader(ABC):
  11. """ A class for loading one content type needed for the run. """
  12. def __init__(self, conf: 'BaseConfig', data_specification: str):
  13. """ Receives as input conf which is a dictionary containing configurations.
  14. This dictionary can also be used to pass fixed addresses for easing the usage!
  15. prefix_name is the str which all the variables that must be filled with this class have
  16. this prefix in their names so it would be clear that this class should fill it.
  17. data_specification is the string that specifies data and where it is! e.g. train, test, val"""
  18. self.conf = conf
  19. self.data_specification = data_specification
  20. self._fill_func_by_var_name: Union[None, Dict[str, LoadFunction]] = None
  21. @abstractmethod
  22. def get_samples_names(self):
  23. """ Returns a list containing names of all the samples of the content loader,
  24. each sample must owns a unique ID, and this function returns all this IDs.
  25. The order of the list must always be the same during one run.
  26. For example, this function can return an ID column of a table for TableLoader
  27. or the dir of images as ID for ImageLoader"""
  28. @abstractmethod
  29. def get_samples_labels(self):
  30. """ Returns list of labels of the whole samples.
  31. The order of the list must always be the same during one run."""
  32. @abstractmethod
  33. def get_placeholder_name_to_fill_function_dict(self) -> Dict[str, LoadFunction]:
  34. """ Returns a dictionary of the placeholders' names (the ones this content loader supports)
  35. to the functions used for filling them. The functions must receive as input
  36. batch_samples_inds and batch_samples_elements_inds which defines the current batch,
  37. and return an array per placeholder name according to the receives batch information.
  38. IMPORTANT: Better to use a fixed prefix in the names of the placeholders to become clear which content loader
  39. they belong to! Some sort of having a mark :))!"""
  40. def fill_placeholders(self,
  41. keys: List[str],
  42. samples_inds: np.ndarray)\
  43. -> Dict[str, Union[None, np.ndarray]]:
  44. """ Receives as input placeholders which is a dictionary of the placeholders'
  45. names to a torch tensor for filling it to feed the model with and
  46. samples_inds and samples_elements_inds, which contain
  47. information about the current batch. Fills placeholders based on the
  48. function dictionary received in get_placeholder_name_to_fill_function_dict."""
  49. if self._fill_func_by_var_name is None:
  50. self._fill_func_by_var_name = self.get_placeholder_name_to_fill_function_dict()
  51. # Filling all the placeholders in the received dictionary!
  52. placeholders = dict()
  53. for placeholder_name in keys:
  54. if placeholder_name in self._fill_func_by_var_name:
  55. placeholder_v = self._fill_func_by_var_name[placeholder_name](
  56. samples_inds)
  57. if placeholder_v is None:
  58. raise Exception('None value for key %s' % placeholder_name)
  59. placeholders[placeholder_name] = placeholder_v
  60. else:
  61. raise Exception(f'Unknown key for content loader: {placeholder_name}')
  62. return placeholders
  63. def get_batch_true_interpretations(self, samples_inds: np.ndarray,
  64. samples_elements_inds: Union[None, np.ndarray, List[np.ndarray]])\
  65. -> np.ndarray:
  66. """
  67. Receives the indices of samples and their elements (if elemented) in the current batch,
  68. returns the true interpretations as expected (Bounding box or a boolean mask)
  69. """
  70. raise NotImplementedError('If you want to evaluate interpretations, you must implement this function in your content loader.')
  71. def drop_samples(self, drop_mask: np.ndarray) -> None:
  72. """
  73. Receives a boolean drop mask and drops the samples whose mask are true,
  74. From now on the indices of samples are based on the new samples set (after elimination)
  75. """
  76. raise NotImplementedError('If you want to filter samples by mapped labels, you should implement this function in your content loader.')