1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- from .base_config import BaseConfig, PhaseType
- from typing import Dict, List, Optional
- from torch import nn
- from torchvision import transforms
- from ..utils.random_brightness_augmentation import ClippedBrightnessAugment
- from ..data.content_loaders.rsna_loader import RSNALoader
- from ..model_evaluation.binary_evaluator import BinaryEvaluator
-
- class RSNAConfigs(BaseConfig):
-
- def __init__(self, try_name: str, try_num: int, input_size, phase_type: PhaseType) -> None:
- super().__init__(try_name, try_num, f'dataset_metadata/RSNA/DataSeparation_R{input_size}', input_size, phase_type, RSNALoader, BinaryEvaluator)
-
- # replaced configs!
- self.batch_size = 64
- self.max_epochs = 200
-
- self.label_map_dict: Dict[str, int] = {
- 'healthy': 0, 'pneumonia': 1
- }
-
- self.augmentations_dict = {
- 'x': nn.Sequential(
- transforms.RandomRotation(45),
- transforms.RandomAffine(0, shear=0.4),
- transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4)),
- ClippedBrightnessAugment(0.5, 1.5, 0, 1)),
-
- 'infection': nn.Sequential(
- transforms.RandomRotation(45),
- transforms.RandomAffine(0, shear=0.4),
- transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4))),
-
- 'interpretations': nn.Sequential(
- transforms.RandomRotation(45),
- transforms.RandomAffine(0, shear=0.4),
- transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4))),
- }
-
- self.title_of_reference_metric_to_choose_best_epoch = 'BAcc'
- self.operator_to_decide_on_improvement_of_val_reference_metric = '>='
- self.keep_best_and_last_epochs_only = True
-
- self.infection_map_size = self.input_size
- self.receptive_field_radii: List[int] = []
-
- self.infections_bb_dir: Optional[str] = 'data/RSNA/infection_bounding_boxs.tsv'
|