Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

linear_assignment.py 7.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. from __future__ import absolute_import
  2. import numpy as np
  3. # from sklearn.utils.linear_assignment_ import linear_assignment
  4. from scipy.optimize import linear_sum_assignment as linear_assignment
  5. from yolox.deepsort_tracker import kalman_filter
  6. INFTY_COST = 1e+5
  7. def min_cost_matching(
  8. distance_metric, max_distance, tracks, detections, track_indices=None,
  9. detection_indices=None):
  10. """Solve linear assignment problem.
  11. Parameters
  12. ----------
  13. distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
  14. The distance metric is given a list of tracks and detections as well as
  15. a list of N track indices and M detection indices. The metric should
  16. return the NxM dimensional cost matrix, where element (i, j) is the
  17. association cost between the i-th track in the given track indices and
  18. the j-th detection in the given detection_indices.
  19. max_distance : float
  20. Gating threshold. Associations with cost larger than this value are
  21. disregarded.
  22. tracks : List[track.Track]
  23. A list of predicted tracks at the current time step.
  24. detections : List[detection.Detection]
  25. A list of detections at the current time step.
  26. track_indices : List[int]
  27. List of track indices that maps rows in `cost_matrix` to tracks in
  28. `tracks` (see description above).
  29. detection_indices : List[int]
  30. List of detection indices that maps columns in `cost_matrix` to
  31. detections in `detections` (see description above).
  32. Returns
  33. -------
  34. (List[(int, int)], List[int], List[int])
  35. Returns a tuple with the following three entries:
  36. * A list of matched track and detection indices.
  37. * A list of unmatched track indices.
  38. * A list of unmatched detection indices.
  39. """
  40. if track_indices is None:
  41. track_indices = np.arange(len(tracks))
  42. if detection_indices is None:
  43. detection_indices = np.arange(len(detections))
  44. if len(detection_indices) == 0 or len(track_indices) == 0:
  45. return [], track_indices, detection_indices # Nothing to match.
  46. cost_matrix = distance_metric(
  47. tracks, detections, track_indices, detection_indices)
  48. cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5
  49. row_indices, col_indices = linear_assignment(cost_matrix)
  50. matches, unmatched_tracks, unmatched_detections = [], [], []
  51. for col, detection_idx in enumerate(detection_indices):
  52. if col not in col_indices:
  53. unmatched_detections.append(detection_idx)
  54. for row, track_idx in enumerate(track_indices):
  55. if row not in row_indices:
  56. unmatched_tracks.append(track_idx)
  57. for row, col in zip(row_indices, col_indices):
  58. track_idx = track_indices[row]
  59. detection_idx = detection_indices[col]
  60. if cost_matrix[row, col] > max_distance:
  61. unmatched_tracks.append(track_idx)
  62. unmatched_detections.append(detection_idx)
  63. else:
  64. matches.append((track_idx, detection_idx))
  65. return matches, unmatched_tracks, unmatched_detections
  66. def matching_cascade(
  67. distance_metric, max_distance, cascade_depth, tracks, detections,
  68. track_indices=None, detection_indices=None):
  69. """Run matching cascade.
  70. Parameters
  71. ----------
  72. distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
  73. The distance metric is given a list of tracks and detections as well as
  74. a list of N track indices and M detection indices. The metric should
  75. return the NxM dimensional cost matrix, where element (i, j) is the
  76. association cost between the i-th track in the given track indices and
  77. the j-th detection in the given detection indices.
  78. max_distance : float
  79. Gating threshold. Associations with cost larger than this value are
  80. disregarded.
  81. cascade_depth: int
  82. The cascade depth, should be se to the maximum track age.
  83. tracks : List[track.Track]
  84. A list of predicted tracks at the current time step.
  85. detections : List[detection.Detection]
  86. A list of detections at the current time step.
  87. track_indices : Optional[List[int]]
  88. List of track indices that maps rows in `cost_matrix` to tracks in
  89. `tracks` (see description above). Defaults to all tracks.
  90. detection_indices : Optional[List[int]]
  91. List of detection indices that maps columns in `cost_matrix` to
  92. detections in `detections` (see description above). Defaults to all
  93. detections.
  94. Returns
  95. -------
  96. (List[(int, int)], List[int], List[int])
  97. Returns a tuple with the following three entries:
  98. * A list of matched track and detection indices.
  99. * A list of unmatched track indices.
  100. * A list of unmatched detection indices.
  101. """
  102. if track_indices is None:
  103. track_indices = list(range(len(tracks)))
  104. if detection_indices is None:
  105. detection_indices = list(range(len(detections)))
  106. unmatched_detections = detection_indices
  107. matches = []
  108. for level in range(cascade_depth):
  109. if len(unmatched_detections) == 0: # No detections left
  110. break
  111. track_indices_l = [
  112. k for k in track_indices
  113. if tracks[k].time_since_update == 1 + level
  114. ]
  115. if len(track_indices_l) == 0: # Nothing to match at this level
  116. continue
  117. matches_l, _, unmatched_detections = \
  118. min_cost_matching(
  119. distance_metric, max_distance, tracks, detections,
  120. track_indices_l, unmatched_detections)
  121. matches += matches_l
  122. unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches))
  123. return matches, unmatched_tracks, unmatched_detections
  124. def gate_cost_matrix(
  125. kf, cost_matrix, tracks, detections, track_indices, detection_indices,
  126. gated_cost=INFTY_COST, only_position=False):
  127. """Invalidate infeasible entries in cost matrix based on the state
  128. distributions obtained by Kalman filtering.
  129. Parameters
  130. ----------
  131. kf : The Kalman filter.
  132. cost_matrix : ndarray
  133. The NxM dimensional cost matrix, where N is the number of track indices
  134. and M is the number of detection indices, such that entry (i, j) is the
  135. association cost between `tracks[track_indices[i]]` and
  136. `detections[detection_indices[j]]`.
  137. tracks : List[track.Track]
  138. A list of predicted tracks at the current time step.
  139. detections : List[detection.Detection]
  140. A list of detections at the current time step.
  141. track_indices : List[int]
  142. List of track indices that maps rows in `cost_matrix` to tracks in
  143. `tracks` (see description above).
  144. detection_indices : List[int]
  145. List of detection indices that maps columns in `cost_matrix` to
  146. detections in `detections` (see description above).
  147. gated_cost : Optional[float]
  148. Entries in the cost matrix corresponding to infeasible associations are
  149. set this value. Defaults to a very large value.
  150. only_position : Optional[bool]
  151. If True, only the x, y position of the state distribution is considered
  152. during gating. Defaults to False.
  153. Returns
  154. -------
  155. ndarray
  156. Returns the modified cost matrix.
  157. """
  158. gating_dim = 2 if only_position else 4
  159. gating_threshold = kalman_filter.chi2inv95[gating_dim]
  160. measurements = np.asarray(
  161. [detections[i].to_xyah() for i in detection_indices])
  162. for row, track_idx in enumerate(track_indices):
  163. track = tracks[track_idx]
  164. gating_distance = kf.gating_distance(
  165. track.mean, track.covariance, measurements, only_position)
  166. cost_matrix[row, gating_distance > gating_threshold] = gated_cost
  167. return cost_matrix