Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

kalmanFilter.cpp 4.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. #include "kalmanFilter.h"
  2. #include <Eigen/Cholesky>
  3. namespace byte_kalman
  4. {
  5. const double KalmanFilter::chi2inv95[10] = {
  6. 0,
  7. 3.8415,
  8. 5.9915,
  9. 7.8147,
  10. 9.4877,
  11. 11.070,
  12. 12.592,
  13. 14.067,
  14. 15.507,
  15. 16.919
  16. };
  17. KalmanFilter::KalmanFilter()
  18. {
  19. int ndim = 4;
  20. double dt = 1.;
  21. _motion_mat = Eigen::MatrixXf::Identity(8, 8);
  22. for (int i = 0; i < ndim; i++) {
  23. _motion_mat(i, ndim + i) = dt;
  24. }
  25. _update_mat = Eigen::MatrixXf::Identity(4, 8);
  26. this->_std_weight_position = 1. / 20;
  27. this->_std_weight_velocity = 1. / 160;
  28. }
  29. KAL_DATA KalmanFilter::initiate(const DETECTBOX &measurement)
  30. {
  31. DETECTBOX mean_pos = measurement;
  32. DETECTBOX mean_vel;
  33. for (int i = 0; i < 4; i++) mean_vel(i) = 0;
  34. KAL_MEAN mean;
  35. for (int i = 0; i < 8; i++) {
  36. if (i < 4) mean(i) = mean_pos(i);
  37. else mean(i) = mean_vel(i - 4);
  38. }
  39. KAL_MEAN std;
  40. std(0) = 2 * _std_weight_position * measurement[3];
  41. std(1) = 2 * _std_weight_position * measurement[3];
  42. std(2) = 1e-2;
  43. std(3) = 2 * _std_weight_position * measurement[3];
  44. std(4) = 10 * _std_weight_velocity * measurement[3];
  45. std(5) = 10 * _std_weight_velocity * measurement[3];
  46. std(6) = 1e-5;
  47. std(7) = 10 * _std_weight_velocity * measurement[3];
  48. KAL_MEAN tmp = std.array().square();
  49. KAL_COVA var = tmp.asDiagonal();
  50. return std::make_pair(mean, var);
  51. }
  52. void KalmanFilter::predict(KAL_MEAN &mean, KAL_COVA &covariance)
  53. {
  54. //revise the data;
  55. DETECTBOX std_pos;
  56. std_pos << _std_weight_position * mean(3),
  57. _std_weight_position * mean(3),
  58. 1e-2,
  59. _std_weight_position * mean(3);
  60. DETECTBOX std_vel;
  61. std_vel << _std_weight_velocity * mean(3),
  62. _std_weight_velocity * mean(3),
  63. 1e-5,
  64. _std_weight_velocity * mean(3);
  65. KAL_MEAN tmp;
  66. tmp.block<1, 4>(0, 0) = std_pos;
  67. tmp.block<1, 4>(0, 4) = std_vel;
  68. tmp = tmp.array().square();
  69. KAL_COVA motion_cov = tmp.asDiagonal();
  70. KAL_MEAN mean1 = this->_motion_mat * mean.transpose();
  71. KAL_COVA covariance1 = this->_motion_mat * covariance *(_motion_mat.transpose());
  72. covariance1 += motion_cov;
  73. mean = mean1;
  74. covariance = covariance1;
  75. }
  76. KAL_HDATA KalmanFilter::project(const KAL_MEAN &mean, const KAL_COVA &covariance)
  77. {
  78. DETECTBOX std;
  79. std << _std_weight_position * mean(3), _std_weight_position * mean(3),
  80. 1e-1, _std_weight_position * mean(3);
  81. KAL_HMEAN mean1 = _update_mat * mean.transpose();
  82. KAL_HCOVA covariance1 = _update_mat * covariance * (_update_mat.transpose());
  83. Eigen::Matrix<float, 4, 4> diag = std.asDiagonal();
  84. diag = diag.array().square().matrix();
  85. covariance1 += diag;
  86. // covariance1.diagonal() << diag;
  87. return std::make_pair(mean1, covariance1);
  88. }
  89. KAL_DATA
  90. KalmanFilter::update(
  91. const KAL_MEAN &mean,
  92. const KAL_COVA &covariance,
  93. const DETECTBOX &measurement)
  94. {
  95. KAL_HDATA pa = project(mean, covariance);
  96. KAL_HMEAN projected_mean = pa.first;
  97. KAL_HCOVA projected_cov = pa.second;
  98. //chol_factor, lower =
  99. //scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)
  100. //kalmain_gain =
  101. //scipy.linalg.cho_solve((cho_factor, lower),
  102. //np.dot(covariance, self._upadte_mat.T).T,
  103. //check_finite=False).T
  104. Eigen::Matrix<float, 4, 8> B = (covariance * (_update_mat.transpose())).transpose();
  105. Eigen::Matrix<float, 8, 4> kalman_gain = (projected_cov.llt().solve(B)).transpose(); // eg.8x4
  106. Eigen::Matrix<float, 1, 4> innovation = measurement - projected_mean; //eg.1x4
  107. auto tmp = innovation * (kalman_gain.transpose());
  108. KAL_MEAN new_mean = (mean.array() + tmp.array()).matrix();
  109. KAL_COVA new_covariance = covariance - kalman_gain * projected_cov*(kalman_gain.transpose());
  110. return std::make_pair(new_mean, new_covariance);
  111. }
  112. Eigen::Matrix<float, 1, -1>
  113. KalmanFilter::gating_distance(
  114. const KAL_MEAN &mean,
  115. const KAL_COVA &covariance,
  116. const std::vector<DETECTBOX> &measurements,
  117. bool only_position)
  118. {
  119. KAL_HDATA pa = this->project(mean, covariance);
  120. if (only_position) {
  121. printf("not implement!");
  122. exit(0);
  123. }
  124. KAL_HMEAN mean1 = pa.first;
  125. KAL_HCOVA covariance1 = pa.second;
  126. // Eigen::Matrix<float, -1, 4, Eigen::RowMajor> d(size, 4);
  127. DETECTBOXSS d(measurements.size(), 4);
  128. int pos = 0;
  129. for (DETECTBOX box : measurements) {
  130. d.row(pos++) = box - mean1;
  131. }
  132. Eigen::Matrix<float, -1, -1, Eigen::RowMajor> factor = covariance1.llt().matrixL();
  133. Eigen::Matrix<float, -1, -1> z = factor.triangularView<Eigen::Lower>().solve<Eigen::OnTheRight>(d).transpose();
  134. auto zz = ((z.array())*(z.array())).matrix();
  135. auto square_maha = zz.colwise().sum();
  136. return square_maha;
  137. }
  138. }