Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

onnx_inference.py 5.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import argparse
  2. import os
  3. import cv2
  4. import numpy as np
  5. from loguru import logger
  6. import onnxruntime
  7. from yolox.data.data_augment import preproc as preprocess
  8. from yolox.utils import mkdir, multiclass_nms, demo_postprocess, vis
  9. from yolox.utils.visualize import plot_tracking
  10. from yolox.tracker.byte_tracker import BYTETracker
  11. from yolox.tracking_utils.timer import Timer
  12. def make_parser():
  13. parser = argparse.ArgumentParser("onnxruntime inference sample")
  14. parser.add_argument(
  15. "-m",
  16. "--model",
  17. type=str,
  18. default="../../bytetrack_s.onnx",
  19. help="Input your onnx model.",
  20. )
  21. parser.add_argument(
  22. "-i",
  23. "--video_path",
  24. type=str,
  25. default='../../videos/palace.mp4',
  26. help="Path to your input image.",
  27. )
  28. parser.add_argument(
  29. "-o",
  30. "--output_dir",
  31. type=str,
  32. default='demo_output',
  33. help="Path to your output directory.",
  34. )
  35. parser.add_argument(
  36. "-s",
  37. "--score_thr",
  38. type=float,
  39. default=0.1,
  40. help="Score threshould to filter the result.",
  41. )
  42. parser.add_argument(
  43. "-n",
  44. "--nms_thr",
  45. type=float,
  46. default=0.7,
  47. help="NMS threshould.",
  48. )
  49. parser.add_argument(
  50. "--input_shape",
  51. type=str,
  52. default="608,1088",
  53. help="Specify an input shape for inference.",
  54. )
  55. parser.add_argument(
  56. "--with_p6",
  57. action="store_true",
  58. help="Whether your model uses p6 in FPN/PAN.",
  59. )
  60. # tracking args
  61. parser.add_argument("--track_thresh", type=float, default=0.5, help="tracking confidence threshold")
  62. parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
  63. parser.add_argument("--match_thresh", type=float, default=0.8, help="matching threshold for tracking")
  64. parser.add_argument('--min-box-area', type=float, default=10, help='filter out tiny boxes')
  65. parser.add_argument("--mot20", dest="mot20", default=False, action="store_true", help="test mot20.")
  66. return parser
  67. class Predictor(object):
  68. def __init__(self, args):
  69. self.rgb_means = (0.485, 0.456, 0.406)
  70. self.std = (0.229, 0.224, 0.225)
  71. self.args = args
  72. self.session = onnxruntime.InferenceSession(args.model)
  73. self.input_shape = tuple(map(int, args.input_shape.split(',')))
  74. def inference(self, ori_img, timer):
  75. img_info = {"id": 0}
  76. height, width = ori_img.shape[:2]
  77. img_info["height"] = height
  78. img_info["width"] = width
  79. img_info["raw_img"] = ori_img
  80. img, ratio = preprocess(ori_img, self.input_shape, self.rgb_means, self.std)
  81. img_info["ratio"] = ratio
  82. ort_inputs = {self.session.get_inputs()[0].name: img[None, :, :, :]}
  83. timer.tic()
  84. output = self.session.run(None, ort_inputs)
  85. predictions = demo_postprocess(output[0], self.input_shape, p6=self.args.with_p6)[0]
  86. boxes = predictions[:, :4]
  87. scores = predictions[:, 4:5] * predictions[:, 5:]
  88. boxes_xyxy = np.ones_like(boxes)
  89. boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
  90. boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
  91. boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
  92. boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
  93. boxes_xyxy /= ratio
  94. dets = multiclass_nms(boxes_xyxy, scores, nms_thr=self.args.nms_thr, score_thr=self.args.score_thr)
  95. return dets[:, :-1], img_info
  96. def imageflow_demo(predictor, args):
  97. cap = cv2.VideoCapture(args.video_path)
  98. width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
  99. height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
  100. fps = cap.get(cv2.CAP_PROP_FPS)
  101. save_folder = args.output_dir
  102. os.makedirs(save_folder, exist_ok=True)
  103. save_path = os.path.join(save_folder, args.video_path.split("/")[-1])
  104. logger.info(f"video save_path is {save_path}")
  105. vid_writer = cv2.VideoWriter(
  106. save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
  107. )
  108. tracker = BYTETracker(args, frame_rate=30)
  109. timer = Timer()
  110. frame_id = 0
  111. results = []
  112. while True:
  113. if frame_id % 20 == 0:
  114. logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
  115. ret_val, frame = cap.read()
  116. if ret_val:
  117. outputs, img_info = predictor.inference(frame, timer)
  118. online_targets = tracker.update(outputs, [img_info['height'], img_info['width']], [img_info['height'], img_info['width']])
  119. online_tlwhs = []
  120. online_ids = []
  121. online_scores = []
  122. for t in online_targets:
  123. tlwh = t.tlwh
  124. tid = t.track_id
  125. vertical = tlwh[2] / tlwh[3] > 1.6
  126. if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
  127. online_tlwhs.append(tlwh)
  128. online_ids.append(tid)
  129. online_scores.append(t.score)
  130. timer.toc()
  131. results.append((frame_id + 1, online_tlwhs, online_ids, online_scores))
  132. online_im = plot_tracking(img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1,
  133. fps=1. / timer.average_time)
  134. vid_writer.write(online_im)
  135. ch = cv2.waitKey(1)
  136. if ch == 27 or ch == ord("q") or ch == ord("Q"):
  137. break
  138. else:
  139. break
  140. frame_id += 1
  141. if __name__ == '__main__':
  142. args = make_parser().parse_args()
  143. predictor = Predictor(args)
  144. imageflow_demo(predictor, args)