Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets_wrapper.py 4.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) Megvii, Inc. and its affiliates.
  4. from torch.utils.data.dataset import ConcatDataset as torchConcatDataset
  5. from torch.utils.data.dataset import Dataset as torchDataset
  6. import bisect
  7. from functools import wraps
  8. class ConcatDataset(torchConcatDataset):
  9. def __init__(self, datasets):
  10. super(ConcatDataset, self).__init__(datasets)
  11. if hasattr(self.datasets[0], "input_dim"):
  12. self._input_dim = self.datasets[0].input_dim
  13. self.input_dim = self.datasets[0].input_dim
  14. def pull_item(self, idx):
  15. if idx < 0:
  16. if -idx > len(self):
  17. raise ValueError(
  18. "absolute value of index should not exceed dataset length"
  19. )
  20. idx = len(self) + idx
  21. dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
  22. if dataset_idx == 0:
  23. sample_idx = idx
  24. else:
  25. sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
  26. return self.datasets[dataset_idx].pull_item(sample_idx)
  27. class MixConcatDataset(torchConcatDataset):
  28. def __init__(self, datasets):
  29. super(MixConcatDataset, self).__init__(datasets)
  30. if hasattr(self.datasets[0], "input_dim"):
  31. self._input_dim = self.datasets[0].input_dim
  32. self.input_dim = self.datasets[0].input_dim
  33. def __getitem__(self, index):
  34. if not isinstance(index, int):
  35. idx = index[1]
  36. if idx < 0:
  37. if -idx > len(self):
  38. raise ValueError(
  39. "absolute value of index should not exceed dataset length"
  40. )
  41. idx = len(self) + idx
  42. dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
  43. if dataset_idx == 0:
  44. sample_idx = idx
  45. else:
  46. sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
  47. if not isinstance(index, int):
  48. index = (index[0], sample_idx, index[2])
  49. return self.datasets[dataset_idx][index]
  50. class Dataset(torchDataset):
  51. """ This class is a subclass of the base :class:`torch.utils.data.Dataset`,
  52. that enables on the fly resizing of the ``input_dim``.
  53. Args:
  54. input_dimension (tuple): (width,height) tuple with default dimensions of the network
  55. """
  56. def __init__(self, input_dimension, mosaic=True):
  57. super().__init__()
  58. self.__input_dim = input_dimension[:2]
  59. self.enable_mosaic = mosaic
  60. @property
  61. def input_dim(self):
  62. """
  63. Dimension that can be used by transforms to set the correct image size, etc.
  64. This allows transforms to have a single source of truth
  65. for the input dimension of the network.
  66. Return:
  67. list: Tuple containing the current width,height
  68. """
  69. if hasattr(self, "_input_dim"):
  70. return self._input_dim
  71. return self.__input_dim
  72. @staticmethod
  73. def resize_getitem(getitem_fn):
  74. """
  75. Decorator method that needs to be used around the ``__getitem__`` method. |br|
  76. This decorator enables the on the fly resizing of
  77. the ``input_dim`` with our :class:`~lightnet.data.DataLoader` class.
  78. Example:
  79. >>> class CustomSet(ln.data.Dataset):
  80. ... def __len__(self):
  81. ... return 10
  82. ... @ln.data.Dataset.resize_getitem
  83. ... def __getitem__(self, index):
  84. ... # Should return (image, anno) but here we return input_dim
  85. ... return self.input_dim
  86. >>> data = CustomSet((200,200))
  87. >>> data[0]
  88. (200, 200)
  89. >>> data[(480,320), 0]
  90. (480, 320)
  91. """
  92. @wraps(getitem_fn)
  93. def wrapper(self, index):
  94. if not isinstance(index, int):
  95. has_dim = True
  96. self._input_dim = index[0]
  97. self.enable_mosaic = index[2]
  98. index = index[1]
  99. else:
  100. has_dim = False
  101. ret_val = getitem_fn(self, index)
  102. if has_dim:
  103. del self._input_dim
  104. return ret_val
  105. return wrapper