You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_loader.py 9.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. """ A class for loading the whole data needed for the run! """
  2. from typing import Dict, List, Union, TYPE_CHECKING
  3. import random
  4. import numpy as np
  5. import torch
  6. from .batch_choosing.batch_chooser import BatchChooser
  7. from .content_loaders.content_loader import ContentLoader
  8. if TYPE_CHECKING:
  9. from ..configs.base_config import BaseConfig
  10. class RunType:
  11. TRAIN = 'train'
  12. VAL = 'val'
  13. TEST = 'test'
  14. class DataLoader():
  15. """ A class for loading the whole data needed for the run! """
  16. def __init__(self, conf: 'BaseConfig', data_specification: str, run_type: RunType):
  17. """ Conf is a dictionary containing configurations,
  18. sample specification is a string specifying the samples e.g. address of the CTs
  19. run type is one of the strings train/val/test specifying the mode that the data_loader
  20. will be used."""
  21. self.conf = conf
  22. self._device = conf.device
  23. self.data_specification = data_specification
  24. self.run_type = run_type
  25. self._content_loader: ContentLoader = self.conf.content_loader_cls(self.conf, data_specification)
  26. self.samples_names: np.ndarray = self._content_loader.get_samples_names()
  27. self.samples_labels = self._content_loader.get_samples_labels()
  28. keep_mask = self.get_samples_keep_mask()
  29. if keep_mask is not None:
  30. drop_mask = np.logical_not(keep_mask)
  31. self._content_loader.drop_samples(drop_mask)
  32. self.samples_names = self.samples_names[keep_mask]
  33. self.samples_labels = self.samples_labels[keep_mask]
  34. self.class_samples_indices = dict()
  35. print('%d samples' % len(self.samples_names), flush=True)
  36. for i in range(len(self.samples_names)):
  37. if self.samples_labels[i] not in self.class_samples_indices:
  38. self.class_samples_indices[self.samples_labels[i]] = []
  39. self.class_samples_indices[self.samples_labels[i]].append(i)
  40. for c in self.class_samples_indices.keys():
  41. self.class_samples_indices[c] = np.asarray(self.class_samples_indices[c])
  42. # for augmentationa
  43. # Only do augmentations in the training phase
  44. self._different_augmentation_per_batch = conf.different_augmentation_per_batch
  45. self._augmentations_dict: Union[None, Dict[str, torch.nn.Module]] = None
  46. if run_type == RunType.TRAIN:
  47. self._augmentations_dict = conf.augmentations_dict
  48. # for preventing reprocessing
  49. self._processed_data_dictionary: Dict[str, torch.Tensor] = dict()
  50. # for preserving same augmentations in one batch
  51. self._transformations_seeds_dict: Dict[torch.nn.Module, Union[int, np.ndarray]] = dict()
  52. if run_type == RunType.TRAIN:
  53. self._batch_chooser: BatchChooser = self.conf.train_batch_chooser_cls(self, conf.batch_size)
  54. else:
  55. self._batch_chooser: BatchChooser = self.conf.eval_batch_chooser_cls(self, conf.batch_size)
  56. def get_samples_names(self) -> np.ndarray:
  57. """ Returns the names of the samples based on the first content loader.
  58. IMPORTANT: These all must be the same for all of the content loaders."""
  59. return self.samples_names
  60. def get_samples_labels(self):
  61. """ Returns the labels of the samples in a numpy array. """
  62. return self.samples_labels
  63. def get_number_of_samples(self):
  64. """ Returns the number of samples loaded. """
  65. return len(self.samples_names)
  66. def get_class_sample_indices(self):
  67. """ Returns a dictionary, containing lists of samples indices belonging to each class label."""
  68. return self.class_samples_indices
  69. def prepare_next_batch(self):
  70. # Resetting information
  71. self._processed_data_dictionary = dict()
  72. self._transformations_seeds_dict = dict()
  73. self._batch_chooser.prepare_next_batch()
  74. def finished_iteration(self) -> bool:
  75. return self._batch_chooser.finished_iteration()
  76. def fill_placeholders(self, placeholders_dict: Dict[str, torch.Tensor]) -> None:
  77. """ Receives as input a dictionary of placeholders and fills them using all the content loaders."""
  78. missed_placeholders_names: List[str]
  79. if len(self._processed_data_dictionary) == 0:
  80. missed_placeholders_names = list(placeholders_dict.keys())
  81. else:
  82. missed_placeholders_names = []
  83. for k, v in placeholders_dict.items():
  84. if k in self._processed_data_dictionary:
  85. placeholders_dict[k] = self._processed_data_dictionary[k]
  86. else:
  87. missed_placeholders_names.append(k)
  88. if len(missed_placeholders_names) == 0:
  89. return
  90. new_batch_info = self._content_loader.fill_placeholders(
  91. missed_placeholders_names, self._batch_chooser.get_current_batch_sample_indices())
  92. # filling the unfilled ones
  93. if len(missed_placeholders_names) > 0:
  94. # filling all missed keys!
  95. with torch.no_grad():
  96. for k in missed_placeholders_names:
  97. self._fill_placeholder(placeholders_dict[k], new_batch_info[k])
  98. if self._augmentations_dict is not None and k in self._augmentations_dict:
  99. placeholders_dict[k] = self._apply_augmentation(k, placeholders_dict[k])
  100. # Keeping a copy of data
  101. self._processed_data_dictionary[k] = placeholders_dict[k]
  102. def get_current_batch_data(self, keyword: str) -> torch.Tensor:
  103. """ Builds placeholder and retrieves the requirement from loaders and returns it!
  104. Args:
  105. keyword (str): The keyword to load for the current batch.
  106. Returns:
  107. torch.Tensor: The information related to the keyword for the current batch.
  108. """
  109. if keyword in self._processed_data_dictionary:
  110. return self._processed_data_dictionary[keyword]
  111. else:
  112. placeholders_dict = {keyword: create_empty_placeholder(self._device)}
  113. self.fill_placeholders(placeholders_dict)
  114. return placeholders_dict[keyword]
  115. def get_current_batch_sample_indices(self) -> np.ndarray:
  116. """ Returns a list of indices of the samples chosen for the current batch. """
  117. return self._batch_chooser.get_current_batch_sample_indices()
  118. def get_max_class_samples_num(self):
  119. return max([len(x) for x in self.class_samples_indices.values()])
  120. def get_classes_num(self):
  121. return len(self.class_samples_indices)
  122. def get_samples_keep_mask(self) -> Union[None, np.ndarray]:
  123. if self.conf.mapped_labels_to_use is None:
  124. return None
  125. existing_dict = dict(zip(self.conf.mapped_labels_to_use, self.conf.mapped_labels_to_use))
  126. keep_mask = np.vectorize(lambda x: x in existing_dict)(self.samples_labels)
  127. print(f'Kept {np.sum(keep_mask)} out of {len(keep_mask)} samples after considering the labels of interest.')
  128. return keep_mask
  129. def get_current_batch_size(self) -> int:
  130. return self._batch_chooser.get_current_batch_size()
  131. def get_current_batch_samples_names(self) -> np.ndarray:
  132. return self.samples_names[self._batch_chooser.get_current_batch_sample_indices()]
  133. def get_current_batch_samples_labels(self) -> np.ndarray:
  134. return self.samples_labels[self._batch_chooser.get_current_batch_sample_indices()]
  135. def get_current_batch_samples_interpretations(self) -> np.ndarray:
  136. return self.get_current_batch_data('interpretations')
  137. @staticmethod
  138. def _fill_placeholder(placeholder: torch.Tensor, val: np.ndarray):
  139. """ Fills the torch placeholder with the given numpy value,
  140. If shape mismatches, resizes the placeholder so the data would fit. """
  141. # Resize if the shape mismatches
  142. if list(placeholder.shape) != list(val.shape):
  143. placeholder.resize_(*tuple(val.shape))
  144. # feeding the value
  145. placeholder.copy_(torch.Tensor(val))
  146. def _apply_augmentation(self, var_name: str, var_val: torch.Tensor) -> torch.Tensor:
  147. """ This function would be called for the variables we are sure they need augmentation
  148. (they are presented in augmentation dictionary), applies the specified augmentation,
  149. makes sure all augmentations be the same on the same batch elements,
  150. and returns the augmented data"""
  151. def run_single_aug(aug, inp):
  152. if type(aug) in self._transformations_seeds_dict:
  153. seed = self._transformations_seeds_dict[type(aug)]
  154. else:
  155. # setting a seed
  156. seed = np.random.randint(2147483647) # make a seed with numpy generator
  157. self._transformations_seeds_dict[type(aug)] = seed
  158. random.seed(seed)
  159. np.random.seed(seed)
  160. torch.manual_seed(seed)
  161. if self._different_augmentation_per_batch:
  162. ret_val = torch.stack([aug(inp[bi]) for bi in range(inp.shape[0])], dim=0)
  163. else:
  164. ret_val = aug(inp)
  165. return ret_val
  166. field_transformation = self._augmentations_dict[var_name]
  167. if not isinstance(field_transformation, torch.nn.Sequential):
  168. return run_single_aug(field_transformation, var_val)
  169. else:
  170. aug_val = var_val
  171. for single_transform in field_transformation.children():
  172. aug_val = run_single_aug(
  173. single_transform,
  174. aug_val)
  175. return aug_val
  176. def create_empty_placeholder(device: torch.device) -> torch.Tensor:
  177. """ Create empty placeholder with shape (1) in the given device.
  178. Args:
  179. device (torch.device): Device to create the placeholder on.
  180. Returns:
  181. torch.Tensor: Empty placeholder.
  182. """
  183. return torch.zeros(1, dtype=torch.float32,
  184. device=device,
  185. requires_grad=False)