Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mosaicdetection.py 10.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) Megvii, Inc. and its affiliates.
  4. import cv2
  5. import numpy as np
  6. from yolox.utils import adjust_box_anns
  7. import random
  8. from ..data_augment import box_candidates, random_perspective, augment_hsv
  9. from .datasets_wrapper import Dataset
  10. def get_mosaic_coordinate(mosaic_image, mosaic_index, xc, yc, w, h, input_h, input_w):
  11. # TODO update doc
  12. # index0 to top left part of image
  13. if mosaic_index == 0:
  14. x1, y1, x2, y2 = max(xc - w, 0), max(yc - h, 0), xc, yc
  15. small_coord = w - (x2 - x1), h - (y2 - y1), w, h
  16. # index1 to top right part of image
  17. elif mosaic_index == 1:
  18. x1, y1, x2, y2 = xc, max(yc - h, 0), min(xc + w, input_w * 2), yc
  19. small_coord = 0, h - (y2 - y1), min(w, x2 - x1), h
  20. # index2 to bottom left part of image
  21. elif mosaic_index == 2:
  22. x1, y1, x2, y2 = max(xc - w, 0), yc, xc, min(input_h * 2, yc + h)
  23. small_coord = w - (x2 - x1), 0, w, min(y2 - y1, h)
  24. # index2 to bottom right part of image
  25. elif mosaic_index == 3:
  26. x1, y1, x2, y2 = xc, yc, min(xc + w, input_w * 2), min(input_h * 2, yc + h) # noqa
  27. small_coord = 0, 0, min(w, x2 - x1), min(y2 - y1, h)
  28. return (x1, y1, x2, y2), small_coord
  29. class MosaicDetection(Dataset):
  30. """Detection dataset wrapper that performs mixup for normal dataset."""
  31. def __init__(
  32. self, dataset, img_size, mosaic=True, preproc=None,
  33. degrees=10.0, translate=0.1, scale=(0.5, 1.5), mscale=(0.5, 1.5),
  34. shear=2.0, perspective=0.0, enable_mixup=True, *args
  35. ):
  36. """
  37. Args:
  38. dataset(Dataset) : Pytorch dataset object.
  39. img_size (tuple):
  40. mosaic (bool): enable mosaic augmentation or not.
  41. preproc (func):
  42. degrees (float):
  43. translate (float):
  44. scale (tuple):
  45. mscale (tuple):
  46. shear (float):
  47. perspective (float):
  48. enable_mixup (bool):
  49. *args(tuple) : Additional arguments for mixup random sampler.
  50. """
  51. super().__init__(img_size, mosaic=mosaic)
  52. self._dataset = dataset
  53. self.preproc = preproc
  54. self.degrees = degrees
  55. self.translate = translate
  56. self.scale = scale
  57. self.shear = shear
  58. self.perspective = perspective
  59. self.mixup_scale = mscale
  60. self.enable_mosaic = mosaic
  61. self.enable_mixup = enable_mixup
  62. def __len__(self):
  63. return len(self._dataset)
  64. @Dataset.resize_getitem
  65. def __getitem__(self, idx):
  66. if self.enable_mosaic:
  67. mosaic_labels = []
  68. input_dim = self._dataset.input_dim
  69. input_h, input_w = input_dim[0], input_dim[1]
  70. # yc, xc = s, s # mosaic center x, y
  71. yc = int(random.uniform(0.5 * input_h, 1.5 * input_h))
  72. xc = int(random.uniform(0.5 * input_w, 1.5 * input_w))
  73. # 3 additional image indices
  74. indices = [idx] + [random.randint(0, len(self._dataset) - 1) for _ in range(3)]
  75. for i_mosaic, index in enumerate(indices):
  76. img, _labels, _, _ = self._dataset.pull_item(index)
  77. h0, w0 = img.shape[:2] # orig hw
  78. scale = min(1. * input_h / h0, 1. * input_w / w0)
  79. img = cv2.resize(
  80. img, (int(w0 * scale), int(h0 * scale)), interpolation=cv2.INTER_LINEAR
  81. )
  82. # generate output mosaic image
  83. (h, w, c) = img.shape[:3]
  84. if i_mosaic == 0:
  85. mosaic_img = np.full((input_h * 2, input_w * 2, c), 114, dtype=np.uint8)
  86. # suffix l means large image, while s means small image in mosaic aug.
  87. (l_x1, l_y1, l_x2, l_y2), (s_x1, s_y1, s_x2, s_y2) = get_mosaic_coordinate(
  88. mosaic_img, i_mosaic, xc, yc, w, h, input_h, input_w
  89. )
  90. mosaic_img[l_y1:l_y2, l_x1:l_x2] = img[s_y1:s_y2, s_x1:s_x2]
  91. padw, padh = l_x1 - s_x1, l_y1 - s_y1
  92. labels = _labels.copy()
  93. # Normalized xywh to pixel xyxy format
  94. if _labels.size > 0:
  95. labels[:, 0] = scale * _labels[:, 0] + padw
  96. labels[:, 1] = scale * _labels[:, 1] + padh
  97. labels[:, 2] = scale * _labels[:, 2] + padw
  98. labels[:, 3] = scale * _labels[:, 3] + padh
  99. mosaic_labels.append(labels)
  100. if len(mosaic_labels):
  101. mosaic_labels = np.concatenate(mosaic_labels, 0)
  102. '''
  103. np.clip(mosaic_labels[:, 0], 0, 2 * input_w, out=mosaic_labels[:, 0])
  104. np.clip(mosaic_labels[:, 1], 0, 2 * input_h, out=mosaic_labels[:, 1])
  105. np.clip(mosaic_labels[:, 2], 0, 2 * input_w, out=mosaic_labels[:, 2])
  106. np.clip(mosaic_labels[:, 3], 0, 2 * input_h, out=mosaic_labels[:, 3])
  107. '''
  108. mosaic_labels = mosaic_labels[mosaic_labels[:, 0] < 2 * input_w]
  109. mosaic_labels = mosaic_labels[mosaic_labels[:, 2] > 0]
  110. mosaic_labels = mosaic_labels[mosaic_labels[:, 1] < 2 * input_h]
  111. mosaic_labels = mosaic_labels[mosaic_labels[:, 3] > 0]
  112. #augment_hsv(mosaic_img)
  113. mosaic_img, mosaic_labels = random_perspective(
  114. mosaic_img,
  115. mosaic_labels,
  116. degrees=self.degrees,
  117. translate=self.translate,
  118. scale=self.scale,
  119. shear=self.shear,
  120. perspective=self.perspective,
  121. border=[-input_h // 2, -input_w // 2],
  122. ) # border to remove
  123. # -----------------------------------------------------------------
  124. # CopyPaste: https://arxiv.org/abs/2012.07177
  125. # -----------------------------------------------------------------
  126. if self.enable_mixup and not len(mosaic_labels) == 0:
  127. mosaic_img, mosaic_labels = self.mixup(mosaic_img, mosaic_labels, self.input_dim)
  128. mix_img, padded_labels = self.preproc(mosaic_img, mosaic_labels, self.input_dim)
  129. img_info = (mix_img.shape[1], mix_img.shape[0])
  130. return mix_img, padded_labels, img_info, np.array([idx])
  131. else:
  132. self._dataset._input_dim = self.input_dim
  133. img, label, img_info, id_ = self._dataset.pull_item(idx)
  134. img, label = self.preproc(img, label, self.input_dim)
  135. return img, label, img_info, id_
  136. def mixup(self, origin_img, origin_labels, input_dim):
  137. jit_factor = random.uniform(*self.mixup_scale)
  138. FLIP = random.uniform(0, 1) > 0.5
  139. cp_labels = []
  140. while len(cp_labels) == 0:
  141. cp_index = random.randint(0, self.__len__() - 1)
  142. cp_labels = self._dataset.load_anno(cp_index)
  143. img, cp_labels, _, _ = self._dataset.pull_item(cp_index)
  144. if len(img.shape) == 3:
  145. cp_img = np.ones((input_dim[0], input_dim[1], 3)) * 114.0
  146. else:
  147. cp_img = np.ones(input_dim) * 114.0
  148. cp_scale_ratio = min(input_dim[0] / img.shape[0], input_dim[1] / img.shape[1])
  149. resized_img = cv2.resize(
  150. img,
  151. (int(img.shape[1] * cp_scale_ratio), int(img.shape[0] * cp_scale_ratio)),
  152. interpolation=cv2.INTER_LINEAR,
  153. ).astype(np.float32)
  154. cp_img[
  155. : int(img.shape[0] * cp_scale_ratio), : int(img.shape[1] * cp_scale_ratio)
  156. ] = resized_img
  157. cp_img = cv2.resize(
  158. cp_img,
  159. (int(cp_img.shape[1] * jit_factor), int(cp_img.shape[0] * jit_factor)),
  160. )
  161. cp_scale_ratio *= jit_factor
  162. if FLIP:
  163. cp_img = cp_img[:, ::-1, :]
  164. origin_h, origin_w = cp_img.shape[:2]
  165. target_h, target_w = origin_img.shape[:2]
  166. padded_img = np.zeros(
  167. (max(origin_h, target_h), max(origin_w, target_w), 3)
  168. ).astype(np.uint8)
  169. padded_img[:origin_h, :origin_w] = cp_img
  170. x_offset, y_offset = 0, 0
  171. if padded_img.shape[0] > target_h:
  172. y_offset = random.randint(0, padded_img.shape[0] - target_h - 1)
  173. if padded_img.shape[1] > target_w:
  174. x_offset = random.randint(0, padded_img.shape[1] - target_w - 1)
  175. padded_cropped_img = padded_img[
  176. y_offset: y_offset + target_h, x_offset: x_offset + target_w
  177. ]
  178. cp_bboxes_origin_np = adjust_box_anns(
  179. cp_labels[:, :4].copy(), cp_scale_ratio, 0, 0, origin_w, origin_h
  180. )
  181. if FLIP:
  182. cp_bboxes_origin_np[:, 0::2] = (
  183. origin_w - cp_bboxes_origin_np[:, 0::2][:, ::-1]
  184. )
  185. cp_bboxes_transformed_np = cp_bboxes_origin_np.copy()
  186. '''
  187. cp_bboxes_transformed_np[:, 0::2] = np.clip(
  188. cp_bboxes_transformed_np[:, 0::2] - x_offset, 0, target_w
  189. )
  190. cp_bboxes_transformed_np[:, 1::2] = np.clip(
  191. cp_bboxes_transformed_np[:, 1::2] - y_offset, 0, target_h
  192. )
  193. '''
  194. cp_bboxes_transformed_np[:, 0::2] = cp_bboxes_transformed_np[:, 0::2] - x_offset
  195. cp_bboxes_transformed_np[:, 1::2] = cp_bboxes_transformed_np[:, 1::2] - y_offset
  196. keep_list = box_candidates(cp_bboxes_origin_np.T, cp_bboxes_transformed_np.T, 5)
  197. if keep_list.sum() >= 1.0:
  198. cls_labels = cp_labels[keep_list, 4:5].copy()
  199. id_labels = cp_labels[keep_list, 5:6].copy()
  200. box_labels = cp_bboxes_transformed_np[keep_list]
  201. labels = np.hstack((box_labels, cls_labels, id_labels))
  202. # remove outside bbox
  203. labels = labels[labels[:, 0] < target_w]
  204. labels = labels[labels[:, 2] > 0]
  205. labels = labels[labels[:, 1] < target_h]
  206. labels = labels[labels[:, 3] > 0]
  207. origin_labels = np.vstack((origin_labels, labels))
  208. origin_img = origin_img.astype(np.float32)
  209. origin_img = 0.5 * origin_img + 0.5 * padded_cropped_img.astype(np.float32)
  210. return origin_img, origin_labels