Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracker.py 8.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import numpy as np
  2. from sklearn.utils.linear_assignment_ import linear_assignment
  3. # from numba import jit
  4. import copy
  5. class Tracker(object):
  6. def __init__(self, opt):
  7. self.opt = opt
  8. self.reset()
  9. def init_track(self, results):
  10. for item in results:
  11. if item['score'] > self.opt.new_thresh:
  12. self.id_count += 1
  13. # active and age are never used in the paper
  14. item['active'] = 1
  15. item['age'] = 1
  16. item['tracking_id'] = self.id_count
  17. if not ('ct' in item):
  18. bbox = item['bbox']
  19. item['ct'] = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
  20. self.tracks.append(item)
  21. def reset(self):
  22. self.id_count = 0
  23. self.tracks = []
  24. def step(self, results_with_low, public_det=None):
  25. results = [item for item in results_with_low if item['score'] >= self.opt.track_thresh]
  26. # first association
  27. N = len(results)
  28. M = len(self.tracks)
  29. dets = np.array(
  30. [det['ct'] + det['tracking'] for det in results], np.float32) # N x 2
  31. track_size = np.array([((track['bbox'][2] - track['bbox'][0]) * \
  32. (track['bbox'][3] - track['bbox'][1])) \
  33. for track in self.tracks], np.float32) # M
  34. track_cat = np.array([track['class'] for track in self.tracks], np.int32) # M
  35. item_size = np.array([((item['bbox'][2] - item['bbox'][0]) * \
  36. (item['bbox'][3] - item['bbox'][1])) \
  37. for item in results], np.float32) # N
  38. item_cat = np.array([item['class'] for item in results], np.int32) # N
  39. tracks = np.array(
  40. [pre_det['ct'] for pre_det in self.tracks], np.float32) # M x 2
  41. dist = (((tracks.reshape(1, -1, 2) - \
  42. dets.reshape(-1, 1, 2)) ** 2).sum(axis=2)) # N x M
  43. invalid = ((dist > track_size.reshape(1, M)) + \
  44. (dist > item_size.reshape(N, 1)) + \
  45. (item_cat.reshape(N, 1) != track_cat.reshape(1, M))) > 0
  46. dist = dist + invalid * 1e18
  47. if self.opt.hungarian:
  48. assert not self.opt.hungarian, 'we only verify centertrack with greedy_assignment'
  49. item_score = np.array([item['score'] for item in results], np.float32) # N
  50. dist[dist > 1e18] = 1e18
  51. matched_indices = linear_assignment(dist)
  52. else:
  53. matched_indices = greedy_assignment(copy.deepcopy(dist))
  54. unmatched_dets = [d for d in range(dets.shape[0]) \
  55. if not (d in matched_indices[:, 0])]
  56. unmatched_tracks = [d for d in range(tracks.shape[0]) \
  57. if not (d in matched_indices[:, 1])]
  58. if self.opt.hungarian:
  59. assert not self.opt.hungarian, 'we only verify centertrack with greedy_assignment'
  60. matches = []
  61. for m in matched_indices:
  62. if dist[m[0], m[1]] > 1e16:
  63. unmatched_dets.append(m[0])
  64. unmatched_tracks.append(m[1])
  65. else:
  66. matches.append(m)
  67. matches = np.array(matches).reshape(-1, 2)
  68. else:
  69. matches = matched_indices
  70. ret = []
  71. for m in matches:
  72. track = results[m[0]]
  73. track['tracking_id'] = self.tracks[m[1]]['tracking_id']
  74. track['age'] = 1
  75. track['active'] = self.tracks[m[1]]['active'] + 1
  76. ret.append(track)
  77. if self.opt.public_det and len(unmatched_dets) > 0:
  78. assert not self.opt.public_det, 'we only verify centertrack with private detection'
  79. # Public detection: only create tracks from provided detections
  80. pub_dets = np.array([d['ct'] for d in public_det], np.float32)
  81. dist3 = ((dets.reshape(-1, 1, 2) - pub_dets.reshape(1, -1, 2)) ** 2).sum(
  82. axis=2)
  83. matched_dets = [d for d in range(dets.shape[0]) \
  84. if not (d in unmatched_dets)]
  85. dist3[matched_dets] = 1e18
  86. for j in range(len(pub_dets)):
  87. i = dist3[:, j].argmin()
  88. if dist3[i, j] < item_size[i]:
  89. dist3[i, :] = 1e18
  90. track = results[i]
  91. if track['score'] > self.opt.new_thresh:
  92. self.id_count += 1
  93. track['tracking_id'] = self.id_count
  94. track['age'] = 1
  95. track['active'] = 1
  96. ret.append(track)
  97. else:
  98. # Private detection: create tracks for all un-matched detections
  99. for i in unmatched_dets:
  100. track = results[i]
  101. if track['score'] > self.opt.new_thresh:
  102. self.id_count += 1
  103. track['tracking_id'] = self.id_count
  104. track['age'] = 1
  105. track['active'] = 1
  106. ret.append(track)
  107. # second association
  108. results_second = [item for item in results_with_low if item['score'] < self.opt.track_thresh]
  109. self_tracks_second = [self.tracks[i] for i in unmatched_tracks if self.tracks[i]['active'] > 0]
  110. second2original = [i for i in unmatched_tracks if self.tracks[i]['active'] > 0]
  111. N = len(results_second)
  112. M = len(self_tracks_second)
  113. if N > 0 and M > 0:
  114. dets = np.array(
  115. [det['ct'] + det['tracking'] for det in results_second], np.float32) # N x 2
  116. track_size = np.array([((track['bbox'][2] - track['bbox'][0]) * \
  117. (track['bbox'][3] - track['bbox'][1])) \
  118. for track in self_tracks_second], np.float32) # M
  119. track_cat = np.array([track['class'] for track in self_tracks_second], np.int32) # M
  120. item_size = np.array([((item['bbox'][2] - item['bbox'][0]) * \
  121. (item['bbox'][3] - item['bbox'][1])) \
  122. for item in results_second], np.float32) # N
  123. item_cat = np.array([item['class'] for item in results_second], np.int32) # N
  124. tracks_second = np.array(
  125. [pre_det['ct'] for pre_det in self_tracks_second], np.float32) # M x 2
  126. dist = (((tracks_second.reshape(1, -1, 2) - \
  127. dets.reshape(-1, 1, 2)) ** 2).sum(axis=2)) # N x M
  128. invalid = ((dist > track_size.reshape(1, M)) + \
  129. (dist > item_size.reshape(N, 1)) + \
  130. (item_cat.reshape(N, 1) != track_cat.reshape(1, M))) > 0
  131. dist = dist + invalid * 1e18
  132. matched_indices_second = greedy_assignment(copy.deepcopy(dist), 1e8)
  133. unmatched_tracks_second = [d for d in range(tracks_second.shape[0]) \
  134. if not (d in matched_indices_second[:, 1])]
  135. matches_second = matched_indices_second
  136. for m in matches_second:
  137. track = results_second[m[0]]
  138. track['tracking_id'] = self_tracks_second[m[1]]['tracking_id']
  139. track['age'] = 1
  140. track['active'] = self_tracks_second[m[1]]['active'] + 1
  141. ret.append(track)
  142. unmatched_tracks = [second2original[i] for i in unmatched_tracks_second] + \
  143. [i for i in unmatched_tracks if self.tracks[i]['active'] == 0]
  144. #. for debug
  145. # unmatched_tracks = [i for i in unmatched_tracks if self.tracks[i]['active'] > 0] + \
  146. # [i for i in unmatched_tracks if self.tracks[i]['active'] == 0]
  147. for i in unmatched_tracks:
  148. track = self.tracks[i]
  149. if track['age'] < self.opt.max_age:
  150. track['age'] += 1
  151. track['active'] = 0
  152. bbox = track['bbox']
  153. ct = track['ct']
  154. v = [0, 0]
  155. track['bbox'] = [
  156. bbox[0] + v[0], bbox[1] + v[1],
  157. bbox[2] + v[0], bbox[3] + v[1]]
  158. track['ct'] = [ct[0] + v[0], ct[1] + v[1]]
  159. ret.append(track)
  160. self.tracks = ret
  161. return ret
  162. def greedy_assignment(dist, thresh=1e16):
  163. matched_indices = []
  164. if dist.shape[1] == 0:
  165. return np.array(matched_indices, np.int32).reshape(-1, 2)
  166. for i in range(dist.shape[0]):
  167. j = dist[i].argmin()
  168. if dist[i][j] < thresh:
  169. dist[:, j] = 1e18
  170. matched_indices.append([i, j])
  171. return np.array(matched_indices, np.int32).reshape(-1, 2)