Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

interpolation.py 5.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import numpy as np
  2. import os
  3. import glob
  4. import motmetrics as mm
  5. from yolox.evaluators.evaluation import Evaluator
  6. def mkdir_if_missing(d):
  7. if not os.path.exists(d):
  8. os.makedirs(d)
  9. def eval_mota(data_root, txt_path):
  10. accs = []
  11. seqs = sorted([s for s in os.listdir(data_root) if s.endswith('FRCNN')])
  12. #seqs = sorted([s for s in os.listdir(data_root)])
  13. for seq in seqs:
  14. video_out_path = os.path.join(txt_path, seq + '.txt')
  15. evaluator = Evaluator(data_root, seq, 'mot')
  16. accs.append(evaluator.eval_file(video_out_path))
  17. metrics = mm.metrics.motchallenge_metrics
  18. mh = mm.metrics.create()
  19. summary = Evaluator.get_summary(accs, seqs, metrics)
  20. strsummary = mm.io.render_summary(
  21. summary,
  22. formatters=mh.formatters,
  23. namemap=mm.io.motchallenge_metric_names
  24. )
  25. print(strsummary)
  26. def get_mota(data_root, txt_path):
  27. accs = []
  28. seqs = sorted([s for s in os.listdir(data_root) if s.endswith('FRCNN')])
  29. #seqs = sorted([s for s in os.listdir(data_root)])
  30. for seq in seqs:
  31. video_out_path = os.path.join(txt_path, seq + '.txt')
  32. evaluator = Evaluator(data_root, seq, 'mot')
  33. accs.append(evaluator.eval_file(video_out_path))
  34. metrics = mm.metrics.motchallenge_metrics
  35. mh = mm.metrics.create()
  36. summary = Evaluator.get_summary(accs, seqs, metrics)
  37. strsummary = mm.io.render_summary(
  38. summary,
  39. formatters=mh.formatters,
  40. namemap=mm.io.motchallenge_metric_names
  41. )
  42. mota = float(strsummary.split(' ')[-6][:-1])
  43. return mota
  44. def write_results_score(filename, results):
  45. save_format = '{frame},{id},{x1},{y1},{w},{h},{s},-1,-1,-1\n'
  46. with open(filename, 'w') as f:
  47. for i in range(results.shape[0]):
  48. frame_data = results[i]
  49. frame_id = int(frame_data[0])
  50. track_id = int(frame_data[1])
  51. x1, y1, w, h = frame_data[2:6]
  52. score = frame_data[6]
  53. line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, w=w, h=h, s=-1)
  54. f.write(line)
  55. def dti(txt_path, save_path, n_min=25, n_dti=20):
  56. seq_txts = sorted(glob.glob(os.path.join(txt_path, '*.txt')))
  57. for seq_txt in seq_txts:
  58. seq_name = seq_txt.split('/')[-1]
  59. seq_data = np.loadtxt(seq_txt, dtype=np.float64, delimiter=',')
  60. min_id = int(np.min(seq_data[:, 1]))
  61. max_id = int(np.max(seq_data[:, 1]))
  62. seq_results = np.zeros((1, 10), dtype=np.float64)
  63. for track_id in range(min_id, max_id + 1):
  64. index = (seq_data[:, 1] == track_id)
  65. tracklet = seq_data[index]
  66. tracklet_dti = tracklet
  67. if tracklet.shape[0] == 0:
  68. continue
  69. n_frame = tracklet.shape[0]
  70. n_conf = np.sum(tracklet[:, 6] > 0.5)
  71. if n_frame > n_min:
  72. frames = tracklet[:, 0]
  73. frames_dti = {}
  74. for i in range(0, n_frame):
  75. right_frame = frames[i]
  76. if i > 0:
  77. left_frame = frames[i - 1]
  78. else:
  79. left_frame = frames[i]
  80. # disconnected track interpolation
  81. if 1 < right_frame - left_frame < n_dti:
  82. num_bi = int(right_frame - left_frame - 1)
  83. right_bbox = tracklet[i, 2:6]
  84. left_bbox = tracklet[i - 1, 2:6]
  85. for j in range(1, num_bi + 1):
  86. curr_frame = j + left_frame
  87. curr_bbox = (curr_frame - left_frame) * (right_bbox - left_bbox) / \
  88. (right_frame - left_frame) + left_bbox
  89. frames_dti[curr_frame] = curr_bbox
  90. num_dti = len(frames_dti.keys())
  91. if num_dti > 0:
  92. data_dti = np.zeros((num_dti, 10), dtype=np.float64)
  93. for n in range(num_dti):
  94. data_dti[n, 0] = list(frames_dti.keys())[n]
  95. data_dti[n, 1] = track_id
  96. data_dti[n, 2:6] = frames_dti[list(frames_dti.keys())[n]]
  97. data_dti[n, 6:] = [1, -1, -1, -1]
  98. tracklet_dti = np.vstack((tracklet, data_dti))
  99. seq_results = np.vstack((seq_results, tracklet_dti))
  100. save_seq_txt = os.path.join(save_path, seq_name)
  101. seq_results = seq_results[1:]
  102. seq_results = seq_results[seq_results[:, 0].argsort()]
  103. write_results_score(save_seq_txt, seq_results)
  104. if __name__ == '__main__':
  105. data_root = '/opt/tiger/demo/ByteTrack/datasets/mot/test'
  106. txt_path = '/opt/tiger/demo/ByteTrack/YOLOX_outputs/yolox_x_mix_det/track_results'
  107. save_path = '/opt/tiger/demo/ByteTrack/YOLOX_outputs/yolox_x_mix_det/track_results_dti'
  108. mkdir_if_missing(save_path)
  109. dti(txt_path, save_path, n_min=5, n_dti=20)
  110. print('Before DTI: ')
  111. eval_mota(data_root, txt_path)
  112. print('After DTI:')
  113. eval_mota(data_root, save_path)
  114. '''
  115. mota_best = 0.0
  116. best_n_min = 0
  117. best_n_dti = 0
  118. for n_min in range(5, 50, 5):
  119. for n_dti in range(5, 30, 5):
  120. dti(txt_path, save_path, n_min, n_dti)
  121. mota = get_mota(data_root, save_path)
  122. if mota > mota_best:
  123. mota_best = mota
  124. best_n_min = n_min
  125. best_n_dti = n_dti
  126. print(mota_best, best_n_min, best_n_dti)
  127. print(mota_best, best_n_min, best_n_dti)
  128. '''