Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

BYTETracker.cpp 6.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. #include "BYTETracker.h"
  2. #include <fstream>
  3. BYTETracker::BYTETracker(int frame_rate, int track_buffer)
  4. {
  5. track_thresh = 0.5;
  6. high_thresh = 0.6;
  7. match_thresh = 0.8;
  8. frame_id = 0;
  9. max_time_lost = int(frame_rate / 30.0 * track_buffer);
  10. cout << "Init ByteTrack!" << endl;
  11. }
  12. BYTETracker::~BYTETracker()
  13. {
  14. }
  15. vector<STrack> BYTETracker::update(const vector<Object>& objects)
  16. {
  17. ////////////////// Step 1: Get detections //////////////////
  18. this->frame_id++;
  19. vector<STrack> activated_stracks;
  20. vector<STrack> refind_stracks;
  21. vector<STrack> removed_stracks;
  22. vector<STrack> lost_stracks;
  23. vector<STrack> detections;
  24. vector<STrack> detections_low;
  25. vector<STrack> detections_cp;
  26. vector<STrack> tracked_stracks_swap;
  27. vector<STrack> resa, resb;
  28. vector<STrack> output_stracks;
  29. vector<STrack*> unconfirmed;
  30. vector<STrack*> tracked_stracks;
  31. vector<STrack*> strack_pool;
  32. vector<STrack*> r_tracked_stracks;
  33. if (objects.size() > 0)
  34. {
  35. for (int i = 0; i < objects.size(); i++)
  36. {
  37. vector<float> tlbr_;
  38. tlbr_.resize(4);
  39. tlbr_[0] = objects[i].rect.x;
  40. tlbr_[1] = objects[i].rect.y;
  41. tlbr_[2] = objects[i].rect.x + objects[i].rect.width;
  42. tlbr_[3] = objects[i].rect.y + objects[i].rect.height;
  43. float score = objects[i].prob;
  44. STrack strack(STrack::tlbr_to_tlwh(tlbr_), score);
  45. if (score >= track_thresh)
  46. {
  47. detections.push_back(strack);
  48. }
  49. else
  50. {
  51. detections_low.push_back(strack);
  52. }
  53. }
  54. }
  55. // Add newly detected tracklets to tracked_stracks
  56. for (int i = 0; i < this->tracked_stracks.size(); i++)
  57. {
  58. if (!this->tracked_stracks[i].is_activated)
  59. unconfirmed.push_back(&this->tracked_stracks[i]);
  60. else
  61. tracked_stracks.push_back(&this->tracked_stracks[i]);
  62. }
  63. ////////////////// Step 2: First association, with IoU //////////////////
  64. strack_pool = joint_stracks(tracked_stracks, this->lost_stracks);
  65. STrack::multi_predict(strack_pool, this->kalman_filter);
  66. vector<vector<float> > dists;
  67. int dist_size = 0, dist_size_size = 0;
  68. dists = iou_distance(strack_pool, detections, dist_size, dist_size_size);
  69. vector<vector<int> > matches;
  70. vector<int> u_track, u_detection;
  71. linear_assignment(dists, dist_size, dist_size_size, match_thresh, matches, u_track, u_detection);
  72. for (int i = 0; i < matches.size(); i++)
  73. {
  74. STrack *track = strack_pool[matches[i][0]];
  75. STrack *det = &detections[matches[i][1]];
  76. if (track->state == TrackState::Tracked)
  77. {
  78. track->update(*det, this->frame_id);
  79. activated_stracks.push_back(*track);
  80. }
  81. else
  82. {
  83. track->re_activate(*det, this->frame_id, false);
  84. refind_stracks.push_back(*track);
  85. }
  86. }
  87. ////////////////// Step 3: Second association, using low score dets //////////////////
  88. for (int i = 0; i < u_detection.size(); i++)
  89. {
  90. detections_cp.push_back(detections[u_detection[i]]);
  91. }
  92. detections.clear();
  93. detections.assign(detections_low.begin(), detections_low.end());
  94. for (int i = 0; i < u_track.size(); i++)
  95. {
  96. if (strack_pool[u_track[i]]->state == TrackState::Tracked)
  97. {
  98. r_tracked_stracks.push_back(strack_pool[u_track[i]]);
  99. }
  100. }
  101. dists.clear();
  102. dists = iou_distance(r_tracked_stracks, detections, dist_size, dist_size_size);
  103. matches.clear();
  104. u_track.clear();
  105. u_detection.clear();
  106. linear_assignment(dists, dist_size, dist_size_size, 0.5, matches, u_track, u_detection);
  107. for (int i = 0; i < matches.size(); i++)
  108. {
  109. STrack *track = r_tracked_stracks[matches[i][0]];
  110. STrack *det = &detections[matches[i][1]];
  111. if (track->state == TrackState::Tracked)
  112. {
  113. track->update(*det, this->frame_id);
  114. activated_stracks.push_back(*track);
  115. }
  116. else
  117. {
  118. track->re_activate(*det, this->frame_id, false);
  119. refind_stracks.push_back(*track);
  120. }
  121. }
  122. for (int i = 0; i < u_track.size(); i++)
  123. {
  124. STrack *track = r_tracked_stracks[u_track[i]];
  125. if (track->state != TrackState::Lost)
  126. {
  127. track->mark_lost();
  128. lost_stracks.push_back(*track);
  129. }
  130. }
  131. // Deal with unconfirmed tracks, usually tracks with only one beginning frame
  132. detections.clear();
  133. detections.assign(detections_cp.begin(), detections_cp.end());
  134. dists.clear();
  135. dists = iou_distance(unconfirmed, detections, dist_size, dist_size_size);
  136. matches.clear();
  137. vector<int> u_unconfirmed;
  138. u_detection.clear();
  139. linear_assignment(dists, dist_size, dist_size_size, 0.7, matches, u_unconfirmed, u_detection);
  140. for (int i = 0; i < matches.size(); i++)
  141. {
  142. unconfirmed[matches[i][0]]->update(detections[matches[i][1]], this->frame_id);
  143. activated_stracks.push_back(*unconfirmed[matches[i][0]]);
  144. }
  145. for (int i = 0; i < u_unconfirmed.size(); i++)
  146. {
  147. STrack *track = unconfirmed[u_unconfirmed[i]];
  148. track->mark_removed();
  149. removed_stracks.push_back(*track);
  150. }
  151. ////////////////// Step 4: Init new stracks //////////////////
  152. for (int i = 0; i < u_detection.size(); i++)
  153. {
  154. STrack *track = &detections[u_detection[i]];
  155. if (track->score < this->high_thresh)
  156. continue;
  157. track->activate(this->kalman_filter, this->frame_id);
  158. activated_stracks.push_back(*track);
  159. }
  160. ////////////////// Step 5: Update state //////////////////
  161. for (int i = 0; i < this->lost_stracks.size(); i++)
  162. {
  163. if (this->frame_id - this->lost_stracks[i].end_frame() > this->max_time_lost)
  164. {
  165. this->lost_stracks[i].mark_removed();
  166. removed_stracks.push_back(this->lost_stracks[i]);
  167. }
  168. }
  169. for (int i = 0; i < this->tracked_stracks.size(); i++)
  170. {
  171. if (this->tracked_stracks[i].state == TrackState::Tracked)
  172. {
  173. tracked_stracks_swap.push_back(this->tracked_stracks[i]);
  174. }
  175. }
  176. this->tracked_stracks.clear();
  177. this->tracked_stracks.assign(tracked_stracks_swap.begin(), tracked_stracks_swap.end());
  178. this->tracked_stracks = joint_stracks(this->tracked_stracks, activated_stracks);
  179. this->tracked_stracks = joint_stracks(this->tracked_stracks, refind_stracks);
  180. //std::cout << activated_stracks.size() << std::endl;
  181. this->lost_stracks = sub_stracks(this->lost_stracks, this->tracked_stracks);
  182. for (int i = 0; i < lost_stracks.size(); i++)
  183. {
  184. this->lost_stracks.push_back(lost_stracks[i]);
  185. }
  186. this->lost_stracks = sub_stracks(this->lost_stracks, this->removed_stracks);
  187. for (int i = 0; i < removed_stracks.size(); i++)
  188. {
  189. this->removed_stracks.push_back(removed_stracks[i]);
  190. }
  191. remove_duplicate_stracks(resa, resb, this->tracked_stracks, this->lost_stracks);
  192. this->tracked_stracks.clear();
  193. this->tracked_stracks.assign(resa.begin(), resa.end());
  194. this->lost_stracks.clear();
  195. this->lost_stracks.assign(resb.begin(), resb.end());
  196. for (int i = 0; i < this->tracked_stracks.size(); i++)
  197. {
  198. if (this->tracked_stracks[i].is_activated)
  199. {
  200. output_stracks.push_back(this->tracked_stracks[i]);
  201. }
  202. }
  203. return output_stracks;
  204. }