| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 | """
The BatchChooser
"""
from abc import abstractmethod
from typing import List, TYPE_CHECKING
import numpy as np
if TYPE_CHECKING:
    from ..data_loader import DataLoader
class BatchChooser:
    """
    The BatchChooser
    """
    def __init__(self, data_loader: 'DataLoader', batch_size: int):
        """ Receives as input a data_loader which contains the information about the samples and the desired batch size. """
        self.data_loader = data_loader
        self.batch_size = batch_size
        self.current_batch_sample_indices: np.ndarray = np.asarray([], dtype=int)
        self.completed_iteration = False
        self.class_samples_indices = self.data_loader.get_class_sample_indices()
        self._classes_labels: List[int] = list(self.class_samples_indices.keys())
    def prepare_next_batch(self):
        """
            Prepares the next batch based on its strategy
        """
        if self.finished_iteration():
            self.reset()
        self.current_batch_sample_indices = \
            self.get_next_batch_sample_indices().astype(int)
        if len(self.current_batch_sample_indices) == 0:
            self.completed_iteration = True
            return
    @abstractmethod
    def get_next_batch_sample_indices(self) -> np.ndarray:
        pass
    def get_current_batch_sample_indices(self) -> np.ndarray:
        """ Returns a list of indices of the samples chosen for the current batch. """
        return self.current_batch_sample_indices
    def finished_iteration(self):
        """ Returns True if iteration is finished over all the slices of all the samples, False otherwise"""
        return self.completed_iteration
    def reset(self):
        """ Resets the sample iterator. """
        self.completed_iteration = False
        self.current_batch_sample_indices = np.asarray([], dtype=int)
    def get_current_batch_size(self) -> int:
        return len(self.current_batch_sample_indices)
 |