123456789101112131415161718192021 |
- from typing import List
-
- def _is_it_hot(param_name: str, hot_modules: List[str]):
- for module_name in hot_modules:
- if module_name in param_name: # str contains
- return True
- return False
-
-
- def auto_freeze(model, hot_modules: List[str]) -> str:
- if hot_modules is None:
- return "No freezing!!!"
-
- return_value = "Hot params are:"
-
- for param_name, weights in model.named_parameters():
- weights.requires_grad = _is_it_hot(param_name, hot_modules)
- if weights.requires_grad:
- return_value += '\n' + param_name
-
- return return_value
|