You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

image_patcher.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. import csv
  2. import json
  3. import os
  4. import os.path as os_path
  5. import random
  6. import re
  7. from math import ceil
  8. from os import listdir
  9. from os.path import isfile, join
  10. import cv2
  11. import tifffile
  12. import zarr as ZarrObject
  13. from tqdm import tqdm
  14. from config import Config
  15. from database_crawlers.web_stain_sample import ThyroidCancerLevel, WebStainImage
  16. from utils import show_and_wait
  17. class ThyroidFragmentFilters:
  18. @staticmethod
  19. def func_laplacian_threshold(threshold=Config.laplacian_threshold):
  20. def wrapper(image_nd_array):
  21. res = ThyroidFragmentFilters._empty_frag_with_laplacian_threshold(image_nd_array, threshold)
  22. return res
  23. return wrapper
  24. @staticmethod
  25. def _empty_frag_with_laplacian_threshold(image_nd_array, threshold=Config.laplacian_threshold,
  26. return_variance=False):
  27. gray = cv2.cvtColor(image_nd_array, cv2.COLOR_BGR2GRAY)
  28. gray = cv2.GaussianBlur(gray, (3, 3), 0)
  29. laplacian = cv2.Laplacian(gray, cv2.CV_64F, ksize=3, )
  30. std = cv2.meanStdDev(laplacian)[1][0][0]
  31. variance = std ** 2
  32. if return_variance:
  33. return variance >= threshold, variance
  34. return variance >= threshold
  35. class ImageAndSlidePatcher:
  36. @classmethod
  37. def _check_magnification_from_description(cls, tiff_address):
  38. try:
  39. tif_file_obj = tifffile.TiffFile(tiff_address)
  40. image_description = tif_file_obj.pages.keyframe.tags["ImageDescription"].value
  41. app_mag = int(re.findall("(AppMag = [0-9]+)", image_description)[0].split(" = ")[-1])
  42. return app_mag
  43. except Exception as e:
  44. return None
  45. @classmethod
  46. def _zarr_loader(cls, tiff_address, key=0):
  47. image_zarr = tifffile.imread(tiff_address, aszarr=True, key=key, )
  48. zarr = ZarrObject.open(image_zarr, mode='r')
  49. return zarr
  50. @classmethod
  51. def _jpeg_loader(cls, jpeg_address):
  52. im = cv2.imread(jpeg_address)
  53. return im
  54. @classmethod
  55. def _json_key_loader(cls, json_file_address, key=None):
  56. with open(json_file_address, 'rb') as file:
  57. json_dict = json.loads(file.read())
  58. if key:
  59. return json_dict[key]
  60. return json_dict
  61. @classmethod
  62. def _get_extension_from_path(cls, file_path):
  63. return os_path.splitext(file_path)[-1]
  64. @classmethod
  65. def _get_file_name_from_path(cls, file_path):
  66. return ".".join(os_path.split(file_path)[-1].split(".")[:-1])
  67. @classmethod
  68. def _get_number_of_initial_frags(cls, zarr_object, frag_size=512, frag_overlap=0.1):
  69. zarr_shape = zarr_object.shape
  70. step_size = int(frag_size * (1 - frag_overlap))
  71. overlap_size = frag_size - step_size
  72. w_range = list(range(0, ceil((zarr_shape[0] - overlap_size) / step_size) * step_size, step_size))
  73. h_range = list(range(0, ceil((zarr_shape[1] - overlap_size) / step_size) * step_size, step_size))
  74. return len(w_range) * len(h_range)
  75. @classmethod
  76. def _generate_raw_fragments_from_image_array_or_zarr(cls, image_object, frag_size=512, frag_overlap=0.1,
  77. shuffle=True):
  78. def frag_picker(w_pos, h_pos):
  79. end_w, end_h = min(zarr_shape[0], w_pos + frag_size), min(zarr_shape[1], h_pos + frag_size)
  80. start_w, start_h = end_w - frag_size, end_h - frag_size
  81. return image_object[start_w:end_w, start_h: end_h], (start_w, start_h)
  82. if image_object is None:
  83. return None
  84. zarr_shape = image_object.shape
  85. step_size = int(frag_size * (1 - frag_overlap))
  86. overlap_size = frag_size - step_size
  87. w_range = list(range(0, ceil((zarr_shape[0] - overlap_size) / step_size) * step_size, step_size))
  88. h_range = list(range(0, ceil((zarr_shape[1] - overlap_size) / step_size) * step_size, step_size))
  89. if shuffle:
  90. pos_list = [None] * len(w_range) * len(h_range)
  91. index = 0
  92. for w in w_range:
  93. for h in h_range:
  94. pos_list[index] = (w, h)
  95. index += 1
  96. random.shuffle(pos_list)
  97. for w, h in pos_list:
  98. yield frag_picker(w, h)
  99. else:
  100. for w in w_range:
  101. for h in h_range:
  102. yield frag_picker(w, h)
  103. @classmethod
  104. def _filter_frag_from_generator(cls, frag_generator, filter_func_list, return_all_with_condition=False,
  105. all_frag_count=None, output_file=None):
  106. for next_test_item, frag_pos in tqdm(frag_generator, total=all_frag_count, file=output_file,
  107. postfix="Filtering", position=0):
  108. condition = True
  109. for function in filter_func_list:
  110. condition &= function(next_test_item)
  111. if return_all_with_condition:
  112. yield next_test_item, frag_pos, condition
  113. elif condition:
  114. # show_and_wait(frag)
  115. yield next_test_item, frag_pos
  116. @classmethod
  117. def _get_json_and_image_address_of_directory(cls, directory_path, ignore_json=False):
  118. image_formats = [".jpeg", ".tiff", ".jpg"]
  119. json_format = ".json"
  120. files = [f for f in listdir(directory_path) if isfile(join(directory_path, f))]
  121. files.sort()
  122. pairs = {}
  123. for file_path in files:
  124. file_path = join(directory_path, file_path)
  125. file_name = cls._get_file_name_from_path(file_path)
  126. pairs[file_name] = pairs.get(file_name, [None, None])
  127. if cls._get_extension_from_path(file_path) in image_formats:
  128. pairs[file_name][1] = file_path
  129. elif cls._get_extension_from_path(file_path) == json_format:
  130. pairs[file_name][0] = file_path
  131. if ignore_json:
  132. return [value for key, value in pairs.values() if value is not None]
  133. return [(key, value) for key, value in pairs.values() if key is not None and value is not None]
  134. @staticmethod
  135. def create_patch_dir_and_initialize_csv(database_path):
  136. data_dir = os.path.join(database_path, "data")
  137. patch_dir = os.path.join(database_path, "patches")
  138. if not os.path.isdir(patch_dir):
  139. os.mkdir(patch_dir)
  140. label_csv_path = os.path.join(patch_dir, "patch_labels.csv")
  141. csv_file = open(label_csv_path, "a+")
  142. csv_writer = csv.writer(csv_file)
  143. csv_file.seek(0)
  144. if len(csv_file.read(100)) <= 0:
  145. csv_writer.writerow(WebStainImage.sorted_json_keys())
  146. return data_dir, patch_dir, csv_writer, csv_file
  147. @classmethod
  148. def save_image_patches_and_update_csv(cls, thyroid_type, thyroid_desired_classes, csv_writer, web_details,
  149. image_path, slide_patch_dir, slide_id):
  150. csv_writer.writerow(list(web_details.values()))
  151. if cls._get_extension_from_path(image_path) in [".tiff", ".tif", ".svs"]:
  152. zarr_object = cls._zarr_loader(image_path)
  153. generator = cls._generate_raw_fragments_from_image_array_or_zarr(zarr_object)
  154. total_counts = cls._get_number_of_initial_frags(zarr_object=zarr_object)
  155. else:
  156. jpeg_image = cls._jpeg_loader(image_path)
  157. jpeg_image = cls.ask_image_scale_and_rescale(jpeg_image)
  158. generator = cls._generate_raw_fragments_from_image_array_or_zarr(jpeg_image)
  159. total_counts = cls._get_number_of_initial_frags(zarr_object=jpeg_image)
  160. if generator is None:
  161. return
  162. if not os.path.isdir(slide_patch_dir):
  163. os.mkdir(slide_patch_dir)
  164. filters = [ThyroidFragmentFilters.func_laplacian_threshold(Config.laplacian_threshold)]
  165. fragment_id = 0
  166. slide_progress_file_path = os.path.join(slide_patch_dir, "progress.txt")
  167. with open(slide_progress_file_path, "w") as file:
  168. for fragment, frag_pos in cls._filter_frag_from_generator(generator, filters, all_frag_count=total_counts,
  169. output_file=file):
  170. fragment_file_path = os.path.join(slide_patch_dir, f"{slide_id}-{fragment_id}.jpeg")
  171. cv2.imwrite(fragment_file_path, fragment)
  172. fragment_id += 1
  173. return fragment_id, total_counts
  174. @classmethod
  175. def save_patches_in_folders(cls, database_directory, dataset_dir=None):
  176. thyroid_desired_classes = [ThyroidCancerLevel.MALIGNANT, ThyroidCancerLevel.BENIGN]
  177. datasets_dirs = os.listdir(database_directory) if dataset_dir is None else [dataset_dir]
  178. list_dir = [os.path.join(database_directory, o) for o in datasets_dirs
  179. if os.path.isdir(os.path.join(database_directory, o, "data"))]
  180. for database_path in list_dir:
  181. print("database path: ", database_path)
  182. data_dir, patch_dir, csv_writer, csv_file = cls.create_patch_dir_and_initialize_csv(database_path)
  183. for json_path, image_path in cls._get_json_and_image_address_of_directory(data_dir):
  184. print("image path: ", image_path)
  185. file_name = cls._get_file_name_from_path(image_path)
  186. slide_id = str(hash(file_name))
  187. slide_patch_dir = os.path.join(patch_dir, slide_id)
  188. if os.path.isdir(slide_patch_dir):
  189. """
  190. it has already been patched
  191. """
  192. continue
  193. web_details = cls._json_key_loader(json_path)
  194. web_details["image_id"] = slide_id
  195. web_label = web_details["image_web_label"]
  196. thyroid_type = ThyroidCancerLevel.get_thyroid_level_from_diagnosis_label(web_label)
  197. web_details["image_class_label"] = thyroid_type.value[1]
  198. cls.save_image_patches_and_update_csv(thyroid_type, thyroid_desired_classes, csv_writer, web_details,
  199. image_path, slide_patch_dir, slide_id)
  200. csv_file.close()
  201. @classmethod
  202. def save_papsociaty_patch(cls, database_path):
  203. thyroid_desired_classes = [ThyroidCancerLevel.MALIGNANT, ThyroidCancerLevel.BENIGN]
  204. print("database path: ", database_path)
  205. for folder in Config.class_names:
  206. group_path = os.path.join(database_path, "data", folder)
  207. data_dir, patch_dir, csv_writer, csv_file = cls.create_patch_dir_and_initialize_csv(database_path)
  208. for image_path in cls._get_json_and_image_address_of_directory(group_path, ignore_json=True):
  209. print("image path: ", image_path)
  210. file_name = cls._get_file_name_from_path(image_path)
  211. slide_id = str(hash(file_name))
  212. slide_patch_dir = os.path.join(patch_dir, slide_id)
  213. if os.path.isdir(slide_patch_dir):
  214. """
  215. it has already been patched
  216. """
  217. continue
  218. web_label = folder + "-" + file_name
  219. thyroid_type = ThyroidCancerLevel.get_thyroid_level_from_diagnosis_label(web_label)
  220. web_details = {"database_name": "PapSociety",
  221. "image_id": slide_id,
  222. "image_web_label": web_label,
  223. "image_class_label": thyroid_type.value[1],
  224. "report": None,
  225. "stain_type": "UNKNOWN",
  226. "is_wsi": False}
  227. cls.save_image_patches_and_update_csv(thyroid_type, thyroid_desired_classes, csv_writer, web_details,
  228. image_path, slide_patch_dir, slide_id)
  229. csv_file.close()
  230. @classmethod
  231. def ask_image_scale_and_rescale(cls, image):
  232. # small: S, Medium: M, Large:L
  233. show_and_wait(image)
  234. res = input("how much plus pointer fill a cell(float, i:ignore, else repeat): ")
  235. try:
  236. if res == "i":
  237. return None
  238. elif re.match("[0-9]+(.[0-9]*)?", res):
  239. scale = 1 / float(res)
  240. return cv2.resize(image, (0, 0), fx=scale, fy=scale)
  241. else:
  242. return cls.ask_image_scale_and_rescale(image)
  243. except Exception as e:
  244. print(e)
  245. return cls.ask_image_scale_and_rescale(image)
  246. if __name__ == '__main__':
  247. random.seed(1)
  248. database_directory = "./"
  249. # ImageAndSlidePatcher.save_patches_in_folders(database_directory, dataset_dir=["stanford_tissue_microarray"])
  250. # ImageAndSlidePatcher.save_papsociaty_patch(os.path.join(database_directory, "papsociaty"))