You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

sequential.py 1.1KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. from typing import TYPE_CHECKING
  2. from .batch_chooser import BatchChooser
  3. import numpy as np
  4. if TYPE_CHECKING:
  5. from ..data_loader import DataLoader
  6. class SequentialBatchChooser(BatchChooser):
  7. def __init__(self, data_loader: 'DataLoader', batch_size: int):
  8. super().__init__(data_loader, batch_size)
  9. self.cursor = 0
  10. self.desired_samples_inds = None
  11. def reset(self):
  12. super(SequentialBatchChooser, self).reset()
  13. self.cursor = 0
  14. self.desired_samples_inds = None
  15. def get_next_batch_sample_indices(self):
  16. """ Returns a list containing the indices of the samples chosen for the next batch."""
  17. if self.desired_samples_inds is None:
  18. self.desired_samples_inds = \
  19. np.arange(self.data_loader.get_number_of_samples())
  20. next_cursor = min(
  21. len(self.desired_samples_inds),
  22. self.cursor + self.batch_size)
  23. next_sample_inds = self.desired_samples_inds[
  24. self.cursor:next_cursor]
  25. self.cursor = next_cursor
  26. return next_sample_inds