""" The BatchChooser """ from abc import abstractmethod from typing import List, TYPE_CHECKING import numpy as np if TYPE_CHECKING: from ..data_loader import DataLoader class BatchChooser: """ The BatchChooser """ def __init__(self, data_loader: 'DataLoader', batch_size: int): """ Receives as input a data_loader which contains the information about the samples and the desired batch size. """ self.data_loader = data_loader self.batch_size = batch_size self.current_batch_sample_indices: np.ndarray = np.asarray([], dtype=int) self.completed_iteration = False self.class_samples_indices = self.data_loader.get_class_sample_indices() self._classes_labels: List[int] = list(self.class_samples_indices.keys()) def prepare_next_batch(self): """ Prepares the next batch based on its strategy """ if self.finished_iteration(): self.reset() self.current_batch_sample_indices = \ self.get_next_batch_sample_indices().astype(int) if len(self.current_batch_sample_indices) == 0: self.completed_iteration = True return @abstractmethod def get_next_batch_sample_indices(self) -> np.ndarray: pass def get_current_batch_sample_indices(self) -> np.ndarray: """ Returns a list of indices of the samples chosen for the current batch. """ return self.current_batch_sample_indices def finished_iteration(self): """ Returns True if iteration is finished over all the slices of all the samples, False otherwise""" return self.completed_iteration def reset(self): """ Resets the sample iterator. """ self.completed_iteration = False self.current_batch_sample_indices = np.asarray([], dtype=int) def get_current_batch_size(self) -> int: return len(self.current_batch_sample_indices)