123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304 |
- from typing import TYPE_CHECKING, Union, Dict, Tuple, Type, List, Callable, Iterator, Optional
- import os
- import random
- import numpy as np
- from sys import stderr
- from enum import Enum
-
- import torch
-
- from ..interpreting.interpreter_maker import InterpretationType
- from ..data.batch_choosing.class_balanced_shuffled_sequential import ClassBalancedShuffledSequentialBatchChooser
- from ..data.batch_choosing.sequential import SequentialBatchChooser
- from ..utils.hooker import HookPair
- if TYPE_CHECKING:
- from ..model_evaluation.evaluator import Evaluator
- from ..data.batch_choosing.batch_chooser import BatchChooser
- from ..data.content_loaders.content_loader import ContentLoader
-
-
- class PhaseType(Enum):
- TRAIN = 'train'
- EVAL = 'eval'
- INTERPRET = 'interpret'
- EVALINTERPRET = 'evalinterpret'
-
-
- class BaseConfig:
-
- def __init__(self,
- try_name: str, try_num: int, data_separation: str,
- input_size: int,
- phase_type: PhaseType,
- content_loader_cls: Type['ContentLoader'],
- evaluator_cls: Type['Evaluator']) -> None:
-
- # -------------------- ESSENTIAL CONFIG --------------------
-
- self.data_separation = data_separation
- self.try_name = try_name
- self.try_num = try_num
- self.phase_type = phase_type
-
- # classes
-
- self.evaluator_cls = evaluator_cls
- self.content_loader_cls = content_loader_cls
- self.train_batch_chooser_cls: Type['BatchChooser'] = ClassBalancedShuffledSequentialBatchChooser
- self.eval_batch_chooser_cls: Type['BatchChooser'] = SequentialBatchChooser
-
- # Setups
-
- self.input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = input_size
- self.batch_size = 64
- self.random_seed: int = 17
- self.samples_dir: Union[str, None] = None
-
- # Saving and loading directories!
-
- self.save_dir: Union[str, None] = None
- self.report_dir: Union[str, None] = None
- self.final_model_dir: Union[str, None] = None
- self.epoch: Union[str, None] = None
-
- # Device setting
-
- self.dev_name = None
- self.device = None
-
- # -------------------- ESSENTIAL CONFIG --------------------
-
- # -------------------- TRAINING CONFIG --------------------
-
- self.pretrained_model_file = None
-
- self.big_batch_size: int = 1
- self.max_epochs: int = 10
- self.iters_per_epoch: Union[int, None] = None
- self.val_iters_per_epoch: Union[int, None] = None
-
- # cleaning log while training
- self.keep_best_and_last_epochs_only = False
-
- # auto-stopping the train
- self.n_unsuccessful_epochs_to_stop: Union[None, int] = 100
-
- self.different_augmentation_per_batch = True
- self.augmentations_dict: Union[Dict[str, torch.nn.Module], None] = None
-
- self.title_of_reference_metric_to_choose_best_epoch = 'Loss'
- self.operator_to_decide_on_improvement_of_val_reference_metric = '<='
-
- # for dataloader
- self.mapped_labels_to_use: Union[List[int], None] = None # None means all, a list of ints= the ones to consider
-
- # hooking
- self.hooks_by_phase: Dict[PhaseType, List[HookPair]] = {
- phasetype: [] for phasetype in PhaseType
- }
-
- # -------------------- TRAINING CONFIG --------------------
-
- # ------------------ OPTIMIZATION CONFIG ------------------
-
- self.init_lr: float = 1e-4
- self.lr_decay: float = 1e-6
- self.freezing_regexes: Union[None, List[str]] = None
- self.optimizer_creator: Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer] = get_adam_creator(self.init_lr, self.lr_decay)
-
- # ------------------ OPTIMIZATION CONFIG ------------------
-
- # ----------------- INTERPRETATION CONFIG -----------------
-
- self.save_by_file_name = False
- self.interpretation_method = InterpretationType.GuidedBackprop
- self.class_label_for_interpretation: Union[int, None] = None
- self.interpret_predictions_vs_gt: bool = True
-
- # saving interpretations:
-
- self.skip_overlay = False
- self.skip_raw = False
-
- # FOR INTERPRETATION EVALUATION
-
- # A threshold to cut all interpretations lower than
- self.cut_threshold: float = 0
-
- # whether the threshold is applied to un-normalized interpretation map
- self.global_threshold: bool = False
-
- # whether to use dynamic threshold
- self.dynamic_threshold: bool = False
-
- self.prediction_key_in_model_output_dict = 'positive_class_probability'
-
- # for when interpretation masks as continuous objects are compared with GT
- self.acceptable_min_intersection_threshold: float = 0.5
-
- # the key for interpretation to be evaluated, None = the first one available
- self.interpretation_tag_to_evaluate = None
-
- self.n_interpretation_samples: Optional[int] = None
-
- # ----------------- INTERPRETATION CONFIG -----------------
-
- def __repr__(self):
- """
- Represent config as string
- """
- return str(self.__dict__)
-
- def _update_based_on_args(self, args) -> None:
- """ Updates the values based on the received arguments from the input! """
- args_dict = args.__dict__
- for arg in args_dict:
- if arg in self.__dict__ and args_dict[arg] is not None:
- setattr(self, arg, args_dict[arg])
-
- def update(self, args):
- """
- updates the config from args
- """
-
- # dirty piece of code for repetition
- self._update_based_on_args(args)
- self._update_based_on_args(args)
-
- self.dev_name, self.device = _get_device(args)
- self._set_random_seeds()
- self._set_save_dir()
-
- self._update_report_dir()
-
- def _set_random_seeds(self):
- # Set the seed for hash based operations in python
- os.environ['PYTHONHASHSEED'] = '0'
-
- torch.manual_seed(self.random_seed)
- torch.cuda.manual_seed_all(self.random_seed)
- np.random.seed(self.random_seed)
- random.seed(self.random_seed)
-
- if self.phase_type != PhaseType.TRAIN:
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
- else:
- torch.backends.cudnn.deterministic = False
- torch.backends.cudnn.benchmark = True
-
- # Set the numpy seed
- np.random.seed(self.random_seed)
-
- def _set_save_dir(self):
- if self.save_dir is None:
- self.save_dir = f"../Results/{self.try_num}_{self.try_name}"
- os.makedirs(self.save_dir, exist_ok=True)
-
- def _update_report_dir(self):
- if self.report_dir is None:
- self.report_dir = self.save_dir + '/' + self.phase_type.value
-
- def get_sample_group_specific_report_dir(self, samples_specification, extra_subdir=None) -> str:
-
- if extra_subdir is None:
- extra_subdir = ''
- else:
- extra_subdir = '/' + extra_subdir
-
- # Making sure of keeping version information!
-
- if self.final_model_dir is not None:
- report_dir = '%s/%s%s' % (
- self.report_dir,
- self.final_model_dir[self.final_model_dir.rfind('/') + 1:self.final_model_dir.rfind('.')],
- extra_subdir)
-
- else:
- epoch = self.epoch
- if epoch is None:
- epoch = 'best'
- report_dir = '%s/epoch,%s%s' % (
- self.report_dir,
- epoch, extra_subdir)
-
- # adding sample info
-
- if samples_specification not in ['train', 'test', 'val'] and not os.path.isfile(samples_specification):
- if ':' in samples_specification:
- base_info, _ = tuple(samples_specification.split(':'))
- else:
- base_info = samples_specification
-
- if not os.path.isdir(base_info):
- report_dir = report_dir + '/' + \
- samples_specification.replace(':', '_').replace('../', '')
- else:
- report_dir = report_dir + '/' + samples_specification
-
- return self.get_unique_save_dir(report_dir)
-
- @staticmethod
- def get_unique_save_dir(sd):
- if os.path.exists(sd):
-
- report_num = 2
- base_sd = sd
-
- while os.path.exists(sd):
- sd = base_sd + '_%d' % report_num
- report_num += 1
-
- return sd
-
- def get_save_dir_for_sample(self, save_dir, sample_name):
- if self.save_by_file_name and '/' in sample_name:
- return save_dir + '/' + sample_name[sample_name.rfind('/') + 1:]
- else:
- return save_dir + '/' + sample_name
-
-
- def _get_device(args):
- """ cuda num, -1 means parallel, -2 means cpu, no cuda means cpu"""
- dev_name = args.device
-
- if 'cuda' in dev_name:
- gpu_num = 0
- if ':' in dev_name:
- gpu_num = int(dev_name.split(':')[1])
-
- if gpu_num >= torch.cuda.device_count():
- print('No %s, using CPU instead!' % dev_name, file=stderr)
- dev_name = 'cpu'
- print('@ Running on CPU @', flush=True)
-
- if dev_name == 'cuda':
- dev_name = None
- the_device = torch.device('cuda')
- print('@ Running on all %d GPUs @' % torch.cuda.device_count(), flush=True)
-
- elif 'cuda:' in dev_name:
- the_device = torch.device(dev_name)
- print('@ Running on GPU:%d @' % int(dev_name.split(':')[1]), flush=True)
-
- else:
- dev_name = 'cpu'
- the_device = torch.device(dev_name)
- print('@ Running on CPU @', flush=True)
-
- return dev_name, the_device
-
-
- def get_adam_creator(init_lr: float = 1e-4, lr_decay: float = 1e-6)\
- -> Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer]:
- def create_adam(params: Iterator[torch.nn.Parameter]) -> torch.optim.Optimizer:
- return torch.optim.Adam(params, lr=init_lr, weight_decay=lr_decay)
- return create_adam
-
-
- def get_sgd_creator(init_lr: float = 1e-4, lr_decay: float = 1e-6)\
- -> Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer]:
- def create_sgd(params: Iterator[torch.nn.Parameter]) -> torch.optim.Optimizer:
- return torch.optim.SGD(params, lr=init_lr, weight_decay=lr_decay)
- return create_sgd
|