Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

samplers.py 3.3KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) Megvii, Inc. and its affiliates.
  4. import torch
  5. import torch.distributed as dist
  6. from torch.utils.data.sampler import BatchSampler as torchBatchSampler
  7. from torch.utils.data.sampler import Sampler
  8. import itertools
  9. from typing import Optional
  10. class YoloBatchSampler(torchBatchSampler):
  11. """
  12. This batch sampler will generate mini-batches of (dim, index) tuples from another sampler.
  13. It works just like the :class:`torch.utils.data.sampler.BatchSampler`,
  14. but it will prepend a dimension, whilst ensuring it stays the same across one mini-batch.
  15. """
  16. def __init__(self, *args, input_dimension=None, mosaic=True, **kwargs):
  17. super().__init__(*args, **kwargs)
  18. self.input_dim = input_dimension
  19. self.new_input_dim = None
  20. self.mosaic = mosaic
  21. def __iter__(self):
  22. self.__set_input_dim()
  23. for batch in super().__iter__():
  24. yield [(self.input_dim, idx, self.mosaic) for idx in batch]
  25. self.__set_input_dim()
  26. def __set_input_dim(self):
  27. """ This function randomly changes the the input dimension of the dataset. """
  28. if self.new_input_dim is not None:
  29. self.input_dim = (self.new_input_dim[0], self.new_input_dim[1])
  30. self.new_input_dim = None
  31. class InfiniteSampler(Sampler):
  32. """
  33. In training, we only care about the "infinite stream" of training data.
  34. So this sampler produces an infinite stream of indices and
  35. all workers cooperate to correctly shuffle the indices and sample different indices.
  36. The samplers in each worker effectively produces `indices[worker_id::num_workers]`
  37. where `indices` is an infinite stream of indices consisting of
  38. `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
  39. or `range(size) + range(size) + ...` (if shuffle is False)
  40. """
  41. def __init__(
  42. self,
  43. size: int,
  44. shuffle: bool = True,
  45. seed: Optional[int] = 0,
  46. rank=0,
  47. world_size=1,
  48. ):
  49. """
  50. Args:
  51. size (int): the total number of data of the underlying dataset to sample from
  52. shuffle (bool): whether to shuffle the indices or not
  53. seed (int): the initial seed of the shuffle. Must be the same
  54. across all workers. If None, will use a random seed shared
  55. among workers (require synchronization among all workers).
  56. """
  57. self._size = size
  58. assert size > 0
  59. self._shuffle = shuffle
  60. self._seed = int(seed)
  61. if dist.is_available() and dist.is_initialized():
  62. self._rank = dist.get_rank()
  63. self._world_size = dist.get_world_size()
  64. else:
  65. self._rank = rank
  66. self._world_size = world_size
  67. def __iter__(self):
  68. start = self._rank
  69. yield from itertools.islice(
  70. self._infinite_indices(), start, None, self._world_size
  71. )
  72. def _infinite_indices(self):
  73. g = torch.Generator()
  74. g.manual_seed(self._seed)
  75. while True:
  76. if self._shuffle:
  77. yield from torch.randperm(self._size, generator=g)
  78. else:
  79. yield from torch.arange(self._size)
  80. def __len__(self):
  81. return self._size // self._world_size