Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

bytetrack.cpp 15KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. #include <fstream>
  2. #include <iostream>
  3. #include <sstream>
  4. #include <numeric>
  5. #include <chrono>
  6. #include <vector>
  7. #include <opencv2/opencv.hpp>
  8. #include <dirent.h>
  9. #include "NvInfer.h"
  10. #include "cuda_runtime_api.h"
  11. #include "logging.h"
  12. #include "BYTETracker.h"
  13. #define CHECK(status) \
  14. do\
  15. {\
  16. auto ret = (status);\
  17. if (ret != 0)\
  18. {\
  19. cerr << "Cuda failure: " << ret << endl;\
  20. abort();\
  21. }\
  22. } while (0)
  23. #define DEVICE 0 // GPU id
  24. #define NMS_THRESH 0.7
  25. #define BBOX_CONF_THRESH 0.1
  26. using namespace nvinfer1;
  27. // stuff we know about the network and the input/output blobs
  28. static const int INPUT_W = 1088;
  29. static const int INPUT_H = 608;
  30. const char* INPUT_BLOB_NAME = "input_0";
  31. const char* OUTPUT_BLOB_NAME = "output_0";
  32. static Logger gLogger;
  33. Mat static_resize(Mat& img) {
  34. float r = min(INPUT_W / (img.cols*1.0), INPUT_H / (img.rows*1.0));
  35. // r = std::min(r, 1.0f);
  36. int unpad_w = r * img.cols;
  37. int unpad_h = r * img.rows;
  38. Mat re(unpad_h, unpad_w, CV_8UC3);
  39. resize(img, re, re.size());
  40. Mat out(INPUT_H, INPUT_W, CV_8UC3, Scalar(114, 114, 114));
  41. re.copyTo(out(Rect(0, 0, re.cols, re.rows)));
  42. return out;
  43. }
  44. struct GridAndStride
  45. {
  46. int grid0;
  47. int grid1;
  48. int stride;
  49. };
  50. static void generate_grids_and_stride(const int target_w, const int target_h, vector<int>& strides, vector<GridAndStride>& grid_strides)
  51. {
  52. for (auto stride : strides)
  53. {
  54. int num_grid_w = target_w / stride;
  55. int num_grid_h = target_h / stride;
  56. for (int g1 = 0; g1 < num_grid_h; g1++)
  57. {
  58. for (int g0 = 0; g0 < num_grid_w; g0++)
  59. {
  60. grid_strides.push_back((GridAndStride){g0, g1, stride});
  61. }
  62. }
  63. }
  64. }
  65. static inline float intersection_area(const Object& a, const Object& b)
  66. {
  67. Rect_<float> inter = a.rect & b.rect;
  68. return inter.area();
  69. }
  70. static void qsort_descent_inplace(vector<Object>& faceobjects, int left, int right)
  71. {
  72. int i = left;
  73. int j = right;
  74. float p = faceobjects[(left + right) / 2].prob;
  75. while (i <= j)
  76. {
  77. while (faceobjects[i].prob > p)
  78. i++;
  79. while (faceobjects[j].prob < p)
  80. j--;
  81. if (i <= j)
  82. {
  83. // swap
  84. swap(faceobjects[i], faceobjects[j]);
  85. i++;
  86. j--;
  87. }
  88. }
  89. #pragma omp parallel sections
  90. {
  91. #pragma omp section
  92. {
  93. if (left < j) qsort_descent_inplace(faceobjects, left, j);
  94. }
  95. #pragma omp section
  96. {
  97. if (i < right) qsort_descent_inplace(faceobjects, i, right);
  98. }
  99. }
  100. }
  101. static void qsort_descent_inplace(vector<Object>& objects)
  102. {
  103. if (objects.empty())
  104. return;
  105. qsort_descent_inplace(objects, 0, objects.size() - 1);
  106. }
  107. static void nms_sorted_bboxes(const vector<Object>& faceobjects, vector<int>& picked, float nms_threshold)
  108. {
  109. picked.clear();
  110. const int n = faceobjects.size();
  111. vector<float> areas(n);
  112. for (int i = 0; i < n; i++)
  113. {
  114. areas[i] = faceobjects[i].rect.area();
  115. }
  116. for (int i = 0; i < n; i++)
  117. {
  118. const Object& a = faceobjects[i];
  119. int keep = 1;
  120. for (int j = 0; j < (int)picked.size(); j++)
  121. {
  122. const Object& b = faceobjects[picked[j]];
  123. // intersection over union
  124. float inter_area = intersection_area(a, b);
  125. float union_area = areas[i] + areas[picked[j]] - inter_area;
  126. // float IoU = inter_area / union_area
  127. if (inter_area / union_area > nms_threshold)
  128. keep = 0;
  129. }
  130. if (keep)
  131. picked.push_back(i);
  132. }
  133. }
  134. static void generate_yolox_proposals(vector<GridAndStride> grid_strides, float* feat_blob, float prob_threshold, vector<Object>& objects)
  135. {
  136. const int num_class = 1;
  137. const int num_anchors = grid_strides.size();
  138. for (int anchor_idx = 0; anchor_idx < num_anchors; anchor_idx++)
  139. {
  140. const int grid0 = grid_strides[anchor_idx].grid0;
  141. const int grid1 = grid_strides[anchor_idx].grid1;
  142. const int stride = grid_strides[anchor_idx].stride;
  143. const int basic_pos = anchor_idx * (num_class + 5);
  144. // yolox/models/yolo_head.py decode logic
  145. float x_center = (feat_blob[basic_pos+0] + grid0) * stride;
  146. float y_center = (feat_blob[basic_pos+1] + grid1) * stride;
  147. float w = exp(feat_blob[basic_pos+2]) * stride;
  148. float h = exp(feat_blob[basic_pos+3]) * stride;
  149. float x0 = x_center - w * 0.5f;
  150. float y0 = y_center - h * 0.5f;
  151. float box_objectness = feat_blob[basic_pos+4];
  152. for (int class_idx = 0; class_idx < num_class; class_idx++)
  153. {
  154. float box_cls_score = feat_blob[basic_pos + 5 + class_idx];
  155. float box_prob = box_objectness * box_cls_score;
  156. if (box_prob > prob_threshold)
  157. {
  158. Object obj;
  159. obj.rect.x = x0;
  160. obj.rect.y = y0;
  161. obj.rect.width = w;
  162. obj.rect.height = h;
  163. obj.label = class_idx;
  164. obj.prob = box_prob;
  165. objects.push_back(obj);
  166. }
  167. } // class loop
  168. } // point anchor loop
  169. }
  170. float* blobFromImage(Mat& img){
  171. cvtColor(img, img, COLOR_BGR2RGB);
  172. float* blob = new float[img.total()*3];
  173. int channels = 3;
  174. int img_h = img.rows;
  175. int img_w = img.cols;
  176. vector<float> mean = {0.485, 0.456, 0.406};
  177. vector<float> std = {0.229, 0.224, 0.225};
  178. for (size_t c = 0; c < channels; c++)
  179. {
  180. for (size_t h = 0; h < img_h; h++)
  181. {
  182. for (size_t w = 0; w < img_w; w++)
  183. {
  184. blob[c * img_w * img_h + h * img_w + w] =
  185. (((float)img.at<Vec3b>(h, w)[c]) / 255.0f - mean[c]) / std[c];
  186. }
  187. }
  188. }
  189. return blob;
  190. }
  191. static void decode_outputs(float* prob, vector<Object>& objects, float scale, const int img_w, const int img_h) {
  192. vector<Object> proposals;
  193. vector<int> strides = {8, 16, 32};
  194. vector<GridAndStride> grid_strides;
  195. generate_grids_and_stride(INPUT_W, INPUT_H, strides, grid_strides);
  196. generate_yolox_proposals(grid_strides, prob, BBOX_CONF_THRESH, proposals);
  197. //std::cout << "num of boxes before nms: " << proposals.size() << std::endl;
  198. qsort_descent_inplace(proposals);
  199. vector<int> picked;
  200. nms_sorted_bboxes(proposals, picked, NMS_THRESH);
  201. int count = picked.size();
  202. //std::cout << "num of boxes: " << count << std::endl;
  203. objects.resize(count);
  204. for (int i = 0; i < count; i++)
  205. {
  206. objects[i] = proposals[picked[i]];
  207. // adjust offset to original unpadded
  208. float x0 = (objects[i].rect.x) / scale;
  209. float y0 = (objects[i].rect.y) / scale;
  210. float x1 = (objects[i].rect.x + objects[i].rect.width) / scale;
  211. float y1 = (objects[i].rect.y + objects[i].rect.height) / scale;
  212. // clip
  213. // x0 = std::max(std::min(x0, (float)(img_w - 1)), 0.f);
  214. // y0 = std::max(std::min(y0, (float)(img_h - 1)), 0.f);
  215. // x1 = std::max(std::min(x1, (float)(img_w - 1)), 0.f);
  216. // y1 = std::max(std::min(y1, (float)(img_h - 1)), 0.f);
  217. objects[i].rect.x = x0;
  218. objects[i].rect.y = y0;
  219. objects[i].rect.width = x1 - x0;
  220. objects[i].rect.height = y1 - y0;
  221. }
  222. }
  223. const float color_list[80][3] =
  224. {
  225. {0.000, 0.447, 0.741},
  226. {0.850, 0.325, 0.098},
  227. {0.929, 0.694, 0.125},
  228. {0.494, 0.184, 0.556},
  229. {0.466, 0.674, 0.188},
  230. {0.301, 0.745, 0.933},
  231. {0.635, 0.078, 0.184},
  232. {0.300, 0.300, 0.300},
  233. {0.600, 0.600, 0.600},
  234. {1.000, 0.000, 0.000},
  235. {1.000, 0.500, 0.000},
  236. {0.749, 0.749, 0.000},
  237. {0.000, 1.000, 0.000},
  238. {0.000, 0.000, 1.000},
  239. {0.667, 0.000, 1.000},
  240. {0.333, 0.333, 0.000},
  241. {0.333, 0.667, 0.000},
  242. {0.333, 1.000, 0.000},
  243. {0.667, 0.333, 0.000},
  244. {0.667, 0.667, 0.000},
  245. {0.667, 1.000, 0.000},
  246. {1.000, 0.333, 0.000},
  247. {1.000, 0.667, 0.000},
  248. {1.000, 1.000, 0.000},
  249. {0.000, 0.333, 0.500},
  250. {0.000, 0.667, 0.500},
  251. {0.000, 1.000, 0.500},
  252. {0.333, 0.000, 0.500},
  253. {0.333, 0.333, 0.500},
  254. {0.333, 0.667, 0.500},
  255. {0.333, 1.000, 0.500},
  256. {0.667, 0.000, 0.500},
  257. {0.667, 0.333, 0.500},
  258. {0.667, 0.667, 0.500},
  259. {0.667, 1.000, 0.500},
  260. {1.000, 0.000, 0.500},
  261. {1.000, 0.333, 0.500},
  262. {1.000, 0.667, 0.500},
  263. {1.000, 1.000, 0.500},
  264. {0.000, 0.333, 1.000},
  265. {0.000, 0.667, 1.000},
  266. {0.000, 1.000, 1.000},
  267. {0.333, 0.000, 1.000},
  268. {0.333, 0.333, 1.000},
  269. {0.333, 0.667, 1.000},
  270. {0.333, 1.000, 1.000},
  271. {0.667, 0.000, 1.000},
  272. {0.667, 0.333, 1.000},
  273. {0.667, 0.667, 1.000},
  274. {0.667, 1.000, 1.000},
  275. {1.000, 0.000, 1.000},
  276. {1.000, 0.333, 1.000},
  277. {1.000, 0.667, 1.000},
  278. {0.333, 0.000, 0.000},
  279. {0.500, 0.000, 0.000},
  280. {0.667, 0.000, 0.000},
  281. {0.833, 0.000, 0.000},
  282. {1.000, 0.000, 0.000},
  283. {0.000, 0.167, 0.000},
  284. {0.000, 0.333, 0.000},
  285. {0.000, 0.500, 0.000},
  286. {0.000, 0.667, 0.000},
  287. {0.000, 0.833, 0.000},
  288. {0.000, 1.000, 0.000},
  289. {0.000, 0.000, 0.167},
  290. {0.000, 0.000, 0.333},
  291. {0.000, 0.000, 0.500},
  292. {0.000, 0.000, 0.667},
  293. {0.000, 0.000, 0.833},
  294. {0.000, 0.000, 1.000},
  295. {0.000, 0.000, 0.000},
  296. {0.143, 0.143, 0.143},
  297. {0.286, 0.286, 0.286},
  298. {0.429, 0.429, 0.429},
  299. {0.571, 0.571, 0.571},
  300. {0.714, 0.714, 0.714},
  301. {0.857, 0.857, 0.857},
  302. {0.000, 0.447, 0.741},
  303. {0.314, 0.717, 0.741},
  304. {0.50, 0.5, 0}
  305. };
  306. void doInference(IExecutionContext& context, float* input, float* output, const int output_size, Size input_shape) {
  307. const ICudaEngine& engine = context.getEngine();
  308. // Pointers to input and output device buffers to pass to engine.
  309. // Engine requires exactly IEngine::getNbBindings() number of buffers.
  310. assert(engine.getNbBindings() == 2);
  311. void* buffers[2];
  312. // In order to bind the buffers, we need to know the names of the input and output tensors.
  313. // Note that indices are guaranteed to be less than IEngine::getNbBindings()
  314. const int inputIndex = engine.getBindingIndex(INPUT_BLOB_NAME);
  315. assert(engine.getBindingDataType(inputIndex) == nvinfer1::DataType::kFLOAT);
  316. const int outputIndex = engine.getBindingIndex(OUTPUT_BLOB_NAME);
  317. assert(engine.getBindingDataType(outputIndex) == nvinfer1::DataType::kFLOAT);
  318. int mBatchSize = engine.getMaxBatchSize();
  319. // Create GPU buffers on device
  320. CHECK(cudaMalloc(&buffers[inputIndex], 3 * input_shape.height * input_shape.width * sizeof(float)));
  321. CHECK(cudaMalloc(&buffers[outputIndex], output_size*sizeof(float)));
  322. // Create stream
  323. cudaStream_t stream;
  324. CHECK(cudaStreamCreate(&stream));
  325. // DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host
  326. CHECK(cudaMemcpyAsync(buffers[inputIndex], input, 3 * input_shape.height * input_shape.width * sizeof(float), cudaMemcpyHostToDevice, stream));
  327. context.enqueue(1, buffers, stream, nullptr);
  328. CHECK(cudaMemcpyAsync(output, buffers[outputIndex], output_size * sizeof(float), cudaMemcpyDeviceToHost, stream));
  329. cudaStreamSynchronize(stream);
  330. // Release stream and buffers
  331. cudaStreamDestroy(stream);
  332. CHECK(cudaFree(buffers[inputIndex]));
  333. CHECK(cudaFree(buffers[outputIndex]));
  334. }
  335. int main(int argc, char** argv) {
  336. cudaSetDevice(DEVICE);
  337. // create a model using the API directly and serialize it to a stream
  338. char *trtModelStream{nullptr};
  339. size_t size{0};
  340. if (argc == 4 && string(argv[2]) == "-i") {
  341. const string engine_file_path {argv[1]};
  342. ifstream file(engine_file_path, ios::binary);
  343. if (file.good()) {
  344. file.seekg(0, file.end);
  345. size = file.tellg();
  346. file.seekg(0, file.beg);
  347. trtModelStream = new char[size];
  348. assert(trtModelStream);
  349. file.read(trtModelStream, size);
  350. file.close();
  351. }
  352. } else {
  353. cerr << "arguments not right!" << endl;
  354. cerr << "run 'python3 tools/trt.py -f exps/example/mot/yolox_s_mix_det.py -c pretrained/bytetrack_s_mot17.pth.tar' to serialize model first!" << std::endl;
  355. cerr << "Then use the following command:" << endl;
  356. cerr << "cd demo/TensorRT/cpp/build" << endl;
  357. cerr << "./bytetrack ../../../../YOLOX_outputs/yolox_s_mix_det/model_trt.engine -i ../../../../videos/palace.mp4 // deserialize file and run inference" << std::endl;
  358. return -1;
  359. }
  360. const string input_video_path {argv[3]};
  361. IRuntime* runtime = createInferRuntime(gLogger);
  362. assert(runtime != nullptr);
  363. ICudaEngine* engine = runtime->deserializeCudaEngine(trtModelStream, size);
  364. assert(engine != nullptr);
  365. IExecutionContext* context = engine->createExecutionContext();
  366. assert(context != nullptr);
  367. delete[] trtModelStream;
  368. auto out_dims = engine->getBindingDimensions(1);
  369. auto output_size = 1;
  370. for(int j=0;j<out_dims.nbDims;j++) {
  371. output_size *= out_dims.d[j];
  372. }
  373. static float* prob = new float[output_size];
  374. VideoCapture cap(input_video_path);
  375. if (!cap.isOpened())
  376. return 0;
  377. int img_w = cap.get(CV_CAP_PROP_FRAME_WIDTH);
  378. int img_h = cap.get(CV_CAP_PROP_FRAME_HEIGHT);
  379. int fps = cap.get(CV_CAP_PROP_FPS);
  380. long nFrame = static_cast<long>(cap.get(CV_CAP_PROP_FRAME_COUNT));
  381. cout << "Total frames: " << nFrame << endl;
  382. VideoWriter writer("demo.mp4", CV_FOURCC('m', 'p', '4', 'v'), fps, Size(img_w, img_h));
  383. Mat img;
  384. BYTETracker tracker(fps, 30);
  385. int num_frames = 0;
  386. int total_ms = 0;
  387. while (true)
  388. {
  389. if(!cap.read(img))
  390. break;
  391. num_frames ++;
  392. if (num_frames % 20 == 0)
  393. {
  394. cout << "Processing frame " << num_frames << " (" << num_frames * 1000000 / total_ms << " fps)" << endl;
  395. }
  396. if (img.empty())
  397. break;
  398. Mat pr_img = static_resize(img);
  399. float* blob;
  400. blob = blobFromImage(pr_img);
  401. float scale = min(INPUT_W / (img.cols*1.0), INPUT_H / (img.rows*1.0));
  402. // run inference
  403. auto start = chrono::system_clock::now();
  404. doInference(*context, blob, prob, output_size, pr_img.size());
  405. vector<Object> objects;
  406. decode_outputs(prob, objects, scale, img_w, img_h);
  407. vector<STrack> output_stracks = tracker.update(objects);
  408. auto end = chrono::system_clock::now();
  409. total_ms = total_ms + chrono::duration_cast<chrono::microseconds>(end - start).count();
  410. for (int i = 0; i < output_stracks.size(); i++)
  411. {
  412. vector<float> tlwh = output_stracks[i].tlwh;
  413. bool vertical = tlwh[2] / tlwh[3] > 1.6;
  414. if (tlwh[2] * tlwh[3] > 20 && !vertical)
  415. {
  416. Scalar s = tracker.get_color(output_stracks[i].track_id);
  417. putText(img, format("%d", output_stracks[i].track_id), Point(tlwh[0], tlwh[1] - 5),
  418. 0, 0.6, Scalar(0, 0, 255), 2, LINE_AA);
  419. rectangle(img, Rect(tlwh[0], tlwh[1], tlwh[2], tlwh[3]), s, 2);
  420. }
  421. }
  422. putText(img, format("frame: %d fps: %d num: %d", num_frames, num_frames * 1000000 / total_ms, output_stracks.size()),
  423. Point(0, 30), 0, 0.6, Scalar(0, 0, 255), 2, LINE_AA);
  424. writer.write(img);
  425. delete blob;
  426. char c = waitKey(1);
  427. if (c > 0)
  428. {
  429. break;
  430. }
  431. }
  432. cap.release();
  433. cout << "FPS: " << num_frames * 1000000 / total_ms << endl;
  434. // destroy the engine
  435. context->destroy();
  436. engine->destroy();
  437. runtime->destroy();
  438. return 0;
  439. }