You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

fragment_splitter.py 8.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import csv
  2. import glob
  3. import os
  4. import random
  5. from tqdm import tqdm
  6. from config import Config
  7. class CustomFragmentLoader:
  8. def __init__(self, datasets_folder_name):
  9. self._datasets_folder_name = datasets_folder_name
  10. self._database_slide_dict = {}
  11. self._load_csv_files_to_dict()
  12. def _load_csv_files_to_dict(self):
  13. databases_directory = "../../../database_crawlers/"
  14. list_dir = [os.path.join(databases_directory, o, "patches") for o in self._datasets_folder_name
  15. if os.path.isdir(os.path.join(databases_directory, o, "patches"))]
  16. for db_dir in list_dir:
  17. csv_dir = os.path.join(db_dir, "patch_labels.csv")
  18. with open(csv_dir, "r") as csv_file:
  19. csv_reader = csv.reader(csv_file)
  20. header = next(csv_reader, None)
  21. for row in csv_reader:
  22. if row:
  23. database_id = row[0]
  24. image_id = row[1]
  25. slide_frag_folder_name = [o for o in os.listdir(db_dir) if image_id.startswith(o)]
  26. if slide_frag_folder_name:
  27. slide_frag_folder_name = slide_frag_folder_name[0]
  28. else:
  29. continue
  30. slide_path = os.path.join(db_dir, slide_frag_folder_name)
  31. image_paths = glob.glob(os.path.join(slide_path, "*.jpeg"))
  32. if image_paths:
  33. d = self._database_slide_dict.get(database_id, {})
  34. d[image_id] = [image_paths] + [row[3], row[2]]
  35. self._database_slide_dict[database_id] = d
  36. def load_image_path_and_labels_and_split(self, test_percent=20, val_percent=10):
  37. train_images, val_images, test_images = [], [], []
  38. for database_name, slides_dict in self._database_slide_dict.items():
  39. image_paths_by_slide = [(len(v[0]), v[0], v[1], v[2]) for v in slides_dict.values()]
  40. random.shuffle(image_paths_by_slide)
  41. # image_paths_by_slide.sort()
  42. class_slides_dict = {}
  43. for item in image_paths_by_slide:
  44. class_name = None
  45. if database_name == "NationalCancerInstitute":
  46. normal_percent = int(item[2].strip(r"(|)|\'").split("\', \'")[0])
  47. tumor_percent = int(item[2].strip(r"(|)|\'").split("\', \'")[1])
  48. stormal_percent = int(item[2].strip(r"(|)|\'").split("\', \'")[2])
  49. if stormal_percent == 0:
  50. if tumor_percent == 100:
  51. class_name = "MALIGNANT"
  52. elif normal_percent == 100:
  53. class_name = "BENIGN"
  54. else:
  55. class_name = str(tumor_percent)
  56. elif database_name == "BioAtlasThyroidSlideProvider":
  57. if "papillary" in item[3].lower():
  58. class_name = "MALIGNANT"
  59. elif "normal" in item[3].lower():
  60. class_name = "BENIGN"
  61. class_name = class_name if class_name else item[2]
  62. if class_name in Config.class_names:
  63. class_slides_dict[class_name] = class_slides_dict.get(class_name, []) + [
  64. (item[0], item[1], class_name)]
  65. # split test val train because they must not share same slide id fragment
  66. for thyroid_class, slide_frags in class_slides_dict.items():
  67. dataset_train_images, dataset_val_images, dataset_test_images = [], [], []
  68. total_counts = sum([item[0] for item in slide_frags])
  69. test_counts = total_counts * test_percent // 100
  70. val_counts = total_counts * val_percent // 100
  71. train_counts = total_counts - test_counts - val_counts
  72. for i, slide_frags_item in enumerate(slide_frags):
  73. if len(dataset_train_images) + slide_frags_item[0] <= train_counts:
  74. dataset_train_images += slide_frags_item[1]
  75. elif len(dataset_val_images) + slide_frags_item[0] <= val_counts:
  76. dataset_val_images += slide_frags_item[1]
  77. else:
  78. dataset_test_images += slide_frags_item[1]
  79. train_images += [(i, thyroid_class) for i in dataset_train_images]
  80. val_images += [(i, thyroid_class) for i in dataset_val_images]
  81. test_images += [(i, thyroid_class) for i in dataset_test_images]
  82. return train_images, val_images, test_images
  83. def national_cancer_image_and_labels_splitter_per_slide(self, test_percent=20, val_percent=10):
  84. train_images, val_images, test_images = [], [], []
  85. for database_name, slides_dict in self._database_slide_dict.items():
  86. print(database_name)
  87. image_paths_by_slide = [(len(v[0]), v[0], v[1], v[2], k) for k, v in slides_dict.items()]
  88. random.shuffle(image_paths_by_slide)
  89. # image_paths_by_slide.sort()
  90. class_slides_dict = {}
  91. for item in tqdm(image_paths_by_slide):
  92. class_name = None
  93. normal_percent = int(item[2].strip(r"(|)|\'").split("\', \'")[0])
  94. tumor_percent = int(item[2].strip(r"(|)|\'").split("\', \'")[1])
  95. stormal_percent = int(item[2].strip(r"(|)|\'").split("\', \'")[2])
  96. if stormal_percent == 0:
  97. if tumor_percent == 100:
  98. class_name = 100
  99. elif normal_percent == 100:
  100. class_name = 0
  101. else:
  102. class_name = tumor_percent
  103. class_name = class_name if class_name is not None else item[2]
  104. if class_name in Config.class_names:
  105. class_slides_dict[class_name] = class_slides_dict.get(class_name, []) + [
  106. (item[0], item[1], class_name, item[4])]
  107. # split test val train because they must not share same slide id fragment
  108. for thyroid_class, slide_frags in class_slides_dict.items():
  109. dataset_train_images, dataset_val_images, dataset_test_images = [], [], []
  110. total_counts = sum([item[0] for item in slide_frags])
  111. test_counts = total_counts * test_percent // 100
  112. val_counts = total_counts * val_percent // 100
  113. train_counts = total_counts - test_counts - val_counts
  114. for i, slide_frags_item in enumerate(slide_frags):
  115. items_paths = [(item_path, slide_frags_item[3]) for item_path in slide_frags_item[1]]
  116. if len(dataset_train_images) + slide_frags_item[0] <= train_counts:
  117. dataset_train_images += items_paths
  118. elif len(dataset_val_images) + slide_frags_item[0] <= val_counts:
  119. dataset_val_images += items_paths
  120. else:
  121. dataset_test_images += items_paths
  122. train_images += [(i, (thyroid_class, j)) for i, j in dataset_train_images]
  123. val_images += [(i, (thyroid_class, j)) for i, j in dataset_val_images]
  124. test_images += [(i, (thyroid_class, j)) for i, j in dataset_test_images]
  125. return train_images, val_images, test_images
  126. if __name__ == '__main__':
  127. # datasets_folder = ["national_cancer_institute"]
  128. datasets_folder = ["papsociaty"]
  129. # datasets_folder = ["stanford_tissue_microarray"]
  130. # datasets_folder = ["bio_atlas_at_jake_gittlen_laboratories"]
  131. train, val, test = CustomFragmentLoader(datasets_folder).load_image_path_and_labels_and_split(
  132. val_percent=Config.val_percent,
  133. test_percent=Config.test_percent)
  134. benign_train = [i for i in train if i[1] == "BENIGN"]
  135. mal_train = [i for i in train if i[1] == "MALIGNANT"]
  136. print(f"train: {len(train)}={len(benign_train)}+{len(mal_train)}")
  137. benign_val = [i for i in val if i[1] == "BENIGN"]
  138. mal_val = [i for i in val if i[1] == "MALIGNANT"]
  139. print(f"val: {len(val)}={len(benign_val)}+{len(mal_val)}")
  140. benign_test = [i for i in test if i[1] == "BENIGN"]
  141. mal_test = [i for i in test if i[1] == "MALIGNANT"]
  142. print(f"test: {len(test)}={len(benign_test)}+{len(mal_test)}")
  143. print(set(train) & set(test))
  144. print(set(train) & set(val))
  145. print(set(test) & set(val))
  146. print(len(set(val) & set(val)))