Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

qdtrack.py 5.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import numpy as np
  2. from mmdet.core import bbox2result
  3. from mmdet.models import TwoStageDetector
  4. from qdtrack.core import track2result
  5. from ..builder import MODELS, build_tracker
  6. from qdtrack.core import imshow_tracks, restore_result
  7. from tracker import BYTETracker
  8. @MODELS.register_module()
  9. class QDTrack(TwoStageDetector):
  10. def __init__(self, tracker=None, freeze_detector=False, *args, **kwargs):
  11. self.prepare_cfg(kwargs)
  12. super().__init__(*args, **kwargs)
  13. self.tracker_cfg = tracker
  14. self.freeze_detector = freeze_detector
  15. if self.freeze_detector:
  16. self._freeze_detector()
  17. def _freeze_detector(self):
  18. self.detector = [
  19. self.backbone, self.neck, self.rpn_head, self.roi_head.bbox_head
  20. ]
  21. for model in self.detector:
  22. model.eval()
  23. for param in model.parameters():
  24. param.requires_grad = False
  25. def prepare_cfg(self, kwargs):
  26. if kwargs.get('train_cfg', False):
  27. kwargs['roi_head']['track_train_cfg'] = kwargs['train_cfg'].get(
  28. 'embed', None)
  29. def init_tracker(self):
  30. # self.tracker = build_tracker(self.tracker_cfg)
  31. self.tracker = BYTETracker()
  32. def forward_train(self,
  33. img,
  34. img_metas,
  35. gt_bboxes,
  36. gt_labels,
  37. gt_match_indices,
  38. ref_img,
  39. ref_img_metas,
  40. ref_gt_bboxes,
  41. ref_gt_labels,
  42. ref_gt_match_indices,
  43. gt_bboxes_ignore=None,
  44. gt_masks=None,
  45. ref_gt_bboxes_ignore=None,
  46. ref_gt_masks=None,
  47. **kwargs):
  48. x = self.extract_feat(img)
  49. losses = dict()
  50. # RPN forward and loss
  51. proposal_cfg = self.train_cfg.get('rpn_proposal', self.test_cfg.rpn)
  52. rpn_losses, proposal_list = self.rpn_head.forward_train(
  53. x,
  54. img_metas,
  55. gt_bboxes,
  56. gt_labels=None,
  57. gt_bboxes_ignore=gt_bboxes_ignore,
  58. proposal_cfg=proposal_cfg)
  59. losses.update(rpn_losses)
  60. ref_x = self.extract_feat(ref_img)
  61. ref_proposals = self.rpn_head.simple_test_rpn(ref_x, ref_img_metas)
  62. roi_losses = self.roi_head.forward_train(
  63. x, img_metas, proposal_list, gt_bboxes, gt_labels,
  64. gt_match_indices, ref_x, ref_img_metas, ref_proposals,
  65. ref_gt_bboxes, ref_gt_labels, gt_bboxes_ignore, gt_masks,
  66. ref_gt_bboxes_ignore, **kwargs)
  67. losses.update(roi_losses)
  68. return losses
  69. def simple_test(self, img, img_metas, rescale=False):
  70. # TODO inherit from a base tracker
  71. assert self.roi_head.with_track, 'Track head must be implemented.'
  72. frame_id = img_metas[0].get('frame_id', -1)
  73. if frame_id == 0:
  74. self.init_tracker()
  75. x = self.extract_feat(img)
  76. proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
  77. det_bboxes, det_labels, track_feats = self.roi_head.simple_test(x, img_metas, proposal_list, rescale)
  78. bboxes, labels, ids = self.tracker.update(det_bboxes, det_labels, frame_id, track_feats)
  79. # if track_feats is not None:
  80. # bboxes, labels, ids = self.tracker.match(
  81. # bboxes=det_bboxes,
  82. # labels=det_labels,
  83. # track_feats=track_feats,
  84. # frame_id=frame_id)
  85. bbox_result = bbox2result(det_bboxes, det_labels,
  86. self.roi_head.bbox_head.num_classes)
  87. if track_feats is not None:
  88. track_result = track2result(bboxes, labels, ids,
  89. self.roi_head.bbox_head.num_classes)
  90. else:
  91. track_result = [
  92. np.zeros((0, 6), dtype=np.float32)
  93. for i in range(self.roi_head.bbox_head.num_classes)
  94. ]
  95. return dict(bbox_results=bbox_result, track_results=track_result)
  96. def show_result(self,
  97. img,
  98. result,
  99. thickness=1,
  100. font_scale=0.5,
  101. show=False,
  102. out_file=None,
  103. wait_time=0,
  104. backend='cv2',
  105. **kwargs):
  106. """Visualize tracking results.
  107. Args:
  108. img (str | ndarray): Filename of loaded image.
  109. result (dict): Tracking result.
  110. The value of key 'track_results' is ndarray with shape (n, 6)
  111. in [id, tl_x, tl_y, br_x, br_y, score] format.
  112. The value of key 'bbox_results' is ndarray with shape (n, 5)
  113. in [tl_x, tl_y, br_x, br_y, score] format.
  114. thickness (int, optional): Thickness of lines. Defaults to 1.
  115. font_scale (float, optional): Font scales of texts. Defaults
  116. to 0.5.
  117. show (bool, optional): Whether show the visualizations on the
  118. fly. Defaults to False.
  119. out_file (str | None, optional): Output filename. Defaults to None.
  120. backend (str, optional): Backend to draw the bounding boxes,
  121. options are `cv2` and `plt`. Defaults to 'cv2'.
  122. Returns:
  123. ndarray: Visualized image.
  124. """
  125. assert isinstance(result, dict)
  126. track_result = result.get('track_results', None)
  127. bboxes, labels, ids = restore_result(track_result, return_ids=True)
  128. img = imshow_tracks(
  129. img,
  130. bboxes,
  131. labels,
  132. ids,
  133. classes=self.CLASSES,
  134. thickness=thickness,
  135. font_scale=font_scale,
  136. show=show,
  137. out_file=out_file,
  138. wait_time=wait_time,
  139. backend=backend)
  140. return img