You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

batch_chooser.py 1.9KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. """
  2. The BatchChooser
  3. """
  4. from abc import abstractmethod
  5. from typing import List, TYPE_CHECKING
  6. import numpy as np
  7. if TYPE_CHECKING:
  8. from ..data_loader import DataLoader
  9. class BatchChooser:
  10. """
  11. The BatchChooser
  12. """
  13. def __init__(self, data_loader: 'DataLoader', batch_size: int):
  14. """ Receives as input a data_loader which contains the information about the samples and the desired batch size. """
  15. self.data_loader = data_loader
  16. self.batch_size = batch_size
  17. self.current_batch_sample_indices: np.ndarray = np.asarray([], dtype=int)
  18. self.completed_iteration = False
  19. self.class_samples_indices = self.data_loader.get_class_sample_indices()
  20. self._classes_labels: List[int] = list(self.class_samples_indices.keys())
  21. def prepare_next_batch(self):
  22. """
  23. Prepares the next batch based on its strategy
  24. """
  25. if self.finished_iteration():
  26. self.reset()
  27. self.current_batch_sample_indices = \
  28. self.get_next_batch_sample_indices().astype(int)
  29. if len(self.current_batch_sample_indices) == 0:
  30. self.completed_iteration = True
  31. return
  32. @abstractmethod
  33. def get_next_batch_sample_indices(self) -> np.ndarray:
  34. pass
  35. def get_current_batch_sample_indices(self) -> np.ndarray:
  36. """ Returns a list of indices of the samples chosen for the current batch. """
  37. return self.current_batch_sample_indices
  38. def finished_iteration(self):
  39. """ Returns True if iteration is finished over all the slices of all the samples, False otherwise"""
  40. return self.completed_iteration
  41. def reset(self):
  42. """ Resets the sample iterator. """
  43. self.completed_iteration = False
  44. self.current_batch_sample_indices = np.asarray([], dtype=int)
  45. def get_current_batch_size(self) -> int:
  46. return len(self.current_batch_sample_indices)