from typing import List def _is_it_hot(param_name: str, hot_modules: List[str]): for module_name in hot_modules: if module_name in param_name: # str contains return True return False def auto_freeze(model, hot_modules: List[str]) -> str: if hot_modules is None: return "No freezing!!!" return_value = "Hot params are:" for param_name, weights in model.named_parameters(): weights.requires_grad = _is_it_hot(param_name, hot_modules) if weights.requires_grad: return_value += '\n' + param_name return return_value