from typing import TYPE_CHECKING from .batch_chooser import BatchChooser import numpy as np if TYPE_CHECKING: from ..data_loader import DataLoader class SequentialBatchChooser(BatchChooser): def __init__(self, data_loader: 'DataLoader', batch_size: int): super().__init__(data_loader, batch_size) self.cursor = 0 self.desired_samples_inds = None def reset(self): super(SequentialBatchChooser, self).reset() self.cursor = 0 self.desired_samples_inds = None def get_next_batch_sample_indices(self): """ Returns a list containing the indices of the samples chosen for the next batch.""" if self.desired_samples_inds is None: self.desired_samples_inds = \ np.arange(self.data_loader.get_number_of_samples()) next_cursor = min( len(self.desired_samples_inds), self.cursor + self.batch_size) next_sample_inds = self.desired_samples_inds[ self.cursor:next_cursor] self.cursor = next_cursor return next_sample_inds