Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

byte_tracker.py 15KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. from collections import deque
  2. import torch
  3. import numpy as np
  4. from utils.kalman_filter import KalmanFilter
  5. from utils.log import logger
  6. from models import *
  7. from tracker import matching
  8. from .basetrack import BaseTrack, TrackState
  9. class STrack(BaseTrack):
  10. def __init__(self, tlwh, score):
  11. # wait activate
  12. self._tlwh = np.asarray(tlwh, dtype=np.float)
  13. self.kalman_filter = None
  14. self.mean, self.covariance = None, None
  15. self.is_activated = False
  16. self.score = score
  17. self.tracklet_len = 0
  18. def predict(self):
  19. mean_state = self.mean.copy()
  20. if self.state != TrackState.Tracked:
  21. mean_state[7] = 0
  22. self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
  23. @staticmethod
  24. def multi_predict(stracks, kalman_filter):
  25. if len(stracks) > 0:
  26. multi_mean = np.asarray([st.mean.copy() for st in stracks])
  27. multi_covariance = np.asarray([st.covariance for st in stracks])
  28. for i, st in enumerate(stracks):
  29. if st.state != TrackState.Tracked:
  30. multi_mean[i][7] = 0
  31. # multi_mean, multi_covariance = STrack.kalman_filter.multi_predict(multi_mean, multi_covariance)
  32. multi_mean, multi_covariance = kalman_filter.multi_predict(multi_mean, multi_covariance)
  33. for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
  34. stracks[i].mean = mean
  35. stracks[i].covariance = cov
  36. def activate(self, kalman_filter, frame_id):
  37. """Start a new tracklet"""
  38. self.kalman_filter = kalman_filter
  39. self.track_id = self.next_id()
  40. self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh))
  41. self.tracklet_len = 0
  42. self.state = TrackState.Tracked
  43. #self.is_activated = True
  44. self.frame_id = frame_id
  45. self.start_frame = frame_id
  46. def re_activate(self, new_track, frame_id, new_id=False):
  47. self.mean, self.covariance = self.kalman_filter.update(
  48. self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
  49. )
  50. self.tracklet_len = 0
  51. self.state = TrackState.Tracked
  52. self.is_activated = True
  53. self.frame_id = frame_id
  54. if new_id:
  55. self.track_id = self.next_id()
  56. def update(self, new_track, frame_id, update_feature=True):
  57. """
  58. Update a matched track
  59. :type new_track: STrack
  60. :type frame_id: int
  61. :type update_feature: bool
  62. :return:
  63. """
  64. self.frame_id = frame_id
  65. self.tracklet_len += 1
  66. new_tlwh = new_track.tlwh
  67. self.mean, self.covariance = self.kalman_filter.update(
  68. self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
  69. self.state = TrackState.Tracked
  70. self.is_activated = True
  71. self.score = new_track.score
  72. @property
  73. def tlwh(self):
  74. """Get current position in bounding box format `(top left x, top left y,
  75. width, height)`.
  76. """
  77. if self.mean is None:
  78. return self._tlwh.copy()
  79. ret = self.mean[:4].copy()
  80. ret[2] *= ret[3]
  81. ret[:2] -= ret[2:] / 2
  82. return ret
  83. @property
  84. def tlbr(self):
  85. """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
  86. `(top left, bottom right)`.
  87. """
  88. ret = self.tlwh.copy()
  89. ret[2:] += ret[:2]
  90. return ret
  91. @staticmethod
  92. def tlwh_to_xyah(tlwh):
  93. """Convert bounding box to format `(center x, center y, aspect ratio,
  94. height)`, where the aspect ratio is `width / height`.
  95. """
  96. ret = np.asarray(tlwh).copy()
  97. ret[:2] += ret[2:] / 2
  98. ret[2] /= ret[3]
  99. return ret
  100. def to_xyah(self):
  101. return self.tlwh_to_xyah(self.tlwh)
  102. @staticmethod
  103. def tlbr_to_tlwh(tlbr):
  104. ret = np.asarray(tlbr).copy()
  105. ret[2:] -= ret[:2]
  106. return ret
  107. @staticmethod
  108. def tlwh_to_tlbr(tlwh):
  109. ret = np.asarray(tlwh).copy()
  110. ret[2:] += ret[:2]
  111. return ret
  112. def __repr__(self):
  113. return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
  114. class BYTETracker(object):
  115. def __init__(self, opt, frame_rate=30):
  116. self.opt = opt
  117. self.model = Darknet(opt.cfg, nID=14455)
  118. # load_darknet_weights(self.model, opt.weights)
  119. self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False)
  120. self.model.cuda().eval()
  121. self.tracked_stracks = [] # type: list[STrack]
  122. self.lost_stracks = [] # type: list[STrack]
  123. self.removed_stracks = [] # type: list[STrack]
  124. self.frame_id = 0
  125. self.det_thresh = opt.conf_thres
  126. self.init_thresh = self.det_thresh + 0.2
  127. self.low_thresh = 0.3
  128. self.buffer_size = int(frame_rate / 30.0 * opt.track_buffer)
  129. self.max_time_lost = self.buffer_size
  130. self.kalman_filter = KalmanFilter()
  131. def update(self, im_blob, img0):
  132. """
  133. Processes the image frame and finds bounding box(detections).
  134. Associates the detection with corresponding tracklets and also handles lost, removed, refound and active tracklets
  135. Parameters
  136. ----------
  137. im_blob : torch.float32
  138. Tensor of shape depending upon the size of image. By default, shape of this tensor is [1, 3, 608, 1088]
  139. img0 : ndarray
  140. ndarray of shape depending on the input image sequence. By default, shape is [608, 1080, 3]
  141. Returns
  142. -------
  143. output_stracks : list of Strack(instances)
  144. The list contains information regarding the online_tracklets for the recieved image tensor.
  145. """
  146. self.frame_id += 1
  147. activated_starcks = [] # for storing active tracks, for the current frame
  148. refind_stracks = [] # Lost Tracks whose detections are obtained in the current frame
  149. lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
  150. removed_stracks = []
  151. t1 = time.time()
  152. ''' Step 1: Network forward, get detections & embeddings'''
  153. with torch.no_grad():
  154. pred = self.model(im_blob)
  155. # pred is tensor of all the proposals (default number of proposals: 54264). Proposals have information associated with the bounding box and embeddings
  156. pred = pred[pred[:, :, 4] > self.low_thresh]
  157. # pred now has lesser number of proposals. Proposals rejected on basis of object confidence score
  158. if len(pred) > 0:
  159. dets = non_max_suppression(pred.unsqueeze(0), self.low_thresh, self.opt.nms_thres)[0].cpu()
  160. # Final proposals are obtained in dets. Information of bounding box and embeddings also included
  161. # Next step changes the detection scales
  162. scale_coords(self.opt.img_size, dets[:, :4], img0.shape).round()
  163. '''Detections is list of (x1, y1, x2, y2, object_conf, class_score, class_pred)'''
  164. # class_pred is the embeddings.
  165. dets = dets.numpy()
  166. remain_inds = dets[:, 4] > self.det_thresh
  167. inds_low = dets[:, 4] > self.low_thresh
  168. inds_high = dets[:, 4] < self.det_thresh
  169. inds_second = np.logical_and(inds_low, inds_high)
  170. dets_second = dets[inds_second]
  171. dets = dets[remain_inds]
  172. detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4]) for
  173. tlbrs in dets[:, :5]]
  174. else:
  175. detections = []
  176. dets_second = []
  177. t2 = time.time()
  178. # print('Forward: {} s'.format(t2-t1))
  179. ''' Add newly detected tracklets to tracked_stracks'''
  180. unconfirmed = []
  181. tracked_stracks = [] # type: list[STrack]
  182. for track in self.tracked_stracks:
  183. if not track.is_activated:
  184. # previous tracks which are not active in the current frame are added in unconfirmed list
  185. unconfirmed.append(track)
  186. # print("Should not be here, in unconfirmed")
  187. else:
  188. # Active tracks are added to the local list 'tracked_stracks'
  189. tracked_stracks.append(track)
  190. ''' Step 2: First association, with embedding'''
  191. # Combining currently tracked_stracks and lost_stracks
  192. strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
  193. # Predict the current location with KF
  194. STrack.multi_predict(strack_pool, self.kalman_filter)
  195. dists = matching.iou_distance(strack_pool, detections)
  196. # The dists is the list of distances of the detection with the tracks in strack_pool
  197. matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.8)
  198. # The matches is the array for corresponding matches of the detection with the corresponding strack_pool
  199. for itracked, idet in matches:
  200. # itracked is the id of the track and idet is the detection
  201. track = strack_pool[itracked]
  202. det = detections[idet]
  203. if track.state == TrackState.Tracked:
  204. # If the track is active, add the detection to the track
  205. track.update(detections[idet], self.frame_id)
  206. activated_starcks.append(track)
  207. else:
  208. # We have obtained a detection from a track which is not active, hence put the track in refind_stracks list
  209. track.re_activate(det, self.frame_id, new_id=False)
  210. refind_stracks.append(track)
  211. # association the untrack to the low score detections
  212. if len(dets_second) > 0:
  213. detections_second = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4]) for
  214. tlbrs in dets_second[:, :5]]
  215. else:
  216. detections_second = []
  217. r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
  218. dists = matching.iou_distance(r_tracked_stracks, detections_second)
  219. matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.4)
  220. for itracked, idet in matches:
  221. track = r_tracked_stracks[itracked]
  222. det = detections_second[idet]
  223. if track.state == TrackState.Tracked:
  224. track.update(det, self.frame_id)
  225. activated_starcks.append(track)
  226. else:
  227. track.re_activate(det, self.frame_id, new_id=False)
  228. refind_stracks.append(track)
  229. for it in u_track:
  230. track = r_tracked_stracks[it]
  231. if not track.state == TrackState.Lost:
  232. track.mark_lost()
  233. lost_stracks.append(track)
  234. # If no detections are obtained for tracks (u_track), the tracks are added to lost_tracks list and are marked lost
  235. '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
  236. detections = [detections[i] for i in u_detection]
  237. dists = matching.iou_distance(unconfirmed, detections)
  238. matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
  239. for itracked, idet in matches:
  240. unconfirmed[itracked].update(detections[idet], self.frame_id)
  241. activated_starcks.append(unconfirmed[itracked])
  242. # The tracks which are yet not matched
  243. for it in u_unconfirmed:
  244. track = unconfirmed[it]
  245. track.mark_removed()
  246. removed_stracks.append(track)
  247. # after all these confirmation steps, if a new detection is found, it is initialized for a new track
  248. """ Step 4: Init new stracks"""
  249. for inew in u_detection:
  250. track = detections[inew]
  251. if track.score < self.init_thresh:
  252. continue
  253. track.activate(self.kalman_filter, self.frame_id)
  254. activated_starcks.append(track)
  255. """ Step 5: Update state"""
  256. # If the tracks are lost for more frames than the threshold number, the tracks are removed.
  257. for track in self.lost_stracks:
  258. if self.frame_id - track.end_frame > self.max_time_lost:
  259. track.mark_removed()
  260. removed_stracks.append(track)
  261. # print('Remained match {} s'.format(t4-t3))
  262. # Update the self.tracked_stracks and self.lost_stracks using the updates in this step.
  263. self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
  264. self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
  265. self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
  266. # self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost] # type: list[STrack]
  267. self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
  268. self.lost_stracks.extend(lost_stracks)
  269. self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
  270. self.removed_stracks.extend(removed_stracks)
  271. self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
  272. # get scores of lost tracks
  273. output_stracks = [track for track in self.tracked_stracks if track.is_activated]
  274. logger.debug('===========Frame {}=========='.format(self.frame_id))
  275. logger.debug('Activated: {}'.format([track.track_id for track in activated_starcks]))
  276. logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
  277. logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
  278. logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
  279. # print('Final {} s'.format(t5-t4))
  280. return output_stracks
  281. def joint_stracks(tlista, tlistb):
  282. exists = {}
  283. res = []
  284. for t in tlista:
  285. exists[t.track_id] = 1
  286. res.append(t)
  287. for t in tlistb:
  288. tid = t.track_id
  289. if not exists.get(tid, 0):
  290. exists[tid] = 1
  291. res.append(t)
  292. return res
  293. def sub_stracks(tlista, tlistb):
  294. stracks = {}
  295. for t in tlista:
  296. stracks[t.track_id] = t
  297. for t in tlistb:
  298. tid = t.track_id
  299. if stracks.get(tid, 0):
  300. del stracks[tid]
  301. return list(stracks.values())
  302. def remove_duplicate_stracks(stracksa, stracksb):
  303. pdist = matching.iou_distance(stracksa, stracksb)
  304. pairs = np.where(pdist<0.15)
  305. dupa, dupb = list(), list()
  306. for p,q in zip(*pairs):
  307. timep = stracksa[p].frame_id - stracksa[p].start_frame
  308. timeq = stracksb[q].frame_id - stracksb[q].start_frame
  309. if timep > timeq:
  310. dupb.append(q)
  311. else:
  312. dupa.append(p)
  313. resa = [t for i,t in enumerate(stracksa) if not i in dupa]
  314. resb = [t for i,t in enumerate(stracksb) if not i in dupb]
  315. return resa, resb