123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329 |
- from typing import Dict, List, TYPE_CHECKING, Optional, Tuple
- from os import path, listdir
- from sys import stderr
- from functools import partial
-
- import cv2
- import numpy as np
- import pandas as pd
- from .content_loader import ContentLoader
-
- from ...utils.bb_generator import generate_bb_map
- if TYPE_CHECKING:
- from ...configs.rsna_configs import RSNAConfigs
-
-
- class RSNALoader(ContentLoader):
- def __init__(self, conf: 'RSNAConfigs', data_specification: str):
- """ read all directories for scans and annotations, which are split by DataSplitter.py
- And then, keeps only those samples which are used for 'usage' """
-
- super().__init__(conf, data_specification)
- self.conf = conf
- self.img_size = conf.input_size
- self.infection_map_size = conf.infection_map_size
- self.samples_info = None
-
- self.load_samples(data_specification)
- self.loaded_images = None
-
- self._infections_bbs: Optional[Dict[str, Tuple[List[Tuple[float, float], List[Tuple[float, float]]]]]] = None
- self._load_infections()
-
- def load_samples(self, data_specification):
-
- if ':' in data_specification:
- data_specification, filter_str = data_specification.split(':')
- else:
- filter_str = None
-
- if data_specification in ['train', 'val', 'test']:
- self.samples_info = \
- pd.read_csv('%s/%s.txt' %
- (
- self.conf.data_separation, data_specification), sep='\t', header=0)
- self.samples_info['sample'] = self.samples_info['sample'].apply(lambda s: s[len('../'):])
- self.samples_info['Interpretation_dir'] = self.samples_info['Interpretation_dir'].apply(lambda s: s[len('../'):])
-
- elif path.isfile(data_specification) and path.exists(data_specification):
- self.samples_info = pd.read_csv(data_specification, sep='\t', header=0)
- self.samples_info['sample'] = self.samples_info['sample'].apply(lambda s: s[len('../'):])
- self.samples_info['Interpretation_dir'] = self.samples_info['Interpretation_dir'].apply(lambda s: s[len('../'):])
-
- elif path.isdir(data_specification) and path.exists(data_specification):
-
- samples_names = np.asarray(
- [data_specification + '/' + x for x in listdir(data_specification)])
- print('%d samples discovered in %s' % (len(samples_names), data_specification))
-
- if filter_str is not None:
- pass_filter_str = np.vectorize(lambda x: '/%s/' % filter_str in x)
- samples_names = samples_names[pass_filter_str(samples_names)]
- print('%d samples remained after filtering' % len(samples_names))
-
- self.samples_info = pd.DataFrame({
- 'label': np.full(len(samples_names), -1, dtype=int),
- 'sample': samples_names,
- 'view': np.full(len(samples_names), 'Unknown', dtype=np.object),
- 'dataset': np.full(len(samples_names), 'unknown', dtype=np.object),
- 'Interpretation_dir': np.full(len(samples_names), '', dtype=np.object),
-
- })
- return
-
- else:
- print('Please implement this part!', flush=True)
-
- org_cnt = len(self.samples_info)
-
- # filtering by filter string!
- if filter_str is not None:
- pass_filter_str = np.vectorize(lambda x: '/%s/' % filter_str in x)
- self.samples_info = self.samples_info[pass_filter_str(self.samples_info['sample'].values)]
-
- print(
- '%d of %d samples remained after filtering for %s!' % (len(self.samples_info), org_cnt, data_specification),
- flush=True)
- org_cnt = len(self.samples_info)
-
- # label mapping/ordering!
- label_mapping_dict = self.conf.label_map_dict
-
- v_map_label = np.vectorize(lambda l: label_mapping_dict.get(l.lower(), -1))
- self.samples_info['label'] = v_map_label(self.samples_info['label'].values)
-
- # filtering unmapped labels
- self.samples_info = self.samples_info[self.samples_info['label'].values != -1]
- print('%d out of %d samples remained after label mapping!' %
- (len(self.samples_info), org_cnt), flush=True)
-
- # counting the labels!
- flattened_labels = self.samples_info['label'].values
- u_labels = np.unique(flattened_labels)
- labels_cnt = np.zeros(len(u_labels))
- np.add.at(labels_cnt, np.searchsorted(u_labels, flattened_labels, side='left'), 1)
- print('[%s]' % ', '.join(['class-%s: %d' % (str(u_labels[i]), labels_cnt[i]) for i in range(len(labels_cnt))]),
- flush=True)
-
- def _load_infections(self) -> None:
- """
- If a file address is given as infections_bb_dir in configs
- (containing imgs and their bounding boxes as a str)
- the information related to bounding boxes will be taken,
- otherwise it is expected from the data separation to have a column
- indicating the address of the infection mask
- """
- if self.conf.infections_bb_dir is not None:
- bbs_info = pd.read_csv(self.conf.infections_bb_dir, sep='\t', header=0)
-
- imgs_dirs = np.vectorize(self._get_id_by_path)(bbs_info['img_dir'].values)
- imgs_bbs = [
- [] if '-' not in str(im_bbs) else [
- tuple([np.clip(float(x), 0, 1) for x in im_bb.split('-')])
- for im_bb in im_bbs.split(',')
- ]
- for im_bbs in bbs_info['img_bb'].values]
-
- # reformatting to tuple of list of starts and list of ends
- imgs_bbs_starts = [[
- (r1, c1) for r1, c1, _, _ in img_bbs]
- for img_bbs in imgs_bbs]
-
- imgs_bbs_ends = [[
- (r2, c2) for _, _, r2, c2 in img_bbs]
- for img_bbs in imgs_bbs]
-
- imgs_bbs = [
- (imgs_bbs_starts[i], imgs_bbs_ends[i])
- for i in range(len(imgs_bbs_starts))]
-
- self._infections_bbs = dict(zip(imgs_dirs, imgs_bbs))
-
- def get_samples_names(self):
- return self.samples_info['sample'].values
-
- def get_samples_labels(self):
- return self.samples_info['label'].values
-
- def drop_samples(self, drop_mask: np.asarray) -> None:
- """ Keep_mask is a bool array for keeping samples """
- self.samples_info = self.samples_info[np.logical_not(drop_mask)]
-
- def get_placeholder_name_to_fill_function_dict(self):
- ret_val = {
- 'x': self.get_batch_scaled_images,
-
- 'y': self.get_batch_label,
-
- 'infection': partial(self.get_batch_extended_bbs_map_from_bbs_file, 0) if
- self._infections_bbs is not None else
- self.get_batch_bbs_map_from_mask_file,
-
- 'interpretations': partial(self.get_batch_extended_bbs_map_from_bbs_file, 0) if
- self._infections_bbs is not None else
- self.get_batch_bbs_map_from_mask_file,
-
- 'has_bb': self.get_batch_has_bb_mask_from_bb_file if
- self._infections_bbs is not None else
- self.get_batch_has_bb_mask_from_mask_file
- }
-
- # if infection radii is not empty, it is expected to have a dictionary
- if len(self.conf.receptive_field_radii) > 0:
- assert self._infections_bbs is not None, \
- "When having receptive field radii, " \
- "you should provide bounding boxes as a file in config.infections_bb_dir"
-
- ret_val.update({
- f'ex_infection_{r}': partial(self.get_batch_extended_bbs_map_from_bbs_file, r)
- for r in self.conf.receptive_field_radii
- })
-
- return ret_val
-
- def get_batch_has_bb_mask_from_mask_file(self, samples_inds: np.ndarray) \
- -> np.ndarray:
-
- assert 'Interpretation_dir' in self.samples_info.columns, \
- 'If you have not specified infections_bb_dirin your config, ' \
- 'you need to have Interpretation_dir column in your data separation'
-
- def has_inf(si):
- if self.samples_info.iloc[si]['label'] == 0:
- return True
- int_dir = self.samples_info.iloc[si]['Interpretation_dir']
- return str(int_dir) != 'nan'
-
- return np.vectorize(has_inf)(samples_inds)
-
- def get_batch_has_bb_mask_from_bb_file(self, samples_inds: np.ndarray) \
- -> np.ndarray:
-
- def has_inf(si):
- if self.samples_info.iloc[si]['label'] == 0:
- return True
- sample_key = self.samples_info.iloc[si]['sample']
- sample_key = self._get_id_by_path(sample_key)
- return sample_key in self._infections_bbs and \
- len(self._infections_bbs[sample_key][0]) > 0
-
- return np.vectorize(has_inf)(samples_inds)
-
- def get_batch_bbs_map_from_mask_file(self, samples_inds: np.ndarray)\
- -> np.ndarray:
-
- def read_interpretation(im_ind):
-
- # for healthy, return full 0
- if self.samples_info.iloc[im_ind]['label'] == 0:
- return np.full((1, self.infection_map_size, self.infection_map_size), 0, dtype=np.float)
-
- int_dir = self.samples_info.iloc[im_ind]['Interpretation_dir']
- # if it does not exist, return full -1
- if str(int_dir) == 'nan':
- return np.full((1, self.infection_map_size, self.infection_map_size), -1, dtype=np.float)
-
- if 'npy' in int_dir:
- interpretation = np.load(int_dir)
- else:
- interpretation = np.load(int_dir)['arr_0']
-
- interpretation = interpretation.astype(float)
-
- if interpretation.shape != (self.infection_map_size, self.infection_map_size):
- interpretation = (cv2.resize(np.round(255 * interpretation, 0),
- dsize=(self.infection_map_size, self.infection_map_size)) >= 128).astype(float)
- return interpretation[np.newaxis, :, :]
-
- batch_interpretations = np.stack(tuple([read_interpretation(si)
- for si in samples_inds]), axis=0)
-
- return batch_interpretations
-
- @staticmethod
- def _get_id_by_path(path: str) -> str:
- return path[path.index('png_images/'):]
-
- def get_batch_extended_bbs_map_from_bbs_file(self, radius, samples_inds: np.ndarray) -> np.ndarray:
-
- def make_map(im_ind):
-
- if self.samples_info.iloc[im_ind]['label'] == 0:
- return np.full((1, self.infection_map_size, self.infection_map_size), 0, dtype=np.float)
-
- sample_key = self.samples_info.iloc[im_ind]['sample']
- sample_key = self._get_id_by_path(sample_key)
-
- if sample_key in self._infections_bbs and \
- len(self._infections_bbs[sample_key][0]) > 0:
- bbs_info = self._infections_bbs[sample_key]
- start_points, end_points = bbs_info
- mask = generate_bb_map(start_points, end_points,
- (self.infection_map_size, self.infection_map_size), radius)
- return mask[np.newaxis, :, :]
- else:
- return np.full((1, self.infection_map_size, self.infection_map_size), -1, dtype=np.float)
-
- batch_interpretations = np.stack(tuple([make_map(si)
- for si in samples_inds]), axis=0)
-
- return batch_interpretations
-
- def get_batch_scaled_images(self,
- samples_inds: np.ndarray) \
- -> np.ndarray:
-
- def read_img(im_ind):
-
- im = cv2.imread(self.samples_info.iloc[im_ind]['sample'])
- if im is None:
- print(self.samples_info.iloc[im_ind]['sample'] + ' is missing!', file=stderr)
- raise Exception('Missing image')
-
- if len(list(im.shape)) == 3:
- ret_val = im[:, :, 0]
- else:
- ret_val = im
-
- if ret_val.shape != (self.img_size, self.img_size):
- ret_val = cv2.resize(ret_val, dsize=(self.img_size, self.img_size))
- return ret_val[np.newaxis, :, :]
-
- batch_imgs = np.stack(tuple([read_img(si)
- for si in samples_inds]), axis=0)
- batch_imgs = batch_imgs.astype(np.float32) / 255
- return batch_imgs
-
- def get_batch_label(self,
- samples_inds: np.ndarray,
- ) \
- -> np.ndarray:
- return self.samples_info['label'].iloc[samples_inds].values.astype(np.float32)
-
- def get_batch_true_interpretations(self, samples_inds: np.ndarray) \
- -> np.ndarray:
-
- def read_interpretation(im_ind):
-
- # for healthy, return full 0
- if self.samples_info.iloc[im_ind]['label'] == 0:
- return np.full((1, self.img_size, self.img_size), 0, dtype=np.float)
-
- int_dir = self.samples_info.iloc[im_ind]['Interpretation_dir']
- if str(int_dir) == 'nan':
- return np.full((1, self.img_size, self.img_size), np.nan, dtype=np.float)
-
- if 'npy' in int_dir:
- interpretation = np.load(int_dir)
- else:
- interpretation = np.load(int_dir)['arr_0'].astype(int)
-
- if interpretation.shape != (self.img_size, self.img_size):
- interpretation = (cv2.resize(np.round(255 * interpretation, 0).astype(np.uint8),
- dsize=(self.img_size, self.img_size)) >= 128).astype(float)
- return interpretation[np.newaxis, :, :]
-
- batch_interpretations = np.stack(tuple([read_interpretation(si)
- for si in samples_inds]), axis=0)
-
- return batch_interpretations
|