Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

deepsort.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. import numpy as np
  2. import torch
  3. import cv2
  4. import os
  5. from .reid_model import Extractor
  6. from yolox.deepsort_tracker import kalman_filter, linear_assignment, iou_matching
  7. from yolox.data.dataloading import get_yolox_datadir
  8. from .detection import Detection
  9. from .track import Track
  10. def _cosine_distance(a, b, data_is_normalized=False):
  11. if not data_is_normalized:
  12. a = np.asarray(a) / np.linalg.norm(a, axis=1, keepdims=True)
  13. b = np.asarray(b) / np.linalg.norm(b, axis=1, keepdims=True)
  14. return 1. - np.dot(a, b.T)
  15. def _nn_cosine_distance(x, y):
  16. distances = _cosine_distance(x, y)
  17. return distances.min(axis=0)
  18. class Tracker:
  19. def __init__(self, metric, max_iou_distance=0.7, max_age=70, n_init=3):
  20. self.metric = metric
  21. self.max_iou_distance = max_iou_distance
  22. self.max_age = max_age
  23. self.n_init = n_init
  24. self.kf = kalman_filter.KalmanFilter()
  25. self.tracks = []
  26. self._next_id = 1
  27. def predict(self):
  28. """Propagate track state distributions one time step forward.
  29. This function should be called once every time step, before `update`.
  30. """
  31. for track in self.tracks:
  32. track.predict(self.kf)
  33. def increment_ages(self):
  34. for track in self.tracks:
  35. track.increment_age()
  36. track.mark_missed()
  37. def update(self, detections, classes):
  38. """Perform measurement update and track management.
  39. Parameters
  40. ----------
  41. detections : List[deep_sort.detection.Detection]
  42. A list of detections at the current time step.
  43. """
  44. # Run matching cascade.
  45. matches, unmatched_tracks, unmatched_detections = \
  46. self._match(detections)
  47. # Update track set.
  48. for track_idx, detection_idx in matches:
  49. self.tracks[track_idx].update(
  50. self.kf, detections[detection_idx])
  51. for track_idx in unmatched_tracks:
  52. self.tracks[track_idx].mark_missed()
  53. for detection_idx in unmatched_detections:
  54. self._initiate_track(detections[detection_idx], classes[detection_idx].item())
  55. self.tracks = [t for t in self.tracks if not t.is_deleted()]
  56. # Update distance metric.
  57. active_targets = [t.track_id for t in self.tracks if t.is_confirmed()]
  58. features, targets = [], []
  59. for track in self.tracks:
  60. if not track.is_confirmed():
  61. continue
  62. features += track.features
  63. targets += [track.track_id for _ in track.features]
  64. track.features = []
  65. self.metric.partial_fit(
  66. np.asarray(features), np.asarray(targets), active_targets)
  67. def _match(self, detections):
  68. def gated_metric(tracks, dets, track_indices, detection_indices):
  69. features = np.array([dets[i].feature for i in detection_indices])
  70. targets = np.array([tracks[i].track_id for i in track_indices])
  71. cost_matrix = self.metric.distance(features, targets)
  72. cost_matrix = linear_assignment.gate_cost_matrix(
  73. self.kf, cost_matrix, tracks, dets, track_indices,
  74. detection_indices)
  75. return cost_matrix
  76. # Split track set into confirmed and unconfirmed tracks.
  77. confirmed_tracks = [
  78. i for i, t in enumerate(self.tracks) if t.is_confirmed()]
  79. unconfirmed_tracks = [
  80. i for i, t in enumerate(self.tracks) if not t.is_confirmed()]
  81. # Associate confirmed tracks using appearance features.
  82. matches_a, unmatched_tracks_a, unmatched_detections = \
  83. linear_assignment.matching_cascade(
  84. gated_metric, self.metric.matching_threshold, self.max_age,
  85. self.tracks, detections, confirmed_tracks)
  86. # Associate remaining tracks together with unconfirmed tracks using IOU.
  87. iou_track_candidates = unconfirmed_tracks + [
  88. k for k in unmatched_tracks_a if
  89. self.tracks[k].time_since_update == 1]
  90. unmatched_tracks_a = [
  91. k for k in unmatched_tracks_a if
  92. self.tracks[k].time_since_update != 1]
  93. matches_b, unmatched_tracks_b, unmatched_detections = \
  94. linear_assignment.min_cost_matching(
  95. iou_matching.iou_cost, self.max_iou_distance, self.tracks,
  96. detections, iou_track_candidates, unmatched_detections)
  97. matches = matches_a + matches_b
  98. unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b))
  99. return matches, unmatched_tracks, unmatched_detections
  100. def _initiate_track(self, detection, class_id):
  101. mean, covariance = self.kf.initiate(detection.to_xyah())
  102. self.tracks.append(Track(
  103. mean, covariance, self._next_id, class_id, self.n_init, self.max_age,
  104. detection.feature))
  105. self._next_id += 1
  106. class NearestNeighborDistanceMetric(object):
  107. def __init__(self, metric, matching_threshold, budget=None):
  108. if metric == "cosine":
  109. self._metric = _nn_cosine_distance
  110. else:
  111. raise ValueError(
  112. "Invalid metric; must be either 'euclidean' or 'cosine'")
  113. self.matching_threshold = matching_threshold
  114. self.budget = budget
  115. self.samples = {}
  116. def partial_fit(self, features, targets, active_targets):
  117. for feature, target in zip(features, targets):
  118. self.samples.setdefault(target, []).append(feature)
  119. if self.budget is not None:
  120. self.samples[target] = self.samples[target][-self.budget:]
  121. self.samples = {k: self.samples[k] for k in active_targets}
  122. def distance(self, features, targets):
  123. cost_matrix = np.zeros((len(targets), len(features)))
  124. for i, target in enumerate(targets):
  125. cost_matrix[i, :] = self._metric(self.samples[target], features)
  126. return cost_matrix
  127. class DeepSort(object):
  128. def __init__(self, model_path, max_dist=0.1, min_confidence=0.3, nms_max_overlap=1.0, max_iou_distance=0.7, max_age=30, n_init=3, nn_budget=100, use_cuda=True):
  129. self.min_confidence = min_confidence
  130. self.nms_max_overlap = nms_max_overlap
  131. self.extractor = Extractor(model_path, use_cuda=use_cuda)
  132. max_cosine_distance = max_dist
  133. metric = NearestNeighborDistanceMetric(
  134. "cosine", max_cosine_distance, nn_budget)
  135. self.tracker = Tracker(
  136. metric, max_iou_distance=max_iou_distance, max_age=max_age, n_init=n_init)
  137. def update(self, output_results, img_info, img_size, img_file_name):
  138. img_file_name = os.path.join(get_yolox_datadir(), 'mot', 'train', img_file_name)
  139. ori_img = cv2.imread(img_file_name)
  140. self.height, self.width = ori_img.shape[:2]
  141. # post process detections
  142. output_results = output_results.cpu().numpy()
  143. confidences = output_results[:, 4] * output_results[:, 5]
  144. bboxes = output_results[:, :4] # x1y1x2y2
  145. img_h, img_w = img_info[0], img_info[1]
  146. scale = min(img_size[0] / float(img_h), img_size[1] / float(img_w))
  147. bboxes /= scale
  148. bbox_xyxy = bboxes
  149. bbox_tlwh = self._xyxy_to_tlwh_array(bbox_xyxy)
  150. remain_inds = confidences > self.min_confidence
  151. bbox_tlwh = bbox_tlwh[remain_inds]
  152. confidences = confidences[remain_inds]
  153. # generate detections
  154. features = self._get_features(bbox_tlwh, ori_img)
  155. detections = [Detection(bbox_tlwh[i], conf, features[i]) for i, conf in enumerate(
  156. confidences) if conf > self.min_confidence]
  157. classes = np.zeros((len(detections), ))
  158. # run on non-maximum supression
  159. boxes = np.array([d.tlwh for d in detections])
  160. scores = np.array([d.confidence for d in detections])
  161. # update tracker
  162. self.tracker.predict()
  163. self.tracker.update(detections, classes)
  164. # output bbox identities
  165. outputs = []
  166. for track in self.tracker.tracks:
  167. if not track.is_confirmed() or track.time_since_update > 1:
  168. continue
  169. box = track.to_tlwh()
  170. x1, y1, x2, y2 = self._tlwh_to_xyxy_noclip(box)
  171. track_id = track.track_id
  172. class_id = track.class_id
  173. outputs.append(np.array([x1, y1, x2, y2, track_id, class_id], dtype=np.int))
  174. if len(outputs) > 0:
  175. outputs = np.stack(outputs, axis=0)
  176. return outputs
  177. """
  178. TODO:
  179. Convert bbox from xc_yc_w_h to xtl_ytl_w_h
  180. Thanks [email protected] for reporting this bug!
  181. """
  182. @staticmethod
  183. def _xywh_to_tlwh(bbox_xywh):
  184. if isinstance(bbox_xywh, np.ndarray):
  185. bbox_tlwh = bbox_xywh.copy()
  186. elif isinstance(bbox_xywh, torch.Tensor):
  187. bbox_tlwh = bbox_xywh.clone()
  188. bbox_tlwh[:, 0] = bbox_xywh[:, 0] - bbox_xywh[:, 2] / 2.
  189. bbox_tlwh[:, 1] = bbox_xywh[:, 1] - bbox_xywh[:, 3] / 2.
  190. return bbox_tlwh
  191. @staticmethod
  192. def _xyxy_to_tlwh_array(bbox_xyxy):
  193. if isinstance(bbox_xyxy, np.ndarray):
  194. bbox_tlwh = bbox_xyxy.copy()
  195. elif isinstance(bbox_xyxy, torch.Tensor):
  196. bbox_tlwh = bbox_xyxy.clone()
  197. bbox_tlwh[:, 2] = bbox_xyxy[:, 2] - bbox_xyxy[:, 0]
  198. bbox_tlwh[:, 3] = bbox_xyxy[:, 3] - bbox_xyxy[:, 1]
  199. return bbox_tlwh
  200. def _xywh_to_xyxy(self, bbox_xywh):
  201. x, y, w, h = bbox_xywh
  202. x1 = max(int(x - w / 2), 0)
  203. x2 = min(int(x + w / 2), self.width - 1)
  204. y1 = max(int(y - h / 2), 0)
  205. y2 = min(int(y + h / 2), self.height - 1)
  206. return x1, y1, x2, y2
  207. def _tlwh_to_xyxy(self, bbox_tlwh):
  208. """
  209. TODO:
  210. Convert bbox from xtl_ytl_w_h to xc_yc_w_h
  211. Thanks [email protected] for reporting this bug!
  212. """
  213. x, y, w, h = bbox_tlwh
  214. x1 = max(int(x), 0)
  215. x2 = min(int(x+w), self.width - 1)
  216. y1 = max(int(y), 0)
  217. y2 = min(int(y+h), self.height - 1)
  218. return x1, y1, x2, y2
  219. def _tlwh_to_xyxy_noclip(self, bbox_tlwh):
  220. """
  221. TODO:
  222. Convert bbox from xtl_ytl_w_h to xc_yc_w_h
  223. Thanks [email protected] for reporting this bug!
  224. """
  225. x, y, w, h = bbox_tlwh
  226. x1 = x
  227. x2 = x + w
  228. y1 = y
  229. y2 = y + h
  230. return x1, y1, x2, y2
  231. def increment_ages(self):
  232. self.tracker.increment_ages()
  233. def _xyxy_to_tlwh(self, bbox_xyxy):
  234. x1, y1, x2, y2 = bbox_xyxy
  235. t = x1
  236. l = y1
  237. w = int(x2 - x1)
  238. h = int(y2 - y1)
  239. return t, l, w, h
  240. def _get_features(self, bbox_xywh, ori_img):
  241. im_crops = []
  242. for box in bbox_xywh:
  243. x1, y1, x2, y2 = self._tlwh_to_xyxy(box)
  244. im = ori_img[y1:y2, x1:x2]
  245. im_crops.append(im)
  246. if im_crops:
  247. features = self.extractor(im_crops)
  248. else:
  249. features = np.array([])
  250. return features