from typing import TYPE_CHECKING, Union, Dict, Tuple, Type, List, Callable, Iterator, Optional import os import random import numpy as np from sys import stderr from enum import Enum import torch from ..interpreting.interpreter_maker import InterpretationType from ..data.batch_choosing.class_balanced_shuffled_sequential import ClassBalancedShuffledSequentialBatchChooser from ..data.batch_choosing.sequential import SequentialBatchChooser from ..utils.hooker import HookPair if TYPE_CHECKING: from ..model_evaluation.evaluator import Evaluator from ..data.batch_choosing.batch_chooser import BatchChooser from ..data.content_loaders.content_loader import ContentLoader class PhaseType(Enum): TRAIN = 'train' EVAL = 'eval' INTERPRET = 'interpret' EVALINTERPRET = 'evalinterpret' class BaseConfig: def __init__(self, try_name: str, try_num: int, data_separation: str, input_size: int, phase_type: PhaseType, content_loader_cls: Type['ContentLoader'], evaluator_cls: Type['Evaluator']) -> None: # -------------------- ESSENTIAL CONFIG -------------------- self.data_separation = data_separation self.try_name = try_name self.try_num = try_num self.phase_type = phase_type # classes self.evaluator_cls = evaluator_cls self.content_loader_cls = content_loader_cls self.train_batch_chooser_cls: Type['BatchChooser'] = ClassBalancedShuffledSequentialBatchChooser self.eval_batch_chooser_cls: Type['BatchChooser'] = SequentialBatchChooser # Setups self.input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = input_size self.batch_size = 64 self.random_seed: int = 17 self.samples_dir: Union[str, None] = None # Saving and loading directories! self.save_dir: Union[str, None] = None self.report_dir: Union[str, None] = None self.final_model_dir: Union[str, None] = None self.epoch: Union[str, None] = None # Device setting self.dev_name = None self.device = None # -------------------- ESSENTIAL CONFIG -------------------- # -------------------- TRAINING CONFIG -------------------- self.pretrained_model_file = None self.big_batch_size: int = 1 self.max_epochs: int = 10 self.iters_per_epoch: Union[int, None] = None self.val_iters_per_epoch: Union[int, None] = None # cleaning log while training self.keep_best_and_last_epochs_only = False # auto-stopping the train self.n_unsuccessful_epochs_to_stop: Union[None, int] = 100 self.different_augmentation_per_batch = True self.augmentations_dict: Union[Dict[str, torch.nn.Module], None] = None self.title_of_reference_metric_to_choose_best_epoch = 'Loss' self.operator_to_decide_on_improvement_of_val_reference_metric = '<=' # for dataloader self.mapped_labels_to_use: Union[List[int], None] = None # None means all, a list of ints= the ones to consider # hooking self.hooks_by_phase: Dict[PhaseType, List[HookPair]] = { phasetype: [] for phasetype in PhaseType } # -------------------- TRAINING CONFIG -------------------- # ------------------ OPTIMIZATION CONFIG ------------------ self.init_lr: float = 1e-4 self.lr_decay: float = 1e-6 self.freezing_regexes: Union[None, List[str]] = None self.optimizer_creator: Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer] = get_adam_creator(self.init_lr, self.lr_decay) # ------------------ OPTIMIZATION CONFIG ------------------ # ----------------- INTERPRETATION CONFIG ----------------- self.save_by_file_name = False self.interpretation_method = InterpretationType.GuidedBackprop self.class_label_for_interpretation: Union[int, None] = None self.interpret_predictions_vs_gt: bool = True # saving interpretations: self.skip_overlay = False self.skip_raw = False # FOR INTERPRETATION EVALUATION # A threshold to cut all interpretations lower than self.cut_threshold: float = 0 # whether the threshold is applied to un-normalized interpretation map self.global_threshold: bool = False # whether to use dynamic threshold self.dynamic_threshold: bool = False self.prediction_key_in_model_output_dict = 'positive_class_probability' # for when interpretation masks as continuous objects are compared with GT self.acceptable_min_intersection_threshold: float = 0.5 # the key for interpretation to be evaluated, None = the first one available self.interpretation_tag_to_evaluate = None self.n_interpretation_samples: Optional[int] = None # ----------------- INTERPRETATION CONFIG ----------------- def __repr__(self): """ Represent config as string """ return str(self.__dict__) def _update_based_on_args(self, args) -> None: """ Updates the values based on the received arguments from the input! """ args_dict = args.__dict__ for arg in args_dict: if arg in self.__dict__ and args_dict[arg] is not None: setattr(self, arg, args_dict[arg]) def update(self, args): """ updates the config from args """ # dirty piece of code for repetition self._update_based_on_args(args) self._update_based_on_args(args) self.dev_name, self.device = _get_device(args) self._set_random_seeds() self._set_save_dir() self._update_report_dir() def _set_random_seeds(self): # Set the seed for hash based operations in python os.environ['PYTHONHASHSEED'] = '0' torch.manual_seed(self.random_seed) torch.cuda.manual_seed_all(self.random_seed) np.random.seed(self.random_seed) random.seed(self.random_seed) if self.phase_type != PhaseType.TRAIN: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False else: torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = True # Set the numpy seed np.random.seed(self.random_seed) def _set_save_dir(self): if self.save_dir is None: self.save_dir = f"../Results/{self.try_num}_{self.try_name}" os.makedirs(self.save_dir, exist_ok=True) def _update_report_dir(self): if self.report_dir is None: self.report_dir = self.save_dir + '/' + self.phase_type.value def get_sample_group_specific_report_dir(self, samples_specification, extra_subdir=None) -> str: if extra_subdir is None: extra_subdir = '' else: extra_subdir = '/' + extra_subdir # Making sure of keeping version information! if self.final_model_dir is not None: report_dir = '%s/%s%s' % ( self.report_dir, self.final_model_dir[self.final_model_dir.rfind('/') + 1:self.final_model_dir.rfind('.')], extra_subdir) else: epoch = self.epoch if epoch is None: epoch = 'best' report_dir = '%s/epoch,%s%s' % ( self.report_dir, epoch, extra_subdir) # adding sample info if samples_specification not in ['train', 'test', 'val'] and not os.path.isfile(samples_specification): if ':' in samples_specification: base_info, _ = tuple(samples_specification.split(':')) else: base_info = samples_specification if not os.path.isdir(base_info): report_dir = report_dir + '/' + \ samples_specification.replace(':', '_').replace('../', '') else: report_dir = report_dir + '/' + samples_specification return self.get_unique_save_dir(report_dir) @staticmethod def get_unique_save_dir(sd): if os.path.exists(sd): report_num = 2 base_sd = sd while os.path.exists(sd): sd = base_sd + '_%d' % report_num report_num += 1 return sd def get_save_dir_for_sample(self, save_dir, sample_name): if self.save_by_file_name and '/' in sample_name: return save_dir + '/' + sample_name[sample_name.rfind('/') + 1:] else: return save_dir + '/' + sample_name def _get_device(args): """ cuda num, -1 means parallel, -2 means cpu, no cuda means cpu""" dev_name = args.device if 'cuda' in dev_name: gpu_num = 0 if ':' in dev_name: gpu_num = int(dev_name.split(':')[1]) if gpu_num >= torch.cuda.device_count(): print('No %s, using CPU instead!' % dev_name, file=stderr) dev_name = 'cpu' print('@ Running on CPU @', flush=True) if dev_name == 'cuda': dev_name = None the_device = torch.device('cuda') print('@ Running on all %d GPUs @' % torch.cuda.device_count(), flush=True) elif 'cuda:' in dev_name: the_device = torch.device(dev_name) print('@ Running on GPU:%d @' % int(dev_name.split(':')[1]), flush=True) else: dev_name = 'cpu' the_device = torch.device(dev_name) print('@ Running on CPU @', flush=True) return dev_name, the_device def get_adam_creator(init_lr: float = 1e-4, lr_decay: float = 1e-6)\ -> Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer]: def create_adam(params: Iterator[torch.nn.Parameter]) -> torch.optim.Optimizer: return torch.optim.Adam(params, lr=init_lr, weight_decay=lr_decay) return create_adam def get_sgd_creator(init_lr: float = 1e-4, lr_decay: float = 1e-6)\ -> Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer]: def create_sgd(params: Iterator[torch.nn.Parameter]) -> torch.optim.Optimizer: return torch.optim.SGD(params, lr=init_lr, weight_decay=lr_decay) return create_sgd