Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

track.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. from loguru import logger
  2. import torch
  3. import torch.backends.cudnn as cudnn
  4. from torch.nn.parallel import DistributedDataParallel as DDP
  5. from yolox.core import launch
  6. from yolox.exp import get_exp
  7. from yolox.utils import configure_nccl, fuse_model, get_local_rank, get_model_info, setup_logger
  8. from yolox.evaluators import MOTEvaluator
  9. import argparse
  10. import os
  11. import random
  12. import warnings
  13. import glob
  14. import motmetrics as mm
  15. from collections import OrderedDict
  16. from pathlib import Path
  17. import learn2learn as l2l
  18. import yolox.statics as statics
  19. def make_parser():
  20. parser = argparse.ArgumentParser("YOLOX Eval")
  21. parser.add_argument("-t", "--task", type=str, default="metamot")
  22. parser.add_argument("-expn", "--experiment-name", type=str, default=None)
  23. parser.add_argument("-n", "--name", type=str, default=None, help="model name")
  24. parser.add_argument(
  25. "--adaptation_period", default=4, type=int, help="if 4, then adapts to one batch in four batches"
  26. )
  27. # distributed
  28. parser.add_argument(
  29. "--dist-backend", default="nccl", type=str, help="distributed backend"
  30. )
  31. parser.add_argument(
  32. "--dist-url",
  33. default=None,
  34. type=str,
  35. help="url used to set up distributed training",
  36. )
  37. parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size")
  38. parser.add_argument(
  39. "-d", "--devices", default=None, type=int, help="device for training"
  40. )
  41. parser.add_argument(
  42. "--local_rank", default=0, type=int, help="local rank for dist training"
  43. )
  44. parser.add_argument(
  45. "--num_machines", default=1, type=int, help="num of node for training"
  46. )
  47. parser.add_argument(
  48. "--machine_rank", default=0, type=int, help="node rank for multi-node training"
  49. )
  50. parser.add_argument(
  51. "-f",
  52. "--exp_file",
  53. default=None,
  54. type=str,
  55. help="pls input your expriment description file",
  56. )
  57. parser.add_argument(
  58. "--fp16",
  59. dest="fp16",
  60. default=False,
  61. action="store_true",
  62. help="Adopting mix precision evaluating.",
  63. )
  64. parser.add_argument(
  65. "--fuse",
  66. dest="fuse",
  67. default=False,
  68. action="store_true",
  69. help="Fuse conv and bn for testing.",
  70. )
  71. parser.add_argument(
  72. "--trt",
  73. dest="trt",
  74. default=False,
  75. action="store_true",
  76. help="Using TensorRT model for testing.",
  77. )
  78. parser.add_argument(
  79. "--test",
  80. dest="test",
  81. default=False,
  82. action="store_true",
  83. help="Evaluating on test-dev set.",
  84. )
  85. parser.add_argument(
  86. "--speed",
  87. dest="speed",
  88. default=False,
  89. action="store_true",
  90. help="speed test only.",
  91. )
  92. parser.add_argument(
  93. "opts",
  94. help="Modify config options using the command-line",
  95. default=None,
  96. nargs=argparse.REMAINDER,
  97. )
  98. # det args
  99. parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
  100. parser.add_argument("--conf", default=0.01, type=float, help="test conf")
  101. parser.add_argument("--nms", default=0.7, type=float, help="test nms threshold")
  102. parser.add_argument("--tsize", default=None, type=int, help="test img size")
  103. parser.add_argument("--seed", default=None, type=int, help="eval seed")
  104. # tracking args
  105. parser.add_argument("--track_thresh", type=float, default=0.6, help="tracking confidence threshold")
  106. parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
  107. parser.add_argument("--match_thresh", type=float, default=0.9, help="matching threshold for tracking")
  108. parser.add_argument("--min-box-area", type=float, default=100, help='filter out tiny boxes')
  109. parser.add_argument("--mot20", dest="mot20", default=False, action="store_true", help="test mot20.")
  110. parser.add_argument("--use_existing_files", default=False, action="store_true", help="to use already created files")
  111. return parser
  112. def compare_dataframes(gts, ts):
  113. accs = []
  114. names = []
  115. for k, tsacc in ts.items():
  116. if k in gts:
  117. logger.info('Comparing {}...'.format(k))
  118. accs.append(mm.utils.compare_to_groundtruth(gts[k], tsacc, 'iou', distth=0.5))
  119. names.append(k)
  120. else:
  121. logger.warning('No ground truth for {}, skipping.'.format(k))
  122. return accs, names
  123. def process_loader(args, exp, val_loader, model, is_distributed, trt_file, decoder):
  124. file_name = os.path.join(exp.output_dir, args.experiment_name)
  125. rank = args.local_rank
  126. if rank == 0:
  127. os.makedirs(file_name, exist_ok=True)
  128. results_folder = os.path.join(file_name, "track_results")
  129. os.makedirs(results_folder, exist_ok=True)
  130. adaptation_period = None
  131. if args.task == 'metamot':
  132. adaptation_period = args.adaptation_period
  133. evaluator = MOTEvaluator(
  134. args=args,
  135. dataloader=val_loader,
  136. img_size=exp.test_size,
  137. confthre=exp.test_conf,
  138. nmsthre=exp.nmsthre,
  139. num_classes=exp.num_classes,
  140. )
  141. # start evaluate
  142. *_, summary = evaluator.evaluate(
  143. model, is_distributed, args.fp16, trt_file, decoder, exp.test_size, results_folder,
  144. adaptation_period=adaptation_period, eval_det=False
  145. )
  146. logger.info("\n" + summary)
  147. def eval_MOT(args, exp, val_ann=None):
  148. file_name = os.path.join(exp.output_dir, args.experiment_name)
  149. rank = args.local_rank
  150. if rank == 0:
  151. os.makedirs(file_name, exist_ok=True)
  152. results_folder = os.path.join(file_name, "track_results")
  153. os.makedirs(results_folder, exist_ok=True)
  154. # evaluate MOTA
  155. mm.lap.default_solver = 'lap'
  156. if val_ann == 'val_half.json':
  157. gt_type = '_val_half'
  158. else:
  159. gt_type = ''
  160. print('gt_type', gt_type)
  161. if args.mot20:
  162. gtfiles = glob.glob(os.path.join(statics.DATA_PATH, 'MOT20/train', '*/gt/gt{}.txt'.format(gt_type)))
  163. else:
  164. gtfiles = glob.glob(os.path.join(statics.DATA_PATH, 'MOT17/train', '*/gt/gt{}.txt'.format(gt_type)))
  165. print('gt_files', gtfiles)
  166. tsfiles = [f for f in glob.glob(os.path.join(results_folder, '*.txt')) if
  167. not os.path.basename(f).startswith('eval')]
  168. logger.info('Found {} groundtruths and {} test files.'.format(len(gtfiles), len(tsfiles)))
  169. logger.info('Available LAP solvers {}'.format(mm.lap.available_solvers))
  170. logger.info('Default LAP solver \'{}\''.format(mm.lap.default_solver))
  171. logger.info('Loading files.')
  172. gt = OrderedDict([(Path(f).parts[-3], mm.io.loadtxt(f, fmt='mot15-2D', min_confidence=1)) for f in gtfiles])
  173. ts = OrderedDict(
  174. [(os.path.splitext(Path(f).parts[-1])[0], mm.io.loadtxt(f, fmt='mot15-2D', min_confidence=-1)) for f in
  175. tsfiles])
  176. mh = mm.metrics.create()
  177. accs, names = compare_dataframes(gt, ts)
  178. logger.info('Running metrics')
  179. metrics = ['recall', 'precision', 'num_unique_objects', 'mostly_tracked',
  180. 'partially_tracked', 'mostly_lost', 'num_false_positives', 'num_misses',
  181. 'num_switches', 'num_fragmentations', 'mota', 'motp', 'num_objects']
  182. summary = mh.compute_many(accs, names=names, metrics=metrics, generate_overall=True)
  183. # summary = mh.compute_many(accs, names=names, metrics=mm.metrics.motchallenge_metrics, generate_overall=True)
  184. # print(mm.io.render_summary(
  185. # summary, formatters=mh.formatters,
  186. # namemap=mm.io.motchallenge_metric_names))
  187. div_dict = {
  188. 'num_objects': ['num_false_positives', 'num_misses', 'num_switches', 'num_fragmentations'],
  189. 'num_unique_objects': ['mostly_tracked', 'partially_tracked', 'mostly_lost']}
  190. for divisor in div_dict:
  191. for divided in div_dict[divisor]:
  192. summary[divided] = (summary[divided] / summary[divisor])
  193. fmt = mh.formatters
  194. change_fmt_list = ['num_false_positives', 'num_misses', 'num_switches', 'num_fragmentations', 'mostly_tracked',
  195. 'partially_tracked', 'mostly_lost']
  196. for k in change_fmt_list:
  197. fmt[k] = fmt['mota']
  198. print(mm.io.render_summary(summary, formatters=fmt, namemap=mm.io.motchallenge_metric_names))
  199. metrics = mm.metrics.motchallenge_metrics + ['num_objects']
  200. summary = mh.compute_many(accs, names=names, metrics=metrics, generate_overall=True)
  201. print(mm.io.render_summary(summary, formatters=mh.formatters, namemap=mm.io.motchallenge_metric_names))
  202. logger.info('Completed')
  203. def load_model(args, exp, is_distributed):
  204. model = exp.get_model()
  205. logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
  206. if args.seed is not None:
  207. random.seed(args.seed)
  208. torch.manual_seed(args.seed)
  209. cudnn.deterministic = True
  210. warnings.warn(
  211. "You have chosen to seed testing. This will turn on the CUDNN deterministic setting, "
  212. )
  213. # set environment variables for distributed training
  214. cudnn.benchmark = True
  215. rank = args.local_rank
  216. # rank = get_local_rank()
  217. file_name = os.path.join(exp.output_dir, args.experiment_name)
  218. if rank == 0:
  219. os.makedirs(file_name, exist_ok=True)
  220. setup_logger(file_name, distributed_rank=rank, filename="val_log.txt", mode="a")
  221. logger.info("Args: {}".format(args))
  222. if args.conf is not None:
  223. exp.test_conf = args.conf
  224. if args.nms is not None:
  225. exp.nmsthre = args.nms
  226. if args.tsize is not None:
  227. exp.test_size = (args.tsize, args.tsize)
  228. if args.task == "metamot":
  229. model = l2l.algorithms.MAML(model, lr=exp.inner_lr, first_order=exp.first_order)
  230. torch.cuda.set_device(rank)
  231. model.cuda(rank)
  232. model.eval()
  233. if not args.speed and not args.trt:
  234. if args.ckpt is None:
  235. ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar")
  236. else:
  237. ckpt_file = args.ckpt
  238. logger.info("loading checkpoint")
  239. loc = "cuda:{}".format(rank)
  240. ckpt = torch.load(ckpt_file, map_location=loc)
  241. # handling meta models
  242. new_dict = {}
  243. if (not list(ckpt["model"].keys())[0].startswith('module')) and args.task == "metamot":
  244. for key in ckpt["model"].keys():
  245. if not key.startswith('module.'):
  246. new_dict['module.' + key] = ckpt["model"][key]
  247. else:
  248. new_dict[key] = ckpt["model"][key]
  249. del ckpt["model"]
  250. ckpt["model"] = new_dict
  251. # load the model state dict
  252. model.load_state_dict(ckpt["model"])
  253. logger.info("loaded checkpoint done.")
  254. if is_distributed:
  255. model = DDP(model, device_ids=[rank])
  256. if args.fuse:
  257. logger.info("\tFusing model...")
  258. model = fuse_model(model)
  259. if args.trt:
  260. assert (
  261. not args.fuse and not is_distributed and args.batch_size == 1
  262. ), "TensorRT model is not support model fusing and distributed inferencing!"
  263. trt_file = os.path.join(file_name, "model_trt.pth")
  264. assert os.path.exists(
  265. trt_file
  266. ), "TensorRT model is not found!\n Run tools/trt.py first!"
  267. model.head.decode_in_inference = False
  268. decoder = model.head.decode_outputs
  269. else:
  270. trt_file = None
  271. decoder = None
  272. return model, trt_file, decoder
  273. @logger.catch
  274. def main(exp, args, num_gpu):
  275. is_distributed = num_gpu > 1
  276. print('is_distributed', is_distributed)
  277. print('num_gpu', num_gpu)
  278. # logger.info("Model Structure:\n{}".format(str(model)))
  279. model, trt_file, decoder = load_model(args, exp, is_distributed)
  280. if args.task == 'metamot':
  281. val_loaders = exp.get_eval_loaders(args.batch_size, is_distributed, args.test)
  282. if not args.use_existing_files:
  283. for val_loader in val_loaders:
  284. logger.info('processing loader...')
  285. process_loader(args, exp, val_loader, model, is_distributed, trt_file, decoder)
  286. eval_MOT(args, exp)
  287. else:
  288. if not args.use_existing_files:
  289. val_loader = exp.get_eval_loader(args.batch_size, is_distributed, args.test)
  290. process_loader(args, exp, val_loader, model, is_distributed, trt_file, decoder)
  291. eval_MOT(args, exp, exp.val_ann)
  292. if __name__ == "__main__":
  293. args = make_parser().parse_args()
  294. exp = get_exp(args.exp_file, args.name)
  295. exp.merge(args.opts)
  296. if not args.experiment_name:
  297. args.experiment_name = exp.exp_name
  298. num_gpu = torch.cuda.device_count() if args.devices is None else args.devices
  299. assert num_gpu <= torch.cuda.device_count()
  300. launch(
  301. main,
  302. num_gpu,
  303. args.num_machines,
  304. args.machine_rank,
  305. backend=args.dist_backend,
  306. dist_url=args.dist_url,
  307. args=(exp, args, num_gpu),
  308. )