Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracker.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. from collections import deque
  2. import torch
  3. import numpy as np
  4. from utils.kalman_filter import KalmanFilter
  5. from utils.log import logger
  6. from models import *
  7. from tracker import matching
  8. from .basetrack import BaseTrack, TrackState
  9. class STrack(BaseTrack):
  10. def __init__(self, tlwh, score, temp_feat, buffer_size=30):
  11. # wait activate
  12. self._tlwh = np.asarray(tlwh, dtype=np.float)
  13. self.kalman_filter = None
  14. self.mean, self.covariance = None, None
  15. self.is_activated = False
  16. self.score = score
  17. self.tracklet_len = 0
  18. self.smooth_feat = None
  19. self.update_features(temp_feat)
  20. self.features = deque([], maxlen=buffer_size)
  21. self.alpha = 0.9
  22. def update_features(self, feat):
  23. feat /= np.linalg.norm(feat)
  24. self.curr_feat = feat
  25. if self.smooth_feat is None:
  26. self.smooth_feat = feat
  27. else:
  28. self.smooth_feat = self.alpha *self.smooth_feat + (1-self.alpha) * feat
  29. self.features.append(feat)
  30. self.smooth_feat /= np.linalg.norm(self.smooth_feat)
  31. def predict(self):
  32. mean_state = self.mean.copy()
  33. if self.state != TrackState.Tracked:
  34. mean_state[7] = 0
  35. self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
  36. @staticmethod
  37. def multi_predict(stracks, kalman_filter):
  38. if len(stracks) > 0:
  39. multi_mean = np.asarray([st.mean.copy() for st in stracks])
  40. multi_covariance = np.asarray([st.covariance for st in stracks])
  41. for i, st in enumerate(stracks):
  42. if st.state != TrackState.Tracked:
  43. multi_mean[i][7] = 0
  44. # multi_mean, multi_covariance = STrack.kalman_filter.multi_predict(multi_mean, multi_covariance)
  45. multi_mean, multi_covariance = kalman_filter.multi_predict(multi_mean, multi_covariance)
  46. for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
  47. stracks[i].mean = mean
  48. stracks[i].covariance = cov
  49. def activate(self, kalman_filter, frame_id):
  50. """Start a new tracklet"""
  51. self.kalman_filter = kalman_filter
  52. self.track_id = self.next_id()
  53. self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh))
  54. self.tracklet_len = 0
  55. self.state = TrackState.Tracked
  56. #self.is_activated = True
  57. self.frame_id = frame_id
  58. self.start_frame = frame_id
  59. def re_activate(self, new_track, frame_id, new_id=False):
  60. self.mean, self.covariance = self.kalman_filter.update(
  61. self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
  62. )
  63. self.update_features(new_track.curr_feat)
  64. self.tracklet_len = 0
  65. self.state = TrackState.Tracked
  66. self.is_activated = True
  67. self.frame_id = frame_id
  68. if new_id:
  69. self.track_id = self.next_id()
  70. def update(self, new_track, frame_id, update_feature=True):
  71. """
  72. Update a matched track
  73. :type new_track: STrack
  74. :type frame_id: int
  75. :type update_feature: bool
  76. :return:
  77. """
  78. self.frame_id = frame_id
  79. self.tracklet_len += 1
  80. new_tlwh = new_track.tlwh
  81. self.mean, self.covariance = self.kalman_filter.update(
  82. self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
  83. self.state = TrackState.Tracked
  84. self.is_activated = True
  85. self.score = new_track.score
  86. if update_feature:
  87. self.update_features(new_track.curr_feat)
  88. @property
  89. def tlwh(self):
  90. """Get current position in bounding box format `(top left x, top left y,
  91. width, height)`.
  92. """
  93. if self.mean is None:
  94. return self._tlwh.copy()
  95. ret = self.mean[:4].copy()
  96. ret[2] *= ret[3]
  97. ret[:2] -= ret[2:] / 2
  98. return ret
  99. @property
  100. def tlbr(self):
  101. """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
  102. `(top left, bottom right)`.
  103. """
  104. ret = self.tlwh.copy()
  105. ret[2:] += ret[:2]
  106. return ret
  107. @staticmethod
  108. def tlwh_to_xyah(tlwh):
  109. """Convert bounding box to format `(center x, center y, aspect ratio,
  110. height)`, where the aspect ratio is `width / height`.
  111. """
  112. ret = np.asarray(tlwh).copy()
  113. ret[:2] += ret[2:] / 2
  114. ret[2] /= ret[3]
  115. return ret
  116. def to_xyah(self):
  117. return self.tlwh_to_xyah(self.tlwh)
  118. @staticmethod
  119. def tlbr_to_tlwh(tlbr):
  120. ret = np.asarray(tlbr).copy()
  121. ret[2:] -= ret[:2]
  122. return ret
  123. @staticmethod
  124. def tlwh_to_tlbr(tlwh):
  125. ret = np.asarray(tlwh).copy()
  126. ret[2:] += ret[:2]
  127. return ret
  128. def __repr__(self):
  129. return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
  130. class JDETracker(object):
  131. def __init__(self, opt, frame_rate=30):
  132. self.opt = opt
  133. self.model = Darknet(opt.cfg, nID=14455)
  134. # load_darknet_weights(self.model, opt.weights)
  135. self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False)
  136. self.model.cuda().eval()
  137. self.tracked_stracks = [] # type: list[STrack]
  138. self.lost_stracks = [] # type: list[STrack]
  139. self.removed_stracks = [] # type: list[STrack]
  140. self.frame_id = 0
  141. self.det_thresh = opt.conf_thres
  142. self.init_thresh = self.det_thresh + 0.2
  143. self.low_thresh = 0.4
  144. self.buffer_size = int(frame_rate / 30.0 * opt.track_buffer)
  145. self.max_time_lost = self.buffer_size
  146. self.kalman_filter = KalmanFilter()
  147. def update(self, im_blob, img0):
  148. """
  149. Processes the image frame and finds bounding box(detections).
  150. Associates the detection with corresponding tracklets and also handles lost, removed, refound and active tracklets
  151. Parameters
  152. ----------
  153. im_blob : torch.float32
  154. Tensor of shape depending upon the size of image. By default, shape of this tensor is [1, 3, 608, 1088]
  155. img0 : ndarray
  156. ndarray of shape depending on the input image sequence. By default, shape is [608, 1080, 3]
  157. Returns
  158. -------
  159. output_stracks : list of Strack(instances)
  160. The list contains information regarding the online_tracklets for the recieved image tensor.
  161. """
  162. self.frame_id += 1
  163. activated_starcks = [] # for storing active tracks, for the current frame
  164. refind_stracks = [] # Lost Tracks whose detections are obtained in the current frame
  165. lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
  166. removed_stracks = []
  167. t1 = time.time()
  168. ''' Step 1: Network forward, get detections & embeddings'''
  169. with torch.no_grad():
  170. pred = self.model(im_blob)
  171. # pred is tensor of all the proposals (default number of proposals: 54264). Proposals have information associated with the bounding box and embeddings
  172. pred = pred[pred[:, :, 4] > self.low_thresh]
  173. # pred now has lesser number of proposals. Proposals rejected on basis of object confidence score
  174. if len(pred) > 0:
  175. dets = non_max_suppression(pred.unsqueeze(0), self.low_thresh, self.opt.nms_thres)[0].cpu()
  176. # Final proposals are obtained in dets. Information of bounding box and embeddings also included
  177. # Next step changes the detection scales
  178. scale_coords(self.opt.img_size, dets[:, :4], img0.shape).round()
  179. '''Detections is list of (x1, y1, x2, y2, object_conf, class_score, class_pred)'''
  180. # class_pred is the embeddings.
  181. dets = dets.numpy()
  182. remain_inds = dets[:, 4] > self.det_thresh
  183. inds_low = dets[:, 4] > self.low_thresh
  184. inds_high = dets[:, 4] < self.det_thresh
  185. inds_second = np.logical_and(inds_low, inds_high)
  186. dets_second = dets[inds_second]
  187. dets = dets[remain_inds]
  188. detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for
  189. (tlbrs, f) in zip(dets[:, :5], dets[:, 6:])]
  190. else:
  191. detections = []
  192. dets_second = []
  193. t2 = time.time()
  194. # print('Forward: {} s'.format(t2-t1))
  195. ''' Add newly detected tracklets to tracked_stracks'''
  196. unconfirmed = []
  197. tracked_stracks = [] # type: list[STrack]
  198. for track in self.tracked_stracks:
  199. if not track.is_activated:
  200. # previous tracks which are not active in the current frame are added in unconfirmed list
  201. unconfirmed.append(track)
  202. # print("Should not be here, in unconfirmed")
  203. else:
  204. # Active tracks are added to the local list 'tracked_stracks'
  205. tracked_stracks.append(track)
  206. ''' Step 2: First association, with embedding'''
  207. # Combining currently tracked_stracks and lost_stracks
  208. strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
  209. # Predict the current location with KF
  210. STrack.multi_predict(strack_pool, self.kalman_filter)
  211. dists = matching.embedding_distance(strack_pool, detections)
  212. dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections)
  213. #dists = matching.iou_distance(strack_pool, detections)
  214. # The dists is the list of distances of the detection with the tracks in strack_pool
  215. matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)
  216. # The matches is the array for corresponding matches of the detection with the corresponding strack_pool
  217. for itracked, idet in matches:
  218. # itracked is the id of the track and idet is the detection
  219. track = strack_pool[itracked]
  220. det = detections[idet]
  221. if track.state == TrackState.Tracked:
  222. # If the track is active, add the detection to the track
  223. track.update(detections[idet], self.frame_id)
  224. activated_starcks.append(track)
  225. else:
  226. # We have obtained a detection from a track which is not active, hence put the track in refind_stracks list
  227. track.re_activate(det, self.frame_id, new_id=False)
  228. refind_stracks.append(track)
  229. # None of the steps below happen if there are no undetected tracks.
  230. ''' Step 3: Second association, with IOU'''
  231. detections = [detections[i] for i in u_detection]
  232. # detections is now a list of the unmatched detections
  233. r_tracked_stracks = [] # This is container for stracks which were tracked till the
  234. # previous frame but no detection was found for it in the current frame
  235. for i in u_track:
  236. if strack_pool[i].state == TrackState.Tracked:
  237. r_tracked_stracks.append(strack_pool[i])
  238. dists = matching.iou_distance(r_tracked_stracks, detections)
  239. matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.5)
  240. # matches is the list of detections which matched with corresponding tracks by IOU distance method
  241. for itracked, idet in matches:
  242. track = r_tracked_stracks[itracked]
  243. det = detections[idet]
  244. if track.state == TrackState.Tracked:
  245. track.update(det, self.frame_id)
  246. activated_starcks.append(track)
  247. else:
  248. track.re_activate(det, self.frame_id, new_id=False)
  249. refind_stracks.append(track)
  250. # Same process done for some unmatched detections, but now considering IOU_distance as measure
  251. # association the untrack to the low score detections
  252. if len(dets_second) > 0:
  253. detections_second = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for
  254. (tlbrs, f) in zip(dets_second[:, :5], dets_second[:, 6:])]
  255. else:
  256. detections_second = []
  257. second_tracked_stracks = [r_tracked_stracks[i] for i in u_track if r_tracked_stracks[i].state == TrackState.Tracked]
  258. dists = matching.iou_distance(second_tracked_stracks, detections_second)
  259. matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.4)
  260. for itracked, idet in matches:
  261. track = second_tracked_stracks[itracked]
  262. det = detections_second[idet]
  263. if track.state == TrackState.Tracked:
  264. track.update(det, self.frame_id)
  265. activated_starcks.append(track)
  266. else:
  267. track.re_activate(det, self.frame_id, new_id=False)
  268. refind_stracks.append(track)
  269. for it in u_track:
  270. track = second_tracked_stracks[it]
  271. if not track.state == TrackState.Lost:
  272. track.mark_lost()
  273. lost_stracks.append(track)
  274. # If no detections are obtained for tracks (u_track), the tracks are added to lost_tracks list and are marked lost
  275. '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
  276. detections = [detections[i] for i in u_detection]
  277. dists = matching.iou_distance(unconfirmed, detections)
  278. matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
  279. for itracked, idet in matches:
  280. unconfirmed[itracked].update(detections[idet], self.frame_id)
  281. activated_starcks.append(unconfirmed[itracked])
  282. # The tracks which are yet not matched
  283. for it in u_unconfirmed:
  284. track = unconfirmed[it]
  285. track.mark_removed()
  286. removed_stracks.append(track)
  287. # after all these confirmation steps, if a new detection is found, it is initialized for a new track
  288. """ Step 4: Init new stracks"""
  289. for inew in u_detection:
  290. track = detections[inew]
  291. if track.score < self.init_thresh:
  292. continue
  293. track.activate(self.kalman_filter, self.frame_id)
  294. activated_starcks.append(track)
  295. """ Step 5: Update state"""
  296. # If the tracks are lost for more frames than the threshold number, the tracks are removed.
  297. for track in self.lost_stracks:
  298. if self.frame_id - track.end_frame > self.max_time_lost:
  299. track.mark_removed()
  300. removed_stracks.append(track)
  301. # print('Remained match {} s'.format(t4-t3))
  302. # Update the self.tracked_stracks and self.lost_stracks using the updates in this step.
  303. self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
  304. self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
  305. self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
  306. # self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost] # type: list[STrack]
  307. self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
  308. self.lost_stracks.extend(lost_stracks)
  309. self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
  310. self.removed_stracks.extend(removed_stracks)
  311. self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
  312. # get scores of lost tracks
  313. output_stracks = [track for track in self.tracked_stracks if track.is_activated]
  314. logger.debug('===========Frame {}=========='.format(self.frame_id))
  315. logger.debug('Activated: {}'.format([track.track_id for track in activated_starcks]))
  316. logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
  317. logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
  318. logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
  319. # print('Final {} s'.format(t5-t4))
  320. return output_stracks
  321. def joint_stracks(tlista, tlistb):
  322. exists = {}
  323. res = []
  324. for t in tlista:
  325. exists[t.track_id] = 1
  326. res.append(t)
  327. for t in tlistb:
  328. tid = t.track_id
  329. if not exists.get(tid, 0):
  330. exists[tid] = 1
  331. res.append(t)
  332. return res
  333. def sub_stracks(tlista, tlistb):
  334. stracks = {}
  335. for t in tlista:
  336. stracks[t.track_id] = t
  337. for t in tlistb:
  338. tid = t.track_id
  339. if stracks.get(tid, 0):
  340. del stracks[tid]
  341. return list(stracks.values())
  342. def remove_duplicate_stracks(stracksa, stracksb):
  343. pdist = matching.iou_distance(stracksa, stracksb)
  344. pairs = np.where(pdist<0.15)
  345. dupa, dupb = list(), list()
  346. for p,q in zip(*pairs):
  347. timep = stracksa[p].frame_id - stracksa[p].start_frame
  348. timeq = stracksb[q].frame_id - stracksb[q].start_frame
  349. if timep > timeq:
  350. dupb.append(q)
  351. else:
  352. dupa.append(p)
  353. resa = [t for i,t in enumerate(stracksa) if not i in dupa]
  354. resb = [t for i,t in enumerate(stracksb) if not i in dupb]
  355. return resa, resb