Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mot.py 5.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import cv2
  2. import numpy as np
  3. from pycocotools.coco import COCO
  4. import os
  5. from ..dataloading import get_yolox_datadir
  6. from .datasets_wrapper import Dataset
  7. import sys
  8. class MOTDataset(Dataset):
  9. """
  10. COCO dataset class.
  11. """
  12. def __init__(
  13. self,
  14. data_dir=None,
  15. json_file="train_half.json",
  16. name="train",
  17. img_size=(608, 1088),
  18. preproc=None,
  19. load_weak=False,
  20. ):
  21. """
  22. COCO dataset initialization. Annotation data are read into memory by COCO API.
  23. Args:
  24. data_dir (str): dataset root directory
  25. json_file (str): COCO json file name
  26. name (str): COCO data name (e.g. 'train2017' or 'val2017')
  27. img_size (int): target image size after pre-processing
  28. preproc: data augmentation strategy
  29. """
  30. super().__init__(img_size)
  31. if data_dir is None:
  32. data_dir = os.path.join(get_yolox_datadir(), "mot")
  33. self.data_dir = data_dir
  34. self.json_file = json_file
  35. self.coco = COCO(os.path.join(self.data_dir, "annotations", self.json_file))
  36. self.ids = self.coco.getImgIds()
  37. self.class_ids = sorted(self.coco.getCatIds())
  38. cats = self.coco.loadCats(self.coco.getCatIds())
  39. self._classes = tuple([c["name"] for c in cats])
  40. self.annotations = self._load_coco_annotations()
  41. self.name = name
  42. self.img_size = img_size
  43. self.preproc = preproc
  44. self.load_weak = load_weak
  45. def __len__(self):
  46. return len(self.ids)
  47. def _load_coco_annotations(self):
  48. return [self.load_anno_from_ids(_ids) for _ids in self.ids]
  49. def load_anno_from_ids(self, id_):
  50. im_ann = self.coco.loadImgs(id_)[0]
  51. width = im_ann["width"]
  52. height = im_ann["height"]
  53. frame_id = im_ann["frame_id"]
  54. video_id = im_ann["video_id"]
  55. anno_ids = self.coco.getAnnIds(imgIds=[int(id_)], iscrowd=False)
  56. annotations = self.coco.loadAnns(anno_ids)
  57. objs = []
  58. for obj in annotations:
  59. x1 = obj["bbox"][0]
  60. y1 = obj["bbox"][1]
  61. x2 = x1 + obj["bbox"][2]
  62. y2 = y1 + obj["bbox"][3]
  63. if obj["area"] > 0 and x2 >= x1 and y2 >= y1:
  64. obj["clean_bbox"] = [x1, y1, x2, y2]
  65. objs.append(obj)
  66. num_objs = len(objs)
  67. res = np.zeros((num_objs, 6))
  68. for ix, obj in enumerate(objs):
  69. cls = self.class_ids.index(obj["category_id"])
  70. res[ix, 0:4] = obj["clean_bbox"]
  71. res[ix, 4] = cls
  72. res[ix, 5] = obj["track_id"]
  73. file_name = im_ann["file_name"] if "file_name" in im_ann else "{:012}".format(id_) + ".jpg"
  74. img_info = (height, width, frame_id, video_id, file_name)
  75. del im_ann, annotations
  76. return (res, img_info, file_name)
  77. def load_anno(self, index):
  78. return self.annotations[index][0]
  79. def pull_item(self, index):
  80. id_ = self.ids[index]
  81. res, img_info, file_name = self.annotations[index]
  82. # load image and preprocess
  83. img_file = os.path.join(
  84. self.data_dir, self.name, file_name
  85. )
  86. head_tail = os.path.split(img_file)
  87. # label_path = os.path.join(head_tail[0], head_tail[1].replace('.jpg','.txt'))
  88. # sys.stderr.write('original shape' + str(res.shape) + '\n values \n' + str(res))
  89. if self.load_weak:
  90. weak_label_path = os.path.join(head_tail[0], head_tail[1] + '_weak_yolox-x.txt')
  91. # load weak labels from weak_label_path
  92. width = img_info[1]
  93. height = img_info[0]
  94. labels = np.loadtxt(weak_label_path)
  95. # print('weak loaded', labels[:3, 2:])
  96. labels[:, 2] = labels[:, 2] * width
  97. labels[:, 4] = labels[:, 4] * width
  98. labels[:, 3] = labels[:, 3] * height
  99. labels[:, 5] = labels[:, 5] * height
  100. labels[:, 4] += labels[:, 2]
  101. labels[:, 5] += labels[:, 3]
  102. # print('weak', labels[:3, 2:])
  103. res = np.zeros_like(labels)
  104. res[:, 0:4] = labels[:, -4:]
  105. res[:, 5] = labels[:, 1]
  106. # sys.stderr.write('weak shape ' + str(res.shape) + '\n values \n' + str(res))
  107. # all are from class one
  108. # res[:, 4] = labels[:, 0]
  109. img = cv2.imread(img_file)
  110. if img is None:
  111. print('img_file is None', img_file)
  112. assert img is not None
  113. return img, res.copy(), img_info, np.array([id_])
  114. @Dataset.resize_getitem
  115. def __getitem__(self, index):
  116. """
  117. One image / label pair for the given index is picked up and pre-processed.
  118. Args:
  119. index (int): data index
  120. Returns:
  121. img (numpy.ndarray): pre-processed image
  122. padded_labels (torch.Tensor): pre-processed label data.
  123. The shape is :math:`[max_labels, 5]`.
  124. each label consists of [class, xc, yc, w, h]:
  125. class (float): class index.
  126. xc, yc (float) : center of bbox whose values range from 0 to 1.
  127. w, h (float) : size of bbox whose values range from 0 to 1.
  128. info_img : tuple of h, w, nh, nw, dx, dy.
  129. h, w (int): original shape of the image
  130. nh, nw (int): shape of the resized image without padding
  131. dx, dy (int): pad size
  132. img_id (int): same as the input index. Used for evaluation.
  133. """
  134. img, target, img_info, img_id = self.pull_item(index)
  135. if self.preproc is not None:
  136. img, target = self.preproc(img, target, self.input_dim)
  137. return img, target, img_info, img_id