Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 19KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. # ------------------------------------------------------------------------
  2. # Copyright (c) 2021 megvii-model. All Rights Reserved.
  3. # ------------------------------------------------------------------------
  4. # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
  5. # Copyright (c) 2020 SenseTime. All Rights Reserved.
  6. # ------------------------------------------------------------------------
  7. # Modified from DETR (https://github.com/facebookresearch/detr)
  8. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  9. # ------------------------------------------------------------------------
  10. """
  11. SORT: A Simple, Online and Realtime Tracker
  12. Copyright (C) 2016-2020 Alex Bewley [email protected]
  13. This program is free software: you can redistribute it and/or modify
  14. it under the terms of the GNU General Public License as published by
  15. the Free Software Foundation, either version 3 of the License, or
  16. (at your option) any later version.
  17. This program is distributed in the hope that it will be useful,
  18. but WITHOUT ANY WARRANTY; without even the implied warranty of
  19. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  20. GNU General Public License for more details.
  21. You should have received a copy of the GNU General Public License
  22. along with this program. If not, see <http://www.gnu.org/licenses/>.
  23. """
  24. from __future__ import print_function
  25. import os
  26. import numpy as np
  27. import random
  28. import argparse
  29. import torchvision.transforms.functional as F
  30. import torch
  31. import cv2
  32. from tqdm import tqdm
  33. from pathlib import Path
  34. from PIL import Image, ImageDraw
  35. from models import build_model
  36. from util.tool import load_model
  37. from main import get_args_parser
  38. from torch.nn.functional import interpolate
  39. from typing import List
  40. from util.evaluation import Evaluator
  41. import motmetrics as mm
  42. import shutil
  43. from detectron2.structures import Instances
  44. from tracker import BYTETracker
  45. np.random.seed(2020)
  46. COLORS_10 = [(144, 238, 144), (178, 34, 34), (221, 160, 221), (0, 255, 0), (0, 128, 0), (210, 105, 30), (220, 20, 60),
  47. (192, 192, 192), (255, 228, 196), (50, 205, 50), (139, 0, 139), (100, 149, 237), (138, 43, 226),
  48. (238, 130, 238),
  49. (255, 0, 255), (0, 100, 0), (127, 255, 0), (255, 0, 255), (0, 0, 205), (255, 140, 0), (255, 239, 213),
  50. (199, 21, 133), (124, 252, 0), (147, 112, 219), (106, 90, 205), (176, 196, 222), (65, 105, 225),
  51. (173, 255, 47),
  52. (255, 20, 147), (219, 112, 147), (186, 85, 211), (199, 21, 133), (148, 0, 211), (255, 99, 71),
  53. (144, 238, 144),
  54. (255, 255, 0), (230, 230, 250), (0, 0, 255), (128, 128, 0), (189, 183, 107), (255, 255, 224),
  55. (128, 128, 128),
  56. (105, 105, 105), (64, 224, 208), (205, 133, 63), (0, 128, 128), (72, 209, 204), (139, 69, 19),
  57. (255, 245, 238),
  58. (250, 240, 230), (152, 251, 152), (0, 255, 255), (135, 206, 235), (0, 191, 255), (176, 224, 230),
  59. (0, 250, 154),
  60. (245, 255, 250), (240, 230, 140), (245, 222, 179), (0, 139, 139), (143, 188, 143), (255, 0, 0),
  61. (240, 128, 128),
  62. (102, 205, 170), (60, 179, 113), (46, 139, 87), (165, 42, 42), (178, 34, 34), (175, 238, 238),
  63. (255, 248, 220),
  64. (218, 165, 32), (255, 250, 240), (253, 245, 230), (244, 164, 96), (210, 105, 30)]
  65. def plot_one_box(x, img, color=None, label=None, score=None, line_thickness=None):
  66. # Plots one bounding box on image img
  67. tl = line_thickness or round(
  68. 0.002 * max(img.shape[0:2])) + 1 # line thickness
  69. color = color or [random.randint(0, 255) for _ in range(3)]
  70. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  71. cv2.rectangle(img, c1, c2, color, thickness=tl)
  72. # if label:
  73. # tf = max(tl - 1, 1) # font thickness
  74. # t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  75. # c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  76. # cv2.rectangle(img, c1, c2, color, -1) # filled
  77. # cv2.putText(img,
  78. # label, (c1[0], c1[1] - 2),
  79. # 0,
  80. # tl / 3, [225, 255, 255],
  81. # thickness=tf,
  82. # lineType=cv2.LINE_AA)
  83. # if score is not None:
  84. # cv2.putText(img, score, (c1[0], c1[1] + 30), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  85. return img
  86. def draw_bboxes(ori_img, bbox, identities=None, offset=(0, 0), cvt_color=False):
  87. if cvt_color:
  88. ori_img = cv2.cvtColor(np.asarray(ori_img), cv2.COLOR_RGB2BGR)
  89. img = ori_img
  90. for i, box in enumerate(bbox):
  91. x1, y1, x2, y2 = [int(i) for i in box[:4]]
  92. x1 += offset[0]
  93. x2 += offset[0]
  94. y1 += offset[1]
  95. y2 += offset[1]
  96. if len(box) > 4:
  97. score = '{:.2f}'.format(box[4])
  98. else:
  99. score = None
  100. # box text and bar
  101. id = int(identities[i]) if identities is not None else 0
  102. color = COLORS_10[id % len(COLORS_10)]
  103. label = '{:d}'.format(id)
  104. # t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 2 , 2)[0]
  105. img = plot_one_box([x1, y1, x2, y2], img, color, label, score=score)
  106. return img
  107. def draw_points(img: np.ndarray, points: np.ndarray, color=(255, 255, 255)) -> np.ndarray:
  108. assert len(points.shape) == 2 and points.shape[1] == 2, 'invalid points shape: {}'.format(points.shape)
  109. for i, (x, y) in enumerate(points):
  110. if i >= 300:
  111. color = (0, 255, 0)
  112. cv2.circle(img, (int(x), int(y)), 2, color=color, thickness=2)
  113. return img
  114. def tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
  115. return tensor.detach().cpu().numpy()
  116. class Track(object):
  117. track_cnt = 0
  118. def __init__(self, box):
  119. self.box = box
  120. self.time_since_update = 0
  121. self.id = Track.track_cnt
  122. Track.track_cnt += 1
  123. self.miss = 0
  124. def miss_one_frame(self):
  125. self.miss += 1
  126. def clear_miss(self):
  127. self.miss = 0
  128. def update(self, box):
  129. self.box = box
  130. self.clear_miss()
  131. def write_results(filename, results):
  132. save_format = '{frame},{id},{x1},{y1},{w},{h},{s},-1,-1,-1\n'
  133. with open(filename, 'w') as f:
  134. for frame_id, tlwhs, track_ids, scores in results:
  135. for tlwh, track_id, score in zip(tlwhs, track_ids, scores):
  136. if track_id < 0:
  137. continue
  138. x1, y1, w, h = tlwh
  139. line = save_format.format(frame=frame_id, id=track_id, x1=round(x1, 1), y1=round(y1, 1), w=round(w, 1), h=round(h, 1), s=round(score, 2))
  140. f.write(line)
  141. logger.info('save results to {}'.format(filename))
  142. class MOTR(object):
  143. def __init__(self, max_age=1, min_hits=3, iou_threshold=0.3):
  144. self.tracker = BYTETracker()
  145. def update(self, dt_instances: Instances):
  146. ret = []
  147. for i in range(len(dt_instances)):
  148. label = dt_instances.labels[i]
  149. if label == 0:
  150. id = dt_instances.obj_idxes[i]
  151. box_with_score = np.concatenate([dt_instances.boxes[i], dt_instances.scores[i:i+1]], axis=-1)
  152. ret.append(np.concatenate((box_with_score, [id + 1])).reshape(1, -1)) # +1 as MOT benchmark requires positive
  153. if len(ret) > 0:
  154. online_targets = self.tracker.update(np.concatenate(ret))
  155. online_ret = []
  156. for t in online_targets:
  157. online_ret.append(np.array([t.tlbr[0], t.tlbr[1], t.tlbr[2], t.tlbr[3], t.score, t.track_id]).reshape(1, -1))
  158. if len(online_ret) > 0:
  159. return np.concatenate(online_ret)
  160. return np.empty((0, 6))
  161. def load_label(label_path: str, img_size: tuple) -> dict:
  162. labels0 = np.loadtxt(label_path, dtype=np.float32).reshape(-1, 6)
  163. h, w = img_size
  164. # Normalized cewh to pixel xyxy format
  165. labels = labels0.copy()
  166. labels[:, 2] = w * (labels0[:, 2] - labels0[:, 4] / 2)
  167. labels[:, 3] = h * (labels0[:, 3] - labels0[:, 5] / 2)
  168. labels[:, 4] = w * (labels0[:, 2] + labels0[:, 4] / 2)
  169. labels[:, 5] = h * (labels0[:, 3] + labels0[:, 5] / 2)
  170. targets = {'boxes': [], 'labels': [], 'area': []}
  171. num_boxes = len(labels)
  172. visited_ids = set()
  173. for label in labels[:num_boxes]:
  174. obj_id = label[1]
  175. if obj_id in visited_ids:
  176. continue
  177. visited_ids.add(obj_id)
  178. targets['boxes'].append(label[2:6].tolist())
  179. targets['area'].append(label[4] * label[5])
  180. targets['labels'].append(0)
  181. targets['boxes'] = np.asarray(targets['boxes'])
  182. targets['area'] = np.asarray(targets['area'])
  183. targets['labels'] = np.asarray(targets['labels'])
  184. return targets
  185. def filter_pub_det(res_file, pub_det_file, filter_iou=False):
  186. frame_boxes = {}
  187. with open(pub_det_file, 'r') as f:
  188. lines = f.readlines()
  189. for line in lines:
  190. if len(line) == 0:
  191. continue
  192. elements = line.strip().split(',')
  193. frame_id = int(elements[0])
  194. x1, y1, w, h = elements[2:6]
  195. x1, y1, w, h = float(x1), float(y1), float(w), float(h)
  196. x2 = x1 + w - 1
  197. y2 = y1 + h - 1
  198. if frame_id not in frame_boxes:
  199. frame_boxes[frame_id] = []
  200. frame_boxes[frame_id].append([x1, y1, x2, y2])
  201. for frame, boxes in frame_boxes.items():
  202. frame_boxes[frame] = np.array(boxes)
  203. ids = {}
  204. num_filter_box = 0
  205. with open(res_file, 'r') as f:
  206. lines = list(f.readlines())
  207. with open(res_file, 'w') as f:
  208. for line in lines:
  209. if len(line) == 0:
  210. continue
  211. elements = line.strip().split(',')
  212. frame_id, obj_id = elements[:2]
  213. frame_id = int(frame_id)
  214. obj_id = int(obj_id)
  215. x1, y1, w, h = elements[2:6]
  216. x1, y1, w, h = float(x1), float(y1), float(w), float(h)
  217. x2 = x1 + w - 1
  218. y2 = y1 + h - 1
  219. if obj_id not in ids:
  220. # track initialization.
  221. if frame_id not in frame_boxes:
  222. num_filter_box += 1
  223. print("filter init box {} {}".format(frame_id, obj_id))
  224. continue
  225. pub_dt_boxes = frame_boxes[frame_id]
  226. dt_box = np.array([[x1, y1, x2, y2]])
  227. if filter_iou:
  228. max_iou = bbox_iou(dt_box, pub_dt_boxes).max()
  229. if max_iou < 0.5:
  230. num_filter_box += 1
  231. print("filter init box {} {}".format(frame_id, obj_id))
  232. continue
  233. else:
  234. pub_dt_centers = (pub_dt_boxes[:, :2] + pub_dt_boxes[:, 2:4]) * 0.5
  235. x_inside = (dt_box[0, 0] <= pub_dt_centers[:, 0]) & (dt_box[0, 2] >= pub_dt_centers[:, 0])
  236. y_inside = (dt_box[0, 1] <= pub_dt_centers[:, 1]) & (dt_box[0, 3] >= pub_dt_centers[:, 1])
  237. center_inside: np.ndarray = x_inside & y_inside
  238. if not center_inside.any():
  239. num_filter_box += 1
  240. print("filter init box {} {}".format(frame_id, obj_id))
  241. continue
  242. print("save init track {} {}".format(frame_id, obj_id))
  243. ids[obj_id] = True
  244. f.write(line)
  245. print("totally {} boxes are filtered.".format(num_filter_box))
  246. class Detector(object):
  247. def __init__(self, args, model=None, seq_num=2):
  248. self.args = args
  249. self.detr = model
  250. self.seq_num = seq_num
  251. img_list = os.listdir(os.path.join(self.args.mot_path, self.seq_num, 'img1'))
  252. img_list = [os.path.join(self.args.mot_path, self.seq_num, 'img1', _) for _ in img_list if
  253. ('jpg' in _) or ('png' in _)]
  254. self.img_list = sorted(img_list)
  255. self.img_len = len(self.img_list)
  256. self.tr_tracker = MOTR()
  257. '''
  258. common settings
  259. '''
  260. self.img_height = 800
  261. self.img_width = 1536
  262. self.mean = [0.485, 0.456, 0.406]
  263. self.std = [0.229, 0.224, 0.225]
  264. self.save_path = os.path.join(self.args.output_dir, 'results/{}'.format(seq_num))
  265. os.makedirs(self.save_path, exist_ok=True)
  266. self.predict_path = os.path.join(self.args.output_dir, 'preds', self.seq_num)
  267. os.makedirs(self.predict_path, exist_ok=True)
  268. if os.path.exists(os.path.join(self.predict_path, 'gt.txt')):
  269. os.remove(os.path.join(self.predict_path, 'gt.txt'))
  270. def load_img_from_file(self,f_path):
  271. label_path = f_path.replace('images', 'labels_with_ids').replace('.png', '.txt').replace('.jpg', '.txt')
  272. cur_img = cv2.imread(f_path)
  273. cur_img = cv2.cvtColor(cur_img, cv2.COLOR_BGR2RGB)
  274. targets = load_label(label_path, cur_img.shape[:2]) if os.path.exists(label_path) else None
  275. return cur_img, targets
  276. def init_img(self, img):
  277. ori_img = img.copy()
  278. self.seq_h, self.seq_w = img.shape[:2]
  279. scale = self.img_height / min(self.seq_h, self.seq_w)
  280. if max(self.seq_h, self.seq_w) * scale > self.img_width:
  281. scale = self.img_width / max(self.seq_h, self.seq_w)
  282. target_h = int(self.seq_h * scale)
  283. target_w = int(self.seq_w * scale)
  284. img = cv2.resize(img, (target_w, target_h))
  285. img = F.normalize(F.to_tensor(img), self.mean, self.std)
  286. img = img.unsqueeze(0)
  287. return img, ori_img
  288. @staticmethod
  289. def filter_dt_by_score(dt_instances: Instances, prob_threshold: float) -> Instances:
  290. keep = dt_instances.scores > prob_threshold
  291. return dt_instances[keep]
  292. @staticmethod
  293. def filter_dt_by_area(dt_instances: Instances, area_threshold: float) -> Instances:
  294. wh = dt_instances.boxes[:, 2:4] - dt_instances.boxes[:, 0:2]
  295. areas = wh[:, 0] * wh[:, 1]
  296. keep = areas > area_threshold
  297. return dt_instances[keep]
  298. @staticmethod
  299. def write_results(txt_path, frame_id, bbox_xyxy, identities):
  300. save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
  301. with open(txt_path, 'a') as f:
  302. for xyxy, track_id in zip(bbox_xyxy, identities):
  303. if track_id < 0 or track_id is None:
  304. continue
  305. x1, y1, x2, y2 = xyxy
  306. w, h = x2 - x1, y2 - y1
  307. line = save_format.format(frame=int(frame_id), id=int(track_id), x1=x1, y1=y1, w=w, h=h)
  308. f.write(line)
  309. def eval_seq(self):
  310. data_root = os.path.join(self.args.mot_path)
  311. result_filename = os.path.join(self.predict_path, 'gt.txt')
  312. evaluator = Evaluator(data_root, self.seq_num)
  313. accs = evaluator.eval_file(result_filename)
  314. return accs
  315. @staticmethod
  316. def visualize_img_with_bbox(img_path, img, dt_instances: Instances, ref_pts=None, gt_boxes=None):
  317. img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  318. if dt_instances.has('scores'):
  319. img_show = draw_bboxes(img, np.concatenate([dt_instances.boxes, dt_instances.scores.reshape(-1, 1)], axis=-1), dt_instances.obj_idxes)
  320. else:
  321. img_show = draw_bboxes(img, dt_instances.boxes, dt_instances.obj_idxes)
  322. # if ref_pts is not None:
  323. # img_show = draw_points(img_show, ref_pts)
  324. # if gt_boxes is not None:
  325. # img_show = draw_bboxes(img_show, gt_boxes, identities=np.ones((len(gt_boxes), )) * -1)
  326. cv2.imwrite(img_path, img_show)
  327. def detect(self, prob_threshold=0.2, area_threshold=100, vis=False):
  328. total_dts = 0
  329. track_instances = None
  330. max_id = 0
  331. # we only consider val split (second half images)
  332. for i in tqdm(range((int(self.img_len / 2)), self.img_len)):
  333. # for i in tqdm(range(0, self.img_len)):
  334. img, targets = self.load_img_from_file(self.img_list[i])
  335. cur_img, ori_img = self.init_img(img)
  336. # track_instances = None
  337. if track_instances is not None:
  338. track_instances.remove('boxes')
  339. track_instances.remove('labels')
  340. res = self.detr.inference_single_image(cur_img.cuda().float(), (self.seq_h, self.seq_w), track_instances)
  341. track_instances = res['track_instances']
  342. max_id = max(max_id, track_instances.obj_idxes.max().item())
  343. print("ref points.shape={}".format(res['ref_pts'].shape))
  344. all_ref_pts = tensor_to_numpy(res['ref_pts'][0, :, :2])
  345. dt_instances = track_instances.to(torch.device('cpu'))
  346. # filter det instances by score.
  347. dt_instances = self.filter_dt_by_score(dt_instances, prob_threshold)
  348. dt_instances = self.filter_dt_by_area(dt_instances, area_threshold)
  349. total_dts += len(dt_instances)
  350. if vis:
  351. # for visual
  352. cur_vis_img_path = os.path.join(self.save_path, 'frame_{:0>8d}.jpg'.format(i))
  353. gt_boxes = None
  354. self.visualize_img_with_bbox(cur_vis_img_path, ori_img, dt_instances, ref_pts=all_ref_pts, gt_boxes=gt_boxes)
  355. tracker_outputs = self.tr_tracker.update(dt_instances)
  356. self.write_results(txt_path=os.path.join(self.predict_path, 'gt.txt'),
  357. frame_id=(i + 1),
  358. bbox_xyxy=tracker_outputs[:, :4],
  359. identities=tracker_outputs[:, 5])
  360. print("totally {} dts max_id={}".format(total_dts, max_id))
  361. if __name__ == '__main__':
  362. parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
  363. args = parser.parse_args()
  364. if args.output_dir:
  365. Path(args.output_dir).mkdir(parents=True, exist_ok=True)
  366. # load model and weights
  367. detr, _, _ = build_model(args)
  368. checkpoint = torch.load(args.resume, map_location='cpu')
  369. detr = load_model(detr, args.resume)
  370. detr = detr.cuda()
  371. detr.eval()
  372. # seq_nums = ['ADL-Rundle-6', 'ETH-Bahnhof', 'KITTI-13', 'PETS09-S2L1', 'TUD-Stadtmitte', 'ADL-Rundle-8', 'KITTI-17',
  373. # 'ETH-Pedcross2', 'ETH-Sunnyday', 'TUD-Campus', 'Venice-2']
  374. seq_nums = ['MOT17-02-SDP',
  375. 'MOT17-04-SDP',
  376. 'MOT17-05-SDP',
  377. 'MOT17-09-SDP',
  378. 'MOT17-10-SDP',
  379. 'MOT17-11-SDP',
  380. 'MOT17-13-SDP']
  381. accs = []
  382. seqs = []
  383. for seq_num in seq_nums:
  384. print("solve {}".format(seq_num))
  385. det = Detector(args, model=detr, seq_num=seq_num)
  386. det.detect(vis=False)
  387. accs.append(det.eval_seq())
  388. seqs.append(seq_num)
  389. metrics = mm.metrics.motchallenge_metrics
  390. mh = mm.metrics.create()
  391. summary = Evaluator.get_summary(accs, seqs, metrics)
  392. strsummary = mm.io.render_summary(
  393. summary,
  394. formatters=mh.formatters,
  395. namemap=mm.io.motchallenge_metric_names
  396. )
  397. print(strsummary)
  398. with open("eval_log.txt", 'a') as f:
  399. print(strsummary, file=f)