123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246 |
- """ A class for loading the whole data needed for the run! """
- from typing import Dict, List, Union, TYPE_CHECKING
- import random
- import numpy as np
- import torch
-
- from .batch_choosing.batch_chooser import BatchChooser
- from .content_loaders.content_loader import ContentLoader
- if TYPE_CHECKING:
- from ..configs.base_config import BaseConfig
-
-
- class RunType:
- TRAIN = 'train'
- VAL = 'val'
- TEST = 'test'
-
-
- class DataLoader():
- """ A class for loading the whole data needed for the run! """
-
- def __init__(self, conf: 'BaseConfig', data_specification: str, run_type: RunType):
- """ Conf is a dictionary containing configurations,
- sample specification is a string specifying the samples e.g. address of the CTs
- run type is one of the strings train/val/test specifying the mode that the data_loader
- will be used."""
-
- self.conf = conf
- self._device = conf.device
- self.data_specification = data_specification
- self.run_type = run_type
-
- self._content_loader: ContentLoader = self.conf.content_loader_cls(self.conf, data_specification)
-
- self.samples_names: np.ndarray = self._content_loader.get_samples_names()
- self.samples_labels = self._content_loader.get_samples_labels()
-
- keep_mask = self.get_samples_keep_mask()
- if keep_mask is not None:
- drop_mask = np.logical_not(keep_mask)
- self._content_loader.drop_samples(drop_mask)
- self.samples_names = self.samples_names[keep_mask]
- self.samples_labels = self.samples_labels[keep_mask]
-
- self.class_samples_indices = dict()
- print('%d samples' % len(self.samples_names), flush=True)
- for i in range(len(self.samples_names)):
- if self.samples_labels[i] not in self.class_samples_indices:
- self.class_samples_indices[self.samples_labels[i]] = []
- self.class_samples_indices[self.samples_labels[i]].append(i)
-
- for c in self.class_samples_indices.keys():
- self.class_samples_indices[c] = np.asarray(self.class_samples_indices[c])
-
- # for augmentationa
- # Only do augmentations in the training phase
- self._different_augmentation_per_batch = conf.different_augmentation_per_batch
- self._augmentations_dict: Union[None, Dict[str, torch.nn.Module]] = None
- if run_type == RunType.TRAIN:
- self._augmentations_dict = conf.augmentations_dict
-
- # for preventing reprocessing
- self._processed_data_dictionary: Dict[str, torch.Tensor] = dict()
-
- # for preserving same augmentations in one batch
- self._transformations_seeds_dict: Dict[torch.nn.Module, Union[int, np.ndarray]] = dict()
-
- if run_type == RunType.TRAIN:
- self._batch_chooser: BatchChooser = self.conf.train_batch_chooser_cls(self, conf.batch_size)
- else:
- self._batch_chooser: BatchChooser = self.conf.eval_batch_chooser_cls(self, conf.batch_size)
-
- def get_samples_names(self) -> np.ndarray:
- """ Returns the names of the samples based on the first content loader.
- IMPORTANT: These all must be the same for all of the content loaders."""
- return self.samples_names
-
- def get_samples_labels(self):
- """ Returns the labels of the samples in a numpy array. """
- return self.samples_labels
-
- def get_number_of_samples(self):
- """ Returns the number of samples loaded. """
- return len(self.samples_names)
-
- def get_class_sample_indices(self):
- """ Returns a dictionary, containing lists of samples indices belonging to each class label."""
- return self.class_samples_indices
-
- def prepare_next_batch(self):
- # Resetting information
- self._processed_data_dictionary = dict()
- self._transformations_seeds_dict = dict()
-
- self._batch_chooser.prepare_next_batch()
-
- def finished_iteration(self) -> bool:
- return self._batch_chooser.finished_iteration()
-
- def fill_placeholders(self, placeholders_dict: Dict[str, torch.Tensor]) -> None:
- """ Receives as input a dictionary of placeholders and fills them using all the content loaders."""
-
- missed_placeholders_names: List[str]
- if len(self._processed_data_dictionary) == 0:
- missed_placeholders_names = list(placeholders_dict.keys())
- else:
- missed_placeholders_names = []
- for k, v in placeholders_dict.items():
- if k in self._processed_data_dictionary:
- placeholders_dict[k] = self._processed_data_dictionary[k]
- else:
- missed_placeholders_names.append(k)
-
- if len(missed_placeholders_names) == 0:
- return
-
- new_batch_info = self._content_loader.fill_placeholders(
- missed_placeholders_names, self._batch_chooser.get_current_batch_sample_indices())
-
- # filling the unfilled ones
- if len(missed_placeholders_names) > 0:
-
- # filling all missed keys!
- with torch.no_grad():
- for k in missed_placeholders_names:
-
- self._fill_placeholder(placeholders_dict[k], new_batch_info[k])
-
- if self._augmentations_dict is not None and k in self._augmentations_dict:
- placeholders_dict[k] = self._apply_augmentation(k, placeholders_dict[k])
-
- # Keeping a copy of data
- self._processed_data_dictionary[k] = placeholders_dict[k]
-
- def get_current_batch_data(self, keyword: str) -> torch.Tensor:
- """ Builds placeholder and retrieves the requirement from loaders and returns it!
-
- Args:
- keyword (str): The keyword to load for the current batch.
-
- Returns:
- torch.Tensor: The information related to the keyword for the current batch.
- """
- if keyword in self._processed_data_dictionary:
- return self._processed_data_dictionary[keyword]
- else:
- placeholders_dict = {keyword: create_empty_placeholder(self._device)}
- self.fill_placeholders(placeholders_dict)
- return placeholders_dict[keyword]
-
- def get_current_batch_sample_indices(self) -> np.ndarray:
- """ Returns a list of indices of the samples chosen for the current batch. """
- return self._batch_chooser.get_current_batch_sample_indices()
-
- def get_max_class_samples_num(self):
- return max([len(x) for x in self.class_samples_indices.values()])
-
- def get_classes_num(self):
- return len(self.class_samples_indices)
-
- def get_samples_keep_mask(self) -> Union[None, np.ndarray]:
- if self.conf.mapped_labels_to_use is None:
- return None
-
- existing_dict = dict(zip(self.conf.mapped_labels_to_use, self.conf.mapped_labels_to_use))
- keep_mask = np.vectorize(lambda x: x in existing_dict)(self.samples_labels)
-
- print(f'Kept {np.sum(keep_mask)} out of {len(keep_mask)} samples after considering the labels of interest.')
-
- return keep_mask
-
- def get_current_batch_size(self) -> int:
- return self._batch_chooser.get_current_batch_size()
-
- def get_current_batch_samples_names(self) -> np.ndarray:
- return self.samples_names[self._batch_chooser.get_current_batch_sample_indices()]
-
- def get_current_batch_samples_labels(self) -> np.ndarray:
- return self.samples_labels[self._batch_chooser.get_current_batch_sample_indices()]
-
- def get_current_batch_samples_interpretations(self) -> np.ndarray:
- return self.get_current_batch_data('interpretations')
-
- @staticmethod
- def _fill_placeholder(placeholder: torch.Tensor, val: np.ndarray):
- """ Fills the torch placeholder with the given numpy value,
- If shape mismatches, resizes the placeholder so the data would fit. """
-
- # Resize if the shape mismatches
- if list(placeholder.shape) != list(val.shape):
- placeholder.resize_(*tuple(val.shape))
-
- # feeding the value
- placeholder.copy_(torch.Tensor(val))
-
- def _apply_augmentation(self, var_name: str, var_val: torch.Tensor) -> torch.Tensor:
- """ This function would be called for the variables we are sure they need augmentation
- (they are presented in augmentation dictionary), applies the specified augmentation,
- makes sure all augmentations be the same on the same batch elements,
- and returns the augmented data"""
-
- def run_single_aug(aug, inp):
- if type(aug) in self._transformations_seeds_dict:
- seed = self._transformations_seeds_dict[type(aug)]
- else:
- # setting a seed
- seed = np.random.randint(2147483647) # make a seed with numpy generator
- self._transformations_seeds_dict[type(aug)] = seed
-
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
-
- if self._different_augmentation_per_batch:
- ret_val = torch.stack([aug(inp[bi]) for bi in range(inp.shape[0])], dim=0)
- else:
- ret_val = aug(inp)
- return ret_val
-
- field_transformation = self._augmentations_dict[var_name]
-
- if not isinstance(field_transformation, torch.nn.Sequential):
- return run_single_aug(field_transformation, var_val)
-
- else:
- aug_val = var_val
- for single_transform in field_transformation.children():
- aug_val = run_single_aug(
- single_transform,
- aug_val)
-
- return aug_val
-
-
- def create_empty_placeholder(device: torch.device) -> torch.Tensor:
- """ Create empty placeholder with shape (1) in the given device.
-
- Args:
- device (torch.device): Device to create the placeholder on.
-
- Returns:
- torch.Tensor: Empty placeholder.
- """
- return torch.zeros(1, dtype=torch.float32,
- device=device,
- requires_grad=False)
|