Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracker.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. import numpy as np
  2. from collections import deque
  3. import itertools
  4. import os
  5. import os.path as osp
  6. import time
  7. import torch
  8. import cv2
  9. import torch.nn.functional as F
  10. from models.model import create_model, load_model
  11. from models.decode import mot_decode
  12. from tracking_utils.utils import *
  13. from tracking_utils.log import logger
  14. from tracking_utils.kalman_filter import KalmanFilter
  15. from models import *
  16. from tracker import matching
  17. from .basetrack import BaseTrack, TrackState
  18. from utils.post_process import ctdet_post_process
  19. from utils.image import get_affine_transform
  20. from models.utils import _tranpose_and_gather_feat
  21. class STrack(BaseTrack):
  22. shared_kalman = KalmanFilter()
  23. def __init__(self, tlwh, score, temp_feat, buffer_size=30):
  24. # wait activate
  25. self._tlwh = np.asarray(tlwh, dtype=np.float)
  26. self.kalman_filter = None
  27. self.mean, self.covariance = None, None
  28. self.is_activated = False
  29. self.score = score
  30. self.score_list = []
  31. self.tracklet_len = 0
  32. self.smooth_feat = None
  33. self.update_features(temp_feat)
  34. self.features = deque([], maxlen=buffer_size)
  35. self.alpha = 0.9
  36. def update_features(self, feat):
  37. feat /= np.linalg.norm(feat)
  38. self.curr_feat = feat
  39. if self.smooth_feat is None:
  40. self.smooth_feat = feat
  41. else:
  42. self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat
  43. self.features.append(feat)
  44. self.smooth_feat /= np.linalg.norm(self.smooth_feat)
  45. def predict(self):
  46. mean_state = self.mean.copy()
  47. if self.state != TrackState.Tracked:
  48. mean_state[7] = 0
  49. self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
  50. @staticmethod
  51. def multi_predict(stracks):
  52. if len(stracks) > 0:
  53. multi_mean = np.asarray([st.mean.copy() for st in stracks])
  54. multi_covariance = np.asarray([st.covariance for st in stracks])
  55. for i, st in enumerate(stracks):
  56. if st.state != TrackState.Tracked:
  57. multi_mean[i][7] = 0
  58. multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
  59. for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
  60. stracks[i].mean = mean
  61. stracks[i].covariance = cov
  62. def activate(self, kalman_filter, frame_id):
  63. """Start a new tracklet"""
  64. self.kalman_filter = kalman_filter
  65. self.track_id = self.next_id()
  66. self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh))
  67. self.tracklet_len = 0
  68. self.state = TrackState.Tracked
  69. if frame_id == 1:
  70. self.is_activated = True
  71. #self.is_activated = True
  72. self.frame_id = frame_id
  73. self.start_frame = frame_id
  74. self.score_list.append(self.score)
  75. def re_activate(self, new_track, frame_id, new_id=False):
  76. self.mean, self.covariance = self.kalman_filter.update(
  77. self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
  78. )
  79. self.update_features(new_track.curr_feat)
  80. self.tracklet_len = 0
  81. self.state = TrackState.Tracked
  82. self.is_activated = True
  83. self.frame_id = frame_id
  84. if new_id:
  85. self.track_id = self.next_id()
  86. self.score = new_track.score
  87. self.score_list.append(self.score)
  88. def update(self, new_track, frame_id, update_feature=True):
  89. """
  90. Update a matched track
  91. :type new_track: STrack
  92. :type frame_id: int
  93. :type update_feature: bool
  94. :return:
  95. """
  96. self.frame_id = frame_id
  97. self.tracklet_len += 1
  98. new_tlwh = new_track.tlwh
  99. self.mean, self.covariance = self.kalman_filter.update(
  100. self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
  101. self.state = TrackState.Tracked
  102. self.is_activated = True
  103. self.score = new_track.score
  104. self.score_list.append(self.score)
  105. if update_feature:
  106. self.update_features(new_track.curr_feat)
  107. @property
  108. # @jit(nopython=True)
  109. def tlwh(self):
  110. """Get current position in bounding box format `(top left x, top left y,
  111. width, height)`.
  112. """
  113. if self.mean is None:
  114. return self._tlwh.copy()
  115. ret = self.mean[:4].copy()
  116. ret[2] *= ret[3]
  117. ret[:2] -= ret[2:] / 2
  118. return ret
  119. @property
  120. # @jit(nopython=True)
  121. def tlbr(self):
  122. """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
  123. `(top left, bottom right)`.
  124. """
  125. ret = self.tlwh.copy()
  126. ret[2:] += ret[:2]
  127. return ret
  128. @staticmethod
  129. # @jit(nopython=True)
  130. def tlwh_to_xyah(tlwh):
  131. """Convert bounding box to format `(center x, center y, aspect ratio,
  132. height)`, where the aspect ratio is `width / height`.
  133. """
  134. ret = np.asarray(tlwh).copy()
  135. ret[:2] += ret[2:] / 2
  136. ret[2] /= ret[3]
  137. return ret
  138. def to_xyah(self):
  139. return self.tlwh_to_xyah(self.tlwh)
  140. @staticmethod
  141. # @jit(nopython=True)
  142. def tlbr_to_tlwh(tlbr):
  143. ret = np.asarray(tlbr).copy()
  144. ret[2:] -= ret[:2]
  145. return ret
  146. @staticmethod
  147. # @jit(nopython=True)
  148. def tlwh_to_tlbr(tlwh):
  149. ret = np.asarray(tlwh).copy()
  150. ret[2:] += ret[:2]
  151. return ret
  152. def __repr__(self):
  153. return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
  154. class JDETracker(object):
  155. def __init__(self, opt, frame_rate=30):
  156. self.opt = opt
  157. if opt.gpus[0] >= 0:
  158. opt.device = torch.device('cuda')
  159. else:
  160. opt.device = torch.device('cpu')
  161. print('Creating model...')
  162. self.model = create_model(opt.arch, opt.heads, opt.head_conv)
  163. self.model = load_model(self.model, opt.load_model)
  164. self.model = self.model.to(opt.device)
  165. self.model.eval()
  166. self.tracked_stracks = [] # type: list[STrack]
  167. self.lost_stracks = [] # type: list[STrack]
  168. self.removed_stracks = [] # type: list[STrack]
  169. self.frame_id = 0
  170. #self.det_thresh = opt.conf_thres
  171. self.det_thresh = opt.conf_thres + 0.1
  172. self.buffer_size = int(frame_rate / 30.0 * opt.track_buffer)
  173. self.max_time_lost = self.buffer_size
  174. self.max_per_image = opt.K
  175. self.mean = np.array(opt.mean, dtype=np.float32).reshape(1, 1, 3)
  176. self.std = np.array(opt.std, dtype=np.float32).reshape(1, 1, 3)
  177. self.kalman_filter = KalmanFilter()
  178. def post_process(self, dets, meta):
  179. dets = dets.detach().cpu().numpy()
  180. dets = dets.reshape(1, -1, dets.shape[2])
  181. dets = ctdet_post_process(
  182. dets.copy(), [meta['c']], [meta['s']],
  183. meta['out_height'], meta['out_width'], self.opt.num_classes)
  184. for j in range(1, self.opt.num_classes + 1):
  185. dets[0][j] = np.array(dets[0][j], dtype=np.float32).reshape(-1, 5)
  186. return dets[0]
  187. def merge_outputs(self, detections):
  188. results = {}
  189. for j in range(1, self.opt.num_classes + 1):
  190. results[j] = np.concatenate(
  191. [detection[j] for detection in detections], axis=0).astype(np.float32)
  192. scores = np.hstack(
  193. [results[j][:, 4] for j in range(1, self.opt.num_classes + 1)])
  194. if len(scores) > self.max_per_image:
  195. kth = len(scores) - self.max_per_image
  196. thresh = np.partition(scores, kth)[kth]
  197. for j in range(1, self.opt.num_classes + 1):
  198. keep_inds = (results[j][:, 4] >= thresh)
  199. results[j] = results[j][keep_inds]
  200. return results
  201. def update(self, im_blob, img0):
  202. self.frame_id += 1
  203. activated_starcks = []
  204. refind_stracks = []
  205. lost_stracks = []
  206. removed_stracks = []
  207. width = img0.shape[1]
  208. height = img0.shape[0]
  209. inp_height = im_blob.shape[2]
  210. inp_width = im_blob.shape[3]
  211. c = np.array([width / 2., height / 2.], dtype=np.float32)
  212. s = max(float(inp_width) / float(inp_height) * height, width) * 1.0
  213. meta = {'c': c, 's': s,
  214. 'out_height': inp_height // self.opt.down_ratio,
  215. 'out_width': inp_width // self.opt.down_ratio}
  216. ''' Step 1: Network forward, get detections & embeddings'''
  217. with torch.no_grad():
  218. output = self.model(im_blob)[-1]
  219. hm = output['hm'].sigmoid_()
  220. wh = output['wh']
  221. id_feature = output['id']
  222. id_feature = F.normalize(id_feature, dim=1)
  223. reg = output['reg'] if self.opt.reg_offset else None
  224. dets, inds = mot_decode(hm, wh, reg=reg, ltrb=self.opt.ltrb, K=self.opt.K)
  225. id_feature = _tranpose_and_gather_feat(id_feature, inds)
  226. id_feature = id_feature.squeeze(0)
  227. id_feature = id_feature.cpu().numpy()
  228. dets = self.post_process(dets, meta)
  229. dets = self.merge_outputs([dets])[1]
  230. remain_inds = dets[:, 4] > self.opt.conf_thres
  231. inds_low = dets[:, 4] > 0.2
  232. #inds_low = dets[:, 4] > self.opt.conf_thres
  233. inds_high = dets[:, 4] < self.opt.conf_thres
  234. inds_second = np.logical_and(inds_low, inds_high)
  235. dets_second = dets[inds_second]
  236. id_feature_second = id_feature[inds_second]
  237. dets = dets[remain_inds]
  238. id_feature = id_feature[remain_inds]
  239. # vis
  240. '''
  241. for i in range(0, dets.shape[0]):
  242. bbox = dets[i][0:4]
  243. cv2.rectangle(img0, (bbox[0], bbox[1]),
  244. (bbox[2], bbox[3]),
  245. (0, 255, 0), 2)
  246. cv2.imshow('dets', img0)
  247. cv2.waitKey(0)
  248. id0 = id0-1
  249. '''
  250. if len(dets) > 0:
  251. '''Detections'''
  252. detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for
  253. (tlbrs, f) in zip(dets[:, :5], id_feature)]
  254. else:
  255. detections = []
  256. ''' Add newly detected tracklets to tracked_stracks'''
  257. unconfirmed = []
  258. tracked_stracks = [] # type: list[STrack]
  259. for track in self.tracked_stracks:
  260. if not track.is_activated:
  261. unconfirmed.append(track)
  262. else:
  263. tracked_stracks.append(track)
  264. ''' Step 2: First association, with embedding'''
  265. strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
  266. # Predict the current location with KF
  267. STrack.multi_predict(strack_pool)
  268. dists = matching.embedding_distance(strack_pool, detections)
  269. #dists = matching.fuse_iou(dists, strack_pool, detections)
  270. #dists = matching.iou_distance(strack_pool, detections)
  271. dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections)
  272. matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.opt.match_thres)
  273. for itracked, idet in matches:
  274. track = strack_pool[itracked]
  275. det = detections[idet]
  276. if track.state == TrackState.Tracked:
  277. track.update(detections[idet], self.frame_id)
  278. activated_starcks.append(track)
  279. else:
  280. track.re_activate(det, self.frame_id, new_id=False)
  281. refind_stracks.append(track)
  282. ''' Step 3: Second association, with IOU'''
  283. detections = [detections[i] for i in u_detection]
  284. r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
  285. dists = matching.iou_distance(r_tracked_stracks, detections)
  286. matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.5)
  287. for itracked, idet in matches:
  288. track = r_tracked_stracks[itracked]
  289. det = detections[idet]
  290. if track.state == TrackState.Tracked:
  291. track.update(det, self.frame_id)
  292. activated_starcks.append(track)
  293. else:
  294. track.re_activate(det, self.frame_id, new_id=False)
  295. refind_stracks.append(track)
  296. # association the untrack to the low score detections
  297. if len(dets_second) > 0:
  298. '''Detections'''
  299. detections_second = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for
  300. (tlbrs, f) in zip(dets_second[:, :5], id_feature_second)]
  301. else:
  302. detections_second = []
  303. second_tracked_stracks = [r_tracked_stracks[i] for i in u_track if r_tracked_stracks[i].state == TrackState.Tracked]
  304. dists = matching.iou_distance(second_tracked_stracks, detections_second)
  305. matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.4)
  306. for itracked, idet in matches:
  307. track = second_tracked_stracks[itracked]
  308. det = detections_second[idet]
  309. if track.state == TrackState.Tracked:
  310. track.update(det, self.frame_id)
  311. activated_starcks.append(track)
  312. else:
  313. track.re_activate(det, self.frame_id, new_id=False)
  314. refind_stracks.append(track)
  315. for it in u_track:
  316. #track = r_tracked_stracks[it]
  317. track = second_tracked_stracks[it]
  318. if not track.state == TrackState.Lost:
  319. track.mark_lost()
  320. lost_stracks.append(track)
  321. '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
  322. detections = [detections[i] for i in u_detection]
  323. dists = matching.iou_distance(unconfirmed, detections)
  324. matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
  325. for itracked, idet in matches:
  326. unconfirmed[itracked].update(detections[idet], self.frame_id)
  327. activated_starcks.append(unconfirmed[itracked])
  328. for it in u_unconfirmed:
  329. track = unconfirmed[it]
  330. track.mark_removed()
  331. removed_stracks.append(track)
  332. """ Step 4: Init new stracks"""
  333. for inew in u_detection:
  334. track = detections[inew]
  335. if track.score < self.det_thresh:
  336. continue
  337. track.activate(self.kalman_filter, self.frame_id)
  338. activated_starcks.append(track)
  339. """ Step 5: Update state"""
  340. for track in self.lost_stracks:
  341. if self.frame_id - track.end_frame > self.max_time_lost:
  342. track.mark_removed()
  343. removed_stracks.append(track)
  344. # print('Ramained match {} s'.format(t4-t3))
  345. self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
  346. self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
  347. self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
  348. self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
  349. self.lost_stracks.extend(lost_stracks)
  350. self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
  351. self.removed_stracks.extend(removed_stracks)
  352. self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
  353. #self.tracked_stracks = remove_fp_stracks(self.tracked_stracks)
  354. # get scores of lost tracks
  355. output_stracks = [track for track in self.tracked_stracks if track.is_activated]
  356. logger.debug('===========Frame {}=========='.format(self.frame_id))
  357. logger.debug('Activated: {}'.format([track.track_id for track in activated_starcks]))
  358. logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
  359. logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
  360. logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
  361. return output_stracks
  362. def joint_stracks(tlista, tlistb):
  363. exists = {}
  364. res = []
  365. for t in tlista:
  366. exists[t.track_id] = 1
  367. res.append(t)
  368. for t in tlistb:
  369. tid = t.track_id
  370. if not exists.get(tid, 0):
  371. exists[tid] = 1
  372. res.append(t)
  373. return res
  374. def sub_stracks(tlista, tlistb):
  375. stracks = {}
  376. for t in tlista:
  377. stracks[t.track_id] = t
  378. for t in tlistb:
  379. tid = t.track_id
  380. if stracks.get(tid, 0):
  381. del stracks[tid]
  382. return list(stracks.values())
  383. def remove_duplicate_stracks(stracksa, stracksb):
  384. pdist = matching.iou_distance(stracksa, stracksb)
  385. pairs = np.where(pdist < 0.15)
  386. dupa, dupb = list(), list()
  387. for p, q in zip(*pairs):
  388. timep = stracksa[p].frame_id - stracksa[p].start_frame
  389. timeq = stracksb[q].frame_id - stracksb[q].start_frame
  390. if timep > timeq:
  391. dupb.append(q)
  392. else:
  393. dupa.append(p)
  394. resa = [t for i, t in enumerate(stracksa) if not i in dupa]
  395. resb = [t for i, t in enumerate(stracksb) if not i in dupb]
  396. return resa, resb
  397. def remove_fp_stracks(stracksa, n_frame=10):
  398. remain = []
  399. for t in stracksa:
  400. score_5 = t.score_list[-n_frame:]
  401. score_5 = np.array(score_5, dtype=np.float32)
  402. index = score_5 < 0.45
  403. num = np.sum(index)
  404. if num < n_frame:
  405. remain.append(t)
  406. return remain