You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_config.py 10KB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. from typing import TYPE_CHECKING, Union, Dict, Tuple, Type, List, Callable, Iterator, Optional
  2. import os
  3. import random
  4. import numpy as np
  5. from sys import stderr
  6. from enum import Enum
  7. import torch
  8. from ..interpreting.interpreter_maker import InterpretationType
  9. from ..data.batch_choosing.class_balanced_shuffled_sequential import ClassBalancedShuffledSequentialBatchChooser
  10. from ..data.batch_choosing.sequential import SequentialBatchChooser
  11. from ..utils.hooker import HookPair
  12. if TYPE_CHECKING:
  13. from ..model_evaluation.evaluator import Evaluator
  14. from ..data.batch_choosing.batch_chooser import BatchChooser
  15. from ..data.content_loaders.content_loader import ContentLoader
  16. class PhaseType(Enum):
  17. TRAIN = 'train'
  18. EVAL = 'eval'
  19. INTERPRET = 'interpret'
  20. EVALINTERPRET = 'evalinterpret'
  21. class BaseConfig:
  22. def __init__(self,
  23. try_name: str, try_num: int, data_separation: str,
  24. input_size: int,
  25. phase_type: PhaseType,
  26. content_loader_cls: Type['ContentLoader'],
  27. evaluator_cls: Type['Evaluator']) -> None:
  28. # -------------------- ESSENTIAL CONFIG --------------------
  29. self.data_separation = data_separation
  30. self.try_name = try_name
  31. self.try_num = try_num
  32. self.phase_type = phase_type
  33. # classes
  34. self.evaluator_cls = evaluator_cls
  35. self.content_loader_cls = content_loader_cls
  36. self.train_batch_chooser_cls: Type['BatchChooser'] = ClassBalancedShuffledSequentialBatchChooser
  37. self.eval_batch_chooser_cls: Type['BatchChooser'] = SequentialBatchChooser
  38. # Setups
  39. self.input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = input_size
  40. self.batch_size = 64
  41. self.random_seed: int = 17
  42. self.samples_dir: Union[str, None] = None
  43. # Saving and loading directories!
  44. self.save_dir: Union[str, None] = None
  45. self.report_dir: Union[str, None] = None
  46. self.final_model_dir: Union[str, None] = None
  47. self.epoch: Union[str, None] = None
  48. # Device setting
  49. self.dev_name = None
  50. self.device = None
  51. # -------------------- ESSENTIAL CONFIG --------------------
  52. # -------------------- TRAINING CONFIG --------------------
  53. self.pretrained_model_file = None
  54. self.big_batch_size: int = 1
  55. self.max_epochs: int = 10
  56. self.iters_per_epoch: Union[int, None] = None
  57. self.val_iters_per_epoch: Union[int, None] = None
  58. # cleaning log while training
  59. self.keep_best_and_last_epochs_only = False
  60. # auto-stopping the train
  61. self.n_unsuccessful_epochs_to_stop: Union[None, int] = 100
  62. self.different_augmentation_per_batch = True
  63. self.augmentations_dict: Union[Dict[str, torch.nn.Module], None] = None
  64. self.title_of_reference_metric_to_choose_best_epoch = 'Loss'
  65. self.operator_to_decide_on_improvement_of_val_reference_metric = '<='
  66. # for dataloader
  67. self.mapped_labels_to_use: Union[List[int], None] = None # None means all, a list of ints= the ones to consider
  68. # hooking
  69. self.hooks_by_phase: Dict[PhaseType, List[HookPair]] = {
  70. phasetype: [] for phasetype in PhaseType
  71. }
  72. # -------------------- TRAINING CONFIG --------------------
  73. # ------------------ OPTIMIZATION CONFIG ------------------
  74. self.init_lr: float = 1e-4
  75. self.lr_decay: float = 1e-6
  76. self.freezing_regexes: Union[None, List[str]] = None
  77. self.optimizer_creator: Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer] = get_adam_creator(self.init_lr, self.lr_decay)
  78. # ------------------ OPTIMIZATION CONFIG ------------------
  79. # ----------------- INTERPRETATION CONFIG -----------------
  80. self.save_by_file_name = False
  81. self.interpretation_method = InterpretationType.GuidedBackprop
  82. self.class_label_for_interpretation: Union[int, None] = None
  83. self.interpret_predictions_vs_gt: bool = True
  84. # saving interpretations:
  85. self.skip_overlay = False
  86. self.skip_raw = False
  87. # FOR INTERPRETATION EVALUATION
  88. # A threshold to cut all interpretations lower than
  89. self.cut_threshold: float = 0
  90. # whether the threshold is applied to un-normalized interpretation map
  91. self.global_threshold: bool = False
  92. # whether to use dynamic threshold
  93. self.dynamic_threshold: bool = False
  94. self.prediction_key_in_model_output_dict = 'positive_class_probability'
  95. # for when interpretation masks as continuous objects are compared with GT
  96. self.acceptable_min_intersection_threshold: float = 0.5
  97. # the key for interpretation to be evaluated, None = the first one available
  98. self.interpretation_tag_to_evaluate = None
  99. self.n_interpretation_samples: Optional[int] = None
  100. # ----------------- INTERPRETATION CONFIG -----------------
  101. def __repr__(self):
  102. """
  103. Represent config as string
  104. """
  105. return str(self.__dict__)
  106. def _update_based_on_args(self, args) -> None:
  107. """ Updates the values based on the received arguments from the input! """
  108. args_dict = args.__dict__
  109. for arg in args_dict:
  110. if arg in self.__dict__ and args_dict[arg] is not None:
  111. setattr(self, arg, args_dict[arg])
  112. def update(self, args):
  113. """
  114. updates the config from args
  115. """
  116. # dirty piece of code for repetition
  117. self._update_based_on_args(args)
  118. self._update_based_on_args(args)
  119. self.dev_name, self.device = _get_device(args)
  120. self._set_random_seeds()
  121. self._set_save_dir()
  122. self._update_report_dir()
  123. def _set_random_seeds(self):
  124. # Set the seed for hash based operations in python
  125. os.environ['PYTHONHASHSEED'] = '0'
  126. torch.manual_seed(self.random_seed)
  127. torch.cuda.manual_seed_all(self.random_seed)
  128. np.random.seed(self.random_seed)
  129. random.seed(self.random_seed)
  130. if self.phase_type != PhaseType.TRAIN:
  131. torch.backends.cudnn.deterministic = True
  132. torch.backends.cudnn.benchmark = False
  133. else:
  134. torch.backends.cudnn.deterministic = False
  135. torch.backends.cudnn.benchmark = True
  136. # Set the numpy seed
  137. np.random.seed(self.random_seed)
  138. def _set_save_dir(self):
  139. if self.save_dir is None:
  140. self.save_dir = f"../Results/{self.try_num}_{self.try_name}"
  141. os.makedirs(self.save_dir, exist_ok=True)
  142. def _update_report_dir(self):
  143. if self.report_dir is None:
  144. self.report_dir = self.save_dir + '/' + self.phase_type.value
  145. def get_sample_group_specific_report_dir(self, samples_specification, extra_subdir=None) -> str:
  146. if extra_subdir is None:
  147. extra_subdir = ''
  148. else:
  149. extra_subdir = '/' + extra_subdir
  150. # Making sure of keeping version information!
  151. if self.final_model_dir is not None:
  152. report_dir = '%s/%s%s' % (
  153. self.report_dir,
  154. self.final_model_dir[self.final_model_dir.rfind('/') + 1:self.final_model_dir.rfind('.')],
  155. extra_subdir)
  156. else:
  157. epoch = self.epoch
  158. if epoch is None:
  159. epoch = 'best'
  160. report_dir = '%s/epoch,%s%s' % (
  161. self.report_dir,
  162. epoch, extra_subdir)
  163. # adding sample info
  164. if samples_specification not in ['train', 'test', 'val'] and not os.path.isfile(samples_specification):
  165. if ':' in samples_specification:
  166. base_info, _ = tuple(samples_specification.split(':'))
  167. else:
  168. base_info = samples_specification
  169. if not os.path.isdir(base_info):
  170. report_dir = report_dir + '/' + \
  171. samples_specification.replace(':', '_').replace('../', '')
  172. else:
  173. report_dir = report_dir + '/' + samples_specification
  174. return self.get_unique_save_dir(report_dir)
  175. @staticmethod
  176. def get_unique_save_dir(sd):
  177. if os.path.exists(sd):
  178. report_num = 2
  179. base_sd = sd
  180. while os.path.exists(sd):
  181. sd = base_sd + '_%d' % report_num
  182. report_num += 1
  183. return sd
  184. def get_save_dir_for_sample(self, save_dir, sample_name):
  185. if self.save_by_file_name and '/' in sample_name:
  186. return save_dir + '/' + sample_name[sample_name.rfind('/') + 1:]
  187. else:
  188. return save_dir + '/' + sample_name
  189. def _get_device(args):
  190. """ cuda num, -1 means parallel, -2 means cpu, no cuda means cpu"""
  191. dev_name = args.device
  192. if 'cuda' in dev_name:
  193. gpu_num = 0
  194. if ':' in dev_name:
  195. gpu_num = int(dev_name.split(':')[1])
  196. if gpu_num >= torch.cuda.device_count():
  197. print('No %s, using CPU instead!' % dev_name, file=stderr)
  198. dev_name = 'cpu'
  199. print('@ Running on CPU @', flush=True)
  200. if dev_name == 'cuda':
  201. dev_name = None
  202. the_device = torch.device('cuda')
  203. print('@ Running on all %d GPUs @' % torch.cuda.device_count(), flush=True)
  204. elif 'cuda:' in dev_name:
  205. the_device = torch.device(dev_name)
  206. print('@ Running on GPU:%d @' % int(dev_name.split(':')[1]), flush=True)
  207. else:
  208. dev_name = 'cpu'
  209. the_device = torch.device(dev_name)
  210. print('@ Running on CPU @', flush=True)
  211. return dev_name, the_device
  212. def get_adam_creator(init_lr: float = 1e-4, lr_decay: float = 1e-6)\
  213. -> Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer]:
  214. def create_adam(params: Iterator[torch.nn.Parameter]) -> torch.optim.Optimizer:
  215. return torch.optim.Adam(params, lr=init_lr, weight_decay=lr_decay)
  216. return create_adam
  217. def get_sgd_creator(init_lr: float = 1e-4, lr_decay: float = 1e-6)\
  218. -> Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer]:
  219. def create_sgd(params: Iterator[torch.nn.Parameter]) -> torch.optim.Optimizer:
  220. return torch.optim.SGD(params, lr=init_lr, weight_decay=lr_decay)
  221. return create_sgd