Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test.py 12KB


  1. import numpy as np
  2. import torchvision
  3. import time
  4. import math
  5. import os
  6. import copy
  7. import pdb
  8. import argparse
  9. import sys
  10. import cv2
  11. import skimage.io
  12. import skimage.transform
  13. import skimage.color
  14. import skimage
  15. import torch
  16. import model
  17. from torch.utils.data import Dataset, DataLoader
  18. from torchvision import datasets, models, transforms
  19. from dataloader import CSVDataset, collater, Resizer, AspectRatioBasedSampler, Augmenter, UnNormalizer, Normalizer, RGB_MEAN, RGB_STD
  20. from scipy.optimize import linear_sum_assignment
  21. # assert torch.__version__.split('.')[1] == '4'
  22. print('CUDA available: {}'.format(torch.cuda.is_available()))
  23. color_list = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 0, 255), (0, 255, 255), (255, 255, 0), (128, 0, 255),
  24. (0, 128, 255), (128, 255, 0), (0, 255, 128), (255, 128, 0), (255, 0, 128), (128, 128, 255), (128, 255, 128), (255, 128, 128), (128, 128, 0), (128, 0, 128)]
  25. class detect_rect:
  26. def __init__(self):
  27. self.curr_frame = 0
  28. self.curr_rect = np.array([0, 0, 1, 1])
  29. self.next_rect = np.array([0, 0, 1, 1])
  30. self.conf = 0
  31. self.id = 0
  32. @property
  33. def position(self):
  34. x = (self.curr_rect[0] + self.curr_rect[2])/2
  35. y = (self.curr_rect[1] + self.curr_rect[3])/2
  36. return np.array([x, y])
  37. @property
  38. def size(self):
  39. w = self.curr_rect[2] - self.curr_rect[0]
  40. h = self.curr_rect[3] - self.curr_rect[1]
  41. return np.array([w, h])
  42. class tracklet:
  43. def __init__(self, det_rect):
  44. self.id = det_rect.id
  45. self.rect_list = [det_rect]
  46. self.rect_num = 1
  47. self.last_rect = det_rect
  48. self.last_frame = det_rect.curr_frame
  49. self.no_match_frame = 0
  50. def add_rect(self, det_rect):
  51. self.rect_list.append(det_rect)
  52. self.rect_num = self.rect_num + 1
  53. self.last_rect = det_rect
  54. self.last_frame = det_rect.curr_frame
  55. @property
  56. def velocity(self):
  57. if(self.rect_num < 2):
  58. return (0, 0)
  59. elif(self.rect_num < 6):
  60. return (self.rect_list[self.rect_num - 1].position - self.rect_list[self.rect_num - 2].position) / (self.rect_list[self.rect_num - 1].curr_frame - self.rect_list[self.rect_num - 2].curr_frame)
  61. else:
  62. v1 = (self.rect_list[self.rect_num - 1].position - self.rect_list[self.rect_num - 4].position) / (self.rect_list[self.rect_num - 1].curr_frame - self.rect_list[self.rect_num - 4].curr_frame)
  63. v2 = (self.rect_list[self.rect_num - 2].position - self.rect_list[self.rect_num - 5].position) / (self.rect_list[self.rect_num - 2].curr_frame - self.rect_list[self.rect_num - 5].curr_frame)
  64. v3 = (self.rect_list[self.rect_num - 3].position - self.rect_list[self.rect_num - 6].position) / (self.rect_list[self.rect_num - 3].curr_frame - self.rect_list[self.rect_num - 6].curr_frame)
  65. return (v1 + v2 + v3) / 3
  66. def cal_iou(rect1, rect2):
  67. x1, y1, x2, y2 = rect1
  68. x3, y3, x4, y4 = rect2
  69. i_w = min(x2, x4) - max(x1, x3)
  70. i_h = min(y2, y4) - max(y1, y3)
  71. if(i_w <= 0 or i_h <= 0):
  72. return 0
  73. i_s = i_w * i_h
  74. s_1 = (x2 - x1) * (y2 - y1)
  75. s_2 = (x4 - x3) * (y4 - y3)
  76. return float(i_s) / (s_1 + s_2 - i_s)
  77. def cal_simi(det_rect1, det_rect2):
  78. return cal_iou(det_rect1.next_rect, det_rect2.curr_rect)
  79. def cal_simi_track_det(track, det_rect):
  80. if(det_rect.curr_frame <= track.last_frame):
  81. print("cal_simi_track_det error")
  82. return 0
  83. elif(det_rect.curr_frame - track.last_frame == 1):
  84. return cal_iou(track.last_rect.next_rect, det_rect.curr_rect)
  85. else:
  86. pred_rect = track.last_rect.curr_rect + np.append(track.velocity, track.velocity) * (det_rect.curr_frame - track.last_frame)
  87. return cal_iou(pred_rect, det_rect.curr_rect)
  88. def track_det_match(tracklet_list, det_rect_list, min_iou = 0.5):
  89. num1 = len(tracklet_list)
  90. num2 = len(det_rect_list)
  91. cost_mat = np.zeros((num1, num2))
  92. for i in range(num1):
  93. for j in range(num2):
  94. cost_mat[i, j] = -cal_simi_track_det(tracklet_list[i], det_rect_list[j])
  95. match_result = linear_sum_assignment(cost_mat)
  96. match_result = np.asarray(match_result)
  97. match_result = np.transpose(match_result)
  98. matches, unmatched1, unmatched2 = [], [], []
  99. for i in range(num1):
  100. if i not in match_result[:, 0]:
  101. unmatched1.append(i)
  102. for j in range(num2):
  103. if j not in match_result[:, 1]:
  104. unmatched2.append(j)
  105. for i, j in match_result:
  106. if cost_mat[i, j] > -min_iou:
  107. unmatched1.append(i)
  108. unmatched2.append(j)
  109. else:
  110. matches.append((i, j))
  111. return matches, unmatched1, unmatched2
  112. def draw_caption(image, box, caption, color):
  113. b = np.array(box).astype(int)
  114. cv2.putText(image, caption, (b[0], b[1] - 8), cv2.FONT_HERSHEY_PLAIN, 2, color, 2)
  115. def run_each_dataset(model_dir, retinanet, dataset_path, subset, cur_dataset):
  116. print(cur_dataset)
  117. img_list = os.listdir(os.path.join(dataset_path, subset, cur_dataset, 'img1'))
  118. img_list = [os.path.join(dataset_path, subset, cur_dataset, 'img1', _) for _ in img_list if ('jpg' in _) or ('png' in _)]
  119. img_list = sorted(img_list)
  120. img_len = len(img_list)
  121. last_feat = None
  122. confidence_threshold = 0.4
  123. IOU_threshold = 0.5
  124. retention_threshold = 10
  125. det_list_all = []
  126. tracklet_all = []
  127. max_id = 0
  128. max_draw_len = 100
  129. draw_interval = 5
  130. img_width = 1920
  131. img_height = 1080
  132. fps = 30
  133. for i in range(img_len):
  134. det_list_all.append([])
  135. for idx in range((int(img_len / 2)), img_len + 1):
  136. i = idx - 1
  137. print('tracking: ', i)
  138. with torch.no_grad():
  139. data_path1 = img_list[min(idx, img_len - 1)]
  140. img_origin1 = skimage.io.imread(data_path1)
  141. img_h, img_w, _ = img_origin1.shape
  142. img_height, img_width = img_h, img_w
  143. resize_h, resize_w = math.ceil(img_h / 32) * 32, math.ceil(img_w / 32) * 32
  144. img1 = np.zeros((resize_h, resize_w, 3), dtype=img_origin1.dtype)
  145. img1[:img_h, :img_w, :] = img_origin1
  146. img1 = (img1.astype(np.float32) / 255.0 - np.array([[RGB_MEAN]])) / np.array([[RGB_STD]])
  147. img1 = torch.from_numpy(img1).permute(2, 0, 1).view(1, 3, resize_h, resize_w)
  148. scores, transformed_anchors, last_feat = retinanet(img1.cuda().float(), last_feat=last_feat)
  149. # if idx > 0:
  150. if idx > (int(img_len / 2)):
  151. idxs = np.where(scores>0.1)
  152. for j in range(idxs[0].shape[0]):
  153. bbox = transformed_anchors[idxs[0][j], :]
  154. x1 = int(bbox[0])
  155. y1 = int(bbox[1])
  156. x2 = int(bbox[2])
  157. y2 = int(bbox[3])
  158. x3 = int(bbox[4])
  159. y3 = int(bbox[5])
  160. x4 = int(bbox[6])
  161. y4 = int(bbox[7])
  162. det_conf = float(scores[idxs[0][j]])
  163. det_rect = detect_rect()
  164. det_rect.curr_frame = idx
  165. det_rect.curr_rect = np.array([x1, y1, x2, y2])
  166. det_rect.next_rect = np.array([x3, y3, x4, y4])
  167. det_rect.conf = det_conf
  168. if det_rect.conf > confidence_threshold:
  169. det_list_all[det_rect.curr_frame - 1].append(det_rect)
  170. # if i == 0:
  171. if i == int(img_len / 2):
  172. for j in range(len(det_list_all[i])):
  173. det_list_all[i][j].id = j + 1
  174. max_id = max(max_id, j + 1)
  175. track = tracklet(det_list_all[i][j])
  176. tracklet_all.append(track)
  177. continue
  178. matches, unmatched1, unmatched2 = track_det_match(tracklet_all, det_list_all[i], IOU_threshold)
  179. for j in range(len(matches)):
  180. det_list_all[i][matches[j][1]].id = tracklet_all[matches[j][0]].id
  181. det_list_all[i][matches[j][1]].id = tracklet_all[matches[j][0]].id
  182. tracklet_all[matches[j][0]].add_rect(det_list_all[i][matches[j][1]])
  183. delete_track_list = []
  184. for j in range(len(unmatched1)):
  185. tracklet_all[unmatched1[j]].no_match_frame = tracklet_all[unmatched1[j]].no_match_frame + 1
  186. if(tracklet_all[unmatched1[j]].no_match_frame >= retention_threshold):
  187. delete_track_list.append(unmatched1[j])
  188. origin_index = set([k for k in range(len(tracklet_all))])
  189. delete_index = set(delete_track_list)
  190. left_index = list(origin_index - delete_index)
  191. tracklet_all = [tracklet_all[k] for k in left_index]
  192. for j in range(len(unmatched2)):
  193. det_list_all[i][unmatched2[j]].id = max_id + 1
  194. max_id = max_id + 1
  195. track = tracklet(det_list_all[i][unmatched2[j]])
  196. tracklet_all.append(track)
  197. #**************visualize tracking result and save evaluate file****************
  198. fout_tracking = open(os.path.join(model_dir, 'results', cur_dataset + '.txt'), 'w')
  199. save_img_dir = os.path.join(model_dir, 'results', cur_dataset)
  200. if not os.path.exists(save_img_dir):
  201. os.makedirs(save_img_dir)
  202. out_video = os.path.join(model_dir, 'results', cur_dataset + '.mp4')
  203. videoWriter = cv2.VideoWriter(out_video, cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), fps, (img_width, img_height))
  204. id_dict = {}
  205. for i in range((int(img_len / 2)), img_len):
  206. print('saving: ', i)
  207. img = cv2.imread(img_list[i])
  208. for j in range(len(det_list_all[i])):
  209. x1, y1, x2, y2 = det_list_all[i][j].curr_rect.astype(int)
  210. trace_id = det_list_all[i][j].id
  211. id_dict.setdefault(str(trace_id),[]).append((int((x1+x2)/2), y2))
  212. draw_trace_id = str(trace_id)
  213. draw_caption(img, (x1, y1, x2, y2), draw_trace_id, color=color_list[trace_id % len(color_list)])
  214. cv2.rectangle(img, (x1, y1), (x2, y2), color=color_list[trace_id % len(color_list)], thickness=2)
  215. trace_len = len(id_dict[str(trace_id)])
  216. trace_len_draw = min(max_draw_len, trace_len)
  217. for k in range(trace_len_draw - draw_interval):
  218. if(k % draw_interval == 0):
  219. draw_point1 = id_dict[str(trace_id)][trace_len - k - 1]
  220. draw_point2 = id_dict[str(trace_id)][trace_len - k - 1 - draw_interval]
  221. cv2.line(img, draw_point1, draw_point2, color=color_list[trace_id % len(color_list)], thickness=2)
  222. fout_tracking.write(str(i+1) + ',' + str(trace_id) + ',' + str(x1) + ',' + str(y1) + ',' + str(x2 - x1) + ',' + str(y2 - y1) + ',-1,-1,-1,-1\n')
  223. cv2.imwrite(os.path.join(save_img_dir, str(i + 1).zfill(6) + '.jpg'), img)
  224. videoWriter.write(img)
  225. # cv2.waitKey(0)
  226. fout_tracking.close()
  227. videoWriter.release()
  228. def run_from_train(model_dir, root_path):
  229. if not os.path.exists(os.path.join(model_dir, 'results')):
  230. os.makedirs(os.path.join(model_dir, 'results'))
  231. retinanet = torch.load(os.path.join(model_dir, 'model_final.pt'))
  232. use_gpu = True
  233. if use_gpu: retinanet = retinanet.cuda()
  234. retinanet.eval()
  235. for seq_num in [2, 4, 5, 9, 10, 11, 13]:
  236. run_each_dataset(model_dir, retinanet, root_path, 'train', 'MOT17-{:02d}'.format(seq_num))
  237. for seq_num in [1, 3, 6, 7, 8, 12, 14]:
  238. run_each_dataset(model_dir, retinanet, root_path, 'test', 'MOT17-{:02d}'.format(seq_num))
  239. def main(args=None):
  240. parser = argparse.ArgumentParser(description='Simple script for testing a CTracker network.')
  241. parser.add_argument('--dataset_path', default='/dockerdata/home/jeromepeng/data/MOT/MOT17/', type=str, help='Dataset path, location of the images sequence.')
  242. parser.add_argument('--model_dir', default='./trained_model/', help='Path to model (.pt) file.')
  243. parser.add_argument('--model_path', default='./trained_model/model_final.pth', help='Path to model (.pt) file.')
  244. parser = parser.parse_args(args)
  245. if not os.path.exists(os.path.join(parser.model_dir, 'results')):
  246. os.makedirs(os.path.join(parser.model_dir, 'results'))
  247. retinanet = model.resnet50(num_classes=1, pretrained=True)
  248. # retinanet_save = torch.load(os.path.join(parser.model_dir, 'model_final.pth'))
  249. retinanet_save = torch.load(os.path.join(parser.model_path))
  250. # rename moco pre-trained keys
  251. state_dict = retinanet_save.state_dict()
  252. for k in list(state_dict.keys()):
  253. # retain only encoder up to before the embedding layer
  254. if k.startswith('module.'):
  255. # remove prefix
  256. state_dict[k[len("module."):]] = state_dict[k]
  257. # delete renamed or unused k
  258. del state_dict[k]
  259. retinanet.load_state_dict(state_dict)
  260. use_gpu = True
  261. if use_gpu: retinanet = retinanet.cuda()
  262. retinanet.eval()
  263. for seq_num in [2, 4, 5, 9, 10, 11, 13]:
  264. run_each_dataset(parser.model_dir, retinanet, parser.dataset_path, 'train', 'MOT17-{:02d}'.format(seq_num))
  265. # for seq_num in [1, 3, 6, 7, 8, 12, 14]:
  266. # run_each_dataset(parser.model_dir, retinanet, parser.dataset_path, 'test', 'MOT17-{:02d}'.format(seq_num))
  267. if __name__ == '__main__':
  268. main()