Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracker.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. import numpy as np
  2. from sklearn.utils.linear_assignment_ import linear_assignment
  3. import copy
  4. from sklearn.metrics.pairwise import cosine_similarity as cosine
  5. class Tracker(object):
  6. def __init__(self, opt):
  7. self.opt = opt
  8. self.reset()
  9. self.nID = 10000
  10. self.alpha = 0.1
  11. def init_track(self, results):
  12. for item in results:
  13. if item['score'] > self.opt.new_thresh:
  14. self.id_count += 1
  15. # active and age are never used in the paper
  16. item['active'] = 1
  17. item['age'] = 1
  18. item['tracking_id'] = self.id_count
  19. if not ('ct' in item):
  20. bbox = item['bbox']
  21. item['ct'] = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
  22. self.tracks.append(item)
  23. self.nID = 10000
  24. self.embedding_bank = np.zeros((self.nID, 128))
  25. self.cat_bank = np.zeros((self.nID), dtype=np.int)
  26. def reset(self):
  27. self.id_count = 0
  28. self.nID = 10000
  29. self.tracks = []
  30. self.embedding_bank = np.zeros((self.nID, 128))
  31. self.cat_bank = np.zeros((self.nID), dtype=np.int)
  32. self.tracklet_ages = np.zeros((self.nID), dtype=np.int)
  33. self.alive = []
  34. def step(self, results_with_low, public_det=None):
  35. results = [item for item in results_with_low if item['score'] >= self.opt.track_thresh]
  36. # first association
  37. N = len(results)
  38. M = len(self.tracks)
  39. self.alive = []
  40. track_boxes = np.array([[track['bbox'][0], track['bbox'][1],
  41. track['bbox'][2], track['bbox'][3]] for track in self.tracks], np.float32) # M x 4
  42. det_boxes = np.array([[item['bbox'][0], item['bbox'][1],
  43. item['bbox'][2], item['bbox'][3]] for item in results], np.float32) # N x 4
  44. box_ious = self.bbox_overlaps_py(det_boxes, track_boxes)
  45. dets = np.array(
  46. [det['ct'] + det['tracking'] for det in results], np.float32) # N x 2
  47. track_size = np.array([((track['bbox'][2] - track['bbox'][0]) * \
  48. (track['bbox'][3] - track['bbox'][1])) \
  49. for track in self.tracks], np.float32) # M
  50. track_cat = np.array([track['class'] for track in self.tracks], np.int32) # M
  51. item_size = np.array([((item['bbox'][2] - item['bbox'][0]) * \
  52. (item['bbox'][3] - item['bbox'][1])) \
  53. for item in results], np.float32) # N
  54. item_cat = np.array([item['class'] for item in results], np.int32) # N
  55. tracks = np.array(
  56. [pre_det['ct'] for pre_det in self.tracks], np.float32) # M x 2
  57. dist = (((tracks.reshape(1, -1, 2) - \
  58. dets.reshape(-1, 1, 2)) ** 2).sum(axis=2)) # N x M
  59. if self.opt.dataset == 'youtube_vis':
  60. invalid = ((dist > track_size.reshape(1, M)) + \
  61. (dist > item_size.reshape(N, 1)) + (box_ious < self.opt.overlap_thresh)) > 0
  62. else:
  63. invalid = ((dist > track_size.reshape(1, M)) + \
  64. (dist > item_size.reshape(N, 1)) + \
  65. (item_cat.reshape(N, 1) != track_cat.reshape(1, M)) + (box_ious < self.opt.overlap_thresh)) > 0
  66. dist = dist + invalid * 1e18
  67. if self.opt.hungarian:
  68. item_score = np.array([item['score'] for item in results], np.float32) # N
  69. dist[dist > 1e18] = 1e18
  70. matched_indices = linear_assignment(dist)
  71. else:
  72. matched_indices = greedy_assignment(copy.deepcopy(dist))
  73. unmatched_dets = [d for d in range(dets.shape[0]) \
  74. if not (d in matched_indices[:, 0])]
  75. unmatched_tracks = [d for d in range(tracks.shape[0]) \
  76. if not (d in matched_indices[:, 1])]
  77. if self.opt.hungarian:
  78. matches = []
  79. for m in matched_indices:
  80. if dist[m[0], m[1]] > 1e16:
  81. unmatched_dets.append(m[0])
  82. unmatched_tracks.append(m[1])
  83. else:
  84. matches.append(m)
  85. matches = np.array(matches).reshape(-1, 2)
  86. else:
  87. matches = matched_indices
  88. ret = []
  89. for m in matches:
  90. track = results[m[0]]
  91. track['tracking_id'] = self.tracks[m[1]]['tracking_id']
  92. track['age'] = 1
  93. track['active'] = self.tracks[m[1]]['active'] + 1
  94. if 'embedding' in track:
  95. self.alive.append(track['tracking_id'])
  96. self.embedding_bank[self.tracks[m[1]]['tracking_id'] - 1, :] = self.alpha * track['embedding'] \
  97. + (1 - self.alpha) * self.embedding_bank[
  98. self.tracks[m[1]][
  99. 'tracking_id'] - 1,
  100. :]
  101. self.cat_bank[self.tracks[m[1]]['tracking_id'] - 1] = track['class']
  102. ret.append(track)
  103. if self.opt.public_det and len(unmatched_dets) > 0:
  104. # Public detection: only create tracks from provided detections
  105. pub_dets = np.array([d['ct'] for d in public_det], np.float32)
  106. dist3 = ((dets.reshape(-1, 1, 2) - pub_dets.reshape(1, -1, 2)) ** 2).sum(
  107. axis=2)
  108. matched_dets = [d for d in range(dets.shape[0]) \
  109. if not (d in unmatched_dets)]
  110. dist3[matched_dets] = 1e18
  111. for j in range(len(pub_dets)):
  112. i = dist3[:, j].argmin()
  113. if dist3[i, j] < item_size[i]:
  114. dist3[i, :] = 1e18
  115. track = results[i]
  116. if track['score'] > self.opt.new_thresh:
  117. self.id_count += 1
  118. track['tracking_id'] = self.id_count
  119. track['age'] = 1
  120. track['active'] = 1
  121. ret.append(track)
  122. else:
  123. # Private detection: create tracks for all un-matched detections
  124. for i in unmatched_dets:
  125. track = results[i]
  126. if track['score'] > self.opt.new_thresh:
  127. if 'embedding' in track:
  128. max_id, max_cos = self.get_similarity(track['embedding'], False, track['class'])
  129. if max_cos >= 0.3 and self.tracklet_ages[max_id - 1] < self.opt.window_size:
  130. track['tracking_id'] = max_id
  131. track['age'] = 1
  132. track['active'] = 1
  133. self.embedding_bank[track['tracking_id'] - 1, :] = self.alpha * track['embedding'] \
  134. + (1 - self.alpha) * self.embedding_bank[track['tracking_id'] - 1,:]
  135. else:
  136. self.id_count += 1
  137. track['tracking_id'] = self.id_count
  138. track['age'] = 1
  139. track['active'] = 1
  140. self.embedding_bank[self.id_count - 1, :] = track['embedding']
  141. self.cat_bank[self.id_count - 1] = track['class']
  142. self.alive.append(track['tracking_id'])
  143. ret.append(track)
  144. else:
  145. self.id_count += 1
  146. track['tracking_id'] = self.id_count
  147. track['age'] = 1
  148. track['active'] = 1
  149. ret.append(track)
  150. self.tracklet_ages[:self.id_count] = self.tracklet_ages[:self.id_count] + 1
  151. for track in ret:
  152. self.tracklet_ages[track['tracking_id'] - 1] = 1
  153. # second association
  154. results_second = [item for item in results_with_low if item['score'] < self.opt.track_thresh]
  155. self_tracks_second = [self.tracks[i] for i in unmatched_tracks if self.tracks[i]['active'] > 0]
  156. second2original = [i for i in unmatched_tracks if self.tracks[i]['active'] > 0]
  157. N = len(results_second)
  158. M = len(self_tracks_second)
  159. if N > 0 and M > 0:
  160. track_boxes_second = np.array([[track['bbox'][0], track['bbox'][1],
  161. track['bbox'][2], track['bbox'][3]] for track in self_tracks_second], np.float32) # M x 4
  162. det_boxes_second = np.array([[item['bbox'][0], item['bbox'][1],
  163. item['bbox'][2], item['bbox'][3]] for item in results_second], np.float32) # N x 4
  164. box_ious_second = self.bbox_overlaps_py(det_boxes_second, track_boxes_second)
  165. dets = np.array(
  166. [det['ct'] + det['tracking'] for det in results_second], np.float32) # N x 2
  167. track_size = np.array([((track['bbox'][2] - track['bbox'][0]) * \
  168. (track['bbox'][3] - track['bbox'][1])) \
  169. for track in self_tracks_second], np.float32) # M
  170. track_cat = np.array([track['class'] for track in self_tracks_second], np.int32) # M
  171. item_size = np.array([((item['bbox'][2] - item['bbox'][0]) * \
  172. (item['bbox'][3] - item['bbox'][1])) \
  173. for item in results_second], np.float32) # N
  174. item_cat = np.array([item['class'] for item in results_second], np.int32) # N
  175. tracks_second = np.array(
  176. [pre_det['ct'] for pre_det in self_tracks_second], np.float32) # M x 2
  177. dist = (((tracks_second.reshape(1, -1, 2) - \
  178. dets.reshape(-1, 1, 2)) ** 2).sum(axis=2)) # N x M
  179. invalid = ((dist > track_size.reshape(1, M)) + \
  180. (dist > item_size.reshape(N, 1)) + \
  181. (item_cat.reshape(N, 1) != track_cat.reshape(1, M)) + (box_ious_second < 0.3)) > 0
  182. dist = dist + invalid * 1e18
  183. matched_indices_second = greedy_assignment(copy.deepcopy(dist), 1e8)
  184. unmatched_tracks_second = [d for d in range(tracks_second.shape[0]) \
  185. if not (d in matched_indices_second[:, 1])]
  186. matches_second = matched_indices_second
  187. for m in matches_second:
  188. track = results_second[m[0]]
  189. track['tracking_id'] = self_tracks_second[m[1]]['tracking_id']
  190. track['age'] = 1
  191. track['active'] = self_tracks_second[m[1]]['active'] + 1
  192. if 'embedding' in track:
  193. self.alive.append(track['tracking_id'])
  194. self.embedding_bank[self_tracks_second[m[1]]['tracking_id'] - 1, :] = self.alpha * track['embedding'] \
  195. + (1 - self.alpha) * self.embedding_bank[self_tracks_second[m[1]]['tracking_id'] - 1,:]
  196. self.cat_bank[self_tracks_second[m[1]]['tracking_id'] - 1] = track['class']
  197. ret.append(track)
  198. unmatched_tracks = [second2original[i] for i in unmatched_tracks_second] + \
  199. [i for i in unmatched_tracks if self.tracks[i]['active'] == 0]
  200. # Never used
  201. for i in unmatched_tracks:
  202. track = self.tracks[i]
  203. if track['age'] < self.opt.max_age:
  204. track['age'] += 1
  205. track['active'] = 1 # 0
  206. bbox = track['bbox']
  207. ct = track['ct']
  208. v = [0, 0]
  209. track['bbox'] = [
  210. bbox[0] + v[0], bbox[1] + v[1],
  211. bbox[2] + v[0], bbox[3] + v[1]]
  212. track['ct'] = [ct[0] + v[0], ct[1] + v[1]]
  213. ret.append(track)
  214. for r_ in ret:
  215. del r_['embedding']
  216. self.tracks = ret
  217. return ret
  218. def get_similarity(self, feat, stat, cls):
  219. max_id = -1
  220. max_cos = -1
  221. if stat:
  222. nID = self.id_count
  223. else:
  224. nID = self.id_count
  225. a = feat[None, :]
  226. b = self.embedding_bank[:nID, :]
  227. if len(b) > 0:
  228. alive = np.array(self.alive, dtype=np.int) - 1
  229. cosim = cosine(a, b)
  230. cosim = np.reshape(cosim, newshape=(-1))
  231. cosim[alive] = -2
  232. cosim[nID - 1] = -2
  233. cosim[np.where(self.cat_bank[:nID] != cls)[0]] = -2
  234. max_id = int(np.argmax(cosim) + 1)
  235. max_cos = np.max(cosim)
  236. return max_id, max_cos
  237. def bbox_overlaps_py(self, boxes, query_boxes):
  238. """
  239. determine overlaps between boxes and query_boxes
  240. :param boxes: n * 4 bounding boxes
  241. :param query_boxes: k * 4 bounding boxes
  242. :return: overlaps: n * k overlaps
  243. """
  244. n_ = boxes.shape[0]
  245. k_ = query_boxes.shape[0]
  246. overlaps = np.zeros((n_, k_), dtype=np.float)
  247. for k in range(k_):
  248. query_box_area = (query_boxes[k, 2] - query_boxes[k, 0] + 1) * (query_boxes[k, 3] - query_boxes[k, 1] + 1)
  249. for n in range(n_):
  250. iw = min(boxes[n, 2], query_boxes[k, 2]) - max(boxes[n, 0], query_boxes[k, 0]) + 1
  251. if iw > 0:
  252. ih = min(boxes[n, 3], query_boxes[k, 3]) - max(boxes[n, 1], query_boxes[k, 1]) + 1
  253. if ih > 0:
  254. box_area = (boxes[n, 2] - boxes[n, 0] + 1) * (boxes[n, 3] - boxes[n, 1] + 1)
  255. all_area = float(box_area + query_box_area - iw * ih)
  256. overlaps[n, k] = iw * ih / all_area
  257. return overlaps
  258. def greedy_assignment(dist, thresh=1e16):
  259. matched_indices = []
  260. if dist.shape[1] == 0:
  261. return np.array(matched_indices, np.int32).reshape(-1, 2)
  262. for i in range(dist.shape[0]):
  263. j = dist[i].argmin()
  264. if dist[i][j] < thresh:
  265. dist[:, j] = 1e18
  266. matched_indices.append([i, j])
  267. return np.array(matched_indices, np.int32).reshape(-1, 2)