You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

class_balanced_shuffled_sequential.py 3.3KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from typing import TYPE_CHECKING
  2. from .batch_chooser import BatchChooser
  3. import numpy as np
  4. if TYPE_CHECKING:
  5. from ..data_loader import DataLoader
  6. class ClassBalancedShuffledSequentialBatchChooser(BatchChooser):
  7. def __init__(self, data_loader: 'DataLoader', batch_size: int):
  8. super().__init__(data_loader, batch_size)
  9. self._cursor = 0
  10. # just copying to prevent any dependency
  11. for k, v in self.class_samples_indices.items():
  12. self.class_samples_indices[k] = np.copy(v)
  13. np.random.shuffle(self.class_samples_indices[k])
  14. self.classes_cursors = np.zeros(len(self.class_samples_indices), dtype=int)
  15. def get_next_batch_sample_indices(self):
  16. """ Returns a list containing the indices of the samples chosen for the next batch."""
  17. n_samples_per_class = self._get_n_samples_per_class()
  18. def sample_for_each_class(class_index):
  19. class_samples = self.class_samples_indices[
  20. self._classes_labels[class_index]]
  21. if self.classes_cursors[class_index] + \
  22. n_samples_per_class[class_index] <= \
  23. len(class_samples):
  24. ret_val = np.copy(class_samples[
  25. self.classes_cursors[class_index]:
  26. self.classes_cursors[class_index] + n_samples_per_class[class_index]
  27. ])
  28. self.classes_cursors[class_index] += n_samples_per_class[class_index]
  29. else:
  30. ret_val = np.copy(class_samples)[
  31. self.classes_cursors[class_index]:]
  32. np.random.shuffle(class_samples)
  33. self.classes_cursors[class_index] = \
  34. n_samples_per_class[class_index] - len(ret_val)
  35. ret_val = np.concatenate((
  36. ret_val,
  37. np.copy(class_samples
  38. [: n_samples_per_class[class_index] - len(ret_val)]
  39. )), axis=0)
  40. if self.classes_cursors[class_index] == len(class_samples):
  41. np.random.shuffle(class_samples)
  42. self.classes_cursors[class_index] = 0
  43. return ret_val
  44. chosen_sample_indices = np.concatenate(tuple([
  45. sample_for_each_class(ci) for ci in range(len(self._classes_labels))
  46. ]), axis=0)
  47. return chosen_sample_indices
  48. def _get_n_samples_per_class(self):
  49. # determining the number of samples per class
  50. n_samples_per_class = np.full(len(self.class_samples_indices),
  51. int(self.batch_size // len(self.class_samples_indices)))
  52. remaining_samles_cnt = self.batch_size - np.sum(n_samples_per_class)
  53. if remaining_samles_cnt:
  54. if self._cursor + remaining_samles_cnt >= len(self._classes_labels):
  55. n_samples_per_class[self._cursor: len(self._classes_labels)] += 1
  56. remaining_samles_cnt -= (len(self._classes_labels) - self._cursor)
  57. self._cursor = 0
  58. if remaining_samles_cnt > 0:
  59. n_samples_per_class[self._cursor: self._cursor + remaining_samles_cnt] += 1
  60. self._cursor += remaining_samles_cnt
  61. return n_samples_per_class