from typing import TYPE_CHECKING from .batch_chooser import BatchChooser import numpy as np if TYPE_CHECKING: from ..data_loader import DataLoader class ClassBalancedShuffledSequentialBatchChooser(BatchChooser): def __init__(self, data_loader: 'DataLoader', batch_size: int): super().__init__(data_loader, batch_size) self._cursor = 0 # just copying to prevent any dependency for k, v in self.class_samples_indices.items(): self.class_samples_indices[k] = np.copy(v) np.random.shuffle(self.class_samples_indices[k]) self.classes_cursors = np.zeros(len(self.class_samples_indices), dtype=int) def get_next_batch_sample_indices(self): """ Returns a list containing the indices of the samples chosen for the next batch.""" n_samples_per_class = self._get_n_samples_per_class() def sample_for_each_class(class_index): class_samples = self.class_samples_indices[ self._classes_labels[class_index]] if self.classes_cursors[class_index] + \ n_samples_per_class[class_index] <= \ len(class_samples): ret_val = np.copy(class_samples[ self.classes_cursors[class_index]: self.classes_cursors[class_index] + n_samples_per_class[class_index] ]) self.classes_cursors[class_index] += n_samples_per_class[class_index] else: ret_val = np.copy(class_samples)[ self.classes_cursors[class_index]:] np.random.shuffle(class_samples) self.classes_cursors[class_index] = \ n_samples_per_class[class_index] - len(ret_val) ret_val = np.concatenate(( ret_val, np.copy(class_samples [: n_samples_per_class[class_index] - len(ret_val)] )), axis=0) if self.classes_cursors[class_index] == len(class_samples): np.random.shuffle(class_samples) self.classes_cursors[class_index] = 0 return ret_val chosen_sample_indices = np.concatenate(tuple([ sample_for_each_class(ci) for ci in range(len(self._classes_labels)) ]), axis=0) return chosen_sample_indices def _get_n_samples_per_class(self): # determining the number of samples per class n_samples_per_class = np.full(len(self.class_samples_indices), int(self.batch_size // len(self.class_samples_indices))) remaining_samles_cnt = self.batch_size - np.sum(n_samples_per_class) if remaining_samles_cnt: if self._cursor + remaining_samles_cnt >= len(self._classes_labels): n_samples_per_class[self._cursor: len(self._classes_labels)] += 1 remaining_samles_cnt -= (len(self._classes_labels) - self._cursor) self._cursor = 0 if remaining_samles_cnt > 0: n_samples_per_class[self._cursor: self._cursor + remaining_samles_cnt] += 1 self._cursor += remaining_samles_cnt return n_samples_per_class