Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

matching.py 3.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import cv2
  2. import numpy as np
  3. import lap
  4. from scipy.spatial.distance import cdist
  5. from cython_bbox import bbox_overlaps as bbox_ious
  6. from yolox.motdt_tracker import kalman_filter
  7. def _indices_to_matches(cost_matrix, indices, thresh):
  8. matched_cost = cost_matrix[tuple(zip(*indices))]
  9. matched_mask = (matched_cost <= thresh)
  10. matches = indices[matched_mask]
  11. unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0]))
  12. unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1]))
  13. return matches, unmatched_a, unmatched_b
  14. def linear_assignment(cost_matrix, thresh):
  15. if cost_matrix.size == 0:
  16. return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
  17. matches, unmatched_a, unmatched_b = [], [], []
  18. cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
  19. for ix, mx in enumerate(x):
  20. if mx >= 0:
  21. matches.append([ix, mx])
  22. unmatched_a = np.where(x < 0)[0]
  23. unmatched_b = np.where(y < 0)[0]
  24. matches = np.asarray(matches)
  25. return matches, unmatched_a, unmatched_b
  26. def ious(atlbrs, btlbrs):
  27. """
  28. Compute cost based on IoU
  29. :type atlbrs: list[tlbr] | np.ndarray
  30. :type atlbrs: list[tlbr] | np.ndarray
  31. :rtype ious np.ndarray
  32. """
  33. ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float)
  34. if ious.size == 0:
  35. return ious
  36. ious = bbox_ious(
  37. np.ascontiguousarray(atlbrs, dtype=np.float),
  38. np.ascontiguousarray(btlbrs, dtype=np.float)
  39. )
  40. return ious
  41. def iou_distance(atracks, btracks):
  42. """
  43. Compute cost based on IoU
  44. :type atracks: list[STrack]
  45. :type btracks: list[STrack]
  46. :rtype cost_matrix np.ndarray
  47. """
  48. atlbrs = [track.tlbr for track in atracks]
  49. btlbrs = [track.tlbr for track in btracks]
  50. _ious = ious(atlbrs, btlbrs)
  51. cost_matrix = 1 - _ious
  52. return cost_matrix
  53. def nearest_reid_distance(tracks, detections, metric='cosine'):
  54. """
  55. Compute cost based on ReID features
  56. :type tracks: list[STrack]
  57. :type detections: list[BaseTrack]
  58. :rtype cost_matrix np.ndarray
  59. """
  60. cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float)
  61. if cost_matrix.size == 0:
  62. return cost_matrix
  63. det_features = np.asarray([track.curr_feature for track in detections], dtype=np.float32)
  64. for i, track in enumerate(tracks):
  65. cost_matrix[i, :] = np.maximum(0.0, cdist(track.features, det_features, metric).min(axis=0))
  66. return cost_matrix
  67. def mean_reid_distance(tracks, detections, metric='cosine'):
  68. """
  69. Compute cost based on ReID features
  70. :type tracks: list[STrack]
  71. :type detections: list[BaseTrack]
  72. :type metric: str
  73. :rtype cost_matrix np.ndarray
  74. """
  75. cost_matrix = np.empty((len(tracks), len(detections)), dtype=np.float)
  76. if cost_matrix.size == 0:
  77. return cost_matrix
  78. track_features = np.asarray([track.curr_feature for track in tracks], dtype=np.float32)
  79. det_features = np.asarray([track.curr_feature for track in detections], dtype=np.float32)
  80. cost_matrix = cdist(track_features, det_features, metric)
  81. return cost_matrix
  82. def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False):
  83. if cost_matrix.size == 0:
  84. return cost_matrix
  85. gating_dim = 2 if only_position else 4
  86. gating_threshold = kalman_filter.chi2inv95[gating_dim]
  87. measurements = np.asarray([det.to_xyah() for det in detections])
  88. for row, track in enumerate(tracks):
  89. gating_distance = kf.gating_distance(
  90. track.mean, track.covariance, measurements, only_position)
  91. cost_matrix[row, gating_distance > gating_threshold] = np.inf
  92. return cost_matrix