Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

byte_tracker.py 20KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. from collections import deque
  2. import os
  3. import cv2
  4. import numpy as np
  5. import torch
  6. import torch.nn.functional as F
  7. from torchsummary import summary
  8. from core.mot.general import non_max_suppression_and_inds, non_max_suppression_jde, non_max_suppression, scale_coords
  9. from core.mot.torch_utils import intersect_dicts
  10. from models.mot.cstrack import Model
  11. from mot_online import matching
  12. from mot_online.kalman_filter import KalmanFilter
  13. from mot_online.log import logger
  14. from mot_online.utils import *
  15. from mot_online.basetrack import BaseTrack, TrackState
  16. class STrack(BaseTrack):
  17. shared_kalman = KalmanFilter()
  18. def __init__(self, tlwh, score):
  19. # wait activate
  20. self._tlwh = np.asarray(tlwh, dtype=np.float)
  21. self.kalman_filter = None
  22. self.mean, self.covariance = None, None
  23. self.is_activated = False
  24. self.score = score
  25. self.tracklet_len = 0
  26. def predict(self):
  27. mean_state = self.mean.copy()
  28. if self.state != TrackState.Tracked:
  29. mean_state[7] = 0
  30. self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
  31. @staticmethod
  32. def multi_predict(stracks):
  33. if len(stracks) > 0:
  34. multi_mean = np.asarray([st.mean.copy() for st in stracks])
  35. multi_covariance = np.asarray([st.covariance for st in stracks])
  36. for i, st in enumerate(stracks):
  37. if st.state != TrackState.Tracked:
  38. multi_mean[i][7] = 0
  39. multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
  40. for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
  41. stracks[i].mean = mean
  42. stracks[i].covariance = cov
  43. def activate(self, kalman_filter, frame_id):
  44. """Start a new tracklet"""
  45. self.kalman_filter = kalman_filter
  46. self.track_id = self.next_id()
  47. self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh))
  48. self.tracklet_len = 0
  49. self.state = TrackState.Tracked
  50. #self.is_activated = True
  51. self.frame_id = frame_id
  52. self.start_frame = frame_id
  53. def re_activate(self, new_track, frame_id, new_id=False):
  54. self.mean, self.covariance = self.kalman_filter.update(
  55. self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
  56. )
  57. self.tracklet_len = 0
  58. self.state = TrackState.Tracked
  59. self.is_activated = True
  60. self.frame_id = frame_id
  61. if new_id:
  62. self.track_id = self.next_id()
  63. def update(self, new_track, frame_id):
  64. """
  65. Update a matched track
  66. :type new_track: STrack
  67. :type frame_id: int
  68. :type update_feature: bool
  69. :return:
  70. """
  71. self.frame_id = frame_id
  72. self.tracklet_len += 1
  73. new_tlwh = new_track.tlwh
  74. self.mean, self.covariance = self.kalman_filter.update(
  75. self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
  76. self.state = TrackState.Tracked
  77. self.is_activated = True
  78. self.score = new_track.score
  79. @property
  80. # @jit(nopython=True)
  81. def tlwh(self):
  82. """Get current position in bounding box format `(top left x, top left y,
  83. width, height)`.
  84. """
  85. if self.mean is None:
  86. return self._tlwh.copy()
  87. ret = self.mean[:4].copy()
  88. ret[2] *= ret[3]
  89. ret[:2] -= ret[2:] / 2
  90. return ret
  91. @property
  92. # @jit(nopython=True)
  93. def tlbr(self):
  94. """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
  95. `(top left, bottom right)`.
  96. """
  97. ret = self.tlwh.copy()
  98. ret[2:] += ret[:2]
  99. return ret
  100. @staticmethod
  101. # @jit(nopython=True)
  102. def tlwh_to_xyah(tlwh):
  103. """Convert bounding box to format `(center x, center y, aspect ratio,
  104. height)`, where the aspect ratio is `width / height`.
  105. """
  106. ret = np.asarray(tlwh).copy()
  107. ret[:2] += ret[2:] / 2
  108. ret[2] /= ret[3]
  109. return ret
  110. def to_xyah(self):
  111. return self.tlwh_to_xyah(self.tlwh)
  112. @staticmethod
  113. # @jit(nopython=True)
  114. def tlbr_to_tlwh(tlbr):
  115. ret = np.asarray(tlbr).copy()
  116. ret[2:] -= ret[:2]
  117. return ret
  118. @staticmethod
  119. # @jit(nopython=True)
  120. def tlwh_to_tlbr(tlwh):
  121. ret = np.asarray(tlwh).copy()
  122. ret[2:] += ret[:2]
  123. return ret
  124. def __repr__(self):
  125. return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
  126. class BYTETracker(object):
  127. def __init__(self, opt, frame_rate=30):
  128. self.opt = opt
  129. if int(opt.gpus[0]) >= 0:
  130. opt.device = torch.device('cuda')
  131. else:
  132. opt.device = torch.device('cpu')
  133. print('Creating model...')
  134. ckpt = torch.load(opt.weights, map_location=opt.device) # load checkpoint
  135. self.model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=1).to(opt.device) # create
  136. exclude = ['anchor'] if opt.cfg else [] # exclude keys
  137. if type(ckpt['model']).__name__ == "OrderedDict":
  138. state_dict = ckpt['model']
  139. else:
  140. state_dict = ckpt['model'].float().state_dict() # to FP32
  141. state_dict = intersect_dicts(state_dict, self.model.state_dict(), exclude=exclude) # intersect
  142. self.model.load_state_dict(state_dict, strict=False) # load
  143. self.model.cuda().eval()
  144. total_params = sum(p.numel() for p in self.model.parameters())
  145. print(f'{total_params:,} total parameters.')
  146. self.tracked_stracks = [] # type: list[STrack]
  147. self.lost_stracks = [] # type: list[STrack]
  148. self.removed_stracks = [] # type: list[STrack]
  149. self.frame_id = 0
  150. self.det_thresh = opt.conf_thres
  151. self.buffer_size = int(frame_rate / 30.0 * opt.track_buffer)
  152. self.max_time_lost = self.buffer_size
  153. self.mean = np.array(opt.mean, dtype=np.float32).reshape(1, 1, 3)
  154. self.std = np.array(opt.std, dtype=np.float32).reshape(1, 1, 3)
  155. self.kalman_filter = KalmanFilter()
  156. self.low_thres = 0.2
  157. self.high_thres = self.opt.conf_thres + 0.1
  158. def update(self, im_blob, img0,seq_num, save_dir):
  159. self.frame_id += 1
  160. activated_starcks = []
  161. refind_stracks = []
  162. lost_stracks = []
  163. removed_stracks = []
  164. dets = []
  165. ''' Step 1: Network forward, get detections & embeddings'''
  166. with torch.no_grad():
  167. output = self.model(im_blob, augment=False)
  168. pred, train_out = output[1]
  169. pred = pred[pred[:, :, 4] > self.low_thres]
  170. detections = []
  171. if len(pred) > 0:
  172. dets,x_inds,y_inds = non_max_suppression_and_inds(pred[:,:6].unsqueeze(0), 0.1, self.opt.nms_thres,method='cluster_diou')
  173. if len(dets) != 0:
  174. scale_coords(self.opt.img_size, dets[:, :4], img0.shape).round()
  175. remain_inds = dets[:, 4] > self.opt.conf_thres
  176. inds_low = dets[:, 4] > self.low_thres
  177. inds_high = dets[:, 4] < self.opt.conf_thres
  178. inds_second = np.logical_and(inds_low, inds_high)
  179. dets_second = dets[inds_second]
  180. dets = dets[remain_inds]
  181. detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4]) for
  182. tlbrs in dets[:, :5]]
  183. else:
  184. detections = []
  185. dets_second = []
  186. id_feature_second = []
  187. ''' Add newly detected tracklets to tracked_stracks'''
  188. unconfirmed = []
  189. tracked_stracks = [] # type: list[STrack]
  190. for track in self.tracked_stracks:
  191. if not track.is_activated:
  192. unconfirmed.append(track)
  193. else:
  194. tracked_stracks.append(track)
  195. ''' Step 2: First association, with embedding'''
  196. strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
  197. # Predict the current location with KF
  198. STrack.multi_predict(strack_pool)
  199. dists = matching.iou_distance(strack_pool, detections)
  200. matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.8)
  201. for itracked, idet in matches:
  202. track = strack_pool[itracked]
  203. det = detections[idet]
  204. if track.state == TrackState.Tracked:
  205. track.update(detections[idet], self.frame_id)
  206. activated_starcks.append(track)
  207. else:
  208. track.re_activate(det, self.frame_id, new_id=False)
  209. refind_stracks.append(track)
  210. # vis
  211. track_features, det_features, cost_matrix, cost_matrix_det, cost_matrix_track = [],[],[],[],[]
  212. if self.opt.vis_state == 1 and self.frame_id % 20 == 0:
  213. if len(dets) != 0:
  214. for i in range(0, dets.shape[0]):
  215. bbox = dets[i][0:4]
  216. cv2.rectangle(img0, (int(bbox[0]), int(bbox[1])),(int(bbox[2]), int(bbox[3])),(0, 255, 0), 2)
  217. track_features, det_features, cost_matrix, cost_matrix_det, cost_matrix_track = matching.vis_id_feature_A_distance(strack_pool, detections)
  218. vis_feature(self.frame_id,seq_num,img0,track_features,
  219. det_features, cost_matrix, cost_matrix_det, cost_matrix_track, max_num=5, out_path=save_dir)
  220. ''' Step 3: Second association, with IOU'''
  221. # association the untrack to the low score detections
  222. if len(dets_second) > 0:
  223. detections_second = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4]) for
  224. tlbrs in dets_second[:, :5]]
  225. else:
  226. detections_second = []
  227. r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
  228. dists = matching.iou_distance(r_tracked_stracks, detections_second)
  229. matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.4)
  230. for itracked, idet in matches:
  231. track = r_tracked_stracks[itracked]
  232. det = detections_second[idet]
  233. if track.state == TrackState.Tracked:
  234. track.update(det, self.frame_id)
  235. activated_starcks.append(track)
  236. else:
  237. track.re_activate(det, self.frame_id, new_id=False)
  238. refind_stracks.append(track)
  239. for it in u_track:
  240. track = r_tracked_stracks[it]
  241. if not track.state == TrackState.Lost:
  242. track.mark_lost()
  243. lost_stracks.append(track)
  244. '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
  245. detections = [detections[i] for i in u_detection]
  246. dists = matching.iou_distance(unconfirmed, detections)
  247. matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
  248. for itracked, idet in matches:
  249. unconfirmed[itracked].update(detections[idet], self.frame_id)
  250. activated_starcks.append(unconfirmed[itracked])
  251. for it in u_unconfirmed:
  252. track = unconfirmed[it]
  253. track.mark_removed()
  254. removed_stracks.append(track)
  255. """ Step 4: Init new stracks"""
  256. for inew in u_detection:
  257. track = detections[inew]
  258. if track.score < self.high_thres:
  259. continue
  260. track.activate(self.kalman_filter, self.frame_id)
  261. activated_starcks.append(track)
  262. """ Step 5: Update state"""
  263. for track in self.lost_stracks:
  264. if self.frame_id - track.end_frame > self.max_time_lost:
  265. track.mark_removed()
  266. removed_stracks.append(track)
  267. # print('Ramained match {} s'.format(t4-t3))
  268. self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
  269. self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
  270. self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
  271. self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
  272. self.lost_stracks.extend(lost_stracks)
  273. self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
  274. self.removed_stracks.extend(removed_stracks)
  275. self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
  276. # get scores of lost tracks
  277. output_stracks = [track for track in self.tracked_stracks if track.is_activated]
  278. logger.debug('===========Frame {}=========='.format(self.frame_id))
  279. logger.debug('Activated: {}'.format([track.track_id for track in activated_starcks]))
  280. logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
  281. logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
  282. logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
  283. return output_stracks
  284. def joint_stracks(tlista, tlistb):
  285. exists = {}
  286. res = []
  287. for t in tlista:
  288. exists[t.track_id] = 1
  289. res.append(t)
  290. for t in tlistb:
  291. tid = t.track_id
  292. if not exists.get(tid, 0):
  293. exists[tid] = 1
  294. res.append(t)
  295. return res
  296. def sub_stracks(tlista, tlistb):
  297. stracks = {}
  298. for t in tlista:
  299. stracks[t.track_id] = t
  300. for t in tlistb:
  301. tid = t.track_id
  302. if stracks.get(tid, 0):
  303. del stracks[tid]
  304. return list(stracks.values())
  305. def remove_duplicate_stracks(stracksa, stracksb):
  306. pdist = matching.iou_distance(stracksa, stracksb)
  307. pairs = np.where(pdist < 0.15)
  308. dupa, dupb = list(), list()
  309. for p, q in zip(*pairs):
  310. timep = stracksa[p].frame_id - stracksa[p].start_frame
  311. timeq = stracksb[q].frame_id - stracksb[q].start_frame
  312. if timep > timeq:
  313. dupb.append(q)
  314. else:
  315. dupa.append(p)
  316. resa = [t for i, t in enumerate(stracksa) if not i in dupa]
  317. resb = [t for i, t in enumerate(stracksb) if not i in dupb]
  318. return resa, resb
  319. def vis_feature(frame_id,seq_num,img,track_features, det_features, cost_matrix, cost_matrix_det, cost_matrix_track,max_num=5, out_path='/home/XX/'):
  320. num_zero = ["0000","000","00","0"]
  321. img = cv2.resize(img, (778, 435))
  322. if len(det_features) != 0:
  323. max_f = det_features.max()
  324. min_f = det_features.min()
  325. det_features = np.round((det_features - min_f) / (max_f - min_f) * 255)
  326. det_features = det_features.astype(np.uint8)
  327. d_F_M = []
  328. cutpff_line = [40]*512
  329. for d_f in det_features:
  330. for row in range(45):
  331. d_F_M += [[40]*3+d_f.tolist()+[40]*3]
  332. for row in range(3):
  333. d_F_M += [[40]*3+cutpff_line+[40]*3]
  334. d_F_M = np.array(d_F_M)
  335. d_F_M = d_F_M.astype(np.uint8)
  336. det_features_img = cv2.applyColorMap(d_F_M, cv2.COLORMAP_JET)
  337. feature_img2 = cv2.resize(det_features_img, (435, 435))
  338. #cv2.putText(feature_img2, "det_features", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  339. else:
  340. feature_img2 = np.zeros((435, 435))
  341. feature_img2 = feature_img2.astype(np.uint8)
  342. feature_img2 = cv2.applyColorMap(feature_img2, cv2.COLORMAP_JET)
  343. #cv2.putText(feature_img2, "det_features", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  344. feature_img = np.concatenate((img, feature_img2), axis=1)
  345. if len(cost_matrix_det) != 0 and len(cost_matrix_det[0]) != 0:
  346. max_f = cost_matrix_det.max()
  347. min_f = cost_matrix_det.min()
  348. cost_matrix_det = np.round((cost_matrix_det - min_f) / (max_f - min_f) * 255)
  349. d_F_M = []
  350. cutpff_line = [40]*len(cost_matrix_det)*10
  351. for c_m in cost_matrix_det:
  352. add = []
  353. for row in range(len(c_m)):
  354. add += [255-c_m[row]]*10
  355. for row in range(10):
  356. d_F_M += [[40]+add+[40]]
  357. d_F_M = np.array(d_F_M)
  358. d_F_M = d_F_M.astype(np.uint8)
  359. cost_matrix_det_img = cv2.applyColorMap(d_F_M, cv2.COLORMAP_JET)
  360. feature_img2 = cv2.resize(cost_matrix_det_img, (435, 435))
  361. #cv2.putText(feature_img2, "cost_matrix_det", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  362. else:
  363. feature_img2 = np.zeros((435, 435))
  364. feature_img2 = feature_img2.astype(np.uint8)
  365. feature_img2 = cv2.applyColorMap(feature_img2, cv2.COLORMAP_JET)
  366. #cv2.putText(feature_img2, "cost_matrix_det", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  367. feature_img = np.concatenate((feature_img, feature_img2), axis=1)
  368. if len(track_features) != 0:
  369. max_f = track_features.max()
  370. min_f = track_features.min()
  371. track_features = np.round((track_features - min_f) / (max_f - min_f) * 255)
  372. track_features = track_features.astype(np.uint8)
  373. d_F_M = []
  374. cutpff_line = [40]*512
  375. for d_f in track_features:
  376. for row in range(45):
  377. d_F_M += [[40]*3+d_f.tolist()+[40]*3]
  378. for row in range(3):
  379. d_F_M += [[40]*3+cutpff_line+[40]*3]
  380. d_F_M = np.array(d_F_M)
  381. d_F_M = d_F_M.astype(np.uint8)
  382. track_features_img = cv2.applyColorMap(d_F_M, cv2.COLORMAP_JET)
  383. feature_img2 = cv2.resize(track_features_img, (435, 435))
  384. #cv2.putText(feature_img2, "track_features", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  385. else:
  386. feature_img2 = np.zeros((435, 435))
  387. feature_img2 = feature_img2.astype(np.uint8)
  388. feature_img2 = cv2.applyColorMap(feature_img2, cv2.COLORMAP_JET)
  389. #cv2.putText(feature_img2, "track_features", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  390. feature_img = np.concatenate((feature_img, feature_img2), axis=1)
  391. if len(cost_matrix_track) != 0 and len(cost_matrix_track[0]) != 0:
  392. max_f = cost_matrix_track.max()
  393. min_f = cost_matrix_track.min()
  394. cost_matrix_track = np.round((cost_matrix_track - min_f) / (max_f - min_f) * 255)
  395. d_F_M = []
  396. cutpff_line = [40]*len(cost_matrix_track)*10
  397. for c_m in cost_matrix_track:
  398. add = []
  399. for row in range(len(c_m)):
  400. add += [255-c_m[row]]*10
  401. for row in range(10):
  402. d_F_M += [[40]+add+[40]]
  403. d_F_M = np.array(d_F_M)
  404. d_F_M = d_F_M.astype(np.uint8)
  405. cost_matrix_track_img = cv2.applyColorMap(d_F_M, cv2.COLORMAP_JET)
  406. feature_img2 = cv2.resize(cost_matrix_track_img, (435, 435))
  407. #cv2.putText(feature_img2, "cost_matrix_track", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  408. else:
  409. feature_img2 = np.zeros((435, 435))
  410. feature_img2 = feature_img2.astype(np.uint8)
  411. feature_img2 = cv2.applyColorMap(feature_img2, cv2.COLORMAP_JET)
  412. #cv2.putText(feature_img2, "cost_matrix_track", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  413. feature_img = np.concatenate((feature_img, feature_img2), axis=1)
  414. if len(cost_matrix) != 0 and len(cost_matrix[0]) != 0:
  415. max_f = cost_matrix.max()
  416. min_f = cost_matrix.min()
  417. cost_matrix = np.round((cost_matrix - min_f) / (max_f - min_f) * 255)
  418. d_F_M = []
  419. cutpff_line = [40]*len(cost_matrix[0])*10
  420. for c_m in cost_matrix:
  421. add = []
  422. for row in range(len(c_m)):
  423. add += [255-c_m[row]]*10
  424. for row in range(10):
  425. d_F_M += [[40]+add+[40]]
  426. d_F_M = np.array(d_F_M)
  427. d_F_M = d_F_M.astype(np.uint8)
  428. cost_matrix_img = cv2.applyColorMap(d_F_M, cv2.COLORMAP_JET)
  429. feature_img2 = cv2.resize(cost_matrix_img, (435, 435))
  430. #cv2.putText(feature_img2, "cost_matrix", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  431. else:
  432. feature_img2 = np.zeros((435, 435))
  433. feature_img2 = feature_img2.astype(np.uint8)
  434. feature_img2 = cv2.applyColorMap(feature_img2, cv2.COLORMAP_JET)
  435. #cv2.putText(feature_img2, "cost_matrix", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  436. feature_img = np.concatenate((feature_img, feature_img2), axis=1)
  437. dst_path = out_path + "/" + seq_num + "_" + num_zero[len(str(frame_id))-1] + str(frame_id) + '.png'
  438. cv2.imwrite(dst_path, feature_img)