from .base_config import BaseConfig, PhaseType from typing import Dict, List, Optional from torch import nn from torchvision import transforms from ..utils.random_brightness_augmentation import ClippedBrightnessAugment from ..data.content_loaders.rsna_loader import RSNALoader from ..model_evaluation.binary_evaluator import BinaryEvaluator class RSNAConfigs(BaseConfig): def __init__(self, try_name: str, try_num: int, input_size, phase_type: PhaseType) -> None: super().__init__(try_name, try_num, f'dataset_metadata/RSNA/DataSeparation_R{input_size}', input_size, phase_type, RSNALoader, BinaryEvaluator) # replaced configs! self.batch_size = 64 self.max_epochs = 200 self.label_map_dict: Dict[str, int] = { 'healthy': 0, 'pneumonia': 1 } self.augmentations_dict = { 'x': nn.Sequential( transforms.RandomRotation(45), transforms.RandomAffine(0, shear=0.4), transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4)), ClippedBrightnessAugment(0.5, 1.5, 0, 1)), 'infection': nn.Sequential( transforms.RandomRotation(45), transforms.RandomAffine(0, shear=0.4), transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4))), 'interpretations': nn.Sequential( transforms.RandomRotation(45), transforms.RandomAffine(0, shear=0.4), transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4))), } self.title_of_reference_metric_to_choose_best_epoch = 'BAcc' self.operator_to_decide_on_improvement_of_val_reference_metric = '>=' self.keep_best_and_last_epochs_only = True self.infection_map_size = self.input_size self.receptive_field_radii: List[int] = [] self.infections_bb_dir: Optional[str] = 'data/RSNA/infection_bounding_boxs.tsv'