Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mot.py 4.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import cv2
  2. import numpy as np
  3. from pycocotools.coco import COCO
  4. import os
  5. from ..dataloading import get_yolox_datadir
  6. from .datasets_wrapper import Dataset
  7. class MOTDataset(Dataset):
  8. """
  9. COCO dataset class.
  10. """
  11. def __init__(
  12. self,
  13. data_dir=None,
  14. json_file="train_half.json",
  15. name="train",
  16. img_size=(608, 1088),
  17. preproc=None,
  18. ):
  19. """
  20. COCO dataset initialization. Annotation data are read into memory by COCO API.
  21. Args:
  22. data_dir (str): dataset root directory
  23. json_file (str): COCO json file name
  24. name (str): COCO data name (e.g. 'train2017' or 'val2017')
  25. img_size (int): target image size after pre-processing
  26. preproc: data augmentation strategy
  27. """
  28. super().__init__(img_size)
  29. if data_dir is None:
  30. data_dir = os.path.join(get_yolox_datadir(), "mot")
  31. self.data_dir = data_dir
  32. self.json_file = json_file
  33. self.coco = COCO(os.path.join(self.data_dir, "annotations", self.json_file))
  34. self.ids = self.coco.getImgIds()
  35. self.class_ids = sorted(self.coco.getCatIds())
  36. cats = self.coco.loadCats(self.coco.getCatIds())
  37. self._classes = tuple([c["name"] for c in cats])
  38. self.annotations = self._load_coco_annotations()
  39. self.name = name
  40. self.img_size = img_size
  41. self.preproc = preproc
  42. def __len__(self):
  43. return len(self.ids)
  44. def _load_coco_annotations(self):
  45. return [self.load_anno_from_ids(_ids) for _ids in self.ids]
  46. def load_anno_from_ids(self, id_):
  47. im_ann = self.coco.loadImgs(id_)[0]
  48. width = im_ann["width"]
  49. height = im_ann["height"]
  50. frame_id = im_ann["frame_id"]
  51. video_id = im_ann["video_id"]
  52. anno_ids = self.coco.getAnnIds(imgIds=[int(id_)], iscrowd=False)
  53. annotations = self.coco.loadAnns(anno_ids)
  54. objs = []
  55. for obj in annotations:
  56. x1 = obj["bbox"][0]
  57. y1 = obj["bbox"][1]
  58. x2 = x1 + obj["bbox"][2]
  59. y2 = y1 + obj["bbox"][3]
  60. if obj["area"] > 0 and x2 >= x1 and y2 >= y1:
  61. obj["clean_bbox"] = [x1, y1, x2, y2]
  62. objs.append(obj)
  63. num_objs = len(objs)
  64. res = np.zeros((num_objs, 6))
  65. for ix, obj in enumerate(objs):
  66. cls = self.class_ids.index(obj["category_id"])
  67. res[ix, 0:4] = obj["clean_bbox"]
  68. res[ix, 4] = cls
  69. res[ix, 5] = obj["track_id"]
  70. file_name = im_ann["file_name"] if "file_name" in im_ann else "{:012}".format(id_) + ".jpg"
  71. img_info = (height, width, frame_id, video_id, file_name)
  72. del im_ann, annotations
  73. return (res, img_info, file_name)
  74. def load_anno(self, index):
  75. return self.annotations[index][0]
  76. def pull_item(self, index):
  77. id_ = self.ids[index]
  78. res, img_info, file_name = self.annotations[index]
  79. # load image and preprocess
  80. img_file = os.path.join(
  81. self.data_dir, self.name, file_name
  82. )
  83. img = cv2.imread(img_file)
  84. # if img is None:
  85. # print('img_file is None',img_file)
  86. assert img is not None
  87. return img, res.copy(), img_info, np.array([id_])
  88. @Dataset.resize_getitem
  89. def __getitem__(self, index):
  90. """
  91. One image / label pair for the given index is picked up and pre-processed.
  92. Args:
  93. index (int): data index
  94. Returns:
  95. img (numpy.ndarray): pre-processed image
  96. padded_labels (torch.Tensor): pre-processed label data.
  97. The shape is :math:`[max_labels, 5]`.
  98. each label consists of [class, xc, yc, w, h]:
  99. class (float): class index.
  100. xc, yc (float) : center of bbox whose values range from 0 to 1.
  101. w, h (float) : size of bbox whose values range from 0 to 1.
  102. info_img : tuple of h, w, nh, nw, dx, dy.
  103. h, w (int): original shape of the image
  104. nh, nw (int): shape of the resized image without padding
  105. dx, dy (int): pad size
  106. img_id (int): same as the input index. Used for evaluation.
  107. """
  108. img, target, img_info, img_id = self.pull_item(index)
  109. if self.preproc is not None:
  110. img, target = self.preproc(img, target, self.input_dim)
  111. return img, target, img_info, img_id