You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

COCO2YOLO.py 4.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import json
  2. import os
  3. import argparse
  4. parser = argparse.ArgumentParser(description='Test yolo data.')
  5. parser.add_argument('-j', help='JSON file', dest='json', required=True)
  6. parser.add_argument('-o', help='path to output folder', dest='out',required=True)
  7. args = parser.parse_args()
  8. json_file = args.json
  9. output = args.out
  10. class COCO2YOLO:
  11. def __init__(self):
  12. self._check_file_and_dir(json_file, output)
  13. self.labels = json.load(open(json_file, 'r', encoding='utf-8'))
  14. self.coco_id_name_map = self._categories()
  15. self.coco_name_list = list(self.coco_id_name_map.values())
  16. print("total images", len(self.labels['images']))
  17. print("total categories", len(self.labels['categories']))
  18. print("total labels", len(self.labels['annotations']))
  19. def _check_file_and_dir(self, file_path, dir_path):
  20. if not os.path.exists(file_path):
  21. raise ValueError("file not found")
  22. if not os.path.exists(dir_path):
  23. os.makedirs(dir_path)
  24. def _categories(self):
  25. categories = {}
  26. for cls in self.labels['categories']:
  27. categories[cls['id']] = cls['name']
  28. return categories
  29. def _load_images_info(self):
  30. images_info = {}
  31. for image in self.labels['images']:
  32. id = image['id']
  33. file_name = image['file_name']
  34. w = image['width']
  35. h = image['height']
  36. images_info[id] = (file_name, w, h)
  37. return images_info
  38. def _bbox_2_yolo(self, bbox, img_w, img_h):
  39. x, y, w, h = bbox[0], bbox[1], bbox[2], bbox[3]
  40. centerx = bbox[0] + w / 2
  41. centery = bbox[1] + h / 2
  42. dw = 1 / img_w
  43. dh = 1 / img_h
  44. centerx *= dw
  45. w *= dw
  46. centery *= dh
  47. h *= dh
  48. return centerx, centery, w, h
  49. def _convert_anno(self, images_info):
  50. anno_dict = dict()
  51. for anno in self.labels['annotations']:
  52. bbox = anno['bbox']
  53. image_id = anno['image_id']
  54. category_id = anno['category_id']
  55. image_info = images_info.get(image_id)
  56. image_name = image_info[0]
  57. img_w = image_info[1]
  58. img_h = image_info[2]
  59. yolo_box = self._bbox_2_yolo(bbox, img_w, img_h)
  60. anno_info = (image_name, category_id, yolo_box)
  61. anno_infos = anno_dict.get(image_id)
  62. if not anno_infos:
  63. anno_dict[image_id] = [anno_info]
  64. else:
  65. anno_infos.append(anno_info)
  66. anno_dict[image_id] = anno_infos
  67. return anno_dict
  68. def save_classes(self):
  69. sorted_classes = list(map(lambda x: x['name'], sorted(self.labels['categories'], key=lambda x: x['id'])))
  70. print('coco names', sorted_classes)
  71. with open('coco.names', 'w', encoding='utf-8') as f:
  72. for cls in sorted_classes:
  73. f.write(cls + '\n')
  74. f.close()
  75. def coco2yolo(self):
  76. print("loading image info...")
  77. images_info = self._load_images_info()
  78. print("loading done, total images", len(images_info))
  79. print("start converting...")
  80. anno_dict = self._convert_anno(images_info)
  81. print("converting done, total labels", len(anno_dict))
  82. print("saving txt file...")
  83. self._save_txt(anno_dict)
  84. print("saving done")
  85. def _save_txt(self, anno_dict):
  86. for k, v in anno_dict.items():
  87. file_name = v[0][0].split(".")[0] + ".txt"
  88. with open(os.path.join(output, file_name), 'w', encoding='utf-8') as f:
  89. print(k, v)
  90. for obj in v:
  91. cat_name = self.coco_id_name_map.get(obj[1])
  92. category_id = self.coco_name_list.index(cat_name)
  93. box = ['{:.6f}'.format(x) for x in obj[2]]
  94. box = ' '.join(box)
  95. line = str(category_id) + ' ' + box
  96. f.write(line + '\n')
  97. if __name__ == '__main__':
  98. c2y = COCO2YOLO()
  99. c2y.coco2yolo()