123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147 |
- """
- The Model base class
- """
- from collections import OrderedDict
- import inspect
- import typing
- from typing import Any, Dict, List, Tuple
- import re
- from abc import ABC
-
- import torch
- from torch.nn import BatchNorm2d, BatchNorm1d, BatchNorm3d
-
-
- ModelIO = Dict[str, torch.Tensor]
-
-
- class Model(torch.nn.Module, ABC):
- """
- The Model base class
- """
- _frozen_bns_list: List[torch.nn.Module] = []
-
- def init_weights_from_other_model(self, pretrained_model_dir=None):
- """ Initializes the weights from another model with the given address,
- Default is to set all the parameters with the same name and shape,
- but can be rewritten in subclasses.
- The model to load is received via the function get_other_model_to_load_from.
- The default value is the same model!"""
-
- if pretrained_model_dir is None:
- print('The model was not preinitialized.', flush=True)
- return
-
- other_model_dict = torch.load(pretrained_model_dir)
- own_state = self.state_dict()
-
- cnt = 0
- skip_cnt = 0
-
- for name, param in other_model_dict.items():
-
- if isinstance(param, torch.nn.Parameter):
- # backwards compatibility for serialized parameters
- param = param.data
-
- if name in own_state and own_state[name].data.shape == param.shape:
- own_state[name].copy_(param)
- cnt += 1
-
- else:
- skip_cnt += 1
-
- print(f'{cnt} out of {len(own_state)} parameters were loaded successfully, {skip_cnt} of the given set were skipped.' , flush=True)
-
- def freeze_parameters(self, regex_list: List[str] = None) -> None:
- """
- An auxiliary function for freezing the parameters of the model
- using an inclusion keywords list and exclusion keywords list.
- :param regex_list: All the parameters matching at least one of the given regexes would be freezed.
- None means no parameter would be frozen.
- :return: None
- """
-
- # One solution for batchnorms is registering a preforward hook
- # to call eval everytime before running them
- # But that causes a problem, we can't dynamically change the frozen batchnorms during training
- # e.g. in layerwise unfreezing.
- # So it's better to set a dynamic attribute that keeps their list, rather than having an init
- # Init causes other problems in multi-inheritance
-
- if regex_list is None:
- print('No parameter was freezeed!', flush=True)
- return
-
- regex_list = [re.compile(the_pattern) for the_pattern in regex_list]
-
- def check_freeze_conditions(param_name):
- for the_regex in regex_list:
- if the_regex.fullmatch(param_name) is not None:
- return True
- return False
-
- frozen_cnt = 0
- total_cnt = 0
-
- frozen_modules_names = dict()
-
- for n, p in self.named_parameters():
- total_cnt += 1
- if check_freeze_conditions(n):
- p.requires_grad = False
- frozen_cnt += 1
- frozen_modules_names[n[:n.rfind('.')]] = True
-
- self._frozen_bns_list = []
-
- for module_name, module in self.named_modules():
- if (module_name in frozen_modules_names) and \
- (isinstance(module, BatchNorm2d) or
- isinstance(module, BatchNorm1d) or
- isinstance(module, BatchNorm3d)):
- self._frozen_bns_list.append(module)
-
- print('********************* NOTICE *********************')
- print('%d out of %d parameters were frozen!' % (frozen_cnt, total_cnt))
- print(
- '%d BatchNorm layers will be frozen by being kept in the val mode, IF YOU HAVE NOT OVERRIDEN IT IN YOUR MAIN MODEL!' % len(
- self._frozen_bns_list))
- print('***************************************************', flush=True)
-
- def train(self, mode: bool = True):
- super(Model, self).train(mode)
- for bn in self._frozen_bns_list:
- bn.train(False)
- return self
-
- @property
- def additional_kwargs(self) -> typing.OrderedDict[str, bool]:
- r""" Returns a dictionary from additional `kwargs` names to their optionality """
- return OrderedDict({})
-
- def get_forward_required_kws_kwargs_and_defaults(self) -> Tuple[List[str], List[Any], Dict[str, Any]]:
- """Returns the names of the keyword arguments required in the forward function and the list of the default values for the ones having them.
- The second list might be shorter than the first and it represents the default values from the end of the arguments.
-
- Returns:
- Tuple[List[str], List[Any], Dict[str, Any]]: The list of args names and default values for the ones having them (from some point to the end) + the dictionary of kwargs
- """
-
- model_argspec = inspect.getfullargspec(self.forward)
-
- model_forward_args = list(model_argspec.args)
-
- # skipping self arg
- model_forward_args = model_forward_args[1:]
-
- args_default_values = model_argspec.defaults
- if args_default_values is None:
- args_default_values = []
- else:
- args_default_values = list(args_default_values)
-
- additional_kwargs = self.additional_kwargs if model_argspec.varkw is not None else {}
-
- return model_forward_args, args_default_values, additional_kwargs
-
|