Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_byte.py 5.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import numpy as np
  2. import torchvision
  3. import time
  4. import math
  5. import os
  6. import copy
  7. import pdb
  8. import argparse
  9. import sys
  10. import cv2
  11. import skimage.io
  12. import skimage.transform
  13. import skimage.color
  14. import skimage
  15. import torch
  16. import model
  17. from torch.utils.data import Dataset, DataLoader
  18. from torchvision import datasets, models, transforms
  19. from dataloader import CSVDataset, collater, Resizer, AspectRatioBasedSampler, Augmenter, UnNormalizer, Normalizer, RGB_MEAN, RGB_STD
  20. from scipy.optimize import linear_sum_assignment
  21. from tracker import BYTETracker
  22. def write_results(filename, results):
  23. save_format = '{frame},{id},{x1},{y1},{w},{h},{s},-1,-1,-1\n'
  24. with open(filename, 'w') as f:
  25. for frame_id, tlwhs, track_ids, scores in results:
  26. for tlwh, track_id, score in zip(tlwhs, track_ids, scores):
  27. if track_id < 0:
  28. continue
  29. x1, y1, w, h = tlwh
  30. line = save_format.format(frame=frame_id, id=track_id, x1=round(x1, 1), y1=round(y1, 1), w=round(w, 1), h=round(h, 1), s=round(score, 2))
  31. f.write(line)
  32. def write_results_no_score(filename, results):
  33. save_format = '{frame},{id},{x1},{y1},{w},{h},-1,-1,-1,-1\n'
  34. with open(filename, 'w') as f:
  35. for frame_id, tlwhs, track_ids in results:
  36. for tlwh, track_id in zip(tlwhs, track_ids):
  37. if track_id < 0:
  38. continue
  39. x1, y1, w, h = tlwh
  40. line = save_format.format(frame=frame_id, id=track_id, x1=round(x1, 1), y1=round(y1, 1), w=round(w, 1), h=round(h, 1))
  41. f.write(line)
  42. def run_each_dataset(model_dir, retinanet, dataset_path, subset, cur_dataset):
  43. print(cur_dataset)
  44. img_list = os.listdir(os.path.join(dataset_path, subset, cur_dataset, 'img1'))
  45. img_list = [os.path.join(dataset_path, subset, cur_dataset, 'img1', _) for _ in img_list if ('jpg' in _) or ('png' in _)]
  46. img_list = sorted(img_list)
  47. img_len = len(img_list)
  48. last_feat = None
  49. confidence_threshold = 0.6
  50. IOU_threshold = 0.5
  51. retention_threshold = 10
  52. det_list_all = []
  53. tracklet_all = []
  54. results = []
  55. max_id = 0
  56. max_draw_len = 100
  57. draw_interval = 5
  58. img_width = 1920
  59. img_height = 1080
  60. fps = 30
  61. tracker = BYTETracker()
  62. for idx in range((int(img_len / 2)), img_len + 1):
  63. i = idx - 1
  64. print('tracking: ', i)
  65. with torch.no_grad():
  66. data_path1 = img_list[min(idx, img_len - 1)]
  67. img_origin1 = skimage.io.imread(data_path1)
  68. img_h, img_w, _ = img_origin1.shape
  69. img_height, img_width = img_h, img_w
  70. resize_h, resize_w = math.ceil(img_h / 32) * 32, math.ceil(img_w / 32) * 32
  71. img1 = np.zeros((resize_h, resize_w, 3), dtype=img_origin1.dtype)
  72. img1[:img_h, :img_w, :] = img_origin1
  73. img1 = (img1.astype(np.float32) / 255.0 - np.array([[RGB_MEAN]])) / np.array([[RGB_STD]])
  74. img1 = torch.from_numpy(img1).permute(2, 0, 1).view(1, 3, resize_h, resize_w)
  75. scores, transformed_anchors, last_feat = retinanet(img1.cuda().float(), last_feat=last_feat)
  76. if idx > (int(img_len / 2)):
  77. idxs = np.where(scores > 0.1)
  78. # run tracking
  79. online_targets = tracker.update(transformed_anchors[idxs[0], :4], scores[idxs[0]])
  80. online_tlwhs = []
  81. online_ids = []
  82. online_scores = []
  83. for t in online_targets:
  84. tlwh = t.tlwh
  85. tid = t.track_id
  86. online_tlwhs.append(tlwh)
  87. online_ids.append(tid)
  88. online_scores.append(t.score)
  89. results.append((idx, online_tlwhs, online_ids, online_scores))
  90. fout_tracking = os.path.join(model_dir, 'results', cur_dataset + '.txt')
  91. write_results(fout_tracking, results)
  92. def main(args=None):
  93. parser = argparse.ArgumentParser(description='Simple script for testing a CTracker network.')
  94. parser.add_argument('--dataset_path', default='/dockerdata/home/jeromepeng/data/MOT/MOT17/', type=str,
  95. help='Dataset path, location of the images sequence.')
  96. parser.add_argument('--model_dir', default='./trained_model/', help='Path to model (.pt) file.')
  97. parser.add_argument('--model_path', default='./trained_model/model_final.pth', help='Path to model (.pt) file.')
  98. parser.add_argument('--seq_nums', default=0, type=int)
  99. parser = parser.parse_args(args)
  100. if not os.path.exists(os.path.join(parser.model_dir, 'results')):
  101. os.makedirs(os.path.join(parser.model_dir, 'results'))
  102. retinanet = model.resnet50(num_classes=1, pretrained=True)
  103. # retinanet_save = torch.load(os.path.join(parser.model_dir, 'model_final.pth'))
  104. retinanet_save = torch.load(os.path.join(parser.model_path))
  105. # rename moco pre-trained keys
  106. state_dict = retinanet_save.state_dict()
  107. for k in list(state_dict.keys()):
  108. # retain only encoder up to before the embedding layer
  109. if k.startswith('module.'):
  110. # remove prefix
  111. state_dict[k[len("module."):]] = state_dict[k]
  112. # delete renamed or unused k
  113. del state_dict[k]
  114. retinanet.load_state_dict(state_dict)
  115. use_gpu = True
  116. if use_gpu: retinanet = retinanet.cuda()
  117. retinanet.eval()
  118. seq_nums = []
  119. if parser.seq_nums > 0:
  120. seq_nums.append(parser.seq_nums)
  121. else:
  122. seq_nums = [2, 4, 5, 9, 10, 11, 13]
  123. for seq_num in seq_nums:
  124. run_each_dataset(parser.model_dir, retinanet, parser.dataset_path, 'train', 'MOT17-{:02d}'.format(seq_num))
  125. # for seq_num in [1, 3, 6, 7, 8, 12, 14]:
  126. # run_each_dataset(parser.model_dir, retinanet, parser.dataset_path, 'test', 'MOT17-{:02d}'.format(seq_num))
  127. if __name__ == '__main__':
  128. main()