You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rsna_loader.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. from typing import Dict, List, TYPE_CHECKING, Optional, Tuple
  2. from os import path, listdir
  3. from sys import stderr
  4. from functools import partial
  5. import cv2
  6. import numpy as np
  7. import pandas as pd
  8. from .content_loader import ContentLoader
  9. from ...utils.bb_generator import generate_bb_map
  10. if TYPE_CHECKING:
  11. from ...configs.rsna_configs import RSNAConfigs
  12. class RSNALoader(ContentLoader):
  13. def __init__(self, conf: 'RSNAConfigs', data_specification: str):
  14. """ read all directories for scans and annotations, which are split by DataSplitter.py
  15. And then, keeps only those samples which are used for 'usage' """
  16. super().__init__(conf, data_specification)
  17. self.conf = conf
  18. self.img_size = conf.input_size
  19. self.infection_map_size = conf.infection_map_size
  20. self.samples_info = None
  21. self.load_samples(data_specification)
  22. self.loaded_images = None
  23. self._infections_bbs: Optional[Dict[str, Tuple[List[Tuple[float, float], List[Tuple[float, float]]]]]] = None
  24. self._load_infections()
  25. def load_samples(self, data_specification):
  26. if ':' in data_specification:
  27. data_specification, filter_str = data_specification.split(':')
  28. else:
  29. filter_str = None
  30. if data_specification in ['train', 'val', 'test']:
  31. self.samples_info = \
  32. pd.read_csv('%s/%s.txt' %
  33. (
  34. self.conf.data_separation, data_specification), sep='\t', header=0)
  35. self.samples_info['sample'] = self.samples_info['sample'].apply(lambda s: s[len('../'):])
  36. self.samples_info['Interpretation_dir'] = self.samples_info['Interpretation_dir'].apply(lambda s: s[len('../'):])
  37. elif path.isfile(data_specification) and path.exists(data_specification):
  38. self.samples_info = pd.read_csv(data_specification, sep='\t', header=0)
  39. self.samples_info['sample'] = self.samples_info['sample'].apply(lambda s: s[len('../'):])
  40. self.samples_info['Interpretation_dir'] = self.samples_info['Interpretation_dir'].apply(lambda s: s[len('../'):])
  41. elif path.isdir(data_specification) and path.exists(data_specification):
  42. samples_names = np.asarray(
  43. [data_specification + '/' + x for x in listdir(data_specification)])
  44. print('%d samples discovered in %s' % (len(samples_names), data_specification))
  45. if filter_str is not None:
  46. pass_filter_str = np.vectorize(lambda x: '/%s/' % filter_str in x)
  47. samples_names = samples_names[pass_filter_str(samples_names)]
  48. print('%d samples remained after filtering' % len(samples_names))
  49. self.samples_info = pd.DataFrame({
  50. 'label': np.full(len(samples_names), -1, dtype=int),
  51. 'sample': samples_names,
  52. 'view': np.full(len(samples_names), 'Unknown', dtype=np.object),
  53. 'dataset': np.full(len(samples_names), 'unknown', dtype=np.object),
  54. 'Interpretation_dir': np.full(len(samples_names), '', dtype=np.object),
  55. })
  56. return
  57. else:
  58. print('Please implement this part!', flush=True)
  59. org_cnt = len(self.samples_info)
  60. # filtering by filter string!
  61. if filter_str is not None:
  62. pass_filter_str = np.vectorize(lambda x: '/%s/' % filter_str in x)
  63. self.samples_info = self.samples_info[pass_filter_str(self.samples_info['sample'].values)]
  64. print(
  65. '%d of %d samples remained after filtering for %s!' % (len(self.samples_info), org_cnt, data_specification),
  66. flush=True)
  67. org_cnt = len(self.samples_info)
  68. # label mapping/ordering!
  69. label_mapping_dict = self.conf.label_map_dict
  70. v_map_label = np.vectorize(lambda l: label_mapping_dict.get(l.lower(), -1))
  71. self.samples_info['label'] = v_map_label(self.samples_info['label'].values)
  72. # filtering unmapped labels
  73. self.samples_info = self.samples_info[self.samples_info['label'].values != -1]
  74. print('%d out of %d samples remained after label mapping!' %
  75. (len(self.samples_info), org_cnt), flush=True)
  76. # counting the labels!
  77. flattened_labels = self.samples_info['label'].values
  78. u_labels = np.unique(flattened_labels)
  79. labels_cnt = np.zeros(len(u_labels))
  80. np.add.at(labels_cnt, np.searchsorted(u_labels, flattened_labels, side='left'), 1)
  81. print('[%s]' % ', '.join(['class-%s: %d' % (str(u_labels[i]), labels_cnt[i]) for i in range(len(labels_cnt))]),
  82. flush=True)
  83. def _load_infections(self) -> None:
  84. """
  85. If a file address is given as infections_bb_dir in configs
  86. (containing imgs and their bounding boxes as a str)
  87. the information related to bounding boxes will be taken,
  88. otherwise it is expected from the data separation to have a column
  89. indicating the address of the infection mask
  90. """
  91. if self.conf.infections_bb_dir is not None:
  92. bbs_info = pd.read_csv(self.conf.infections_bb_dir, sep='\t', header=0)
  93. imgs_dirs = np.vectorize(self._get_id_by_path)(bbs_info['img_dir'].values)
  94. imgs_bbs = [
  95. [] if '-' not in str(im_bbs) else [
  96. tuple([np.clip(float(x), 0, 1) for x in im_bb.split('-')])
  97. for im_bb in im_bbs.split(',')
  98. ]
  99. for im_bbs in bbs_info['img_bb'].values]
  100. # reformatting to tuple of list of starts and list of ends
  101. imgs_bbs_starts = [[
  102. (r1, c1) for r1, c1, _, _ in img_bbs]
  103. for img_bbs in imgs_bbs]
  104. imgs_bbs_ends = [[
  105. (r2, c2) for _, _, r2, c2 in img_bbs]
  106. for img_bbs in imgs_bbs]
  107. imgs_bbs = [
  108. (imgs_bbs_starts[i], imgs_bbs_ends[i])
  109. for i in range(len(imgs_bbs_starts))]
  110. self._infections_bbs = dict(zip(imgs_dirs, imgs_bbs))
  111. def get_samples_names(self):
  112. return self.samples_info['sample'].values
  113. def get_samples_labels(self):
  114. return self.samples_info['label'].values
  115. def drop_samples(self, drop_mask: np.asarray) -> None:
  116. """ Keep_mask is a bool array for keeping samples """
  117. self.samples_info = self.samples_info[np.logical_not(drop_mask)]
  118. def get_placeholder_name_to_fill_function_dict(self):
  119. ret_val = {
  120. 'x': self.get_batch_scaled_images,
  121. 'y': self.get_batch_label,
  122. 'infection': partial(self.get_batch_extended_bbs_map_from_bbs_file, 0) if
  123. self._infections_bbs is not None else
  124. self.get_batch_bbs_map_from_mask_file,
  125. 'interpretations': partial(self.get_batch_extended_bbs_map_from_bbs_file, 0) if
  126. self._infections_bbs is not None else
  127. self.get_batch_bbs_map_from_mask_file,
  128. 'has_bb': self.get_batch_has_bb_mask_from_bb_file if
  129. self._infections_bbs is not None else
  130. self.get_batch_has_bb_mask_from_mask_file
  131. }
  132. # if infection radii is not empty, it is expected to have a dictionary
  133. if len(self.conf.receptive_field_radii) > 0:
  134. assert self._infections_bbs is not None, \
  135. "When having receptive field radii, " \
  136. "you should provide bounding boxes as a file in config.infections_bb_dir"
  137. ret_val.update({
  138. f'ex_infection_{r}': partial(self.get_batch_extended_bbs_map_from_bbs_file, r)
  139. for r in self.conf.receptive_field_radii
  140. })
  141. return ret_val
  142. def get_batch_has_bb_mask_from_mask_file(self, samples_inds: np.ndarray) \
  143. -> np.ndarray:
  144. assert 'Interpretation_dir' in self.samples_info.columns, \
  145. 'If you have not specified infections_bb_dirin your config, ' \
  146. 'you need to have Interpretation_dir column in your data separation'
  147. def has_inf(si):
  148. if self.samples_info.iloc[si]['label'] == 0:
  149. return True
  150. int_dir = self.samples_info.iloc[si]['Interpretation_dir']
  151. return str(int_dir) != 'nan'
  152. return np.vectorize(has_inf)(samples_inds)
  153. def get_batch_has_bb_mask_from_bb_file(self, samples_inds: np.ndarray) \
  154. -> np.ndarray:
  155. def has_inf(si):
  156. if self.samples_info.iloc[si]['label'] == 0:
  157. return True
  158. sample_key = self.samples_info.iloc[si]['sample']
  159. sample_key = self._get_id_by_path(sample_key)
  160. return sample_key in self._infections_bbs and \
  161. len(self._infections_bbs[sample_key][0]) > 0
  162. return np.vectorize(has_inf)(samples_inds)
  163. def get_batch_bbs_map_from_mask_file(self, samples_inds: np.ndarray)\
  164. -> np.ndarray:
  165. def read_interpretation(im_ind):
  166. # for healthy, return full 0
  167. if self.samples_info.iloc[im_ind]['label'] == 0:
  168. return np.full((1, self.infection_map_size, self.infection_map_size), 0, dtype=np.float)
  169. int_dir = self.samples_info.iloc[im_ind]['Interpretation_dir']
  170. # if it does not exist, return full -1
  171. if str(int_dir) == 'nan':
  172. return np.full((1, self.infection_map_size, self.infection_map_size), -1, dtype=np.float)
  173. if 'npy' in int_dir:
  174. interpretation = np.load(int_dir)
  175. else:
  176. interpretation = np.load(int_dir)['arr_0']
  177. interpretation = interpretation.astype(float)
  178. if interpretation.shape != (self.infection_map_size, self.infection_map_size):
  179. interpretation = (cv2.resize(np.round(255 * interpretation, 0),
  180. dsize=(self.infection_map_size, self.infection_map_size)) >= 128).astype(float)
  181. return interpretation[np.newaxis, :, :]
  182. batch_interpretations = np.stack(tuple([read_interpretation(si)
  183. for si in samples_inds]), axis=0)
  184. return batch_interpretations
  185. @staticmethod
  186. def _get_id_by_path(path: str) -> str:
  187. return path[path.index('png_images/'):]
  188. def get_batch_extended_bbs_map_from_bbs_file(self, radius, samples_inds: np.ndarray) -> np.ndarray:
  189. def make_map(im_ind):
  190. if self.samples_info.iloc[im_ind]['label'] == 0:
  191. return np.full((1, self.infection_map_size, self.infection_map_size), 0, dtype=np.float)
  192. sample_key = self.samples_info.iloc[im_ind]['sample']
  193. sample_key = self._get_id_by_path(sample_key)
  194. if sample_key in self._infections_bbs and \
  195. len(self._infections_bbs[sample_key][0]) > 0:
  196. bbs_info = self._infections_bbs[sample_key]
  197. start_points, end_points = bbs_info
  198. mask = generate_bb_map(start_points, end_points,
  199. (self.infection_map_size, self.infection_map_size), radius)
  200. return mask[np.newaxis, :, :]
  201. else:
  202. return np.full((1, self.infection_map_size, self.infection_map_size), -1, dtype=np.float)
  203. batch_interpretations = np.stack(tuple([make_map(si)
  204. for si in samples_inds]), axis=0)
  205. return batch_interpretations
  206. def get_batch_scaled_images(self,
  207. samples_inds: np.ndarray) \
  208. -> np.ndarray:
  209. def read_img(im_ind):
  210. im = cv2.imread(self.samples_info.iloc[im_ind]['sample'])
  211. if im is None:
  212. print(self.samples_info.iloc[im_ind]['sample'] + ' is missing!', file=stderr)
  213. raise Exception('Missing image')
  214. if len(list(im.shape)) == 3:
  215. ret_val = im[:, :, 0]
  216. else:
  217. ret_val = im
  218. if ret_val.shape != (self.img_size, self.img_size):
  219. ret_val = cv2.resize(ret_val, dsize=(self.img_size, self.img_size))
  220. return ret_val[np.newaxis, :, :]
  221. batch_imgs = np.stack(tuple([read_img(si)
  222. for si in samples_inds]), axis=0)
  223. batch_imgs = batch_imgs.astype(np.float32) / 255
  224. return batch_imgs
  225. def get_batch_label(self,
  226. samples_inds: np.ndarray,
  227. ) \
  228. -> np.ndarray:
  229. return self.samples_info['label'].iloc[samples_inds].values.astype(np.float32)
  230. def get_batch_true_interpretations(self, samples_inds: np.ndarray) \
  231. -> np.ndarray:
  232. def read_interpretation(im_ind):
  233. # for healthy, return full 0
  234. if self.samples_info.iloc[im_ind]['label'] == 0:
  235. return np.full((1, self.img_size, self.img_size), 0, dtype=np.float)
  236. int_dir = self.samples_info.iloc[im_ind]['Interpretation_dir']
  237. if str(int_dir) == 'nan':
  238. return np.full((1, self.img_size, self.img_size), np.nan, dtype=np.float)
  239. if 'npy' in int_dir:
  240. interpretation = np.load(int_dir)
  241. else:
  242. interpretation = np.load(int_dir)['arr_0'].astype(int)
  243. if interpretation.shape != (self.img_size, self.img_size):
  244. interpretation = (cv2.resize(np.round(255 * interpretation, 0).astype(np.uint8),
  245. dsize=(self.img_size, self.img_size)) >= 128).astype(float)
  246. return interpretation[np.newaxis, :, :]
  247. batch_interpretations = np.stack(tuple([read_interpretation(si)
  248. for si in samples_inds]), axis=0)
  249. return batch_interpretations