Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

byte_tracker.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. import numpy as np
  2. from collections import deque
  3. import itertools
  4. import os
  5. import os.path as osp
  6. import time
  7. import torch
  8. import cv2
  9. import torch.nn.functional as F
  10. from models.model import create_model, load_model
  11. from models.decode import mot_decode
  12. from tracking_utils.utils import *
  13. from tracking_utils.log import logger
  14. from tracking_utils.kalman_filter import KalmanFilter
  15. from models import *
  16. from tracker import matching
  17. from .basetrack import BaseTrack, TrackState
  18. from utils.post_process import ctdet_post_process
  19. from utils.image import get_affine_transform
  20. from models.utils import _tranpose_and_gather_feat
  21. class STrack(BaseTrack):
  22. shared_kalman = KalmanFilter()
  23. def __init__(self, tlwh, score):
  24. # wait activate
  25. self._tlwh = np.asarray(tlwh, dtype=np.float)
  26. self.kalman_filter = None
  27. self.mean, self.covariance = None, None
  28. self.is_activated = False
  29. self.score = score
  30. self.tracklet_len = 0
  31. def predict(self):
  32. mean_state = self.mean.copy()
  33. if self.state != TrackState.Tracked:
  34. mean_state[7] = 0
  35. self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
  36. @staticmethod
  37. def multi_predict(stracks):
  38. if len(stracks) > 0:
  39. multi_mean = np.asarray([st.mean.copy() for st in stracks])
  40. multi_covariance = np.asarray([st.covariance for st in stracks])
  41. for i, st in enumerate(stracks):
  42. if st.state != TrackState.Tracked:
  43. multi_mean[i][7] = 0
  44. multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
  45. for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
  46. stracks[i].mean = mean
  47. stracks[i].covariance = cov
  48. def activate(self, kalman_filter, frame_id):
  49. """Start a new tracklet"""
  50. self.kalman_filter = kalman_filter
  51. self.track_id = self.next_id()
  52. self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh))
  53. self.tracklet_len = 0
  54. self.state = TrackState.Tracked
  55. if frame_id == 1:
  56. self.is_activated = True
  57. #self.is_activated = True
  58. self.frame_id = frame_id
  59. self.start_frame = frame_id
  60. def re_activate(self, new_track, frame_id, new_id=False):
  61. self.mean, self.covariance = self.kalman_filter.update(
  62. self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
  63. )
  64. self.tracklet_len = 0
  65. self.state = TrackState.Tracked
  66. self.is_activated = True
  67. self.frame_id = frame_id
  68. if new_id:
  69. self.track_id = self.next_id()
  70. self.score = new_track.score
  71. def update(self, new_track, frame_id):
  72. """
  73. Update a matched track
  74. :type new_track: STrack
  75. :type frame_id: int
  76. :type update_feature: bool
  77. :return:
  78. """
  79. self.frame_id = frame_id
  80. self.tracklet_len += 1
  81. new_tlwh = new_track.tlwh
  82. self.mean, self.covariance = self.kalman_filter.update(
  83. self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
  84. self.state = TrackState.Tracked
  85. self.is_activated = True
  86. self.score = new_track.score
  87. @property
  88. # @jit(nopython=True)
  89. def tlwh(self):
  90. """Get current position in bounding box format `(top left x, top left y,
  91. width, height)`.
  92. """
  93. if self.mean is None:
  94. return self._tlwh.copy()
  95. ret = self.mean[:4].copy()
  96. ret[2] *= ret[3]
  97. ret[:2] -= ret[2:] / 2
  98. return ret
  99. @property
  100. # @jit(nopython=True)
  101. def tlbr(self):
  102. """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
  103. `(top left, bottom right)`.
  104. """
  105. ret = self.tlwh.copy()
  106. ret[2:] += ret[:2]
  107. return ret
  108. @staticmethod
  109. # @jit(nopython=True)
  110. def tlwh_to_xyah(tlwh):
  111. """Convert bounding box to format `(center x, center y, aspect ratio,
  112. height)`, where the aspect ratio is `width / height`.
  113. """
  114. ret = np.asarray(tlwh).copy()
  115. ret[:2] += ret[2:] / 2
  116. ret[2] /= ret[3]
  117. return ret
  118. def to_xyah(self):
  119. return self.tlwh_to_xyah(self.tlwh)
  120. @staticmethod
  121. # @jit(nopython=True)
  122. def tlbr_to_tlwh(tlbr):
  123. ret = np.asarray(tlbr).copy()
  124. ret[2:] -= ret[:2]
  125. return ret
  126. @staticmethod
  127. # @jit(nopython=True)
  128. def tlwh_to_tlbr(tlwh):
  129. ret = np.asarray(tlwh).copy()
  130. ret[2:] += ret[:2]
  131. return ret
  132. def __repr__(self):
  133. return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
  134. class BYTETracker(object):
  135. def __init__(self, opt, frame_rate=30):
  136. self.opt = opt
  137. if opt.gpus[0] >= 0:
  138. opt.device = torch.device('cuda')
  139. else:
  140. opt.device = torch.device('cpu')
  141. print('Creating model...')
  142. self.model = create_model(opt.arch, opt.heads, opt.head_conv)
  143. self.model = load_model(self.model, opt.load_model)
  144. self.model = self.model.to(opt.device)
  145. self.model.eval()
  146. self.tracked_stracks = [] # type: list[STrack]
  147. self.lost_stracks = [] # type: list[STrack]
  148. self.removed_stracks = [] # type: list[STrack]
  149. self.frame_id = 0
  150. #self.det_thresh = opt.conf_thres
  151. self.det_thresh = opt.conf_thres + 0.1
  152. self.buffer_size = int(frame_rate / 30.0 * opt.track_buffer)
  153. self.max_time_lost = self.buffer_size
  154. self.max_per_image = opt.K
  155. self.mean = np.array(opt.mean, dtype=np.float32).reshape(1, 1, 3)
  156. self.std = np.array(opt.std, dtype=np.float32).reshape(1, 1, 3)
  157. self.kalman_filter = KalmanFilter()
  158. def post_process(self, dets, meta):
  159. dets = dets.detach().cpu().numpy()
  160. dets = dets.reshape(1, -1, dets.shape[2])
  161. dets = ctdet_post_process(
  162. dets.copy(), [meta['c']], [meta['s']],
  163. meta['out_height'], meta['out_width'], self.opt.num_classes)
  164. for j in range(1, self.opt.num_classes + 1):
  165. dets[0][j] = np.array(dets[0][j], dtype=np.float32).reshape(-1, 5)
  166. return dets[0]
  167. def merge_outputs(self, detections):
  168. results = {}
  169. for j in range(1, self.opt.num_classes + 1):
  170. results[j] = np.concatenate(
  171. [detection[j] for detection in detections], axis=0).astype(np.float32)
  172. scores = np.hstack(
  173. [results[j][:, 4] for j in range(1, self.opt.num_classes + 1)])
  174. if len(scores) > self.max_per_image:
  175. kth = len(scores) - self.max_per_image
  176. thresh = np.partition(scores, kth)[kth]
  177. for j in range(1, self.opt.num_classes + 1):
  178. keep_inds = (results[j][:, 4] >= thresh)
  179. results[j] = results[j][keep_inds]
  180. return results
  181. def update(self, im_blob, img0):
  182. self.frame_id += 1
  183. activated_starcks = []
  184. refind_stracks = []
  185. lost_stracks = []
  186. removed_stracks = []
  187. width = img0.shape[1]
  188. height = img0.shape[0]
  189. inp_height = im_blob.shape[2]
  190. inp_width = im_blob.shape[3]
  191. c = np.array([width / 2., height / 2.], dtype=np.float32)
  192. s = max(float(inp_width) / float(inp_height) * height, width) * 1.0
  193. meta = {'c': c, 's': s,
  194. 'out_height': inp_height // self.opt.down_ratio,
  195. 'out_width': inp_width // self.opt.down_ratio}
  196. ''' Step 1: Network forward, get detections & embeddings'''
  197. with torch.no_grad():
  198. output = self.model(im_blob)[-1]
  199. hm = output['hm'].sigmoid_()
  200. wh = output['wh']
  201. reg = output['reg'] if self.opt.reg_offset else None
  202. dets, inds = mot_decode(hm, wh, reg=reg, ltrb=self.opt.ltrb, K=self.opt.K)
  203. dets = self.post_process(dets, meta)
  204. dets = self.merge_outputs([dets])[1]
  205. remain_inds = dets[:, 4] > self.opt.conf_thres
  206. inds_low = dets[:, 4] > 0.2
  207. inds_high = dets[:, 4] < self.opt.conf_thres
  208. inds_second = np.logical_and(inds_low, inds_high)
  209. dets_second = dets[inds_second]
  210. dets = dets[remain_inds]
  211. if len(dets) > 0:
  212. '''Detections'''
  213. detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4]) for
  214. tlbrs in dets[:, :5]]
  215. else:
  216. detections = []
  217. ''' Add newly detected tracklets to tracked_stracks'''
  218. unconfirmed = []
  219. tracked_stracks = [] # type: list[STrack]
  220. for track in self.tracked_stracks:
  221. if not track.is_activated:
  222. unconfirmed.append(track)
  223. else:
  224. tracked_stracks.append(track)
  225. ''' Step 2: First association, with IOU'''
  226. strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
  227. # Predict the current location with KF
  228. STrack.multi_predict(strack_pool)
  229. dists = matching.iou_distance(strack_pool, detections)
  230. matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.opt.match_thres)
  231. for itracked, idet in matches:
  232. track = strack_pool[itracked]
  233. det = detections[idet]
  234. if track.state == TrackState.Tracked:
  235. track.update(detections[idet], self.frame_id)
  236. activated_starcks.append(track)
  237. else:
  238. track.re_activate(det, self.frame_id, new_id=False)
  239. refind_stracks.append(track)
  240. # association the untrack to the low score detections
  241. if len(dets_second) > 0:
  242. '''Detections'''
  243. detections_second = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4]) for
  244. tlbrs in dets_second[:, :5]]
  245. else:
  246. detections_second = []
  247. r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
  248. dists = matching.iou_distance(r_tracked_stracks, detections_second)
  249. matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.4)
  250. for itracked, idet in matches:
  251. track = r_tracked_stracks[itracked]
  252. det = detections_second[idet]
  253. if track.state == TrackState.Tracked:
  254. track.update(det, self.frame_id)
  255. activated_starcks.append(track)
  256. else:
  257. track.re_activate(det, self.frame_id, new_id=False)
  258. refind_stracks.append(track)
  259. for it in u_track:
  260. track = r_tracked_stracks[it]
  261. if not track.state == TrackState.Lost:
  262. track.mark_lost()
  263. lost_stracks.append(track)
  264. '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
  265. detections = [detections[i] for i in u_detection]
  266. dists = matching.iou_distance(unconfirmed, detections)
  267. matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
  268. for itracked, idet in matches:
  269. unconfirmed[itracked].update(detections[idet], self.frame_id)
  270. activated_starcks.append(unconfirmed[itracked])
  271. for it in u_unconfirmed:
  272. track = unconfirmed[it]
  273. track.mark_removed()
  274. removed_stracks.append(track)
  275. """ Step 4: Init new stracks"""
  276. for inew in u_detection:
  277. track = detections[inew]
  278. if track.score < self.det_thresh:
  279. continue
  280. track.activate(self.kalman_filter, self.frame_id)
  281. activated_starcks.append(track)
  282. """ Step 5: Update state"""
  283. for track in self.lost_stracks:
  284. if self.frame_id - track.end_frame > self.max_time_lost:
  285. track.mark_removed()
  286. removed_stracks.append(track)
  287. # print('Ramained match {} s'.format(t4-t3))
  288. self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
  289. self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
  290. self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
  291. self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
  292. self.lost_stracks.extend(lost_stracks)
  293. self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
  294. self.removed_stracks.extend(removed_stracks)
  295. self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
  296. #self.tracked_stracks = remove_fp_stracks(self.tracked_stracks)
  297. # get scores of lost tracks
  298. output_stracks = [track for track in self.tracked_stracks if track.is_activated]
  299. logger.debug('===========Frame {}=========='.format(self.frame_id))
  300. logger.debug('Activated: {}'.format([track.track_id for track in activated_starcks]))
  301. logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
  302. logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
  303. logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
  304. return output_stracks
  305. def joint_stracks(tlista, tlistb):
  306. exists = {}
  307. res = []
  308. for t in tlista:
  309. exists[t.track_id] = 1
  310. res.append(t)
  311. for t in tlistb:
  312. tid = t.track_id
  313. if not exists.get(tid, 0):
  314. exists[tid] = 1
  315. res.append(t)
  316. return res
  317. def sub_stracks(tlista, tlistb):
  318. stracks = {}
  319. for t in tlista:
  320. stracks[t.track_id] = t
  321. for t in tlistb:
  322. tid = t.track_id
  323. if stracks.get(tid, 0):
  324. del stracks[tid]
  325. return list(stracks.values())
  326. def remove_duplicate_stracks(stracksa, stracksb):
  327. pdist = matching.iou_distance(stracksa, stracksb)
  328. pairs = np.where(pdist < 0.15)
  329. dupa, dupb = list(), list()
  330. for p, q in zip(*pairs):
  331. timep = stracksa[p].frame_id - stracksa[p].start_frame
  332. timeq = stracksb[q].frame_id - stracksb[q].start_frame
  333. if timep > timeq:
  334. dupb.append(q)
  335. else:
  336. dupa.append(p)
  337. resa = [t for i, t in enumerate(stracksa) if not i in dupa]
  338. resb = [t for i, t in enumerate(stracksb) if not i in dupb]
  339. return resa, resb
  340. def remove_fp_stracks(stracksa, n_frame=10):
  341. remain = []
  342. for t in stracksa:
  343. score_5 = t.score_list[-n_frame:]
  344. score_5 = np.array(score_5, dtype=np.float32)
  345. index = score_5 < 0.45
  346. num = np.sum(index)
  347. if num < n_frame:
  348. remain.append(t)
  349. return remain