Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloading.py 6.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) Megvii, Inc. and its affiliates.
  4. import torch
  5. from torch.utils.data.dataloader import DataLoader as torchDataLoader
  6. from torch.utils.data.dataloader import default_collate
  7. from yolox import statics
  8. import os
  9. import random
  10. from .samplers import YoloBatchSampler
  11. def get_yolox_datadir():
  12. """
  13. get dataset dir of YOLOX. If environment variable named `YOLOX_DATADIR` is set,
  14. this function will return value of the environment variable. Otherwise, use data
  15. """
  16. yolox_datadir = os.getenv("YOLOX_DATADIR", None)
  17. if yolox_datadir is None:
  18. import yolox
  19. yolox_path = os.path.dirname(os.path.dirname(yolox.__file__))
  20. yolox_datadir = os.path.join(yolox_path, statics.DATA_PATH)
  21. return yolox_datadir
  22. class DataLoader(torchDataLoader):
  23. """
  24. Lightnet dataloader that enables on the fly resizing of the images.
  25. See :class:`torch.utils.data.DataLoader` for more information on the arguments.
  26. Check more on the following website:
  27. https://gitlab.com/EAVISE/lightnet/-/blob/master/lightnet/data/_dataloading.py
  28. Note:
  29. This dataloader only works with :class:`lightnet.data.Dataset` based datasets.
  30. Example:
  31. >>> class CustomSet(ln.data.Dataset):
  32. ... def __len__(self):
  33. ... return 4
  34. ... @ln.data.Dataset.resize_getitem
  35. ... def __getitem__(self, index):
  36. ... # Should return (image, anno) but here we return (input_dim,)
  37. ... return (self.input_dim,)
  38. >>> dl = ln.data.DataLoader(
  39. ... CustomSet((200,200)),
  40. ... batch_size = 2,
  41. ... collate_fn = ln.data.list_collate # We want the data to be grouped as a list
  42. ... )
  43. >>> dl.dataset.input_dim # Default input_dim
  44. (200, 200)
  45. >>> for d in dl:
  46. ... d
  47. [[(200, 200), (200, 200)]]
  48. [[(200, 200), (200, 200)]]
  49. >>> dl.change_input_dim(320, random_range=None)
  50. (320, 320)
  51. >>> for d in dl:
  52. ... d
  53. [[(320, 320), (320, 320)]]
  54. [[(320, 320), (320, 320)]]
  55. >>> dl.change_input_dim((480, 320), random_range=None)
  56. (480, 320)
  57. >>> for d in dl:
  58. ... d
  59. [[(480, 320), (480, 320)]]
  60. [[(480, 320), (480, 320)]]
  61. """
  62. def __init__(self, *args, **kwargs):
  63. super().__init__(*args, **kwargs)
  64. self.__initialized = False
  65. shuffle = False
  66. batch_sampler = None
  67. if len(args) > 5:
  68. shuffle = args[2]
  69. sampler = args[3]
  70. batch_sampler = args[4]
  71. elif len(args) > 4:
  72. shuffle = args[2]
  73. sampler = args[3]
  74. if "batch_sampler" in kwargs:
  75. batch_sampler = kwargs["batch_sampler"]
  76. elif len(args) > 3:
  77. shuffle = args[2]
  78. if "sampler" in kwargs:
  79. sampler = kwargs["sampler"]
  80. if "batch_sampler" in kwargs:
  81. batch_sampler = kwargs["batch_sampler"]
  82. else:
  83. if "shuffle" in kwargs:
  84. shuffle = kwargs["shuffle"]
  85. if "sampler" in kwargs:
  86. sampler = kwargs["sampler"]
  87. if "batch_sampler" in kwargs:
  88. batch_sampler = kwargs["batch_sampler"]
  89. # Use custom BatchSampler
  90. if batch_sampler is None:
  91. if sampler is None:
  92. if shuffle:
  93. sampler = torch.utils.data.sampler.RandomSampler(self.dataset)
  94. # sampler = torch.utils.data.DistributedSampler(self.dataset)
  95. else:
  96. sampler = torch.utils.data.sampler.SequentialSampler(self.dataset)
  97. batch_sampler = YoloBatchSampler(
  98. sampler,
  99. self.batch_size,
  100. self.drop_last,
  101. input_dimension=self.dataset.input_dim,
  102. )
  103. # batch_sampler = IterationBasedBatchSampler(batch_sampler, num_iterations =
  104. self.batch_sampler = batch_sampler
  105. self.__initialized = True
  106. def close_mosaic(self):
  107. self.batch_sampler.mosaic = False
  108. def change_input_dim(self, multiple=32, random_range=(10, 19)):
  109. """This function will compute a new size and update it on the next mini_batch.
  110. Args:
  111. multiple (int or tuple, optional): values to multiply the randomly generated range by.
  112. Default **32**
  113. random_range (tuple, optional): This (min, max) tuple sets the range
  114. for the randomisation; Default **(10, 19)**
  115. Return:
  116. tuple: width, height tuple with new dimension
  117. Note:
  118. The new size is generated as follows: |br|
  119. First we compute a random integer inside ``[random_range]``.
  120. We then multiply that number with the ``multiple`` argument,
  121. which gives our final new input size. |br|
  122. If ``multiple`` is an integer we generate a square size. If you give a tuple
  123. of **(width, height)**, the size is computed
  124. as :math:`rng * multiple[0], rng * multiple[1]`.
  125. Note:
  126. You can set the ``random_range`` argument to **None** to set
  127. an exact size of multiply. |br|
  128. See the example above for how this works.
  129. """
  130. if random_range is None:
  131. size = 1
  132. else:
  133. size = random.randint(*random_range)
  134. if isinstance(multiple, int):
  135. size = (size * multiple, size * multiple)
  136. else:
  137. size = (size * multiple[0], size * multiple[1])
  138. self.batch_sampler.new_input_dim = size
  139. return size
  140. def list_collate(batch):
  141. """
  142. Function that collates lists or tuples together into one list (of lists/tuples).
  143. Use this as the collate function in a Dataloader, if you want to have a list of
  144. items as an output, as opposed to tensors (eg. Brambox.boxes).
  145. """
  146. items = list(zip(*batch))
  147. for i in range(len(items)):
  148. if isinstance(items[i][0], (list, tuple)):
  149. items[i] = list(items[i])
  150. else:
  151. items[i] = default_collate(items[i])
  152. return items