Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

generate_weak_labels.py 7.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # Mahdi Abdollahpour, 22/12/2021, 02:26 PM, PyCharm, ByteTrack
  2. import os
  3. import time
  4. from loguru import logger
  5. # from opts import opts
  6. from os import listdir
  7. from os.path import isfile, join
  8. import cv2
  9. import numpy as np
  10. import torch
  11. from yolox.data.data_augment import ValTransform
  12. # from yolox.data.datasets import COCO_CLASSES
  13. from yolox.exp import get_exp
  14. from yolox.utils import fuse_model, get_model_info, postprocess, vis
  15. from yolox import statics
  16. COCO_CLASSES = (
  17. "person",
  18. "bicycle",
  19. "car",
  20. "motorcycle",
  21. "airplane",
  22. "bus",
  23. "train",
  24. "truck",
  25. "boat",
  26. "traffic light",
  27. "fire hydrant",
  28. "stop sign",
  29. "parking meter",
  30. "bench",
  31. "bird",
  32. "cat",
  33. "dog",
  34. "horse",
  35. "sheep",
  36. "cow",
  37. "elephant",
  38. "bear",
  39. "zebra",
  40. "giraffe",
  41. "backpack",
  42. "umbrella",
  43. "handbag",
  44. "tie",
  45. "suitcase",
  46. "frisbee",
  47. "skis",
  48. "snowboard",
  49. "sports ball",
  50. "kite",
  51. "baseball bat",
  52. "baseball glove",
  53. "skateboard",
  54. "surfboard",
  55. "tennis racket",
  56. "bottle",
  57. "wine glass",
  58. "cup",
  59. "fork",
  60. "knife",
  61. "spoon",
  62. "bowl",
  63. "banana",
  64. "apple",
  65. "sandwich",
  66. "orange",
  67. "broccoli",
  68. "carrot",
  69. "hot dog",
  70. "pizza",
  71. "donut",
  72. "cake",
  73. "chair",
  74. "couch",
  75. "potted plant",
  76. "bed",
  77. "dining table",
  78. "toilet",
  79. "tv",
  80. "laptop",
  81. "mouse",
  82. "remote",
  83. "keyboard",
  84. "cell phone",
  85. "microwave",
  86. "oven",
  87. "toaster",
  88. "sink",
  89. "refrigerator",
  90. "book",
  91. "clock",
  92. "vase",
  93. "scissors",
  94. "teddy bear",
  95. "hair drier",
  96. "toothbrush",
  97. )
  98. IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
  99. use_cuda = True
  100. MOT = 'MOT20'
  101. section = 'train'
  102. root_dir = os.path.join(statics.DATA_PATH, MOT, section)
  103. classes = ['person', 'bicycle', 'car', 'motorcycle', 'truck', 'bus']
  104. fuse = False
  105. def get_labels(bboxes, cls, scores, th, tw):
  106. id = 0
  107. labels = []
  108. # print(pred['scores'])
  109. n, _ = bboxes.shape
  110. for i in range(n):
  111. if COCO_CLASSES[int(cls[i])] not in classes:
  112. # print('Rejecting',COCO_CLASSES[int(cls[i])],scores[i])
  113. continue
  114. if use_cuda:
  115. box = bboxes[i, :].detach().cpu().numpy()
  116. else:
  117. box = bboxes[i, :].detach().numpy()
  118. ## TODO: check if matches
  119. # print(box[0], box[1], box[2], box[3], '--', th, tw)
  120. # print(box[0] / th, box[1] / tw, box[2] / th, box[3] / tw)
  121. x = box[0] / th
  122. y = box[1] / tw
  123. w = (box[2] - box[0]) / th
  124. h = (box[3] - box[1]) / tw
  125. x += w / 2
  126. y += h / 2
  127. label = [0, id, x, y, w, h]
  128. # label = [0, id, box[0], box[1], (box[2] - box[0]), (box[3] - box[1])]
  129. id += 1
  130. labels.append(label)
  131. # print(id)
  132. labels0 = np.array(labels)
  133. return labels0
  134. class Predictor(object):
  135. def __init__(
  136. self,
  137. model,
  138. exp,
  139. cls_names=COCO_CLASSES,
  140. trt_file=None,
  141. decoder=None,
  142. device="cpu",
  143. fp16=False,
  144. legacy=False,
  145. ):
  146. self.model = model
  147. self.cls_names = cls_names
  148. self.decoder = decoder
  149. self.num_classes = exp.num_classes
  150. self.confthre = 0.1
  151. self.nmsthre = 0.3
  152. self.test_size = exp.test_size
  153. self.device = device
  154. self.fp16 = fp16
  155. self.preproc = ValTransform()
  156. # if trt_file is not None:
  157. # from torch2trt import TRTModule
  158. #
  159. # model_trt = TRTModule()
  160. # model_trt.load_state_dict(torch.load(trt_file))
  161. #
  162. # x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
  163. # self.model(x)
  164. # self.model = model_trt
  165. def inference(self, img):
  166. img_info = {"id": 0}
  167. if isinstance(img, str):
  168. img_info["file_name"] = os.path.basename(img)
  169. img = cv2.imread(img)
  170. else:
  171. img_info["file_name"] = None
  172. height, width = img.shape[:2]
  173. img_info["height"] = height
  174. img_info["width"] = width
  175. img_info["raw_img"] = img
  176. ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])
  177. # print(self.test_size[0] , img.shape[0], self.test_size[1] , img.shape[1])
  178. img_info["ratio"] = ratio
  179. img, _ = self.preproc(img, None, self.test_size)
  180. img = torch.from_numpy(img).unsqueeze(0)
  181. img = img.float()
  182. if self.device == "gpu":
  183. img = img.cuda()
  184. # if self.fp16:
  185. # img = img.half() # to FP16
  186. with torch.no_grad():
  187. t0 = time.time()
  188. outputs = self.model(img)
  189. if self.decoder is not None:
  190. outputs = self.decoder(outputs, dtype=outputs.type())
  191. outputs = postprocess(
  192. outputs, self.num_classes, self.confthre,
  193. self.nmsthre
  194. )
  195. # logger.info("Infer time: {:.4f}s".format(time.time() - t0))
  196. # print(img.shape)
  197. _, _, tw, th = img.shape
  198. img_info['tw'] = tw
  199. img_info['th'] = th
  200. return outputs, img_info
  201. def visual(self, output, img_info, cls_conf=0.35):
  202. ratio = img_info["ratio"]
  203. img = img_info["raw_img"]
  204. if output is None:
  205. return img
  206. output = output.cpu()
  207. bboxes = output[:, 0:4]
  208. # preprocessing: resize
  209. bboxes /= ratio
  210. cls = output[:, 6]
  211. scores = output[:, 4] * output[:, 5]
  212. vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)
  213. return vis_res
  214. def image_demo(predictor, path):
  215. folders = [f for f in listdir(path)]
  216. # folders = folders[3:]
  217. for folder in folders:
  218. print(folder)
  219. images_folder = join(join(path, folder), 'img1')
  220. images = [f for f in listdir(images_folder) if isfile(join(images_folder, f))]
  221. images = [a for a in images if a.endswith('.jpg')]
  222. images.sort()
  223. for i, image_name in enumerate(images):
  224. if i % 300 == 0:
  225. print(folder, i)
  226. outputs, img_info = predictor.inference(join(images_folder, image_name))
  227. ratio = img_info["ratio"]
  228. # print(ratio)
  229. img = img_info["raw_img"]
  230. output = outputs[0]
  231. if output is None:
  232. continue
  233. output = output.cpu()
  234. bboxes = output[:, 0:4]
  235. # preprocessing: resize
  236. bboxes /= ratio
  237. cls = output[:, 6]
  238. scores = output[:, 4] * output[:, 5]
  239. # print('cls',cls)
  240. labels0 = get_labels(bboxes, cls, scores, img_info["width"], img_info["height"])
  241. # out_path = join(images_folder, 'weak_' + imm + '.npy')
  242. # print(imm)
  243. np.savetxt(join(images_folder, image_name + '_weak_' + model_name + '.txt'), labels0, delimiter=' ')
  244. def main(exp, ckpt_file):
  245. model = exp.get_model()
  246. if use_cuda:
  247. model = model.cuda()
  248. device = 'gpu'
  249. else:
  250. device = 'cpu'
  251. model.eval()
  252. logger.info("loading checkpoint")
  253. ckpt = torch.load(ckpt_file, map_location="cpu")
  254. # load the model state dict
  255. model.load_state_dict(ckpt["model"])
  256. logger.info("loaded checkpoint done.")
  257. if fuse:
  258. logger.info("\tFusing model...")
  259. model = fuse_model(model)
  260. trt_file = None
  261. decoder = None
  262. predictor = Predictor(
  263. model, exp, COCO_CLASSES, trt_file, decoder,
  264. device, False, False,
  265. )
  266. current_time = time.localtime()
  267. image_demo(predictor, root_dir)
  268. model_name = 'yolox-x'
  269. # cuda = torch.device('cuda:1')
  270. if __name__ == "__main__":
  271. # print(COCO_CLASSES)
  272. # if use_cuda:
  273. # torch.cuda.set_device(1)
  274. # with torch.cuda.device(1):
  275. # os.environ['CUDA_VISIBLE_DEVICES'] = '1'
  276. ckpt_file = '/home/abdollahpour.ce.sharif/yolox_x.pth'
  277. exp = get_exp(None, model_name)
  278. main(exp, ckpt_file)