from typing import Dict, List, TYPE_CHECKING, Optional, Tuple from os import path, listdir from sys import stderr from functools import partial import cv2 import numpy as np import pandas as pd from .content_loader import ContentLoader from ...utils.bb_generator import generate_bb_map if TYPE_CHECKING: from ...configs.rsna_configs import RSNAConfigs class RSNALoader(ContentLoader): def __init__(self, conf: 'RSNAConfigs', data_specification: str): """ read all directories for scans and annotations, which are split by DataSplitter.py And then, keeps only those samples which are used for 'usage' """ super().__init__(conf, data_specification) self.conf = conf self.img_size = conf.input_size self.infection_map_size = conf.infection_map_size self.samples_info = None self.load_samples(data_specification) self.loaded_images = None self._infections_bbs: Optional[Dict[str, Tuple[List[Tuple[float, float], List[Tuple[float, float]]]]]] = None self._load_infections() def load_samples(self, data_specification): if ':' in data_specification: data_specification, filter_str = data_specification.split(':') else: filter_str = None if data_specification in ['train', 'val', 'test']: self.samples_info = \ pd.read_csv('%s/%s.txt' % ( self.conf.data_separation, data_specification), sep='\t', header=0) self.samples_info['sample'] = self.samples_info['sample'].apply(lambda s: s[len('../'):]) self.samples_info['Interpretation_dir'] = self.samples_info['Interpretation_dir'].apply(lambda s: s[len('../'):]) elif path.isfile(data_specification) and path.exists(data_specification): self.samples_info = pd.read_csv(data_specification, sep='\t', header=0) self.samples_info['sample'] = self.samples_info['sample'].apply(lambda s: s[len('../'):]) self.samples_info['Interpretation_dir'] = self.samples_info['Interpretation_dir'].apply(lambda s: s[len('../'):]) elif path.isdir(data_specification) and path.exists(data_specification): samples_names = np.asarray( [data_specification + '/' + x for x in listdir(data_specification)]) print('%d samples discovered in %s' % (len(samples_names), data_specification)) if filter_str is not None: pass_filter_str = np.vectorize(lambda x: '/%s/' % filter_str in x) samples_names = samples_names[pass_filter_str(samples_names)] print('%d samples remained after filtering' % len(samples_names)) self.samples_info = pd.DataFrame({ 'label': np.full(len(samples_names), -1, dtype=int), 'sample': samples_names, 'view': np.full(len(samples_names), 'Unknown', dtype=np.object), 'dataset': np.full(len(samples_names), 'unknown', dtype=np.object), 'Interpretation_dir': np.full(len(samples_names), '', dtype=np.object), }) return else: print('Please implement this part!', flush=True) org_cnt = len(self.samples_info) # filtering by filter string! if filter_str is not None: pass_filter_str = np.vectorize(lambda x: '/%s/' % filter_str in x) self.samples_info = self.samples_info[pass_filter_str(self.samples_info['sample'].values)] print( '%d of %d samples remained after filtering for %s!' % (len(self.samples_info), org_cnt, data_specification), flush=True) org_cnt = len(self.samples_info) # label mapping/ordering! label_mapping_dict = self.conf.label_map_dict v_map_label = np.vectorize(lambda l: label_mapping_dict.get(l.lower(), -1)) self.samples_info['label'] = v_map_label(self.samples_info['label'].values) # filtering unmapped labels self.samples_info = self.samples_info[self.samples_info['label'].values != -1] print('%d out of %d samples remained after label mapping!' % (len(self.samples_info), org_cnt), flush=True) # counting the labels! flattened_labels = self.samples_info['label'].values u_labels = np.unique(flattened_labels) labels_cnt = np.zeros(len(u_labels)) np.add.at(labels_cnt, np.searchsorted(u_labels, flattened_labels, side='left'), 1) print('[%s]' % ', '.join(['class-%s: %d' % (str(u_labels[i]), labels_cnt[i]) for i in range(len(labels_cnt))]), flush=True) def _load_infections(self) -> None: """ If a file address is given as infections_bb_dir in configs (containing imgs and their bounding boxes as a str) the information related to bounding boxes will be taken, otherwise it is expected from the data separation to have a column indicating the address of the infection mask """ if self.conf.infections_bb_dir is not None: bbs_info = pd.read_csv(self.conf.infections_bb_dir, sep='\t', header=0) imgs_dirs = np.vectorize(self._get_id_by_path)(bbs_info['img_dir'].values) imgs_bbs = [ [] if '-' not in str(im_bbs) else [ tuple([np.clip(float(x), 0, 1) for x in im_bb.split('-')]) for im_bb in im_bbs.split(',') ] for im_bbs in bbs_info['img_bb'].values] # reformatting to tuple of list of starts and list of ends imgs_bbs_starts = [[ (r1, c1) for r1, c1, _, _ in img_bbs] for img_bbs in imgs_bbs] imgs_bbs_ends = [[ (r2, c2) for _, _, r2, c2 in img_bbs] for img_bbs in imgs_bbs] imgs_bbs = [ (imgs_bbs_starts[i], imgs_bbs_ends[i]) for i in range(len(imgs_bbs_starts))] self._infections_bbs = dict(zip(imgs_dirs, imgs_bbs)) def get_samples_names(self): return self.samples_info['sample'].values def get_samples_labels(self): return self.samples_info['label'].values def drop_samples(self, drop_mask: np.asarray) -> None: """ Keep_mask is a bool array for keeping samples """ self.samples_info = self.samples_info[np.logical_not(drop_mask)] def get_placeholder_name_to_fill_function_dict(self): ret_val = { 'x': self.get_batch_scaled_images, 'y': self.get_batch_label, 'infection': partial(self.get_batch_extended_bbs_map_from_bbs_file, 0) if self._infections_bbs is not None else self.get_batch_bbs_map_from_mask_file, 'interpretations': partial(self.get_batch_extended_bbs_map_from_bbs_file, 0) if self._infections_bbs is not None else self.get_batch_bbs_map_from_mask_file, 'has_bb': self.get_batch_has_bb_mask_from_bb_file if self._infections_bbs is not None else self.get_batch_has_bb_mask_from_mask_file } # if infection radii is not empty, it is expected to have a dictionary if len(self.conf.receptive_field_radii) > 0: assert self._infections_bbs is not None, \ "When having receptive field radii, " \ "you should provide bounding boxes as a file in config.infections_bb_dir" ret_val.update({ f'ex_infection_{r}': partial(self.get_batch_extended_bbs_map_from_bbs_file, r) for r in self.conf.receptive_field_radii }) return ret_val def get_batch_has_bb_mask_from_mask_file(self, samples_inds: np.ndarray) \ -> np.ndarray: assert 'Interpretation_dir' in self.samples_info.columns, \ 'If you have not specified infections_bb_dirin your config, ' \ 'you need to have Interpretation_dir column in your data separation' def has_inf(si): if self.samples_info.iloc[si]['label'] == 0: return True int_dir = self.samples_info.iloc[si]['Interpretation_dir'] return str(int_dir) != 'nan' return np.vectorize(has_inf)(samples_inds) def get_batch_has_bb_mask_from_bb_file(self, samples_inds: np.ndarray) \ -> np.ndarray: def has_inf(si): if self.samples_info.iloc[si]['label'] == 0: return True sample_key = self.samples_info.iloc[si]['sample'] sample_key = self._get_id_by_path(sample_key) return sample_key in self._infections_bbs and \ len(self._infections_bbs[sample_key][0]) > 0 return np.vectorize(has_inf)(samples_inds) def get_batch_bbs_map_from_mask_file(self, samples_inds: np.ndarray)\ -> np.ndarray: def read_interpretation(im_ind): # for healthy, return full 0 if self.samples_info.iloc[im_ind]['label'] == 0: return np.full((1, self.infection_map_size, self.infection_map_size), 0, dtype=np.float) int_dir = self.samples_info.iloc[im_ind]['Interpretation_dir'] # if it does not exist, return full -1 if str(int_dir) == 'nan': return np.full((1, self.infection_map_size, self.infection_map_size), -1, dtype=np.float) if 'npy' in int_dir: interpretation = np.load(int_dir) else: interpretation = np.load(int_dir)['arr_0'] interpretation = interpretation.astype(float) if interpretation.shape != (self.infection_map_size, self.infection_map_size): interpretation = (cv2.resize(np.round(255 * interpretation, 0), dsize=(self.infection_map_size, self.infection_map_size)) >= 128).astype(float) return interpretation[np.newaxis, :, :] batch_interpretations = np.stack(tuple([read_interpretation(si) for si in samples_inds]), axis=0) return batch_interpretations @staticmethod def _get_id_by_path(path: str) -> str: return path[path.index('png_images/'):] def get_batch_extended_bbs_map_from_bbs_file(self, radius, samples_inds: np.ndarray) -> np.ndarray: def make_map(im_ind): if self.samples_info.iloc[im_ind]['label'] == 0: return np.full((1, self.infection_map_size, self.infection_map_size), 0, dtype=np.float) sample_key = self.samples_info.iloc[im_ind]['sample'] sample_key = self._get_id_by_path(sample_key) if sample_key in self._infections_bbs and \ len(self._infections_bbs[sample_key][0]) > 0: bbs_info = self._infections_bbs[sample_key] start_points, end_points = bbs_info mask = generate_bb_map(start_points, end_points, (self.infection_map_size, self.infection_map_size), radius) return mask[np.newaxis, :, :] else: return np.full((1, self.infection_map_size, self.infection_map_size), -1, dtype=np.float) batch_interpretations = np.stack(tuple([make_map(si) for si in samples_inds]), axis=0) return batch_interpretations def get_batch_scaled_images(self, samples_inds: np.ndarray) \ -> np.ndarray: def read_img(im_ind): im = cv2.imread(self.samples_info.iloc[im_ind]['sample']) if im is None: print(self.samples_info.iloc[im_ind]['sample'] + ' is missing!', file=stderr) raise Exception('Missing image') if len(list(im.shape)) == 3: ret_val = im[:, :, 0] else: ret_val = im if ret_val.shape != (self.img_size, self.img_size): ret_val = cv2.resize(ret_val, dsize=(self.img_size, self.img_size)) return ret_val[np.newaxis, :, :] batch_imgs = np.stack(tuple([read_img(si) for si in samples_inds]), axis=0) batch_imgs = batch_imgs.astype(np.float32) / 255 return batch_imgs def get_batch_label(self, samples_inds: np.ndarray, ) \ -> np.ndarray: return self.samples_info['label'].iloc[samples_inds].values.astype(np.float32) def get_batch_true_interpretations(self, samples_inds: np.ndarray) \ -> np.ndarray: def read_interpretation(im_ind): # for healthy, return full 0 if self.samples_info.iloc[im_ind]['label'] == 0: return np.full((1, self.img_size, self.img_size), 0, dtype=np.float) int_dir = self.samples_info.iloc[im_ind]['Interpretation_dir'] if str(int_dir) == 'nan': return np.full((1, self.img_size, self.img_size), np.nan, dtype=np.float) if 'npy' in int_dir: interpretation = np.load(int_dir) else: interpretation = np.load(int_dir)['arr_0'].astype(int) if interpretation.shape != (self.img_size, self.img_size): interpretation = (cv2.resize(np.round(255 * interpretation, 0).astype(np.uint8), dsize=(self.img_size, self.img_size)) >= 128).astype(float) return interpretation[np.newaxis, :, :] batch_interpretations = np.stack(tuple([read_interpretation(si) for si in samples_inds]), axis=0) return batch_interpretations