Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

kalman_filter.py 7.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # vim: expandtab:ts=4:sw=4
  2. import numpy as np
  3. import scipy.linalg
  4. """
  5. Table for the 0.95 quantile of the chi-square distribution with N degrees of
  6. freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
  7. function and used as Mahalanobis gating threshold.
  8. """
  9. chi2inv95 = {
  10. 1: 3.8415,
  11. 2: 5.9915,
  12. 3: 7.8147,
  13. 4: 9.4877,
  14. 5: 11.070,
  15. 6: 12.592,
  16. 7: 14.067,
  17. 8: 15.507,
  18. 9: 16.919}
  19. class KalmanFilter(object):
  20. """
  21. A simple Kalman filter for tracking bounding boxes in image space.
  22. The 8-dimensional state space
  23. x, y, a, h, vx, vy, va, vh
  24. contains the bounding box center position (x, y), aspect ratio a, height h,
  25. and their respective velocities.
  26. Object motion follows a constant velocity model. The bounding box location
  27. (x, y, a, h) is taken as direct observation of the state space (linear
  28. observation model).
  29. """
  30. def __init__(self):
  31. ndim, dt = 4, 1.
  32. # Create Kalman filter model matrices.
  33. self._motion_mat = np.eye(2 * ndim, 2 * ndim)
  34. for i in range(ndim):
  35. self._motion_mat[i, ndim + i] = dt
  36. self._update_mat = np.eye(ndim, 2 * ndim)
  37. # Motion and observation uncertainty are chosen relative to the current
  38. # state estimate. These weights control the amount of uncertainty in
  39. # the model. This is a bit hacky.
  40. self._std_weight_position = 1. / 20
  41. self._std_weight_velocity = 1. / 160
  42. def initiate(self, measurement):
  43. """Create track from unassociated measurement.
  44. Parameters
  45. ----------
  46. measurement : ndarray
  47. Bounding box coordinates (x, y, a, h) with center position (x, y),
  48. aspect ratio a, and height h.
  49. Returns
  50. -------
  51. (ndarray, ndarray)
  52. Returns the mean vector (8 dimensional) and covariance matrix (8x8
  53. dimensional) of the new track. Unobserved velocities are initialized
  54. to 0 mean.
  55. """
  56. mean_pos = measurement
  57. mean_vel = np.zeros_like(mean_pos)
  58. mean = np.r_[mean_pos, mean_vel]
  59. std = [
  60. 2 * self._std_weight_position * measurement[3],
  61. 2 * self._std_weight_position * measurement[3],
  62. 1e-2,
  63. 2 * self._std_weight_position * measurement[3],
  64. 10 * self._std_weight_velocity * measurement[3],
  65. 10 * self._std_weight_velocity * measurement[3],
  66. 1e-5,
  67. 10 * self._std_weight_velocity * measurement[3]]
  68. covariance = np.diag(np.square(std))
  69. return mean, covariance
  70. def predict(self, mean, covariance):
  71. """Run Kalman filter prediction step.
  72. Parameters
  73. ----------
  74. mean : ndarray
  75. The 8 dimensional mean vector of the object state at the previous
  76. time step.
  77. covariance : ndarray
  78. The 8x8 dimensional covariance matrix of the object state at the
  79. previous time step.
  80. Returns
  81. -------
  82. (ndarray, ndarray)
  83. Returns the mean vector and covariance matrix of the predicted
  84. state. Unobserved velocities are initialized to 0 mean.
  85. """
  86. std_pos = [
  87. self._std_weight_position * mean[3],
  88. self._std_weight_position * mean[3],
  89. 1e-2,
  90. self._std_weight_position * mean[3]]
  91. std_vel = [
  92. self._std_weight_velocity * mean[3],
  93. self._std_weight_velocity * mean[3],
  94. 1e-5,
  95. self._std_weight_velocity * mean[3]]
  96. motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
  97. mean = np.dot(self._motion_mat, mean)
  98. covariance = np.linalg.multi_dot((
  99. self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
  100. return mean, covariance
  101. def project(self, mean, covariance):
  102. """Project state distribution to measurement space.
  103. Parameters
  104. ----------
  105. mean : ndarray
  106. The state's mean vector (8 dimensional array).
  107. covariance : ndarray
  108. The state's covariance matrix (8x8 dimensional).
  109. Returns
  110. -------
  111. (ndarray, ndarray)
  112. Returns the projected mean and covariance matrix of the given state
  113. estimate.
  114. """
  115. std = [
  116. self._std_weight_position * mean[3],
  117. self._std_weight_position * mean[3],
  118. 1e-1,
  119. self._std_weight_position * mean[3]]
  120. innovation_cov = np.diag(np.square(std))
  121. mean = np.dot(self._update_mat, mean)
  122. covariance = np.linalg.multi_dot((
  123. self._update_mat, covariance, self._update_mat.T))
  124. return mean, covariance + innovation_cov
  125. def update(self, mean, covariance, measurement):
  126. """Run Kalman filter correction step.
  127. Parameters
  128. ----------
  129. mean : ndarray
  130. The predicted state's mean vector (8 dimensional).
  131. covariance : ndarray
  132. The state's covariance matrix (8x8 dimensional).
  133. measurement : ndarray
  134. The 4 dimensional measurement vector (x, y, a, h), where (x, y)
  135. is the center position, a the aspect ratio, and h the height of the
  136. bounding box.
  137. Returns
  138. -------
  139. (ndarray, ndarray)
  140. Returns the measurement-corrected state distribution.
  141. """
  142. projected_mean, projected_cov = self.project(mean, covariance)
  143. chol_factor, lower = scipy.linalg.cho_factor(
  144. projected_cov, lower=True, check_finite=False)
  145. kalman_gain = scipy.linalg.cho_solve(
  146. (chol_factor, lower), np.dot(covariance, self._update_mat.T).T,
  147. check_finite=False).T
  148. innovation = measurement - projected_mean
  149. new_mean = mean + np.dot(innovation, kalman_gain.T)
  150. new_covariance = covariance - np.linalg.multi_dot((
  151. kalman_gain, projected_cov, kalman_gain.T))
  152. return new_mean, new_covariance
  153. def gating_distance(self, mean, covariance, measurements,
  154. only_position=False):
  155. """Compute gating distance between state distribution and measurements.
  156. A suitable distance threshold can be obtained from `chi2inv95`. If
  157. `only_position` is False, the chi-square distribution has 4 degrees of
  158. freedom, otherwise 2.
  159. Parameters
  160. ----------
  161. mean : ndarray
  162. Mean vector over the state distribution (8 dimensional).
  163. covariance : ndarray
  164. Covariance of the state distribution (8x8 dimensional).
  165. measurements : ndarray
  166. An Nx4 dimensional matrix of N measurements, each in
  167. format (x, y, a, h) where (x, y) is the bounding box center
  168. position, a the aspect ratio, and h the height.
  169. only_position : Optional[bool]
  170. If True, distance computation is done with respect to the bounding
  171. box center position only.
  172. Returns
  173. -------
  174. ndarray
  175. Returns an array of length N, where the i-th element contains the
  176. squared Mahalanobis distance between (mean, covariance) and
  177. `measurements[i]`.
  178. """
  179. mean, covariance = self.project(mean, covariance)
  180. if only_position:
  181. mean, covariance = mean[:2], covariance[:2, :2]
  182. measurements = measurements[:, :2]
  183. cholesky_factor = np.linalg.cholesky(covariance)
  184. d = measurements - mean
  185. z = scipy.linalg.solve_triangular(
  186. cholesky_factor, d.T, lower=True, check_finite=False,
  187. overwrite_b=True)
  188. squared_maha = np.sum(z * z, axis=0)
  189. return squared_maha