Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

kalman_filter.py 9.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. # vim: expandtab:ts=4:sw=4
  2. import numpy as np
  3. import scipy.linalg
  4. """
  5. Table for the 0.95 quantile of the chi-square distribution with N degrees of
  6. freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
  7. function and used as Mahalanobis gating threshold.
  8. """
  9. chi2inv95 = {
  10. 1: 3.8415,
  11. 2: 5.9915,
  12. 3: 7.8147,
  13. 4: 9.4877,
  14. 5: 11.070,
  15. 6: 12.592,
  16. 7: 14.067,
  17. 8: 15.507,
  18. 9: 16.919}
  19. class KalmanFilter(object):
  20. """
  21. A simple Kalman filter for tracking bounding boxes in image space.
  22. The 8-dimensional state space
  23. x, y, a, h, vx, vy, va, vh
  24. contains the bounding box center position (x, y), aspect ratio a, height h,
  25. and their respective velocities.
  26. Object motion follows a constant velocity model. The bounding box location
  27. (x, y, a, h) is taken as direct observation of the state space (linear
  28. observation model).
  29. """
  30. def __init__(self):
  31. ndim, dt = 4, 1.
  32. # Create Kalman filter model matrices.
  33. self._motion_mat = np.eye(2 * ndim, 2 * ndim)
  34. for i in range(ndim):
  35. self._motion_mat[i, ndim + i] = dt
  36. self._update_mat = np.eye(ndim, 2 * ndim)
  37. # Motion and observation uncertainty are chosen relative to the current
  38. # state estimate. These weights control the amount of uncertainty in
  39. # the model. This is a bit hacky.
  40. self._std_weight_position = 1. / 20
  41. self._std_weight_velocity = 1. / 160
  42. def initiate(self, measurement):
  43. """Create track from unassociated measurement.
  44. Parameters
  45. ----------
  46. measurement : ndarray
  47. Bounding box coordinates (x, y, a, h) with center position (x, y),
  48. aspect ratio a, and height h.
  49. Returns
  50. -------
  51. (ndarray, ndarray)
  52. Returns the mean vector (8 dimensional) and covariance matrix (8x8
  53. dimensional) of the new track. Unobserved velocities are initialized
  54. to 0 mean.
  55. """
  56. mean_pos = measurement
  57. mean_vel = np.zeros_like(mean_pos)
  58. mean = np.r_[mean_pos, mean_vel]
  59. std = [
  60. 2 * self._std_weight_position * measurement[3],
  61. 2 * self._std_weight_position * measurement[3],
  62. 1e-2,
  63. 2 * self._std_weight_position * measurement[3],
  64. 10 * self._std_weight_velocity * measurement[3],
  65. 10 * self._std_weight_velocity * measurement[3],
  66. 1e-5,
  67. 10 * self._std_weight_velocity * measurement[3]]
  68. covariance = np.diag(np.square(std))
  69. return mean, covariance
  70. def predict(self, mean, covariance):
  71. """Run Kalman filter prediction step.
  72. Parameters
  73. ----------
  74. mean : ndarray
  75. The 8 dimensional mean vector of the object state at the previous
  76. time step.
  77. covariance : ndarray
  78. The 8x8 dimensional covariance matrix of the object state at the
  79. previous time step.
  80. Returns
  81. -------
  82. (ndarray, ndarray)
  83. Returns the mean vector and covariance matrix of the predicted
  84. state. Unobserved velocities are initialized to 0 mean.
  85. """
  86. std_pos = [
  87. self._std_weight_position * mean[3],
  88. self._std_weight_position * mean[3],
  89. 1e-2,
  90. self._std_weight_position * mean[3]]
  91. std_vel = [
  92. self._std_weight_velocity * mean[3],
  93. self._std_weight_velocity * mean[3],
  94. 1e-5,
  95. self._std_weight_velocity * mean[3]]
  96. motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
  97. #mean = np.dot(self._motion_mat, mean)
  98. mean = np.dot(mean, self._motion_mat.T)
  99. covariance = np.linalg.multi_dot((
  100. self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
  101. return mean, covariance
  102. def project(self, mean, covariance):
  103. """Project state distribution to measurement space.
  104. Parameters
  105. ----------
  106. mean : ndarray
  107. The state's mean vector (8 dimensional array).
  108. covariance : ndarray
  109. The state's covariance matrix (8x8 dimensional).
  110. Returns
  111. -------
  112. (ndarray, ndarray)
  113. Returns the projected mean and covariance matrix of the given state
  114. estimate.
  115. """
  116. std = [
  117. self._std_weight_position * mean[3],
  118. self._std_weight_position * mean[3],
  119. 1e-1,
  120. self._std_weight_position * mean[3]]
  121. innovation_cov = np.diag(np.square(std))
  122. mean = np.dot(self._update_mat, mean)
  123. covariance = np.linalg.multi_dot((
  124. self._update_mat, covariance, self._update_mat.T))
  125. return mean, covariance + innovation_cov
  126. def multi_predict(self, mean, covariance):
  127. """Run Kalman filter prediction step (Vectorized version).
  128. Parameters
  129. ----------
  130. mean : ndarray
  131. The Nx8 dimensional mean matrix of the object states at the previous
  132. time step.
  133. covariance : ndarray
  134. The Nx8x8 dimensional covariance matrics of the object states at the
  135. previous time step.
  136. Returns
  137. -------
  138. (ndarray, ndarray)
  139. Returns the mean vector and covariance matrix of the predicted
  140. state. Unobserved velocities are initialized to 0 mean.
  141. """
  142. std_pos = [
  143. self._std_weight_position * mean[:, 3],
  144. self._std_weight_position * mean[:, 3],
  145. 1e-2 * np.ones_like(mean[:, 3]),
  146. self._std_weight_position * mean[:, 3]]
  147. std_vel = [
  148. self._std_weight_velocity * mean[:, 3],
  149. self._std_weight_velocity * mean[:, 3],
  150. 1e-5 * np.ones_like(mean[:, 3]),
  151. self._std_weight_velocity * mean[:, 3]]
  152. sqr = np.square(np.r_[std_pos, std_vel]).T
  153. motion_cov = []
  154. for i in range(len(mean)):
  155. motion_cov.append(np.diag(sqr[i]))
  156. motion_cov = np.asarray(motion_cov)
  157. mean = np.dot(mean, self._motion_mat.T)
  158. left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
  159. covariance = np.dot(left, self._motion_mat.T) + motion_cov
  160. return mean, covariance
  161. def update(self, mean, covariance, measurement):
  162. """Run Kalman filter correction step.
  163. Parameters
  164. ----------
  165. mean : ndarray
  166. The predicted state's mean vector (8 dimensional).
  167. covariance : ndarray
  168. The state's covariance matrix (8x8 dimensional).
  169. measurement : ndarray
  170. The 4 dimensional measurement vector (x, y, a, h), where (x, y)
  171. is the center position, a the aspect ratio, and h the height of the
  172. bounding box.
  173. Returns
  174. -------
  175. (ndarray, ndarray)
  176. Returns the measurement-corrected state distribution.
  177. """
  178. projected_mean, projected_cov = self.project(mean, covariance)
  179. chol_factor, lower = scipy.linalg.cho_factor(
  180. projected_cov, lower=True, check_finite=False)
  181. kalman_gain = scipy.linalg.cho_solve(
  182. (chol_factor, lower), np.dot(covariance, self._update_mat.T).T,
  183. check_finite=False).T
  184. innovation = measurement - projected_mean
  185. new_mean = mean + np.dot(innovation, kalman_gain.T)
  186. new_covariance = covariance - np.linalg.multi_dot((
  187. kalman_gain, projected_cov, kalman_gain.T))
  188. return new_mean, new_covariance
  189. def gating_distance(self, mean, covariance, measurements,
  190. only_position=False, metric='maha'):
  191. """Compute gating distance between state distribution and measurements.
  192. A suitable distance threshold can be obtained from `chi2inv95`. If
  193. `only_position` is False, the chi-square distribution has 4 degrees of
  194. freedom, otherwise 2.
  195. Parameters
  196. ----------
  197. mean : ndarray
  198. Mean vector over the state distribution (8 dimensional).
  199. covariance : ndarray
  200. Covariance of the state distribution (8x8 dimensional).
  201. measurements : ndarray
  202. An Nx4 dimensional matrix of N measurements, each in
  203. format (x, y, a, h) where (x, y) is the bounding box center
  204. position, a the aspect ratio, and h the height.
  205. only_position : Optional[bool]
  206. If True, distance computation is done with respect to the bounding
  207. box center position only.
  208. Returns
  209. -------
  210. ndarray
  211. Returns an array of length N, where the i-th element contains the
  212. squared Mahalanobis distance between (mean, covariance) and
  213. `measurements[i]`.
  214. """
  215. mean, covariance = self.project(mean, covariance)
  216. if only_position:
  217. mean, covariance = mean[:2], covariance[:2, :2]
  218. measurements = measurements[:, :2]
  219. d = measurements - mean
  220. if metric == 'gaussian':
  221. return np.sum(d * d, axis=1)
  222. elif metric == 'maha':
  223. cholesky_factor = np.linalg.cholesky(covariance)
  224. z = scipy.linalg.solve_triangular(
  225. cholesky_factor, d.T, lower=True, check_finite=False,
  226. overwrite_b=True)
  227. squared_maha = np.sum(z * z, axis=0)
  228. return squared_maha
  229. else:
  230. raise ValueError('invalid distance metric')