""" The Model base class """ from collections import OrderedDict import inspect import typing from typing import Any, Dict, List, Tuple import re from abc import ABC import torch from torch.nn import BatchNorm2d, BatchNorm1d, BatchNorm3d ModelIO = Dict[str, torch.Tensor] class Model(torch.nn.Module, ABC): """ The Model base class """ _frozen_bns_list: List[torch.nn.Module] = [] def init_weights_from_other_model(self, pretrained_model_dir=None): """ Initializes the weights from another model with the given address, Default is to set all the parameters with the same name and shape, but can be rewritten in subclasses. The model to load is received via the function get_other_model_to_load_from. The default value is the same model!""" if pretrained_model_dir is None: print('The model was not preinitialized.', flush=True) return other_model_dict = torch.load(pretrained_model_dir) own_state = self.state_dict() cnt = 0 skip_cnt = 0 for name, param in other_model_dict.items(): if isinstance(param, torch.nn.Parameter): # backwards compatibility for serialized parameters param = param.data if name in own_state and own_state[name].data.shape == param.shape: own_state[name].copy_(param) cnt += 1 else: skip_cnt += 1 print(f'{cnt} out of {len(own_state)} parameters were loaded successfully, {skip_cnt} of the given set were skipped.' , flush=True) def freeze_parameters(self, regex_list: List[str] = None) -> None: """ An auxiliary function for freezing the parameters of the model using an inclusion keywords list and exclusion keywords list. :param regex_list: All the parameters matching at least one of the given regexes would be freezed. None means no parameter would be frozen. :return: None """ # One solution for batchnorms is registering a preforward hook # to call eval everytime before running them # But that causes a problem, we can't dynamically change the frozen batchnorms during training # e.g. in layerwise unfreezing. # So it's better to set a dynamic attribute that keeps their list, rather than having an init # Init causes other problems in multi-inheritance if regex_list is None: print('No parameter was freezeed!', flush=True) return regex_list = [re.compile(the_pattern) for the_pattern in regex_list] def check_freeze_conditions(param_name): for the_regex in regex_list: if the_regex.fullmatch(param_name) is not None: return True return False frozen_cnt = 0 total_cnt = 0 frozen_modules_names = dict() for n, p in self.named_parameters(): total_cnt += 1 if check_freeze_conditions(n): p.requires_grad = False frozen_cnt += 1 frozen_modules_names[n[:n.rfind('.')]] = True self._frozen_bns_list = [] for module_name, module in self.named_modules(): if (module_name in frozen_modules_names) and \ (isinstance(module, BatchNorm2d) or isinstance(module, BatchNorm1d) or isinstance(module, BatchNorm3d)): self._frozen_bns_list.append(module) print('********************* NOTICE *********************') print('%d out of %d parameters were frozen!' % (frozen_cnt, total_cnt)) print( '%d BatchNorm layers will be frozen by being kept in the val mode, IF YOU HAVE NOT OVERRIDEN IT IN YOUR MAIN MODEL!' % len( self._frozen_bns_list)) print('***************************************************', flush=True) def train(self, mode: bool = True): super(Model, self).train(mode) for bn in self._frozen_bns_list: bn.train(False) return self @property def additional_kwargs(self) -> typing.OrderedDict[str, bool]: r""" Returns a dictionary from additional `kwargs` names to their optionality """ return OrderedDict({}) def get_forward_required_kws_kwargs_and_defaults(self) -> Tuple[List[str], List[Any], Dict[str, Any]]: """Returns the names of the keyword arguments required in the forward function and the list of the default values for the ones having them. The second list might be shorter than the first and it represents the default values from the end of the arguments. Returns: Tuple[List[str], List[Any], Dict[str, Any]]: The list of args names and default values for the ones having them (from some point to the end) + the dictionary of kwargs """ model_argspec = inspect.getfullargspec(self.forward) model_forward_args = list(model_argspec.args) # skipping self arg model_forward_args = model_forward_args[1:] args_default_values = model_argspec.defaults if args_default_values is None: args_default_values = [] else: args_default_values = list(args_default_values) additional_kwargs = self.additional_kwargs if model_argspec.varkw is not None else {} return model_forward_args, args_default_values, additional_kwargs