Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluation.py 4.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import os
  2. import numpy as np
  3. import copy
  4. import motmetrics as mm
  5. mm.lap.default_solver = 'lap'
  6. from yolox.tracking_utils.io import read_results, unzip_objs
  7. class Evaluator(object):
  8. def __init__(self, data_root, seq_name, data_type):
  9. self.data_root = data_root
  10. self.seq_name = seq_name
  11. self.data_type = data_type
  12. self.load_annotations()
  13. self.reset_accumulator()
  14. def load_annotations(self):
  15. assert self.data_type == 'mot'
  16. gt_filename = os.path.join(self.data_root, self.seq_name, 'gt', 'gt.txt')
  17. self.gt_frame_dict = read_results(gt_filename, self.data_type, is_gt=True)
  18. self.gt_ignore_frame_dict = read_results(gt_filename, self.data_type, is_ignore=True)
  19. def reset_accumulator(self):
  20. self.acc = mm.MOTAccumulator(auto_id=True)
  21. def eval_frame(self, frame_id, trk_tlwhs, trk_ids, rtn_events=False):
  22. # results
  23. trk_tlwhs = np.copy(trk_tlwhs)
  24. trk_ids = np.copy(trk_ids)
  25. # gts
  26. gt_objs = self.gt_frame_dict.get(frame_id, [])
  27. gt_tlwhs, gt_ids = unzip_objs(gt_objs)[:2]
  28. # ignore boxes
  29. ignore_objs = self.gt_ignore_frame_dict.get(frame_id, [])
  30. ignore_tlwhs = unzip_objs(ignore_objs)[0]
  31. # remove ignored results
  32. keep = np.ones(len(trk_tlwhs), dtype=bool)
  33. iou_distance = mm.distances.iou_matrix(ignore_tlwhs, trk_tlwhs, max_iou=0.5)
  34. if len(iou_distance) > 0:
  35. match_is, match_js = mm.lap.linear_sum_assignment(iou_distance)
  36. match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js])
  37. match_ious = iou_distance[match_is, match_js]
  38. match_js = np.asarray(match_js, dtype=int)
  39. match_js = match_js[np.logical_not(np.isnan(match_ious))]
  40. keep[match_js] = False
  41. trk_tlwhs = trk_tlwhs[keep]
  42. trk_ids = trk_ids[keep]
  43. #match_is, match_js = mm.lap.linear_sum_assignment(iou_distance)
  44. #match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js])
  45. #match_ious = iou_distance[match_is, match_js]
  46. #match_js = np.asarray(match_js, dtype=int)
  47. #match_js = match_js[np.logical_not(np.isnan(match_ious))]
  48. #keep[match_js] = False
  49. #trk_tlwhs = trk_tlwhs[keep]
  50. #trk_ids = trk_ids[keep]
  51. # get distance matrix
  52. iou_distance = mm.distances.iou_matrix(gt_tlwhs, trk_tlwhs, max_iou=0.5)
  53. # acc
  54. self.acc.update(gt_ids, trk_ids, iou_distance)
  55. if rtn_events and iou_distance.size > 0 and hasattr(self.acc, 'last_mot_events'):
  56. events = self.acc.last_mot_events # only supported by https://github.com/longcw/py-motmetrics
  57. else:
  58. events = None
  59. return events
  60. def eval_file(self, filename):
  61. self.reset_accumulator()
  62. result_frame_dict = read_results(filename, self.data_type, is_gt=False)
  63. #frames = sorted(list(set(self.gt_frame_dict.keys()) | set(result_frame_dict.keys())))
  64. frames = sorted(list(set(result_frame_dict.keys())))
  65. for frame_id in frames:
  66. trk_objs = result_frame_dict.get(frame_id, [])
  67. trk_tlwhs, trk_ids = unzip_objs(trk_objs)[:2]
  68. self.eval_frame(frame_id, trk_tlwhs, trk_ids, rtn_events=False)
  69. return self.acc
  70. @staticmethod
  71. def get_summary(accs, names, metrics=('mota', 'num_switches', 'idp', 'idr', 'idf1', 'precision', 'recall')):
  72. names = copy.deepcopy(names)
  73. if metrics is None:
  74. metrics = mm.metrics.motchallenge_metrics
  75. metrics = copy.deepcopy(metrics)
  76. mh = mm.metrics.create()
  77. summary = mh.compute_many(
  78. accs,
  79. metrics=metrics,
  80. names=names,
  81. generate_overall=True
  82. )
  83. return summary
  84. @staticmethod
  85. def save_summary(summary, filename):
  86. import pandas as pd
  87. writer = pd.ExcelWriter(filename)
  88. summary.to_excel(writer)
  89. writer.save()