Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

demo_track.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. from loguru import logger
  2. import cv2
  3. import torch
  4. from yolox.data.data_augment import preproc
  5. from yolox.exp import get_exp
  6. from yolox.utils import fuse_model, get_model_info, postprocess, vis
  7. from yolox.utils.visualize import plot_tracking
  8. from yolox.tracker.byte_tracker import BYTETracker
  9. from yolox.tracking_utils.timer import Timer
  10. import argparse
  11. import os
  12. import time
  13. IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
  14. def make_parser():
  15. parser = argparse.ArgumentParser("ByteTrack Demo!")
  16. parser.add_argument(
  17. "demo", default="image", help="demo type, eg. image, video and webcam"
  18. )
  19. parser.add_argument("-expn", "--experiment-name", type=str, default=None)
  20. parser.add_argument("-n", "--name", type=str, default=None, help="model name")
  21. parser.add_argument(
  22. #"--path", default="./datasets/mot/train/MOT17-05-FRCNN/img1", help="path to images or video"
  23. "--path", default="./videos/palace.mp4", help="path to images or video"
  24. )
  25. parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
  26. parser.add_argument(
  27. "--save_result",
  28. action="store_true",
  29. help="whether to save the inference result of image/video",
  30. )
  31. # exp file
  32. parser.add_argument(
  33. "-f",
  34. "--exp_file",
  35. default=None,
  36. type=str,
  37. help="pls input your expriment description file",
  38. )
  39. parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
  40. parser.add_argument(
  41. "--device",
  42. default="gpu",
  43. type=str,
  44. help="device to run our model, can either be cpu or gpu",
  45. )
  46. parser.add_argument("--conf", default=None, type=float, help="test conf")
  47. parser.add_argument("--nms", default=None, type=float, help="test nms threshold")
  48. parser.add_argument("--tsize", default=None, type=int, help="test img size")
  49. parser.add_argument(
  50. "--fp16",
  51. dest="fp16",
  52. default=False,
  53. action="store_true",
  54. help="Adopting mix precision evaluating.",
  55. )
  56. parser.add_argument(
  57. "--fuse",
  58. dest="fuse",
  59. default=False,
  60. action="store_true",
  61. help="Fuse conv and bn for testing.",
  62. )
  63. parser.add_argument(
  64. "--trt",
  65. dest="trt",
  66. default=False,
  67. action="store_true",
  68. help="Using TensorRT model for testing.",
  69. )
  70. # tracking args
  71. parser.add_argument("--track_thresh", type=float, default=0.5, help="tracking confidence threshold")
  72. parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
  73. parser.add_argument("--match_thresh", type=float, default=0.8, help="matching threshold for tracking")
  74. parser.add_argument('--min-box-area', type=float, default=10, help='filter out tiny boxes')
  75. parser.add_argument("--mot20", dest="mot20", default=False, action="store_true", help="test mot20.")
  76. return parser
  77. def get_image_list(path):
  78. image_names = []
  79. for maindir, subdir, file_name_list in os.walk(path):
  80. for filename in file_name_list:
  81. apath = os.path.join(maindir, filename)
  82. ext = os.path.splitext(apath)[1]
  83. if ext in IMAGE_EXT:
  84. image_names.append(apath)
  85. return image_names
  86. def write_results(filename, results):
  87. save_format = '{frame},{id},{x1},{y1},{w},{h},{s},-1,-1,-1\n'
  88. with open(filename, 'w') as f:
  89. for frame_id, tlwhs, track_ids, scores in results:
  90. for tlwh, track_id, score in zip(tlwhs, track_ids, scores):
  91. if track_id < 0:
  92. continue
  93. x1, y1, w, h = tlwh
  94. line = save_format.format(frame=frame_id, id=track_id, x1=round(x1, 1), y1=round(y1, 1), w=round(w, 1), h=round(h, 1), s=round(score, 2))
  95. f.write(line)
  96. logger.info('save results to {}'.format(filename))
  97. class Predictor(object):
  98. def __init__(
  99. self,
  100. model,
  101. exp,
  102. trt_file=None,
  103. decoder=None,
  104. device="cpu",
  105. fp16=False
  106. ):
  107. self.model = model
  108. self.decoder = decoder
  109. self.num_classes = exp.num_classes
  110. self.confthre = exp.test_conf
  111. self.nmsthre = exp.nmsthre
  112. self.test_size = exp.test_size
  113. self.device = device
  114. self.fp16 = fp16
  115. if trt_file is not None:
  116. from torch2trt import TRTModule
  117. model_trt = TRTModule()
  118. model_trt.load_state_dict(torch.load(trt_file))
  119. x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
  120. self.model(x)
  121. self.model = model_trt
  122. self.rgb_means = (0.485, 0.456, 0.406)
  123. self.std = (0.229, 0.224, 0.225)
  124. def inference(self, img, timer):
  125. img_info = {"id": 0}
  126. if isinstance(img, str):
  127. img_info["file_name"] = os.path.basename(img)
  128. img = cv2.imread(img)
  129. else:
  130. img_info["file_name"] = None
  131. height, width = img.shape[:2]
  132. img_info["height"] = height
  133. img_info["width"] = width
  134. img_info["raw_img"] = img
  135. img, ratio = preproc(img, self.test_size, self.rgb_means, self.std)
  136. img_info["ratio"] = ratio
  137. img = torch.from_numpy(img).unsqueeze(0)
  138. img = img.float()
  139. if self.device == "gpu":
  140. img = img.cuda()
  141. if self.fp16:
  142. img = img.half() # to FP16
  143. with torch.no_grad():
  144. timer.tic()
  145. outputs = self.model(img)
  146. if self.decoder is not None:
  147. outputs = self.decoder(outputs, dtype=outputs.type())
  148. outputs = postprocess(
  149. outputs, self.num_classes, self.confthre, self.nmsthre
  150. )
  151. #logger.info("Infer time: {:.4f}s".format(time.time() - t0))
  152. return outputs, img_info
  153. def image_demo(predictor, vis_folder, path, current_time, save_result):
  154. if os.path.isdir(path):
  155. files = get_image_list(path)
  156. else:
  157. files = [path]
  158. files.sort()
  159. tracker = BYTETracker(args, frame_rate=30)
  160. timer = Timer()
  161. frame_id = 0
  162. results = []
  163. for image_name in files:
  164. if frame_id % 20 == 0:
  165. logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
  166. outputs, img_info = predictor.inference(image_name, timer)
  167. if outputs[0] is not None:
  168. online_targets = tracker.update(outputs[0], [img_info['height'], img_info['width']], exp.test_size)
  169. online_tlwhs = []
  170. online_ids = []
  171. online_scores = []
  172. for t in online_targets:
  173. tlwh = t.tlwh
  174. tid = t.track_id
  175. vertical = tlwh[2] / tlwh[3] > 1.6
  176. if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
  177. online_tlwhs.append(tlwh)
  178. online_ids.append(tid)
  179. online_scores.append(t.score)
  180. # save results
  181. results.append((frame_id + 1, online_tlwhs, online_ids, online_scores))
  182. timer.toc()
  183. online_im = plot_tracking(img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1,
  184. fps=1. / timer.average_time)
  185. else:
  186. timer.toc()
  187. online_im = img_info['raw_img']
  188. #result_image = predictor.visual(outputs[0], img_info, predictor.confthre)
  189. if save_result:
  190. save_folder = os.path.join(
  191. vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
  192. )
  193. os.makedirs(save_folder, exist_ok=True)
  194. save_file_name = os.path.join(save_folder, os.path.basename(image_name))
  195. cv2.imwrite(save_file_name, online_im)
  196. ch = cv2.waitKey(0)
  197. frame_id += 1
  198. if ch == 27 or ch == ord("q") or ch == ord("Q"):
  199. break
  200. #write_results(result_filename, results)
  201. def imageflow_demo(predictor, vis_folder, current_time, args):
  202. cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)
  203. width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
  204. height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
  205. fps = cap.get(cv2.CAP_PROP_FPS)
  206. save_folder = os.path.join(
  207. vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
  208. )
  209. os.makedirs(save_folder, exist_ok=True)
  210. if args.demo == "video":
  211. save_path = os.path.join(save_folder, args.path.split("/")[-1])
  212. else:
  213. save_path = os.path.join(save_folder, "camera.mp4")
  214. logger.info(f"video save_path is {save_path}")
  215. vid_writer = cv2.VideoWriter(
  216. save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
  217. )
  218. tracker = BYTETracker(args, frame_rate=30)
  219. timer = Timer()
  220. frame_id = 0
  221. results = []
  222. while True:
  223. if frame_id % 20 == 0:
  224. logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
  225. ret_val, frame = cap.read()
  226. if ret_val:
  227. outputs, img_info = predictor.inference(frame, timer)
  228. if outputs[0] is not None:
  229. online_targets = tracker.update(outputs[0], [img_info['height'], img_info['width']], exp.test_size)
  230. online_tlwhs = []
  231. online_ids = []
  232. online_scores = []
  233. for t in online_targets:
  234. tlwh = t.tlwh
  235. tid = t.track_id
  236. vertical = tlwh[2] / tlwh[3] > 1.6
  237. if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
  238. online_tlwhs.append(tlwh)
  239. online_ids.append(tid)
  240. online_scores.append(t.score)
  241. results.append((frame_id + 1, online_tlwhs, online_ids, online_scores))
  242. timer.toc()
  243. online_im = plot_tracking(img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1,
  244. fps=1. / timer.average_time)
  245. else:
  246. timer.toc()
  247. online_im = img_info['raw_img']
  248. if args.save_result:
  249. vid_writer.write(online_im)
  250. ch = cv2.waitKey(1)
  251. if ch == 27 or ch == ord("q") or ch == ord("Q"):
  252. break
  253. else:
  254. break
  255. frame_id += 1
  256. def main(exp, args):
  257. if not args.experiment_name:
  258. args.experiment_name = exp.exp_name
  259. file_name = os.path.join(exp.output_dir, args.experiment_name)
  260. os.makedirs(file_name, exist_ok=True)
  261. if args.save_result:
  262. vis_folder = os.path.join(file_name, "track_vis")
  263. os.makedirs(vis_folder, exist_ok=True)
  264. if args.trt:
  265. args.device = "gpu"
  266. logger.info("Args: {}".format(args))
  267. if args.conf is not None:
  268. exp.test_conf = args.conf
  269. if args.nms is not None:
  270. exp.nmsthre = args.nms
  271. if args.tsize is not None:
  272. exp.test_size = (args.tsize, args.tsize)
  273. model = exp.get_model()
  274. logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
  275. if args.device == "gpu":
  276. model.cuda()
  277. model.eval()
  278. if not args.trt:
  279. if args.ckpt is None:
  280. ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar")
  281. else:
  282. ckpt_file = args.ckpt
  283. logger.info("loading checkpoint")
  284. ckpt = torch.load(ckpt_file, map_location="cpu")
  285. # load the model state dict
  286. model.load_state_dict(ckpt["model"])
  287. logger.info("loaded checkpoint done.")
  288. if args.fuse:
  289. logger.info("\tFusing model...")
  290. model = fuse_model(model)
  291. if args.fp16:
  292. model = model.half() # to FP16
  293. if args.trt:
  294. assert not args.fuse, "TensorRT model is not support model fusing!"
  295. trt_file = os.path.join(file_name, "model_trt.pth")
  296. assert os.path.exists(
  297. trt_file
  298. ), "TensorRT model is not found!\n Run python3 tools/trt.py first!"
  299. model.head.decode_in_inference = False
  300. decoder = model.head.decode_outputs
  301. logger.info("Using TensorRT to inference")
  302. else:
  303. trt_file = None
  304. decoder = None
  305. predictor = Predictor(model, exp, trt_file, decoder, args.device, args.fp16)
  306. current_time = time.localtime()
  307. if args.demo == "image":
  308. image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
  309. elif args.demo == "video" or args.demo == "webcam":
  310. imageflow_demo(predictor, vis_folder, current_time, args)
  311. if __name__ == "__main__":
  312. args = make_parser().parse_args()
  313. exp = get_exp(args.exp_file, args.name)
  314. main(exp, args)