Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval_motchallenge.py 5.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. """py-motmetrics - metrics for multiple object tracker (MOT) benchmarking.
  2. Christoph Heindl, 2017
  3. https://github.com/cheind/py-motmetrics
  4. Modified by Rufeng Zhang
  5. """
  6. import argparse
  7. import glob
  8. import os
  9. import logging
  10. import motmetrics as mm
  11. import pandas as pd
  12. from collections import OrderedDict
  13. from pathlib import Path
  14. def parse_args():
  15. parser = argparse.ArgumentParser(description="""
  16. Compute metrics for trackers using MOTChallenge ground-truth data.
  17. Files
  18. -----
  19. All file content, ground truth and test files, have to comply with the
  20. format described in
  21. Milan, Anton, et al.
  22. "Mot16: A benchmark for multi-object tracking."
  23. arXiv preprint arXiv:1603.00831 (2016).
  24. https://motchallenge.net/
  25. Structure
  26. ---------
  27. Layout for ground truth data
  28. <GT_ROOT>/<SEQUENCE_1>/gt/gt.txt
  29. <GT_ROOT>/<SEQUENCE_2>/gt/gt.txt
  30. ...
  31. Layout for test data
  32. <TEST_ROOT>/<SEQUENCE_1>.txt
  33. <TEST_ROOT>/<SEQUENCE_2>.txt
  34. ...
  35. Sequences of ground truth and test will be matched according to the `<SEQUENCE_X>`
  36. string.""", formatter_class=argparse.RawTextHelpFormatter)
  37. parser.add_argument('--groundtruths', type=str, help='Directory containing ground truth files.')
  38. parser.add_argument('--tests', type=str, help='Directory containing tracker result files')
  39. parser.add_argument('--score_threshold', type=float, help='Score threshold',default=0.5)
  40. parser.add_argument('--gt_type', type=str, default='')
  41. parser.add_argument('--eval_official', action='store_true')
  42. parser.add_argument('--loglevel', type=str, help='Log level', default='info')
  43. parser.add_argument('--fmt', type=str, help='Data format', default='mot15-2D')
  44. parser.add_argument('--solver', type=str, help='LAP solver to use')
  45. return parser.parse_args()
  46. def compare_dataframes(gts, ts):
  47. accs = []
  48. names = []
  49. for k, tsacc in ts.items():
  50. if k in gts:
  51. logging.info('Comparing {}...'.format(k))
  52. accs.append(mm.utils.compare_to_groundtruth(gts[k], tsacc, 'iou', distth=0.5))
  53. names.append(k)
  54. else:
  55. logging.warning('No ground truth for {}, skipping.'.format(k))
  56. return accs, names
  57. if __name__ == '__main__':
  58. args = parse_args()
  59. loglevel = getattr(logging, args.loglevel.upper(), None)
  60. if not isinstance(loglevel, int):
  61. raise ValueError('Invalid log level: {} '.format(args.loglevel))
  62. logging.basicConfig(level=loglevel, format='%(asctime)s %(levelname)s - %(message)s', datefmt='%I:%M:%S')
  63. if args.solver:
  64. mm.lap.default_solver = args.solver
  65. gt_type = args.gt_type
  66. print('gt_type', gt_type)
  67. gtfiles = glob.glob(
  68. os.path.join(args.groundtruths, '*/gt/gt_{}.txt'.format(gt_type)))
  69. print('gt_files', gtfiles)
  70. tsfiles = [f for f in glob.glob(os.path.join(args.tests, '*.txt')) if not os.path.basename(f).startswith('eval')]
  71. logging.info('Found {} groundtruths and {} test files.'.format(len(gtfiles), len(tsfiles)))
  72. logging.info('Available LAP solvers {}'.format(mm.lap.available_solvers))
  73. logging.info('Default LAP solver \'{}\''.format(mm.lap.default_solver))
  74. logging.info('Loading files.')
  75. gt = OrderedDict([(Path(f).parts[-3], mm.io.loadtxt(f, fmt=args.fmt, min_confidence=1)) for f in gtfiles])
  76. ts = OrderedDict([(os.path.splitext(Path(f).parts[-1])[0], mm.io.loadtxt(f, fmt=args.fmt, min_confidence=args.score_threshold)) for f in tsfiles])
  77. # ts = gt
  78. mh = mm.metrics.create()
  79. accs, names = compare_dataframes(gt, ts)
  80. logging.info('Running metrics')
  81. metrics = ['recall', 'precision', 'num_unique_objects', 'mostly_tracked',
  82. 'partially_tracked', 'mostly_lost', 'num_false_positives', 'num_misses',
  83. 'num_switches', 'num_fragmentations', 'mota', 'motp', 'num_objects']
  84. summary = mh.compute_many(accs, names=names, metrics=metrics, generate_overall=True)
  85. # summary = mh.compute_many(accs, names=names, metrics=mm.metrics.motchallenge_metrics, generate_overall=True)
  86. # print(mm.io.render_summary(
  87. # summary, formatters=mh.formatters,
  88. # namemap=mm.io.motchallenge_metric_names))
  89. div_dict = {
  90. 'num_objects': ['num_false_positives', 'num_misses', 'num_switches', 'num_fragmentations'],
  91. 'num_unique_objects': ['mostly_tracked', 'partially_tracked', 'mostly_lost']}
  92. for divisor in div_dict:
  93. for divided in div_dict[divisor]:
  94. summary[divided] = (summary[divided] / summary[divisor])
  95. fmt = mh.formatters
  96. change_fmt_list = ['num_false_positives', 'num_misses', 'num_switches', 'num_fragmentations', 'mostly_tracked',
  97. 'partially_tracked', 'mostly_lost']
  98. for k in change_fmt_list:
  99. fmt[k] = fmt['mota']
  100. print(mm.io.render_summary(summary, formatters=fmt, namemap=mm.io.motchallenge_metric_names))
  101. if args.eval_official:
  102. metrics = mm.metrics.motchallenge_metrics + ['num_objects']
  103. summary = mh.compute_many(accs, names=names, metrics=metrics, generate_overall=True)
  104. print(mm.io.render_summary(summary, formatters=mh.formatters, namemap=mm.io.motchallenge_metric_names))
  105. logging.info('Completed')