You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

COCO2YOLO.py 4.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import json
  2. import os
  3. import argparse
  4. parser = argparse.ArgumentParser(description='Test yolo data.')
  5. parser.add_argument('-j', help='JSON file', dest='json', required=True)
  6. parser.add_argument('-o', help='path to output folder', dest='out',required=True)
  7. args = parser.parse_args()
  8. json_file = args.json
  9. output = args.out
  10. class COCO2YOLO:
  11. def __init__(self):
  12. self._check_file_and_dir(json_file, output)
  13. self.labels = json.load(open(json_file, 'r', encoding='utf-8'))
  14. self.coco_id_name_map = self._categories()
  15. self.coco_name_list = list(self.coco_id_name_map.values())
  16. print("total images", len(self.labels['images']))
  17. print("total categories", len(self.labels['categories']))
  18. print("total labels", len(self.labels['annotations']))
  19. def _check_file_and_dir(self, file_path, dir_path):
  20. if not os.path.exists(file_path):
  21. raise ValueError("file not found")
  22. if not os.path.exists(dir_path):
  23. os.makedirs(dir_path)
  24. def _categories(self):
  25. categories = {}
  26. for cls in self.labels['categories']:
  27. categories[cls['id']] = cls['name']
  28. return categories
  29. def _load_images_info(self):
  30. images_info = {}
  31. for image in self.labels['images']:
  32. id = image['id']
  33. file_name = image['file_name']
  34. if file_name.find('\\') > -1:
  35. file_name = file_name[file_name.index('\\')+1:]
  36. w = image['width']
  37. h = image['height']
  38. images_info[id] = (file_name, w, h)
  39. return images_info
  40. def _bbox_2_yolo(self, bbox, img_w, img_h):
  41. x, y, w, h = bbox[0], bbox[1], bbox[2], bbox[3]
  42. centerx = bbox[0] + w / 2
  43. centery = bbox[1] + h / 2
  44. dw = 1 / img_w
  45. dh = 1 / img_h
  46. centerx *= dw
  47. w *= dw
  48. centery *= dh
  49. h *= dh
  50. return centerx, centery, w, h
  51. def _convert_anno(self, images_info):
  52. anno_dict = dict()
  53. for anno in self.labels['annotations']:
  54. bbox = anno['bbox']
  55. image_id = anno['image_id']
  56. category_id = anno['category_id']
  57. image_info = images_info.get(image_id)
  58. image_name = image_info[0]
  59. img_w = image_info[1]
  60. img_h = image_info[2]
  61. yolo_box = self._bbox_2_yolo(bbox, img_w, img_h)
  62. anno_info = (image_name, category_id, yolo_box)
  63. anno_infos = anno_dict.get(image_id)
  64. if not anno_infos:
  65. anno_dict[image_id] = [anno_info]
  66. else:
  67. anno_infos.append(anno_info)
  68. anno_dict[image_id] = anno_infos
  69. return anno_dict
  70. def save_classes(self):
  71. sorted_classes = list(map(lambda x: x['name'], sorted(self.labels['categories'], key=lambda x: x['id'])))
  72. print('coco names', sorted_classes)
  73. with open('coco.names', 'w', encoding='utf-8') as f:
  74. for cls in sorted_classes:
  75. f.write(cls + '\n')
  76. f.close()
  77. def coco2yolo(self):
  78. print("loading image info...")
  79. images_info = self._load_images_info()
  80. print("loading done, total images", len(images_info))
  81. print("start converting...")
  82. anno_dict = self._convert_anno(images_info)
  83. print("converting done, total labels", len(anno_dict))
  84. print("saving txt file...")
  85. self._save_txt(anno_dict)
  86. print("saving done")
  87. def _save_txt(self, anno_dict):
  88. for k, v in anno_dict.items():
  89. file_name = v[0][0].split(".")[0] + ".txt"
  90. with open(os.path.join(output, file_name), 'w', encoding='utf-8') as f:
  91. print(k, v)
  92. for obj in v:
  93. cat_name = self.coco_id_name_map.get(obj[1])
  94. category_id = self.coco_name_list.index(cat_name)
  95. box = ['{:.6f}'.format(x) for x in obj[2]]
  96. box = ' '.join(box)
  97. line = str(category_id) + ' ' + box
  98. f.write(line + '\n')
  99. if __name__ == '__main__':
  100. c2y = COCO2YOLO()
  101. c2y.coco2yolo()