You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 5.4KB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. """
  2. The Model base class
  3. """
  4. from collections import OrderedDict
  5. import inspect
  6. import typing
  7. from typing import Any, Dict, List, Tuple
  8. import re
  9. from abc import ABC
  10. import torch
  11. from torch.nn import BatchNorm2d, BatchNorm1d, BatchNorm3d
  12. ModelIO = Dict[str, torch.Tensor]
  13. class Model(torch.nn.Module, ABC):
  14. """
  15. The Model base class
  16. """
  17. _frozen_bns_list: List[torch.nn.Module] = []
  18. def init_weights_from_other_model(self, pretrained_model_dir=None):
  19. """ Initializes the weights from another model with the given address,
  20. Default is to set all the parameters with the same name and shape,
  21. but can be rewritten in subclasses.
  22. The model to load is received via the function get_other_model_to_load_from.
  23. The default value is the same model!"""
  24. if pretrained_model_dir is None:
  25. print('The model was not preinitialized.', flush=True)
  26. return
  27. other_model_dict = torch.load(pretrained_model_dir)
  28. own_state = self.state_dict()
  29. cnt = 0
  30. skip_cnt = 0
  31. for name, param in other_model_dict.items():
  32. if isinstance(param, torch.nn.Parameter):
  33. # backwards compatibility for serialized parameters
  34. param = param.data
  35. if name in own_state and own_state[name].data.shape == param.shape:
  36. own_state[name].copy_(param)
  37. cnt += 1
  38. else:
  39. skip_cnt += 1
  40. print(f'{cnt} out of {len(own_state)} parameters were loaded successfully, {skip_cnt} of the given set were skipped.' , flush=True)
  41. def freeze_parameters(self, regex_list: List[str] = None) -> None:
  42. """
  43. An auxiliary function for freezing the parameters of the model
  44. using an inclusion keywords list and exclusion keywords list.
  45. :param regex_list: All the parameters matching at least one of the given regexes would be freezed.
  46. None means no parameter would be frozen.
  47. :return: None
  48. """
  49. # One solution for batchnorms is registering a preforward hook
  50. # to call eval everytime before running them
  51. # But that causes a problem, we can't dynamically change the frozen batchnorms during training
  52. # e.g. in layerwise unfreezing.
  53. # So it's better to set a dynamic attribute that keeps their list, rather than having an init
  54. # Init causes other problems in multi-inheritance
  55. if regex_list is None:
  56. print('No parameter was freezeed!', flush=True)
  57. return
  58. regex_list = [re.compile(the_pattern) for the_pattern in regex_list]
  59. def check_freeze_conditions(param_name):
  60. for the_regex in regex_list:
  61. if the_regex.fullmatch(param_name) is not None:
  62. return True
  63. return False
  64. frozen_cnt = 0
  65. total_cnt = 0
  66. frozen_modules_names = dict()
  67. for n, p in self.named_parameters():
  68. total_cnt += 1
  69. if check_freeze_conditions(n):
  70. p.requires_grad = False
  71. frozen_cnt += 1
  72. frozen_modules_names[n[:n.rfind('.')]] = True
  73. self._frozen_bns_list = []
  74. for module_name, module in self.named_modules():
  75. if (module_name in frozen_modules_names) and \
  76. (isinstance(module, BatchNorm2d) or
  77. isinstance(module, BatchNorm1d) or
  78. isinstance(module, BatchNorm3d)):
  79. self._frozen_bns_list.append(module)
  80. print('********************* NOTICE *********************')
  81. print('%d out of %d parameters were frozen!' % (frozen_cnt, total_cnt))
  82. print(
  83. '%d BatchNorm layers will be frozen by being kept in the val mode, IF YOU HAVE NOT OVERRIDEN IT IN YOUR MAIN MODEL!' % len(
  84. self._frozen_bns_list))
  85. print('***************************************************', flush=True)
  86. def train(self, mode: bool = True):
  87. super(Model, self).train(mode)
  88. for bn in self._frozen_bns_list:
  89. bn.train(False)
  90. return self
  91. @property
  92. def additional_kwargs(self) -> typing.OrderedDict[str, bool]:
  93. r""" Returns a dictionary from additional `kwargs` names to their optionality """
  94. return OrderedDict({})
  95. def get_forward_required_kws_kwargs_and_defaults(self) -> Tuple[List[str], List[Any], Dict[str, Any]]:
  96. """Returns the names of the keyword arguments required in the forward function and the list of the default values for the ones having them.
  97. The second list might be shorter than the first and it represents the default values from the end of the arguments.
  98. Returns:
  99. Tuple[List[str], List[Any], Dict[str, Any]]: The list of args names and default values for the ones having them (from some point to the end) + the dictionary of kwargs
  100. """
  101. model_argspec = inspect.getfullargspec(self.forward)
  102. model_forward_args = list(model_argspec.args)
  103. # skipping self arg
  104. model_forward_args = model_forward_args[1:]
  105. args_default_values = model_argspec.defaults
  106. if args_default_values is None:
  107. args_default_values = []
  108. else:
  109. args_default_values = list(args_default_values)
  110. additional_kwargs = self.additional_kwargs if model_argspec.varkw is not None else {}
  111. return model_forward_args, args_default_values, additional_kwargs