1234567891011121314151617181920212223242526272829303132333435363738 |
- from typing import TYPE_CHECKING
- from .batch_chooser import BatchChooser
- import numpy as np
- if TYPE_CHECKING:
- from ..data_loader import DataLoader
-
-
- class SequentialBatchChooser(BatchChooser):
-
- def __init__(self, data_loader: 'DataLoader', batch_size: int):
- super().__init__(data_loader, batch_size)
-
- self.cursor = 0
- self.desired_samples_inds = None
-
- def reset(self):
- super(SequentialBatchChooser, self).reset()
- self.cursor = 0
- self.desired_samples_inds = None
-
- def get_next_batch_sample_indices(self):
- """ Returns a list containing the indices of the samples chosen for the next batch."""
-
- if self.desired_samples_inds is None:
- self.desired_samples_inds = \
- np.arange(self.data_loader.get_number_of_samples())
-
- next_cursor = min(
- len(self.desired_samples_inds),
- self.cursor + self.batch_size)
-
- next_sample_inds = self.desired_samples_inds[
- self.cursor:next_cursor]
-
- self.cursor = next_cursor
-
- return next_sample_inds
-
|