Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

motdt_tracker.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. import numpy as np
  2. #from numba import jit
  3. from collections import OrderedDict, deque
  4. import itertools
  5. import os
  6. import cv2
  7. import torch
  8. from torch._C import dtype
  9. import torchvision
  10. from yolox.motdt_tracker import matching
  11. from .kalman_filter import KalmanFilter
  12. from .reid_model import load_reid_model, extract_reid_features
  13. from yolox.data.dataloading import get_yolox_datadir
  14. from .basetrack import BaseTrack, TrackState
  15. class STrack(BaseTrack):
  16. def __init__(self, tlwh, score, max_n_features=100, from_det=True):
  17. # wait activate
  18. self._tlwh = np.asarray(tlwh, dtype=np.float)
  19. self.kalman_filter = None
  20. self.mean, self.covariance = None, None
  21. self.is_activated = False
  22. self.score = score
  23. self.max_n_features = max_n_features
  24. self.curr_feature = None
  25. self.last_feature = None
  26. self.features = deque([], maxlen=self.max_n_features)
  27. # classification
  28. self.from_det = from_det
  29. self.tracklet_len = 0
  30. self.time_by_tracking = 0
  31. # self-tracking
  32. self.tracker = None
  33. def set_feature(self, feature):
  34. if feature is None:
  35. return False
  36. self.features.append(feature)
  37. self.curr_feature = feature
  38. self.last_feature = feature
  39. # self._p_feature = 0
  40. return True
  41. def predict(self):
  42. if self.time_since_update > 0:
  43. self.tracklet_len = 0
  44. self.time_since_update += 1
  45. mean_state = self.mean.copy()
  46. if self.state != TrackState.Tracked:
  47. mean_state[7] = 0
  48. self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
  49. if self.tracker:
  50. self.tracker.update_roi(self.tlwh)
  51. def self_tracking(self, image):
  52. tlwh = self.tracker.predict(image) if self.tracker else self.tlwh
  53. return tlwh
  54. def activate(self, kalman_filter, frame_id, image):
  55. """Start a new tracklet"""
  56. self.kalman_filter = kalman_filter # type: KalmanFilter
  57. self.track_id = self.next_id()
  58. # cx, cy, aspect_ratio, height, dx, dy, da, dh
  59. self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh))
  60. # self.tracker = sot.SingleObjectTracker()
  61. # self.tracker.init(image, self.tlwh)
  62. del self._tlwh
  63. self.time_since_update = 0
  64. self.time_by_tracking = 0
  65. self.tracklet_len = 0
  66. self.state = TrackState.Tracked
  67. # self.is_activated = True
  68. self.frame_id = frame_id
  69. self.start_frame = frame_id
  70. def re_activate(self, new_track, frame_id, image, new_id=False):
  71. # self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(new_track.tlwh))
  72. self.mean, self.covariance = self.kalman_filter.update(
  73. self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
  74. )
  75. self.time_since_update = 0
  76. self.time_by_tracking = 0
  77. self.tracklet_len = 0
  78. self.state = TrackState.Tracked
  79. self.is_activated = True
  80. self.frame_id = frame_id
  81. if new_id:
  82. self.track_id = self.next_id()
  83. self.set_feature(new_track.curr_feature)
  84. def update(self, new_track, frame_id, image, update_feature=True):
  85. """
  86. Update a matched track
  87. :type new_track: STrack
  88. :type frame_id: int
  89. :type update_feature: bool
  90. :return:
  91. """
  92. self.frame_id = frame_id
  93. self.time_since_update = 0
  94. if new_track.from_det:
  95. self.time_by_tracking = 0
  96. else:
  97. self.time_by_tracking += 1
  98. self.tracklet_len += 1
  99. new_tlwh = new_track.tlwh
  100. self.mean, self.covariance = self.kalman_filter.update(
  101. self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
  102. self.state = TrackState.Tracked
  103. self.is_activated = True
  104. self.score = new_track.score
  105. if update_feature:
  106. self.set_feature(new_track.curr_feature)
  107. if self.tracker:
  108. self.tracker.update(image, self.tlwh)
  109. @property
  110. #@jit
  111. def tlwh(self):
  112. """Get current position in bounding box format `(top left x, top left y,
  113. width, height)`.
  114. """
  115. if self.mean is None:
  116. return self._tlwh.copy()
  117. ret = self.mean[:4].copy()
  118. ret[2] *= ret[3]
  119. ret[:2] -= ret[2:] / 2
  120. return ret
  121. @property
  122. #@jit
  123. def tlbr(self):
  124. """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
  125. `(top left, bottom right)`.
  126. """
  127. ret = self.tlwh.copy()
  128. ret[2:] += ret[:2]
  129. return ret
  130. @staticmethod
  131. #@jit
  132. def tlwh_to_xyah(tlwh):
  133. """Convert bounding box to format `(center x, center y, aspect ratio,
  134. height)`, where the aspect ratio is `width / height`.
  135. """
  136. ret = np.asarray(tlwh).copy()
  137. ret[:2] += ret[2:] / 2
  138. ret[2] /= ret[3]
  139. return ret
  140. def to_xyah(self):
  141. return self.tlwh_to_xyah(self.tlwh)
  142. def tracklet_score(self):
  143. # score = (1 - np.exp(-0.6 * self.hit_streak)) * np.exp(-0.03 * self.time_by_tracking)
  144. score = max(0, 1 - np.log(1 + 0.05 * self.time_by_tracking)) * (self.tracklet_len - self.time_by_tracking > 2)
  145. # score = max(0, 1 - np.log(1 + 0.05 * self.n_tracking)) * (1 - np.exp(-0.6 * self.hit_streak))
  146. return score
  147. def __repr__(self):
  148. return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
  149. class OnlineTracker(object):
  150. def __init__(self, model_folder, min_cls_score=0.4, min_ap_dist=0.8, max_time_lost=30, use_tracking=True, use_refind=True):
  151. self.min_cls_score = min_cls_score
  152. self.min_ap_dist = min_ap_dist
  153. self.max_time_lost = max_time_lost
  154. self.kalman_filter = KalmanFilter()
  155. self.tracked_stracks = [] # type: list[STrack]
  156. self.lost_stracks = [] # type: list[STrack]
  157. self.removed_stracks = [] # type: list[STrack]
  158. self.use_refind = use_refind
  159. self.use_tracking = use_tracking
  160. self.classifier = None
  161. self.reid_model = load_reid_model(model_folder)
  162. self.frame_id = 0
  163. def update(self, output_results, img_info, img_size, img_file_name):
  164. img_file_name = os.path.join(get_yolox_datadir(), 'mot', 'train', img_file_name)
  165. image = cv2.imread(img_file_name)
  166. # post process detections
  167. output_results = output_results.cpu().numpy()
  168. confidences = output_results[:, 4] * output_results[:, 5]
  169. bboxes = output_results[:, :4] # x1y1x2y2
  170. img_h, img_w = img_info[0], img_info[1]
  171. scale = min(img_size[0] / float(img_h), img_size[1] / float(img_w))
  172. bboxes /= scale
  173. bbox_xyxy = bboxes
  174. tlwhs = self._xyxy_to_tlwh_array(bbox_xyxy)
  175. remain_inds = confidences > self.min_cls_score
  176. tlwhs = tlwhs[remain_inds]
  177. det_scores = confidences[remain_inds]
  178. self.frame_id += 1
  179. activated_starcks = []
  180. refind_stracks = []
  181. lost_stracks = []
  182. removed_stracks = []
  183. """step 1: prediction"""
  184. for strack in itertools.chain(self.tracked_stracks, self.lost_stracks):
  185. strack.predict()
  186. """step 2: scoring and selection"""
  187. if det_scores is None:
  188. det_scores = np.ones(len(tlwhs), dtype=float)
  189. detections = [STrack(tlwh, score, from_det=True) for tlwh, score in zip(tlwhs, det_scores)]
  190. if self.use_tracking:
  191. tracks = [STrack(t.self_tracking(image), 0.6 * t.tracklet_score(), from_det=False)
  192. for t in itertools.chain(self.tracked_stracks, self.lost_stracks) if t.is_activated]
  193. detections.extend(tracks)
  194. rois = np.asarray([d.tlbr for d in detections], dtype=np.float32)
  195. scores = np.asarray([d.score for d in detections], dtype=np.float32)
  196. # nms
  197. if len(detections) > 0:
  198. nms_out_index = torchvision.ops.batched_nms(
  199. torch.from_numpy(rois),
  200. torch.from_numpy(scores.reshape(-1)).to(torch.from_numpy(rois).dtype),
  201. torch.zeros_like(torch.from_numpy(scores.reshape(-1))),
  202. 0.7,
  203. )
  204. keep = nms_out_index.numpy()
  205. mask = np.zeros(len(rois), dtype=np.bool)
  206. mask[keep] = True
  207. keep = np.where(mask & (scores >= self.min_cls_score))[0]
  208. detections = [detections[i] for i in keep]
  209. scores = scores[keep]
  210. for d, score in zip(detections, scores):
  211. d.score = score
  212. pred_dets = [d for d in detections if not d.from_det]
  213. detections = [d for d in detections if d.from_det]
  214. # set features
  215. tlbrs = [det.tlbr for det in detections]
  216. features = extract_reid_features(self.reid_model, image, tlbrs)
  217. features = features.cpu().numpy()
  218. for i, det in enumerate(detections):
  219. det.set_feature(features[i])
  220. """step 3: association for tracked"""
  221. # matching for tracked targets
  222. unconfirmed = []
  223. tracked_stracks = [] # type: list[STrack]
  224. for track in self.tracked_stracks:
  225. if not track.is_activated:
  226. unconfirmed.append(track)
  227. else:
  228. tracked_stracks.append(track)
  229. dists = matching.nearest_reid_distance(tracked_stracks, detections, metric='euclidean')
  230. dists = matching.gate_cost_matrix(self.kalman_filter, dists, tracked_stracks, detections)
  231. matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.min_ap_dist)
  232. for itracked, idet in matches:
  233. tracked_stracks[itracked].update(detections[idet], self.frame_id, image)
  234. # matching for missing targets
  235. detections = [detections[i] for i in u_detection]
  236. dists = matching.nearest_reid_distance(self.lost_stracks, detections, metric='euclidean')
  237. dists = matching.gate_cost_matrix(self.kalman_filter, dists, self.lost_stracks, detections)
  238. matches, u_lost, u_detection = matching.linear_assignment(dists, thresh=self.min_ap_dist)
  239. for ilost, idet in matches:
  240. track = self.lost_stracks[ilost] # type: STrack
  241. det = detections[idet]
  242. track.re_activate(det, self.frame_id, image, new_id=not self.use_refind)
  243. refind_stracks.append(track)
  244. # remaining tracked
  245. # tracked
  246. len_det = len(u_detection)
  247. detections = [detections[i] for i in u_detection] + pred_dets
  248. r_tracked_stracks = [tracked_stracks[i] for i in u_track]
  249. dists = matching.iou_distance(r_tracked_stracks, detections)
  250. matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.5)
  251. for itracked, idet in matches:
  252. r_tracked_stracks[itracked].update(detections[idet], self.frame_id, image, update_feature=True)
  253. for it in u_track:
  254. track = r_tracked_stracks[it]
  255. track.mark_lost()
  256. lost_stracks.append(track)
  257. # unconfirmed
  258. detections = [detections[i] for i in u_detection if i < len_det]
  259. dists = matching.iou_distance(unconfirmed, detections)
  260. matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
  261. for itracked, idet in matches:
  262. unconfirmed[itracked].update(detections[idet], self.frame_id, image, update_feature=True)
  263. for it in u_unconfirmed:
  264. track = unconfirmed[it]
  265. track.mark_removed()
  266. removed_stracks.append(track)
  267. """step 4: init new stracks"""
  268. for inew in u_detection:
  269. track = detections[inew]
  270. if not track.from_det or track.score < 0.6:
  271. continue
  272. track.activate(self.kalman_filter, self.frame_id, image)
  273. activated_starcks.append(track)
  274. """step 6: update state"""
  275. for track in self.lost_stracks:
  276. if self.frame_id - track.end_frame > self.max_time_lost:
  277. track.mark_removed()
  278. removed_stracks.append(track)
  279. self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
  280. self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost] # type: list[STrack]
  281. self.tracked_stracks.extend(activated_starcks)
  282. self.tracked_stracks.extend(refind_stracks)
  283. self.lost_stracks.extend(lost_stracks)
  284. self.removed_stracks.extend(removed_stracks)
  285. # output_stracks = self.tracked_stracks + self.lost_stracks
  286. # get scores of lost tracks
  287. output_tracked_stracks = [track for track in self.tracked_stracks if track.is_activated]
  288. output_stracks = output_tracked_stracks
  289. return output_stracks
  290. @staticmethod
  291. def _xyxy_to_tlwh_array(bbox_xyxy):
  292. if isinstance(bbox_xyxy, np.ndarray):
  293. bbox_tlwh = bbox_xyxy.copy()
  294. elif isinstance(bbox_xyxy, torch.Tensor):
  295. bbox_tlwh = bbox_xyxy.clone()
  296. bbox_tlwh[:, 2] = bbox_xyxy[:, 2] - bbox_xyxy[:, 0]
  297. bbox_tlwh[:, 3] = bbox_xyxy[:, 3] - bbox_xyxy[:, 1]
  298. return bbox_tlwh