You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluate_image_patcher_and_visualize.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. import concurrent.futures
  2. import math
  3. import cv2
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. from config import Config
  7. from database_crawlers.image_patcher.image_patcher import ImageAndSlidePatcher, ThyroidFragmentFilters
  8. from utils import check_if_generator_is_empty
  9. def imul(a, b):
  10. return math.ceil(a * b)
  11. def calculate_acc_and_sensitivity(image_path, zarr_loader_mask, zarr_loader, frag_generator, scaled_masked_image,
  12. generated_mask_scale, laplacian_threshold, slide_patch_size,
  13. save_generated_image=True):
  14. def process_frag(args):
  15. next_test_item, frag_pos, condition = args
  16. frag_shape = next_test_item.shape
  17. mask_scaled_frag_shape = list((imul(frag_shape[i], mask_scale) for i in range(2)))
  18. mask_frag_pos = list((imul(frag_pos[i], mask_scale) for i in range(2)))
  19. mask_w1, mask_w2 = mask_frag_pos[0], mask_frag_pos[0] + mask_scaled_frag_shape[0]
  20. mask_h1, mask_h2 = mask_frag_pos[1], mask_frag_pos[1] + mask_scaled_frag_shape[1]
  21. mask_item = zarr_loader_mask[mask_w1:mask_w2, mask_h1:mask_h2]
  22. mask_item = cv2.resize(mask_item, dsize=(0, 0), fx=1 / mask_scale, fy=1 / mask_scale)
  23. fragment_size = next_test_item.shape
  24. scaled_frag_size = (imul(fragment_size[0], generated_mask_scale), imul(fragment_size[1], generated_mask_scale))
  25. scaled_frag = cv2.resize(next_test_item[:, :, :3], dsize=scaled_frag_size, interpolation=cv2.INTER_CUBIC)
  26. scaled_frag_size = scaled_frag.shape
  27. if next_test_item is not None:
  28. mask_item = mask_item[:, :, 0]
  29. masked = mask_item.mean() > 256 * .3
  30. if condition and masked:
  31. background_dict["TP"] += 1
  32. elif condition and not masked:
  33. background_dict["FP"] += 1
  34. elif not condition and masked:
  35. background_dict["FN"] += 1
  36. # show_and_wait(next_test_item)
  37. # show_and_wait(mask_item)
  38. elif not condition and not masked:
  39. background_dict["TN"] += 1
  40. else:
  41. return None
  42. if not condition:
  43. # background patches get dark
  44. scaled_frag = (scaled_frag * 0.3).astype(np.int8)
  45. scaled_pos = list((imul(frag_pos[i], generated_mask_scale) for i in range(2)))
  46. try:
  47. mask_g_w1, mask_g_w2 = scaled_pos[0], scaled_pos[0] + scaled_frag_size[0]
  48. mask_g_h1, mask_g_h2 = scaled_pos[1], scaled_pos[1] + scaled_frag_size[1]
  49. scaled_masked_image[mask_g_w1:mask_g_w2, mask_g_h1:mask_g_h2] = scaled_frag
  50. except Exception as e:
  51. print(e)
  52. return True
  53. mask_scale = zarr_loader_mask.shape[0] / zarr_loader.shape[0]
  54. filter_func_list = [ThyroidFragmentFilters.func_laplacian_threshold(laplacian_threshold)]
  55. background_dict = {"TP": 0, "FP": 0, "TN": 0, "FN": 0}
  56. total_frags = slide_patch_size if slide_patch_size else ImageAndSlidePatcher._get_number_of_initial_frags(
  57. zarr_loader)
  58. frag_filtered = ImageAndSlidePatcher._filter_frag_from_generator(frag_generator, filter_func_list,
  59. return_all_with_condition=True,
  60. all_frag_count=total_frags)
  61. with concurrent.futures.ThreadPoolExecutor(max_workers=Config.workers) as executor:
  62. futures = []
  63. patch_count = 0
  64. for args in frag_filtered:
  65. patch_count += 1
  66. future_res = executor.submit(process_frag, args)
  67. futures.append(future_res)
  68. if len(futures) >= Config.workers or patch_count == slide_patch_size:
  69. for future in concurrent.futures.as_completed(futures):
  70. future.result()
  71. futures = []
  72. if patch_count == slide_patch_size:
  73. break
  74. if save_generated_image:
  75. masked_image_path = ".".join(image_path.split(".")[:-1]) + "_generated_mask.jpg"
  76. cv2.imwrite(masked_image_path, scaled_masked_image)
  77. return background_dict
  78. def score_calculator(accuracy, specificity, acc_w=0.75):
  79. return accuracy * acc_w + specificity * (1 - acc_w)
  80. def get_zarr_loaders_and_generators():
  81. zarr_loaders_and_generators = []
  82. for _img_mask_path, _img_path in image_lists:
  83. _zarr_loader_mask = ImageAndSlidePatcher._zarr_loader(_img_mask_path)
  84. _zarr_loader = ImageAndSlidePatcher._zarr_loader(_img_path)
  85. _frag_generator = ImageAndSlidePatcher._generate_raw_fragments_from_image_array_or_zarr(_zarr_loader,
  86. shuffle=True)
  87. _zarr_shape = _zarr_loader.shape
  88. _generated_mask_scale = 10 / 512
  89. _scaled_zarr_shape = (
  90. imul(_zarr_shape[0], _generated_mask_scale) + 5, imul(_zarr_shape[1], _generated_mask_scale) + 5, 3)
  91. _scaled_masked_image = np.zeros(_scaled_zarr_shape)
  92. zarr_loaders_and_generators.append([
  93. _zarr_loader_mask, _zarr_loader, _frag_generator, _scaled_masked_image, _generated_mask_scale
  94. ])
  95. return zarr_loaders_and_generators
  96. def update_and_find_best_threshold(initial_thresh, learn_threshold_and_log_cf_matrix_per_patch=True):
  97. initial_threshold_jump_size_const = 120
  98. threshold_jump_size = initial_threshold_jump_size_const
  99. decay_const = 0.85
  100. decay_count = 0
  101. threshold_jump_increase = 1
  102. threshold_score = None
  103. # update after initial run
  104. laplacian_threshold = initial_thresh
  105. threshold_history = []
  106. score_history = []
  107. for epoch in range((Config.n_epoch_for_image_patcher if learn_threshold_and_log_cf_matrix_per_patch else 1)):
  108. print("New Epoch")
  109. zarr_loaders_and_generators = get_zarr_loaders_and_generators()
  110. whole_background_dict_per_slide = [{} for i in range(len(zarr_loaders_and_generators))]
  111. whole_background_dict = {}
  112. while sum([item is not None for item in zarr_loaders_and_generators]) >= 1:
  113. none_empty_generators = [i for i in range(len(zarr_loaders_and_generators)) if
  114. zarr_loaders_and_generators[i] is not None]
  115. if learn_threshold_and_log_cf_matrix_per_patch:
  116. whole_background_dict = {}
  117. if len(none_empty_generators) >= 6 or not learn_threshold_and_log_cf_matrix_per_patch:
  118. for slide_pick in none_empty_generators:
  119. img_path = image_lists[slide_pick][1]
  120. zarr_loader_mask = zarr_loaders_and_generators[slide_pick][0]
  121. zarr_loader = zarr_loaders_and_generators[slide_pick][1]
  122. frag_generator = zarr_loaders_and_generators[slide_pick][2]
  123. generated_scaled_mask_image = zarr_loaders_and_generators[slide_pick][3]
  124. generated_mask_scale = zarr_loaders_and_generators[slide_pick][4]
  125. group_dict = calculate_acc_and_sensitivity(img_path,
  126. zarr_loader_mask,
  127. zarr_loader,
  128. frag_generator,
  129. generated_scaled_mask_image,
  130. generated_mask_scale,
  131. laplacian_threshold,
  132. slide_patch_size=2000,
  133. save_generated_image=not learn_threshold_and_log_cf_matrix_per_patch)
  134. for i in range(len(zarr_loaders_and_generators)):
  135. if zarr_loaders_and_generators[i]:
  136. generator = check_if_generator_is_empty(zarr_loaders_and_generators[i][2])
  137. if generator:
  138. zarr_loaders_and_generators[i][2] = generator
  139. else:
  140. zarr_loaders_and_generators[i] = None
  141. for key, value in group_dict.items():
  142. whole_background_dict[key] = whole_background_dict.get(key, 0) + value
  143. whole_background_dict_per_slide[slide_pick][key] = whole_background_dict_per_slide[
  144. slide_pick].get(key, 0) + value
  145. if learn_threshold_and_log_cf_matrix_per_patch:
  146. e = .000001
  147. total_preds = (sum(list(whole_background_dict.values())) + e)
  148. acc = (whole_background_dict["TP"] + whole_background_dict["TN"]) / total_preds
  149. positive_preds = (whole_background_dict["TP"] + whole_background_dict["FP"] + e)
  150. precision = whole_background_dict["TP"] / positive_preds
  151. next_score = score_calculator(acc, precision)
  152. if threshold_score is None:
  153. threshold_score = next_score
  154. else:
  155. threshold_history.append(laplacian_threshold)
  156. score_history.append(next_score)
  157. if next_score > threshold_score:
  158. threshold_score = next_score
  159. laplacian_threshold += threshold_jump_increase * threshold_jump_size
  160. elif next_score <= threshold_score:
  161. threshold_score = next_score
  162. threshold_jump_increase *= -1
  163. threshold_jump_size *= decay_const
  164. laplacian_threshold += threshold_jump_increase * threshold_jump_size
  165. decay_count += 1
  166. save_threshold_and_score_chart(threshold_history, score_history)
  167. acc = round(acc, 3)
  168. precision = round(precision, 3)
  169. threshold_score_rounded = round(threshold_score, 3)
  170. print(
  171. f"acc:{acc},precision:{precision},score:{threshold_score_rounded},table:{whole_background_dict}" +
  172. f"thresh:{laplacian_threshold},jump_size:{threshold_jump_size}")
  173. else:
  174. print(f"table:{whole_background_dict},table_per_slide:{whole_background_dict_per_slide}" +
  175. f"threshold:{laplacian_threshold},jump_size:{threshold_jump_size}")
  176. else:
  177. break
  178. return laplacian_threshold
  179. def save_threshold_and_score_chart(threshold_history, score_history):
  180. fig_save_path = "laplacian_threshold_history_chart.jpeg"
  181. plt.plot(range(len(threshold_history)), threshold_history)
  182. plt.xlabel('Batch')
  183. plt.ylabel('Laplacian threshold')
  184. plt.savefig(fig_save_path)
  185. plt.clf()
  186. fig_save_path = "laplacian_threshold_score_history_chart.jpeg"
  187. plt.plot(range(len(score_history)), score_history)
  188. plt.xlabel('Batch')
  189. plt.ylabel('Objective function - Sore')
  190. plt.savefig(fig_save_path)
  191. plt.clf()
  192. if __name__ == '__main__':
  193. image_lists = [
  194. ( # "('0', '100', '0')"
  195. "./TCGA-BJ-A3F0-01A-01-TSA.728CE583-95BE-462B-AFDF-FC0B228DF3DE__3_masked.tiff",
  196. "./TCGA-BJ-A3F0-01A-01-TSA.728CE583-95BE-462B-AFDF-FC0B228DF3DE__3.svs"
  197. ),
  198. ( # "('0', '100', '0')"
  199. "./TCGA-DJ-A1QG-01A-01-TSA.04c62c21-dd45-49ea-a74f-53822defe097__2000_masked.tiff",
  200. "./TCGA-DJ-A1QG-01A-01-TSA.04c62c21-dd45-49ea-a74f-53822defe097__2000.svs"
  201. ),
  202. # ( # "('0', '100', '0')"
  203. # "./TCGA-EL-A3ZQ-01A-01-TS1.344610D2-AB50-41C6-916E-FF0F08940BF1__2000_masked.tiff",
  204. # "./TCGA-EL-A3ZQ-01A-01-TS1.344610D2-AB50-41C6-916E-FF0F08940BF1__2000.svs"
  205. # ),
  206. ( # "('45', '55', '0')"
  207. "./TCGA-ET-A39N-01A-01-TSA.C38FCE19-9558-4035-9F0B-AD05B9BE321D___198_masked.tiff",
  208. "./TCGA-ET-A39N-01A-01-TSA.C38FCE19-9558-4035-9F0B-AD05B9BE321D___198.svs"
  209. ),
  210. # ( # "('0', '40', '60')"
  211. # "./TCGA-J8-A42S-01A-01-TSA.7B80CBEB-7B85-417E-AA0C-11C79DE40250__0_masked.tiff",
  212. # "./TCGA-J8-A42S-01A-01-TSA.7B80CBEB-7B85-417E-AA0C-11C79DE40250__0.svs"
  213. # ),
  214. ( # "('0', '90', '10')"
  215. "./TCGA-ET-A39O-01A-01-TSA.3829C900-7597-4EA9-AFC7-AA238221CE69_7000_masked.tiff",
  216. "./TCGA-ET-A39O-01A-01-TSA.3829C900-7597-4EA9-AFC7-AA238221CE69_7000.svs"
  217. ),
  218. ( # "('100', '0', '0')"
  219. "./TCGA-EL-A4K7-11A-01-TS1.C08B59AA-87DF-4ABB-8B70-25FEF9893C7F__70_masked.tiff",
  220. "./TCGA-EL-A4K7-11A-01-TS1.C08B59AA-87DF-4ABB-8B70-25FEF9893C7F__70.svs"
  221. ),
  222. ( # "('100', '0', '0')"
  223. "./TCGA-EL-A3TB-11A-01-TS1.6E0966C9-1552-4B30-9008-8ACF737CA8C3__2000_masked.tiff",
  224. "./TCGA-EL-A3TB-11A-01-TS1.6E0966C9-1552-4B30-9008-8ACF737CA8C3__2000.svs"
  225. ),
  226. ]
  227. learned_threshold = update_and_find_best_threshold(500, learn_threshold_and_log_cf_matrix_per_patch=True)
  228. update_and_find_best_threshold(learned_threshold, learn_threshold_and_log_cf_matrix_per_patch=False)
  229. # Start with 500 with jump size 120 and decay 0.85
  230. # table:{'TP': 15018, 'FP': 412, 'TN': 66898, 'FN': 2389},
  231. # table_per_slide:[
  232. # {'TP': 460, 'FP': 0, 'TN': 19618, 'FN': 1426},
  233. # {'TP': 4624, 'FP': 126, 'TN': 14100, 'FN': 226},
  234. # {'TP': 1138, 'FP': 4, 'TN': 6671, 'FN': 492},
  235. # {'TP': 7615, 'FP': 92, 'TN': 20871, 'FN': 234},
  236. # {'TP': 78, 'FP': 18, 'TN': 1880, 'FN': 4},
  237. # {'TP': 1103, 'FP': 172, 'TN': 3758, 'FN': 7}
  238. # ]
  239. # threshold:298.86314585743395,jump_size:120