12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- """
- The BatchChooser
- """
- from abc import abstractmethod
- from typing import List, TYPE_CHECKING
- import numpy as np
- if TYPE_CHECKING:
- from ..data_loader import DataLoader
-
-
-
- class BatchChooser:
- """
- The BatchChooser
- """
-
- def __init__(self, data_loader: 'DataLoader', batch_size: int):
- """ Receives as input a data_loader which contains the information about the samples and the desired batch size. """
-
- self.data_loader = data_loader
- self.batch_size = batch_size
- self.current_batch_sample_indices: np.ndarray = np.asarray([], dtype=int)
- self.completed_iteration = False
-
- self.class_samples_indices = self.data_loader.get_class_sample_indices()
- self._classes_labels: List[int] = list(self.class_samples_indices.keys())
-
- def prepare_next_batch(self):
- """
- Prepares the next batch based on its strategy
- """
-
- if self.finished_iteration():
- self.reset()
-
- self.current_batch_sample_indices = \
- self.get_next_batch_sample_indices().astype(int)
-
- if len(self.current_batch_sample_indices) == 0:
- self.completed_iteration = True
- return
-
- @abstractmethod
- def get_next_batch_sample_indices(self) -> np.ndarray:
- pass
-
- def get_current_batch_sample_indices(self) -> np.ndarray:
- """ Returns a list of indices of the samples chosen for the current batch. """
- return self.current_batch_sample_indices
-
- def finished_iteration(self):
- """ Returns True if iteration is finished over all the slices of all the samples, False otherwise"""
- return self.completed_iteration
-
- def reset(self):
- """ Resets the sample iterator. """
- self.completed_iteration = False
- self.current_batch_sample_indices = np.asarray([], dtype=int)
-
- def get_current_batch_size(self) -> int:
- return len(self.current_batch_sample_indices)
|