Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

track.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. from loguru import logger
  2. import torch
  3. import torch.backends.cudnn as cudnn
  4. from torch.nn.parallel import DistributedDataParallel as DDP
  5. from yolox.core import launch
  6. from yolox.exp import get_exp
  7. from yolox.utils import configure_nccl, fuse_model, get_local_rank, get_model_info, setup_logger
  8. from yolox.evaluators import MOTEvaluator
  9. import argparse
  10. import os
  11. import random
  12. import warnings
  13. import glob
  14. import motmetrics as mm
  15. from collections import OrderedDict
  16. from pathlib import Path
  17. def make_parser():
  18. parser = argparse.ArgumentParser("YOLOX Eval")
  19. parser.add_argument("-t", "--task", type=str, default="metamot")
  20. parser.add_argument("-expn", "--experiment-name", type=str, default=None)
  21. parser.add_argument("-n", "--name", type=str, default=None, help="model name")
  22. # distributed
  23. parser.add_argument(
  24. "--dist-backend", default="nccl", type=str, help="distributed backend"
  25. )
  26. parser.add_argument(
  27. "--dist-url",
  28. default=None,
  29. type=str,
  30. help="url used to set up distributed training",
  31. )
  32. parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size")
  33. parser.add_argument(
  34. "-d", "--devices", default=None, type=int, help="device for training"
  35. )
  36. parser.add_argument(
  37. "--local_rank", default=0, type=int, help="local rank for dist training"
  38. )
  39. parser.add_argument(
  40. "--num_machines", default=1, type=int, help="num of node for training"
  41. )
  42. parser.add_argument(
  43. "--machine_rank", default=0, type=int, help="node rank for multi-node training"
  44. )
  45. parser.add_argument(
  46. "-f",
  47. "--exp_file",
  48. default=None,
  49. type=str,
  50. help="pls input your expriment description file",
  51. )
  52. parser.add_argument(
  53. "--fp16",
  54. dest="fp16",
  55. default=False,
  56. action="store_true",
  57. help="Adopting mix precision evaluating.",
  58. )
  59. parser.add_argument(
  60. "--fuse",
  61. dest="fuse",
  62. default=False,
  63. action="store_true",
  64. help="Fuse conv and bn for testing.",
  65. )
  66. parser.add_argument(
  67. "--trt",
  68. dest="trt",
  69. default=False,
  70. action="store_true",
  71. help="Using TensorRT model for testing.",
  72. )
  73. parser.add_argument(
  74. "--test",
  75. dest="test",
  76. default=False,
  77. action="store_true",
  78. help="Evaluating on test-dev set.",
  79. )
  80. parser.add_argument(
  81. "--speed",
  82. dest="speed",
  83. default=False,
  84. action="store_true",
  85. help="speed test only.",
  86. )
  87. parser.add_argument(
  88. "opts",
  89. help="Modify config options using the command-line",
  90. default=None,
  91. nargs=argparse.REMAINDER,
  92. )
  93. # det args
  94. parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
  95. parser.add_argument("--conf", default=0.01, type=float, help="test conf")
  96. parser.add_argument("--nms", default=0.7, type=float, help="test nms threshold")
  97. parser.add_argument("--tsize", default=None, type=int, help="test img size")
  98. parser.add_argument("--seed", default=None, type=int, help="eval seed")
  99. # tracking args
  100. parser.add_argument("--track_thresh", type=float, default=0.6, help="tracking confidence threshold")
  101. parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
  102. parser.add_argument("--match_thresh", type=float, default=0.9, help="matching threshold for tracking")
  103. parser.add_argument("--min-box-area", type=float, default=100, help='filter out tiny boxes')
  104. parser.add_argument("--mot20", dest="mot20", default=False, action="store_true", help="test mot20.")
  105. return parser
  106. def compare_dataframes(gts, ts):
  107. accs = []
  108. names = []
  109. for k, tsacc in ts.items():
  110. if k in gts:
  111. logger.info('Comparing {}...'.format(k))
  112. accs.append(mm.utils.compare_to_groundtruth(gts[k], tsacc, 'iou', distth=0.5))
  113. names.append(k)
  114. else:
  115. logger.warning('No ground truth for {}, skipping.'.format(k))
  116. return accs, names
  117. def process_loader(args, val_loader, model, is_distributed):
  118. if args.seed is not None:
  119. random.seed(args.seed)
  120. torch.manual_seed(args.seed)
  121. cudnn.deterministic = True
  122. warnings.warn(
  123. "You have chosen to seed testing. This will turn on the CUDNN deterministic setting, "
  124. )
  125. # set environment variables for distributed training
  126. cudnn.benchmark = True
  127. rank = args.local_rank
  128. # rank = get_local_rank()
  129. file_name = os.path.join(exp.output_dir, args.experiment_name)
  130. if rank == 0:
  131. os.makedirs(file_name, exist_ok=True)
  132. results_folder = os.path.join(file_name, "track_results")
  133. os.makedirs(results_folder, exist_ok=True)
  134. setup_logger(file_name, distributed_rank=rank, filename="val_log.txt", mode="a")
  135. logger.info("Args: {}".format(args))
  136. if args.conf is not None:
  137. exp.test_conf = args.conf
  138. if args.nms is not None:
  139. exp.nmsthre = args.nms
  140. if args.tsize is not None:
  141. exp.test_size = (args.tsize, args.tsize)
  142. evaluator = MOTEvaluator(
  143. args=args,
  144. dataloader=val_loader,
  145. img_size=exp.test_size,
  146. confthre=exp.test_conf,
  147. nmsthre=exp.nmsthre,
  148. num_classes=exp.num_classes,
  149. )
  150. torch.cuda.set_device(rank)
  151. model.cuda(rank)
  152. model.eval()
  153. if not args.speed and not args.trt:
  154. if args.ckpt is None:
  155. ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar")
  156. else:
  157. ckpt_file = args.ckpt
  158. logger.info("loading checkpoint")
  159. loc = "cuda:{}".format(rank)
  160. ckpt = torch.load(ckpt_file, map_location=loc)
  161. # load the model state dict
  162. model.load_state_dict(ckpt["model"])
  163. logger.info("loaded checkpoint done.")
  164. if is_distributed:
  165. model = DDP(model, device_ids=[rank])
  166. if args.fuse:
  167. logger.info("\tFusing model...")
  168. model = fuse_model(model)
  169. if args.trt:
  170. assert (
  171. not args.fuse and not is_distributed and args.batch_size == 1
  172. ), "TensorRT model is not support model fusing and distributed inferencing!"
  173. trt_file = os.path.join(file_name, "model_trt.pth")
  174. assert os.path.exists(
  175. trt_file
  176. ), "TensorRT model is not found!\n Run tools/trt.py first!"
  177. model.head.decode_in_inference = False
  178. decoder = model.head.decode_outputs
  179. else:
  180. trt_file = None
  181. decoder = None
  182. # start evaluate
  183. *_, summary = evaluator.evaluate(
  184. model, is_distributed, args.fp16, trt_file, decoder, exp.test_size, results_folder
  185. )
  186. logger.info("\n" + summary)
  187. # evaluate MOTA
  188. mm.lap.default_solver = 'lap'
  189. if exp.val_ann == 'val_half.json':
  190. gt_type = '_val_half'
  191. else:
  192. gt_type = ''
  193. print('gt_type', gt_type)
  194. if args.mot20:
  195. gtfiles = glob.glob(os.path.join('datasets/MOT20/train', '*/gt/gt{}.txt'.format(gt_type)))
  196. else:
  197. gtfiles = glob.glob(os.path.join('datasets/mot/train', '*/gt/gt{}.txt'.format(gt_type)))
  198. print('gt_files', gtfiles)
  199. tsfiles = [f for f in glob.glob(os.path.join(results_folder, '*.txt')) if
  200. not os.path.basename(f).startswith('eval')]
  201. logger.info('Found {} groundtruths and {} test files.'.format(len(gtfiles), len(tsfiles)))
  202. logger.info('Available LAP solvers {}'.format(mm.lap.available_solvers))
  203. logger.info('Default LAP solver \'{}\''.format(mm.lap.default_solver))
  204. logger.info('Loading files.')
  205. gt = OrderedDict([(Path(f).parts[-3], mm.io.loadtxt(f, fmt='mot15-2D', min_confidence=1)) for f in gtfiles])
  206. ts = OrderedDict(
  207. [(os.path.splitext(Path(f).parts[-1])[0], mm.io.loadtxt(f, fmt='mot15-2D', min_confidence=-1)) for f in
  208. tsfiles])
  209. mh = mm.metrics.create()
  210. accs, names = compare_dataframes(gt, ts)
  211. logger.info('Running metrics')
  212. metrics = ['recall', 'precision', 'num_unique_objects', 'mostly_tracked',
  213. 'partially_tracked', 'mostly_lost', 'num_false_positives', 'num_misses',
  214. 'num_switches', 'num_fragmentations', 'mota', 'motp', 'num_objects']
  215. summary = mh.compute_many(accs, names=names, metrics=metrics, generate_overall=True)
  216. # summary = mh.compute_many(accs, names=names, metrics=mm.metrics.motchallenge_metrics, generate_overall=True)
  217. # print(mm.io.render_summary(
  218. # summary, formatters=mh.formatters,
  219. # namemap=mm.io.motchallenge_metric_names))
  220. div_dict = {
  221. 'num_objects': ['num_false_positives', 'num_misses', 'num_switches', 'num_fragmentations'],
  222. 'num_unique_objects': ['mostly_tracked', 'partially_tracked', 'mostly_lost']}
  223. for divisor in div_dict:
  224. for divided in div_dict[divisor]:
  225. summary[divided] = (summary[divided] / summary[divisor])
  226. fmt = mh.formatters
  227. change_fmt_list = ['num_false_positives', 'num_misses', 'num_switches', 'num_fragmentations', 'mostly_tracked',
  228. 'partially_tracked', 'mostly_lost']
  229. for k in change_fmt_list:
  230. fmt[k] = fmt['mota']
  231. print(mm.io.render_summary(summary, formatters=fmt, namemap=mm.io.motchallenge_metric_names))
  232. metrics = mm.metrics.motchallenge_metrics + ['num_objects']
  233. summary = mh.compute_many(accs, names=names, metrics=metrics, generate_overall=True)
  234. print(mm.io.render_summary(summary, formatters=mh.formatters, namemap=mm.io.motchallenge_metric_names))
  235. logger.info('Completed')
  236. @logger.catch
  237. def main(exp, args, num_gpu):
  238. is_distributed = num_gpu > 1
  239. print('is_distributed', is_distributed)
  240. print('num_gpu', num_gpu)
  241. model = exp.get_model()
  242. logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
  243. # logger.info("Model Structure:\n{}".format(str(model)))
  244. if args.task == 'metamot':
  245. val_loaders = exp.get_eval_loaders(args.batch_size, is_distributed, args.test)
  246. for val_loader in val_loaders:
  247. learner = model.clone()
  248. process_loader(args, val_loader, learner, is_distributed)
  249. else:
  250. val_loader = exp.get_eval_loader(args.batch_size, is_distributed, args.test)
  251. process_loader(args, val_loader, model, is_distributed)
  252. if __name__ == "__main__":
  253. args = make_parser().parse_args()
  254. exp = get_exp(args.exp_file, args.name)
  255. exp.merge(args.opts)
  256. if not args.experiment_name:
  257. args.experiment_name = exp.exp_name
  258. num_gpu = torch.cuda.device_count() if args.devices is None else args.devices
  259. assert num_gpu <= torch.cuda.device_count()
  260. launch(
  261. main,
  262. num_gpu,
  263. args.num_machines,
  264. args.machine_rank,
  265. backend=args.dist_backend,
  266. dist_url=args.dist_url,
  267. args=(exp, args, num_gpu),
  268. )