Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracker.py 7.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. """
  2. Copyright (c) https://github.com/xingyizhou/CenterTrack
  3. Modified by Peize Sun, Rufeng Zhang
  4. """
  5. # coding: utf-8
  6. import torch
  7. from scipy.optimize import linear_sum_assignment
  8. from util import box_ops
  9. import copy
  10. class Tracker(object):
  11. def __init__(self, score_thresh, max_age=32):
  12. self.score_thresh = score_thresh
  13. self.low_thresh = 0.2
  14. self.high_thresh = score_thresh + 0.1
  15. self.max_age = max_age
  16. self.id_count = 0
  17. self.tracks_dict = dict()
  18. self.tracks = list()
  19. self.unmatched_tracks = list()
  20. self.reset_all()
  21. def reset_all(self):
  22. self.id_count = 0
  23. self.tracks_dict = dict()
  24. self.tracks = list()
  25. self.unmatched_tracks = list()
  26. def init_track(self, results):
  27. scores = results["scores"]
  28. classes = results["labels"]
  29. bboxes = results["boxes"] # x1y1x2y2
  30. ret = list()
  31. ret_dict = dict()
  32. for idx in range(scores.shape[0]):
  33. if scores[idx] >= self.score_thresh:
  34. self.id_count += 1
  35. obj = dict()
  36. obj["score"] = float(scores[idx])
  37. obj["bbox"] = bboxes[idx, :].cpu().numpy().tolist()
  38. obj["tracking_id"] = self.id_count
  39. obj['active'] = 1
  40. obj['age'] = 1
  41. ret.append(obj)
  42. ret_dict[idx] = obj
  43. self.tracks = ret
  44. self.tracks_dict = ret_dict
  45. return copy.deepcopy(ret)
  46. def step(self, output_results):
  47. scores = output_results["scores"]
  48. bboxes = output_results["boxes"] # x1y1x2y2
  49. track_bboxes = output_results["track_boxes"] if "track_boxes" in output_results else None # x1y1x2y2
  50. results = list()
  51. results_dict = dict()
  52. results_second = list()
  53. tracks = list()
  54. for idx in range(scores.shape[0]):
  55. if idx in self.tracks_dict and track_bboxes is not None:
  56. self.tracks_dict[idx]["bbox"] = track_bboxes[idx, :].cpu().numpy().tolist()
  57. if scores[idx] >= self.score_thresh:
  58. obj = dict()
  59. obj["score"] = float(scores[idx])
  60. obj["bbox"] = bboxes[idx, :].cpu().numpy().tolist()
  61. results.append(obj)
  62. results_dict[idx] = obj
  63. elif scores[idx] >= self.low_thresh:
  64. second_obj = dict()
  65. second_obj["score"] = float(scores[idx])
  66. second_obj["bbox"] = bboxes[idx, :].cpu().numpy().tolist()
  67. results_second.append(second_obj)
  68. results_dict[idx] = second_obj
  69. tracks = [v for v in self.tracks_dict.values()] + self.unmatched_tracks
  70. # for trackss in tracks:
  71. # print(trackss.keys())
  72. N = len(results)
  73. M = len(tracks)
  74. ret = list()
  75. unmatched_tracks = [t for t in range(M)]
  76. unmatched_dets = [d for d in range(N)]
  77. if N > 0 and M > 0:
  78. det_box = torch.stack([torch.tensor(obj['bbox']) for obj in results], dim=0) # N x 4
  79. track_box = torch.stack([torch.tensor(obj['bbox']) for obj in tracks], dim=0) # M x 4
  80. cost_bbox = 1.0 - box_ops.generalized_box_iou(det_box, track_box) # N x M
  81. matched_indices = linear_sum_assignment(cost_bbox)
  82. unmatched_dets = [d for d in range(N) if not (d in matched_indices[0])]
  83. unmatched_tracks = [d for d in range(M) if not (d in matched_indices[1])]
  84. matches = [[],[]]
  85. for (m0, m1) in zip(matched_indices[0], matched_indices[1]):
  86. if cost_bbox[m0, m1] > 1.2:
  87. unmatched_dets.append(m0)
  88. unmatched_tracks.append(m1)
  89. else:
  90. matches[0].append(m0)
  91. matches[1].append(m1)
  92. for (m0, m1) in zip(matches[0], matches[1]):
  93. track = results[m0]
  94. track['tracking_id'] = tracks[m1]['tracking_id']
  95. track['age'] = 1
  96. track['active'] = 1
  97. ret.append(track)
  98. # second association
  99. N_second = len(results_second)
  100. unmatched_tracks_obj = list()
  101. for i in unmatched_tracks:
  102. #print(tracks[i].keys())
  103. track = tracks[i]
  104. if track['active'] == 1:
  105. unmatched_tracks_obj.append(track)
  106. M_second = len(unmatched_tracks_obj)
  107. unmatched_tracks_second = [t for t in range(M_second)]
  108. if N_second > 0 and M_second > 0:
  109. det_box_second = torch.stack([torch.tensor(obj['bbox']) for obj in results_second], dim=0) # N_second x 4
  110. track_box_second = torch.stack([torch.tensor(obj['bbox']) for obj in unmatched_tracks_obj], dim=0) # M_second x 4
  111. cost_bbox_second = 1.0 - box_ops.generalized_box_iou(det_box_second, track_box_second) # N_second x M_second
  112. matched_indices_second = linear_sum_assignment(cost_bbox_second)
  113. unmatched_tracks_second = [d for d in range(M_second) if not (d in matched_indices_second[1])]
  114. matches_second = [[],[]]
  115. for (m0, m1) in zip(matched_indices_second[0], matched_indices_second[1]):
  116. if cost_bbox_second[m0, m1] > 0.8:
  117. unmatched_tracks_second.append(m1)
  118. else:
  119. matches_second[0].append(m0)
  120. matches_second[1].append(m1)
  121. for (m0, m1) in zip(matches_second[0], matches_second[1]):
  122. track = results_second[m0]
  123. track['tracking_id'] = unmatched_tracks_obj[m1]['tracking_id']
  124. track['age'] = 1
  125. track['active'] = 1
  126. ret.append(track)
  127. for i in unmatched_dets:
  128. trackd = results[i]
  129. if trackd["score"] >= self.high_thresh:
  130. self.id_count += 1
  131. trackd['tracking_id'] = self.id_count
  132. trackd['age'] = 1
  133. trackd['active'] = 1
  134. ret.append(trackd)
  135. # ------------------------------------------------------ #
  136. ret_unmatched_tracks = []
  137. for j in unmatched_tracks:
  138. track = tracks[j]
  139. if track['active'] == 0 and track['age'] < self.max_age:
  140. track['age'] += 1
  141. track['active'] = 0
  142. ret.append(track)
  143. ret_unmatched_tracks.append(track)
  144. for i in unmatched_tracks_second:
  145. track = unmatched_tracks_obj[i]
  146. if track['age'] < self.max_age:
  147. track['age'] += 1
  148. track['active'] = 0
  149. ret.append(track)
  150. ret_unmatched_tracks.append(track)
  151. # for i in unmatched_tracks:
  152. # track = tracks[i]
  153. # if track['age'] < self.max_age:
  154. # track['age'] += 1
  155. # track['active'] = 0
  156. # ret.append(track)
  157. # ret_unmatched_tracks.append(track)
  158. #print(len(ret_unmatched_tracks))
  159. self.tracks = ret
  160. self.tracks_dict = {red_ind:red for red_ind, red in results_dict.items() if 'tracking_id' in red}
  161. self.unmatched_tracks = ret_unmatched_tracks
  162. return copy.deepcopy(ret)