""" A class for loading the whole data needed for the run! """ from typing import Dict, List, Union, TYPE_CHECKING import random import numpy as np import torch from .batch_choosing.batch_chooser import BatchChooser from .content_loaders.content_loader import ContentLoader if TYPE_CHECKING: from ..configs.base_config import BaseConfig class RunType: TRAIN = 'train' VAL = 'val' TEST = 'test' class DataLoader(): """ A class for loading the whole data needed for the run! """ def __init__(self, conf: 'BaseConfig', data_specification: str, run_type: RunType): """ Conf is a dictionary containing configurations, sample specification is a string specifying the samples e.g. address of the CTs run type is one of the strings train/val/test specifying the mode that the data_loader will be used.""" self.conf = conf self._device = conf.device self.data_specification = data_specification self.run_type = run_type self._content_loader: ContentLoader = self.conf.content_loader_cls(self.conf, data_specification) self.samples_names: np.ndarray = self._content_loader.get_samples_names() self.samples_labels = self._content_loader.get_samples_labels() keep_mask = self.get_samples_keep_mask() if keep_mask is not None: drop_mask = np.logical_not(keep_mask) self._content_loader.drop_samples(drop_mask) self.samples_names = self.samples_names[keep_mask] self.samples_labels = self.samples_labels[keep_mask] self.class_samples_indices = dict() print('%d samples' % len(self.samples_names), flush=True) for i in range(len(self.samples_names)): if self.samples_labels[i] not in self.class_samples_indices: self.class_samples_indices[self.samples_labels[i]] = [] self.class_samples_indices[self.samples_labels[i]].append(i) for c in self.class_samples_indices.keys(): self.class_samples_indices[c] = np.asarray(self.class_samples_indices[c]) # for augmentationa # Only do augmentations in the training phase self._different_augmentation_per_batch = conf.different_augmentation_per_batch self._augmentations_dict: Union[None, Dict[str, torch.nn.Module]] = None if run_type == RunType.TRAIN: self._augmentations_dict = conf.augmentations_dict # for preventing reprocessing self._processed_data_dictionary: Dict[str, torch.Tensor] = dict() # for preserving same augmentations in one batch self._transformations_seeds_dict: Dict[torch.nn.Module, Union[int, np.ndarray]] = dict() if run_type == RunType.TRAIN: self._batch_chooser: BatchChooser = self.conf.train_batch_chooser_cls(self, conf.batch_size) else: self._batch_chooser: BatchChooser = self.conf.eval_batch_chooser_cls(self, conf.batch_size) def get_samples_names(self) -> np.ndarray: """ Returns the names of the samples based on the first content loader. IMPORTANT: These all must be the same for all of the content loaders.""" return self.samples_names def get_samples_labels(self): """ Returns the labels of the samples in a numpy array. """ return self.samples_labels def get_number_of_samples(self): """ Returns the number of samples loaded. """ return len(self.samples_names) def get_class_sample_indices(self): """ Returns a dictionary, containing lists of samples indices belonging to each class label.""" return self.class_samples_indices def prepare_next_batch(self): # Resetting information self._processed_data_dictionary = dict() self._transformations_seeds_dict = dict() self._batch_chooser.prepare_next_batch() def finished_iteration(self) -> bool: return self._batch_chooser.finished_iteration() def fill_placeholders(self, placeholders_dict: Dict[str, torch.Tensor]) -> None: """ Receives as input a dictionary of placeholders and fills them using all the content loaders.""" missed_placeholders_names: List[str] if len(self._processed_data_dictionary) == 0: missed_placeholders_names = list(placeholders_dict.keys()) else: missed_placeholders_names = [] for k, v in placeholders_dict.items(): if k in self._processed_data_dictionary: placeholders_dict[k] = self._processed_data_dictionary[k] else: missed_placeholders_names.append(k) if len(missed_placeholders_names) == 0: return new_batch_info = self._content_loader.fill_placeholders( missed_placeholders_names, self._batch_chooser.get_current_batch_sample_indices()) # filling the unfilled ones if len(missed_placeholders_names) > 0: # filling all missed keys! with torch.no_grad(): for k in missed_placeholders_names: self._fill_placeholder(placeholders_dict[k], new_batch_info[k]) if self._augmentations_dict is not None and k in self._augmentations_dict: placeholders_dict[k] = self._apply_augmentation(k, placeholders_dict[k]) # Keeping a copy of data self._processed_data_dictionary[k] = placeholders_dict[k] def get_current_batch_data(self, keyword: str) -> torch.Tensor: """ Builds placeholder and retrieves the requirement from loaders and returns it! Args: keyword (str): The keyword to load for the current batch. Returns: torch.Tensor: The information related to the keyword for the current batch. """ if keyword in self._processed_data_dictionary: return self._processed_data_dictionary[keyword] else: placeholders_dict = {keyword: create_empty_placeholder(self._device)} self.fill_placeholders(placeholders_dict) return placeholders_dict[keyword] def get_current_batch_sample_indices(self) -> np.ndarray: """ Returns a list of indices of the samples chosen for the current batch. """ return self._batch_chooser.get_current_batch_sample_indices() def get_max_class_samples_num(self): return max([len(x) for x in self.class_samples_indices.values()]) def get_classes_num(self): return len(self.class_samples_indices) def get_samples_keep_mask(self) -> Union[None, np.ndarray]: if self.conf.mapped_labels_to_use is None: return None existing_dict = dict(zip(self.conf.mapped_labels_to_use, self.conf.mapped_labels_to_use)) keep_mask = np.vectorize(lambda x: x in existing_dict)(self.samples_labels) print(f'Kept {np.sum(keep_mask)} out of {len(keep_mask)} samples after considering the labels of interest.') return keep_mask def get_current_batch_size(self) -> int: return self._batch_chooser.get_current_batch_size() def get_current_batch_samples_names(self) -> np.ndarray: return self.samples_names[self._batch_chooser.get_current_batch_sample_indices()] def get_current_batch_samples_labels(self) -> np.ndarray: return self.samples_labels[self._batch_chooser.get_current_batch_sample_indices()] def get_current_batch_samples_interpretations(self) -> np.ndarray: return self.get_current_batch_data('interpretations') @staticmethod def _fill_placeholder(placeholder: torch.Tensor, val: np.ndarray): """ Fills the torch placeholder with the given numpy value, If shape mismatches, resizes the placeholder so the data would fit. """ # Resize if the shape mismatches if list(placeholder.shape) != list(val.shape): placeholder.resize_(*tuple(val.shape)) # feeding the value placeholder.copy_(torch.Tensor(val)) def _apply_augmentation(self, var_name: str, var_val: torch.Tensor) -> torch.Tensor: """ This function would be called for the variables we are sure they need augmentation (they are presented in augmentation dictionary), applies the specified augmentation, makes sure all augmentations be the same on the same batch elements, and returns the augmented data""" def run_single_aug(aug, inp): if type(aug) in self._transformations_seeds_dict: seed = self._transformations_seeds_dict[type(aug)] else: # setting a seed seed = np.random.randint(2147483647) # make a seed with numpy generator self._transformations_seeds_dict[type(aug)] = seed random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if self._different_augmentation_per_batch: ret_val = torch.stack([aug(inp[bi]) for bi in range(inp.shape[0])], dim=0) else: ret_val = aug(inp) return ret_val field_transformation = self._augmentations_dict[var_name] if not isinstance(field_transformation, torch.nn.Sequential): return run_single_aug(field_transformation, var_val) else: aug_val = var_val for single_transform in field_transformation.children(): aug_val = run_single_aug( single_transform, aug_val) return aug_val def create_empty_placeholder(device: torch.device) -> torch.Tensor: """ Create empty placeholder with shape (1) in the given device. Args: device (torch.device): Device to create the placeholder on. Returns: torch.Tensor: Empty placeholder. """ return torch.zeros(1, dtype=torch.float32, device=device, requires_grad=False)