Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.cpp 8.9KB


  1. #include "BYTETracker.h"
  2. #include "lapjv.h"
  3. vector<STrack*> BYTETracker::joint_stracks(vector<STrack*> &tlista, vector<STrack> &tlistb)
  4. {
  5. map<int, int> exists;
  6. vector<STrack*> res;
  7. for (int i = 0; i < tlista.size(); i++)
  8. {
  9. exists.insert(pair<int, int>(tlista[i]->track_id, 1));
  10. res.push_back(tlista[i]);
  11. }
  12. for (int i = 0; i < tlistb.size(); i++)
  13. {
  14. int tid = tlistb[i].track_id;
  15. if (!exists[tid] || exists.count(tid) == 0)
  16. {
  17. exists[tid] = 1;
  18. res.push_back(&tlistb[i]);
  19. }
  20. }
  21. return res;
  22. }
  23. vector<STrack> BYTETracker::joint_stracks(vector<STrack> &tlista, vector<STrack> &tlistb)
  24. {
  25. map<int, int> exists;
  26. vector<STrack> res;
  27. for (int i = 0; i < tlista.size(); i++)
  28. {
  29. exists.insert(pair<int, int>(tlista[i].track_id, 1));
  30. res.push_back(tlista[i]);
  31. }
  32. for (int i = 0; i < tlistb.size(); i++)
  33. {
  34. int tid = tlistb[i].track_id;
  35. if (!exists[tid] || exists.count(tid) == 0)
  36. {
  37. exists[tid] = 1;
  38. res.push_back(tlistb[i]);
  39. }
  40. }
  41. return res;
  42. }
  43. vector<STrack> BYTETracker::sub_stracks(vector<STrack> &tlista, vector<STrack> &tlistb)
  44. {
  45. map<int, STrack> stracks;
  46. for (int i = 0; i < tlista.size(); i++)
  47. {
  48. stracks.insert(pair<int, STrack>(tlista[i].track_id, tlista[i]));
  49. }
  50. for (int i = 0; i < tlistb.size(); i++)
  51. {
  52. int tid = tlistb[i].track_id;
  53. if (stracks.count(tid) != 0)
  54. {
  55. stracks.erase(tid);
  56. }
  57. }
  58. vector<STrack> res;
  59. std::map<int, STrack>::iterator it;
  60. for (it = stracks.begin(); it != stracks.end(); ++it)
  61. {
  62. res.push_back(it->second);
  63. }
  64. return res;
  65. }
  66. void BYTETracker::remove_duplicate_stracks(vector<STrack> &resa, vector<STrack> &resb, vector<STrack> &stracksa, vector<STrack> &stracksb)
  67. {
  68. vector<vector<float> > pdist = iou_distance(stracksa, stracksb);
  69. vector<pair<int, int> > pairs;
  70. for (int i = 0; i < pdist.size(); i++)
  71. {
  72. for (int j = 0; j < pdist[i].size(); j++)
  73. {
  74. if (pdist[i][j] < 0.15)
  75. {
  76. pairs.push_back(pair<int, int>(i, j));
  77. }
  78. }
  79. }
  80. vector<int> dupa, dupb;
  81. for (int i = 0; i < pairs.size(); i++)
  82. {
  83. int timep = stracksa[pairs[i].first].frame_id - stracksa[pairs[i].first].start_frame;
  84. int timeq = stracksb[pairs[i].second].frame_id - stracksb[pairs[i].second].start_frame;
  85. if (timep > timeq)
  86. dupb.push_back(pairs[i].second);
  87. else
  88. dupa.push_back(pairs[i].first);
  89. }
  90. for (int i = 0; i < stracksa.size(); i++)
  91. {
  92. vector<int>::iterator iter = find(dupa.begin(), dupa.end(), i);
  93. if (iter == dupa.end())
  94. {
  95. resa.push_back(stracksa[i]);
  96. }
  97. }
  98. for (int i = 0; i < stracksb.size(); i++)
  99. {
  100. vector<int>::iterator iter = find(dupb.begin(), dupb.end(), i);
  101. if (iter == dupb.end())
  102. {
  103. resb.push_back(stracksb[i]);
  104. }
  105. }
  106. }
  107. void BYTETracker::linear_assignment(vector<vector<float> > &cost_matrix, int cost_matrix_size, int cost_matrix_size_size, float thresh,
  108. vector<vector<int> > &matches, vector<int> &unmatched_a, vector<int> &unmatched_b)
  109. {
  110. if (cost_matrix.size() == 0)
  111. {
  112. for (int i = 0; i < cost_matrix_size; i++)
  113. {
  114. unmatched_a.push_back(i);
  115. }
  116. for (int i = 0; i < cost_matrix_size_size; i++)
  117. {
  118. unmatched_b.push_back(i);
  119. }
  120. return;
  121. }
  122. vector<int> rowsol; vector<int> colsol;
  123. float c = lapjv(cost_matrix, rowsol, colsol, true, thresh);
  124. for (int i = 0; i < rowsol.size(); i++)
  125. {
  126. if (rowsol[i] >= 0)
  127. {
  128. vector<int> match;
  129. match.push_back(i);
  130. match.push_back(rowsol[i]);
  131. matches.push_back(match);
  132. }
  133. else
  134. {
  135. unmatched_a.push_back(i);
  136. }
  137. }
  138. for (int i = 0; i < colsol.size(); i++)
  139. {
  140. if (colsol[i] < 0)
  141. {
  142. unmatched_b.push_back(i);
  143. }
  144. }
  145. }
  146. vector<vector<float> > BYTETracker::ious(vector<vector<float> > &atlbrs, vector<vector<float> > &btlbrs)
  147. {
  148. vector<vector<float> > ious;
  149. if (atlbrs.size()*btlbrs.size() == 0)
  150. return ious;
  151. ious.resize(atlbrs.size());
  152. for (int i = 0; i < ious.size(); i++)
  153. {
  154. ious[i].resize(btlbrs.size());
  155. }
  156. //bbox_ious
  157. for (int k = 0; k < btlbrs.size(); k++)
  158. {
  159. vector<float> ious_tmp;
  160. float box_area = (btlbrs[k][2] - btlbrs[k][0] + 1)*(btlbrs[k][3] - btlbrs[k][1] + 1);
  161. for (int n = 0; n < atlbrs.size(); n++)
  162. {
  163. float iw = min(atlbrs[n][2], btlbrs[k][2]) - max(atlbrs[n][0], btlbrs[k][0]) + 1;
  164. if (iw > 0)
  165. {
  166. float ih = min(atlbrs[n][3], btlbrs[k][3]) - max(atlbrs[n][1], btlbrs[k][1]) + 1;
  167. if(ih > 0)
  168. {
  169. float ua = (atlbrs[n][2] - atlbrs[n][0] + 1)*(atlbrs[n][3] - atlbrs[n][1] + 1) + box_area - iw * ih;
  170. ious[n][k] = iw * ih / ua;
  171. }
  172. else
  173. {
  174. ious[n][k] = 0.0;
  175. }
  176. }
  177. else
  178. {
  179. ious[n][k] = 0.0;
  180. }
  181. }
  182. }
  183. return ious;
  184. }
  185. vector<vector<float> > BYTETracker::iou_distance(vector<STrack*> &atracks, vector<STrack> &btracks, int &dist_size, int &dist_size_size)
  186. {
  187. vector<vector<float> > cost_matrix;
  188. if (atracks.size() * btracks.size() == 0)
  189. {
  190. dist_size = atracks.size();
  191. dist_size_size = btracks.size();
  192. return cost_matrix;
  193. }
  194. vector<vector<float> > atlbrs, btlbrs;
  195. for (int i = 0; i < atracks.size(); i++)
  196. {
  197. atlbrs.push_back(atracks[i]->tlbr);
  198. }
  199. for (int i = 0; i < btracks.size(); i++)
  200. {
  201. btlbrs.push_back(btracks[i].tlbr);
  202. }
  203. dist_size = atracks.size();
  204. dist_size_size = btracks.size();
  205. vector<vector<float> > _ious = ious(atlbrs, btlbrs);
  206. for (int i = 0; i < _ious.size();i++)
  207. {
  208. vector<float> _iou;
  209. for (int j = 0; j < _ious[i].size(); j++)
  210. {
  211. _iou.push_back(1 - _ious[i][j]);
  212. }
  213. cost_matrix.push_back(_iou);
  214. }
  215. return cost_matrix;
  216. }
  217. vector<vector<float> > BYTETracker::iou_distance(vector<STrack> &atracks, vector<STrack> &btracks)
  218. {
  219. vector<vector<float> > atlbrs, btlbrs;
  220. for (int i = 0; i < atracks.size(); i++)
  221. {
  222. atlbrs.push_back(atracks[i].tlbr);
  223. }
  224. for (int i = 0; i < btracks.size(); i++)
  225. {
  226. btlbrs.push_back(btracks[i].tlbr);
  227. }
  228. vector<vector<float> > _ious = ious(atlbrs, btlbrs);
  229. vector<vector<float> > cost_matrix;
  230. for (int i = 0; i < _ious.size(); i++)
  231. {
  232. vector<float> _iou;
  233. for (int j = 0; j < _ious[i].size(); j++)
  234. {
  235. _iou.push_back(1 - _ious[i][j]);
  236. }
  237. cost_matrix.push_back(_iou);
  238. }
  239. return cost_matrix;
  240. }
  241. double BYTETracker::lapjv(const vector<vector<float> > &cost, vector<int> &rowsol, vector<int> &colsol,
  242. bool extend_cost, float cost_limit, bool return_cost)
  243. {
  244. vector<vector<float> > cost_c;
  245. cost_c.assign(cost.begin(), cost.end());
  246. vector<vector<float> > cost_c_extended;
  247. int n_rows = cost.size();
  248. int n_cols = cost[0].size();
  249. rowsol.resize(n_rows);
  250. colsol.resize(n_cols);
  251. int n = 0;
  252. if (n_rows == n_cols)
  253. {
  254. n = n_rows;
  255. }
  256. else
  257. {
  258. if (!extend_cost)
  259. {
  260. cout << "set extend_cost=True" << endl;
  261. system("pause");
  262. exit(0);
  263. }
  264. }
  265. if (extend_cost || cost_limit < LONG_MAX)
  266. {
  267. n = n_rows + n_cols;
  268. cost_c_extended.resize(n);
  269. for (int i = 0; i < cost_c_extended.size(); i++)
  270. cost_c_extended[i].resize(n);
  271. if (cost_limit < LONG_MAX)
  272. {
  273. for (int i = 0; i < cost_c_extended.size(); i++)
  274. {
  275. for (int j = 0; j < cost_c_extended[i].size(); j++)
  276. {
  277. cost_c_extended[i][j] = cost_limit / 2.0;
  278. }
  279. }
  280. }
  281. else
  282. {
  283. float cost_max = -1;
  284. for (int i = 0; i < cost_c.size(); i++)
  285. {
  286. for (int j = 0; j < cost_c[i].size(); j++)
  287. {
  288. if (cost_c[i][j] > cost_max)
  289. cost_max = cost_c[i][j];
  290. }
  291. }
  292. for (int i = 0; i < cost_c_extended.size(); i++)
  293. {
  294. for (int j = 0; j < cost_c_extended[i].size(); j++)
  295. {
  296. cost_c_extended[i][j] = cost_max + 1;
  297. }
  298. }
  299. }
  300. for (int i = n_rows; i < cost_c_extended.size(); i++)
  301. {
  302. for (int j = n_cols; j < cost_c_extended[i].size(); j++)
  303. {
  304. cost_c_extended[i][j] = 0;
  305. }
  306. }
  307. for (int i = 0; i < n_rows; i++)
  308. {
  309. for (int j = 0; j < n_cols; j++)
  310. {
  311. cost_c_extended[i][j] = cost_c[i][j];
  312. }
  313. }
  314. cost_c.clear();
  315. cost_c.assign(cost_c_extended.begin(), cost_c_extended.end());
  316. }
  317. double **cost_ptr;
  318. cost_ptr = new double *[sizeof(double *) * n];
  319. for (int i = 0; i < n; i++)
  320. cost_ptr[i] = new double[sizeof(double) * n];
  321. for (int i = 0; i < n; i++)
  322. {
  323. for (int j = 0; j < n; j++)
  324. {
  325. cost_ptr[i][j] = cost_c[i][j];
  326. }
  327. }
  328. int* x_c = new int[sizeof(int) * n];
  329. int *y_c = new int[sizeof(int) * n];
  330. int ret = lapjv_internal(n, cost_ptr, x_c, y_c);
  331. if (ret != 0)
  332. {
  333. cout << "Calculate Wrong!" << endl;
  334. system("pause");
  335. exit(0);
  336. }
  337. double opt = 0.0;
  338. if (n != n_rows)
  339. {
  340. for (int i = 0; i < n; i++)
  341. {
  342. if (x_c[i] >= n_cols)
  343. x_c[i] = -1;
  344. if (y_c[i] >= n_rows)
  345. y_c[i] = -1;
  346. }
  347. for (int i = 0; i < n_rows; i++)
  348. {
  349. rowsol[i] = x_c[i];
  350. }
  351. for (int i = 0; i < n_cols; i++)
  352. {
  353. colsol[i] = y_c[i];
  354. }
  355. if (return_cost)
  356. {
  357. for (int i = 0; i < rowsol.size(); i++)
  358. {
  359. if (rowsol[i] != -1)
  360. {
  361. //cout << i << "\t" << rowsol[i] << "\t" << cost_ptr[i][rowsol[i]] << endl;
  362. opt += cost_ptr[i][rowsol[i]];
  363. }
  364. }
  365. }
  366. }
  367. else if (return_cost)
  368. {
  369. for (int i = 0; i < rowsol.size(); i++)
  370. {
  371. opt += cost_ptr[i][rowsol[i]];
  372. }
  373. }
  374. for (int i = 0; i < n; i++)
  375. {
  376. delete[]cost_ptr[i];
  377. }
  378. delete[]cost_ptr;
  379. delete[]x_c;
  380. delete[]y_c;
  381. return opt;
  382. }
  383. Scalar BYTETracker::get_color(int idx)
  384. {
  385. idx += 3;
  386. return Scalar(37 * idx % 255, 17 * idx % 255, 29 * idx % 255);
  387. }