You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rsna_configs.py 1.9KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. from .base_config import BaseConfig, PhaseType
  2. from typing import Dict, List, Optional
  3. from torch import nn
  4. from torchvision import transforms
  5. from ..utils.random_brightness_augmentation import ClippedBrightnessAugment
  6. from ..data.content_loaders.rsna_loader import RSNALoader
  7. from ..model_evaluation.binary_evaluator import BinaryEvaluator
  8. class RSNAConfigs(BaseConfig):
  9. def __init__(self, try_name: str, try_num: int, input_size, phase_type: PhaseType) -> None:
  10. super().__init__(try_name, try_num, f'dataset_metadata/RSNA/DataSeparation_R{input_size}', input_size, phase_type, RSNALoader, BinaryEvaluator)
  11. # replaced configs!
  12. self.batch_size = 64
  13. self.max_epochs = 200
  14. self.label_map_dict: Dict[str, int] = {
  15. 'healthy': 0, 'pneumonia': 1
  16. }
  17. self.augmentations_dict = {
  18. 'x': nn.Sequential(
  19. transforms.RandomRotation(45),
  20. transforms.RandomAffine(0, shear=0.4),
  21. transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4)),
  22. ClippedBrightnessAugment(0.5, 1.5, 0, 1)),
  23. 'infection': nn.Sequential(
  24. transforms.RandomRotation(45),
  25. transforms.RandomAffine(0, shear=0.4),
  26. transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4))),
  27. 'interpretations': nn.Sequential(
  28. transforms.RandomRotation(45),
  29. transforms.RandomAffine(0, shear=0.4),
  30. transforms.RandomResizedCrop(self.input_size, scale=(0.6, 1.4))),
  31. }
  32. self.title_of_reference_metric_to_choose_best_epoch = 'BAcc'
  33. self.operator_to_decide_on_improvement_of_val_reference_metric = '>='
  34. self.keep_best_and_last_epochs_only = True
  35. self.infection_map_size = self.input_size
  36. self.receptive_field_radii: List[int] = []
  37. self.infections_bb_dir: Optional[str] = 'data/RSNA/infection_bounding_boxs.tsv'