You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

COCO2YOLO.py 4.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import json
  2. import os
  3. import argparse
  4. parser = argparse.ArgumentParser(description='Test yolo data.')
  5. parser.add_argument('-j', help='JSON file', dest='json', required=True)
  6. parser.add_argument('-o', help='path to output folder', dest='out',required=True)
  7. args = parser.parse_args()
  8. # train
  9. json_file = args.json
  10. output = args.out
  11. # val
  12. # json_file = "/home/user/datasets/coco/annotations/instances_val2017.json"
  13. # output = "/home/user/datasets/coco/labels/val2017/"
  14. """
  15. 将COCO instance数据集转为YOLO格式
  16. """
  17. class COCO2YOLO:
  18. def __init__(self):
  19. self._check_file_and_dir(json_file, output)
  20. self.labels = json.load(open(json_file, 'r', encoding='utf-8'))
  21. self.coco_id_name_map = self._categories()
  22. self.coco_name_list = list(self.coco_id_name_map.values())
  23. # print(self.labels['images'][0])
  24. # print(self.labels['categories'][0])
  25. # print(self.labels['annotations'][0])
  26. print("total images", len(self.labels['images']))
  27. print("total categories", len(self.labels['categories']))
  28. print("total labels", len(self.labels['annotations']))
  29. def _check_file_and_dir(self, file_path, dir_path):
  30. if not os.path.exists(file_path):
  31. raise ValueError("file not found")
  32. if not os.path.exists(dir_path):
  33. os.makedirs(dir_path)
  34. def _categories(self):
  35. categories = {}
  36. for cls in self.labels['categories']:
  37. categories[cls['id']] = cls['name']
  38. return categories
  39. def _load_images_info(self):
  40. images_info = {}
  41. for image in self.labels['images']:
  42. id = image['id']
  43. file_name = image['file_name']
  44. w = image['width']
  45. h = image['height']
  46. images_info[id] = (file_name, w, h)
  47. return images_info
  48. def _bbox_2_yolo(self, bbox, img_w, img_h):
  49. # bbox矩形框, 左上角坐标 , 宽, 高
  50. x, y, w, h = bbox[0], bbox[1], bbox[2], bbox[3]
  51. centerx = bbox[0] + w / 2
  52. centery = bbox[1] + h / 2
  53. dw = 1 / img_w
  54. dh = 1 / img_h
  55. centerx *= dw
  56. w *= dw
  57. centery *= dh
  58. h *= dh
  59. return centerx, centery, w, h
  60. def _convert_anno(self, images_info):
  61. anno_dict = dict()
  62. for anno in self.labels['annotations']:
  63. bbox = anno['bbox']
  64. image_id = anno['image_id']
  65. category_id = anno['category_id']
  66. image_info = images_info.get(image_id)
  67. image_name = image_info[0]
  68. img_w = image_info[1]
  69. img_h = image_info[2]
  70. yolo_box = self._bbox_2_yolo(bbox, img_w, img_h)
  71. anno_info = (image_name, category_id, yolo_box)
  72. anno_infos = anno_dict.get(image_id)
  73. if not anno_infos:
  74. anno_dict[image_id] = [anno_info]
  75. else:
  76. anno_infos.append(anno_info)
  77. anno_dict[image_id] = anno_infos
  78. return anno_dict
  79. def save_classes(self):
  80. sorted_classes = list(map(lambda x: x['name'], sorted(self.labels['categories'], key=lambda x: x['id'])))
  81. print('coco names', sorted_classes)
  82. with open('coco.names', 'w', encoding='utf-8') as f:
  83. for cls in sorted_classes:
  84. f.write(cls + '\n')
  85. f.close()
  86. def coco2yolo(self):
  87. print("loading image info...")
  88. images_info = self._load_images_info()
  89. print("loading done, total images", len(images_info))
  90. print("start converting...")
  91. anno_dict = self._convert_anno(images_info)
  92. print("converting done, total labels", len(anno_dict))
  93. print("saving txt file...")
  94. self._save_txt(anno_dict)
  95. print("saving done")
  96. def _save_txt(self, anno_dict):
  97. for k, v in anno_dict.items():
  98. file_name = v[0][0].split(".")[0] + ".txt"
  99. with open(os.path.join(output, file_name), 'w', encoding='utf-8') as f:
  100. print(k, v)
  101. for obj in v:
  102. cat_name = self.coco_id_name_map.get(obj[1])
  103. category_id = self.coco_name_list.index(cat_name)
  104. box = ['{:.6f}'.format(x) for x in obj[2]]
  105. box = ' '.join(box)
  106. line = str(category_id) + ' ' + box
  107. f.write(line + '\n')
  108. if __name__ == '__main__':
  109. c2y = COCO2YOLO()
  110. c2y.coco2yolo()