Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mot.py 5.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import cv2
  2. import numpy as np
  3. from pycocotools.coco import COCO
  4. import os
  5. from ..dataloading import get_yolox_datadir
  6. from .datasets_wrapper import Dataset
  7. class MOTDataset(Dataset):
  8. """
  9. COCO dataset class.
  10. """
  11. def __init__(
  12. self,
  13. data_dir=None,
  14. json_file="train_half.json",
  15. name="train",
  16. img_size=(608, 1088),
  17. preproc=None,
  18. load_weak=False,
  19. ):
  20. """
  21. COCO dataset initialization. Annotation data are read into memory by COCO API.
  22. Args:
  23. data_dir (str): dataset root directory
  24. json_file (str): COCO json file name
  25. name (str): COCO data name (e.g. 'train2017' or 'val2017')
  26. img_size (int): target image size after pre-processing
  27. preproc: data augmentation strategy
  28. """
  29. super().__init__(img_size)
  30. if data_dir is None:
  31. data_dir = os.path.join(get_yolox_datadir(), "mot")
  32. self.data_dir = data_dir
  33. self.json_file = json_file
  34. self.coco = COCO(os.path.join(self.data_dir, "annotations", self.json_file))
  35. self.ids = self.coco.getImgIds()
  36. self.class_ids = sorted(self.coco.getCatIds())
  37. cats = self.coco.loadCats(self.coco.getCatIds())
  38. self._classes = tuple([c["name"] for c in cats])
  39. self.annotations = self._load_coco_annotations()
  40. self.name = name
  41. self.img_size = img_size
  42. self.preproc = preproc
  43. self.load_weak = load_weak
  44. def __len__(self):
  45. return len(self.ids)
  46. def _load_coco_annotations(self):
  47. return [self.load_anno_from_ids(_ids) for _ids in self.ids]
  48. def load_anno_from_ids(self, id_):
  49. im_ann = self.coco.loadImgs(id_)[0]
  50. width = im_ann["width"]
  51. height = im_ann["height"]
  52. frame_id = im_ann["frame_id"]
  53. video_id = im_ann["video_id"]
  54. anno_ids = self.coco.getAnnIds(imgIds=[int(id_)], iscrowd=False)
  55. annotations = self.coco.loadAnns(anno_ids)
  56. objs = []
  57. for obj in annotations:
  58. x1 = obj["bbox"][0]
  59. y1 = obj["bbox"][1]
  60. x2 = x1 + obj["bbox"][2]
  61. y2 = y1 + obj["bbox"][3]
  62. if obj["area"] > 0 and x2 >= x1 and y2 >= y1:
  63. obj["clean_bbox"] = [x1, y1, x2, y2]
  64. objs.append(obj)
  65. num_objs = len(objs)
  66. res = np.zeros((num_objs, 6))
  67. for ix, obj in enumerate(objs):
  68. cls = self.class_ids.index(obj["category_id"])
  69. res[ix, 0:4] = obj["clean_bbox"]
  70. res[ix, 4] = cls
  71. res[ix, 5] = obj["track_id"]
  72. file_name = im_ann["file_name"] if "file_name" in im_ann else "{:012}".format(id_) + ".jpg"
  73. img_info = (height, width, frame_id, video_id, file_name)
  74. del im_ann, annotations
  75. return (res, img_info, file_name)
  76. def load_anno(self, index):
  77. return self.annotations[index][0]
  78. def pull_item(self, index):
  79. id_ = self.ids[index]
  80. res, img_info, file_name = self.annotations[index]
  81. # load image and preprocess
  82. img_file = os.path.join(
  83. self.data_dir, self.name, file_name
  84. )
  85. head_tail = os.path.split(img_file)
  86. # label_path = os.path.join(head_tail[0], head_tail[1].replace('.jpg','.txt'))
  87. if self.load_weak:
  88. weak_label_path = os.path.join(head_tail[0], head_tail[1] + '_weak_yolox-x.txt')
  89. # load weak labels from weak_label_path
  90. width = img_info[1]
  91. height = img_info[0]
  92. labels = np.loadtxt(weak_label_path)
  93. res = np.ones_like(labels)
  94. labels[2, :] *= width
  95. labels[4, :] *= width
  96. labels[3, :] *= height
  97. labels[5, :] *= height
  98. labels[4, :] += labels[2, :]
  99. labels[5, :] += labels[3, :]
  100. res[:, 0:4] = labels[:, -4:]
  101. res[:, 5] = labels[:, 1]
  102. # all are from class one
  103. # res[:, 4] = labels[:, 0]
  104. img = cv2.imread(img_file)
  105. if img is None:
  106. print('img_file is None', img_file)
  107. assert img is not None
  108. return img, res.copy(), img_info, np.array([id_])
  109. @Dataset.resize_getitem
  110. def __getitem__(self, index):
  111. """
  112. One image / label pair for the given index is picked up and pre-processed.
  113. Args:
  114. index (int): data index
  115. Returns:
  116. img (numpy.ndarray): pre-processed image
  117. padded_labels (torch.Tensor): pre-processed label data.
  118. The shape is :math:`[max_labels, 5]`.
  119. each label consists of [class, xc, yc, w, h]:
  120. class (float): class index.
  121. xc, yc (float) : center of bbox whose values range from 0 to 1.
  122. w, h (float) : size of bbox whose values range from 0 to 1.
  123. info_img : tuple of h, w, nh, nw, dx, dy.
  124. h, w (int): original shape of the image
  125. nh, nw (int): shape of the resized image without padding
  126. dx, dy (int): pad size
  127. img_id (int): same as the input index. Used for evaluation.
  128. """
  129. img, target, img_info, img_id = self.pull_item(index)
  130. if self.preproc is not None:
  131. img, target = self.preproc(img, target, self.input_dim)
  132. return img, target, img_info, img_id