|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207 |
- # ------------------------------------------------------------------------
- # Copyright (c) 2021 megvii-model. All Rights Reserved.
- # ------------------------------------------------------------------------
- # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
- # Copyright (c) 2020 SenseTime. All Rights Reserved.
- # ------------------------------------------------------------------------
- # Modified from DETR (https://github.com/facebookresearch/detr)
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- # ------------------------------------------------------------------------
-
-
- import os
- import numpy as np
- import copy
- import motmetrics as mm
- mm.lap.default_solver = 'lap'
- import os
- from typing import Dict
- import numpy as np
- import logging
-
- def read_results(filename, data_type: str, is_gt=False, is_ignore=False):
- if data_type in ('mot', 'lab'):
- read_fun = read_mot_results
- else:
- raise ValueError('Unknown data type: {}'.format(data_type))
-
- return read_fun(filename, is_gt, is_ignore)
-
- # def read_mot_results(filename, is_gt, is_ignore):
- # results_dict = dict()
- # if os.path.isfile(filename):
- # with open(filename, 'r') as f:
- # for line in f.readlines():
- # linelist = line.split(',')
- # if len(linelist) < 7:
- # continue
- # fid = int(linelist[0])
- # if fid < 1:
- # continue
- # results_dict.setdefault(fid, list())
-
- # if is_gt:
- # mark = int(float(linelist[6]))
- # if mark == 0 :
- # continue
- # score = 1
- # elif is_ignore:
- # score = 1
- # else:
- # score = float(linelist[6])
-
- # tlwh = tuple(map(float, linelist[2:6]))
- # target_id = int(float(linelist[1]))
- # results_dict[fid].append((tlwh, target_id, score))
-
- # return results_dict
-
- def read_mot_results(filename, is_gt, is_ignore):
- valid_labels = {1}
- ignore_labels = {0, 2, 7, 8, 12}
- results_dict = dict()
- if os.path.isfile(filename):
- with open(filename, 'r') as f:
- for line in f.readlines():
- linelist = line.split(',')
- if len(linelist) < 7:
- continue
- fid = int(linelist[0])
- if fid < 1:
- continue
- results_dict.setdefault(fid, list())
-
- if is_gt:
- if 'MOT16-' in filename or 'MOT17-' in filename:
- label = int(float(linelist[7]))
- mark = int(float(linelist[6]))
- if mark == 0 or label not in valid_labels:
- continue
- score = 1
- elif is_ignore:
- if 'MOT16-' in filename or 'MOT17-' in filename:
- label = int(float(linelist[7]))
- vis_ratio = float(linelist[8])
- if label not in ignore_labels and vis_ratio >= 0:
- continue
- elif 'MOT15' in filename:
- label = int(float(linelist[6]))
- if label not in ignore_labels:
- continue
- else:
- continue
- score = 1
- else:
- score = float(linelist[6])
-
- tlwh = tuple(map(float, linelist[2:6]))
- target_id = int(linelist[1])
-
- results_dict[fid].append((tlwh, target_id, score))
-
- return results_dict
-
- def unzip_objs(objs):
- if len(objs) > 0:
- tlwhs, ids, scores = zip(*objs)
- else:
- tlwhs, ids, scores = [], [], []
- tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4)
- return tlwhs, ids, scores
-
-
- class Evaluator(object):
- def __init__(self, data_root, seq_name, data_type='mot'):
-
- self.data_root = data_root
- self.seq_name = seq_name
- self.data_type = data_type
-
- self.load_annotations()
- self.reset_accumulator()
-
- def load_annotations(self):
- assert self.data_type == 'mot'
-
- gt_filename = os.path.join(self.data_root, self.seq_name, 'gt', 'gt.txt')
- self.gt_frame_dict = read_results(gt_filename, self.data_type, is_gt=True)
- self.gt_ignore_frame_dict = read_results(gt_filename, self.data_type, is_ignore=True)
-
- def reset_accumulator(self):
- self.acc = mm.MOTAccumulator(auto_id=True)
-
- def eval_frame(self, frame_id, trk_tlwhs, trk_ids, rtn_events=False):
- # results
- trk_tlwhs = np.copy(trk_tlwhs)
- trk_ids = np.copy(trk_ids)
-
- # gts
- gt_objs = self.gt_frame_dict.get(frame_id, [])
- gt_tlwhs, gt_ids = unzip_objs(gt_objs)[:2]
-
- # ignore boxes
- ignore_objs = self.gt_ignore_frame_dict.get(frame_id, [])
- ignore_tlwhs = unzip_objs(ignore_objs)[0]
- # remove ignored results
- keep = np.ones(len(trk_tlwhs), dtype=bool)
- iou_distance = mm.distances.iou_matrix(ignore_tlwhs, trk_tlwhs, max_iou=0.5)
- if len(iou_distance) > 0:
- match_is, match_js = mm.lap.linear_sum_assignment(iou_distance)
- match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js])
- match_ious = iou_distance[match_is, match_js]
-
- match_js = np.asarray(match_js, dtype=int)
- match_js = match_js[np.logical_not(np.isnan(match_ious))]
- keep[match_js] = False
- trk_tlwhs = trk_tlwhs[keep]
- trk_ids = trk_ids[keep]
-
- # get distance matrix
- iou_distance = mm.distances.iou_matrix(gt_tlwhs, trk_tlwhs, max_iou=0.5)
-
- # acc
- self.acc.update(gt_ids, trk_ids, iou_distance)
-
- if rtn_events and iou_distance.size > 0 and hasattr(self.acc, 'last_mot_events'):
- events = self.acc.last_mot_events # only supported by https://github.com/longcw/py-motmetrics
- else:
- events = None
- return events
-
- def eval_file(self, filename):
- self.reset_accumulator()
-
- result_frame_dict = read_results(filename, self.data_type, is_gt=False)
- #frames = sorted(list(set(self.gt_frame_dict.keys()) | set(result_frame_dict.keys())))
- frames = sorted(list(set(result_frame_dict.keys())))
-
- for frame_id in frames:
- trk_objs = result_frame_dict.get(frame_id, [])
- trk_tlwhs, trk_ids = unzip_objs(trk_objs)[:2]
- self.eval_frame(frame_id, trk_tlwhs, trk_ids, rtn_events=False)
-
- return self.acc
-
- @staticmethod
- def get_summary(accs, names, metrics=('mota', 'num_switches', 'idp', 'idr', 'idf1', 'precision', 'recall')):
- names = copy.deepcopy(names)
- if metrics is None:
- metrics = mm.metrics.motchallenge_metrics
- metrics = copy.deepcopy(metrics)
-
- mh = mm.metrics.create()
- summary = mh.compute_many(
- accs,
- metrics=metrics,
- names=names,
- generate_overall=True
- )
-
- return summary
-
- @staticmethod
- def save_summary(summary, filename):
- import pandas as pd
- writer = pd.ExcelWriter(filename)
- summary.to_excel(writer)
- writer.save()
|