1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- from typing import TYPE_CHECKING
- from .batch_chooser import BatchChooser
- import numpy as np
- if TYPE_CHECKING:
- from ..data_loader import DataLoader
-
-
- class ClassBalancedShuffledSequentialBatchChooser(BatchChooser):
-
- def __init__(self, data_loader: 'DataLoader', batch_size: int):
- super().__init__(data_loader, batch_size)
-
- self._cursor = 0
-
- # just copying to prevent any dependency
- for k, v in self.class_samples_indices.items():
- self.class_samples_indices[k] = np.copy(v)
- np.random.shuffle(self.class_samples_indices[k])
- self.classes_cursors = np.zeros(len(self.class_samples_indices), dtype=int)
-
- def get_next_batch_sample_indices(self):
- """ Returns a list containing the indices of the samples chosen for the next batch."""
-
- n_samples_per_class = self._get_n_samples_per_class()
-
- def sample_for_each_class(class_index):
-
- class_samples = self.class_samples_indices[
- self._classes_labels[class_index]]
-
- if self.classes_cursors[class_index] + \
- n_samples_per_class[class_index] <= \
- len(class_samples):
-
- ret_val = np.copy(class_samples[
- self.classes_cursors[class_index]:
- self.classes_cursors[class_index] + n_samples_per_class[class_index]
- ])
- self.classes_cursors[class_index] += n_samples_per_class[class_index]
-
- else:
- ret_val = np.copy(class_samples)[
- self.classes_cursors[class_index]:]
- np.random.shuffle(class_samples)
- self.classes_cursors[class_index] = \
- n_samples_per_class[class_index] - len(ret_val)
- ret_val = np.concatenate((
- ret_val,
- np.copy(class_samples
- [: n_samples_per_class[class_index] - len(ret_val)]
- )), axis=0)
-
- if self.classes_cursors[class_index] == len(class_samples):
- np.random.shuffle(class_samples)
- self.classes_cursors[class_index] = 0
-
- return ret_val
-
- chosen_sample_indices = np.concatenate(tuple([
- sample_for_each_class(ci) for ci in range(len(self._classes_labels))
- ]), axis=0)
-
- return chosen_sample_indices
-
- def _get_n_samples_per_class(self):
-
- # determining the number of samples per class
-
- n_samples_per_class = np.full(len(self.class_samples_indices),
- int(self.batch_size // len(self.class_samples_indices)))
-
- remaining_samles_cnt = self.batch_size - np.sum(n_samples_per_class)
-
- if remaining_samles_cnt:
-
- if self._cursor + remaining_samles_cnt >= len(self._classes_labels):
- n_samples_per_class[self._cursor: len(self._classes_labels)] += 1
- remaining_samles_cnt -= (len(self._classes_labels) - self._cursor)
- self._cursor = 0
-
- if remaining_samles_cnt > 0:
- n_samples_per_class[self._cursor: self._cursor + remaining_samles_cnt] += 1
- self._cursor += remaining_samles_cnt
-
- return n_samples_per_class
-
|