You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

auto_freeze.py 614B

123456789101112131415161718192021
  1. from typing import List
  2. def _is_it_hot(param_name: str, hot_modules: List[str]):
  3. for module_name in hot_modules:
  4. if module_name in param_name: # str contains
  5. return True
  6. return False
  7. def auto_freeze(model, hot_modules: List[str]) -> str:
  8. if hot_modules is None:
  9. return "No freezing!!!"
  10. return_value = "Hot params are:"
  11. for param_name, weights in model.named_parameters():
  12. weights.requires_grad = _is_it_hot(param_name, hot_modules)
  13. if weights.requires_grad:
  14. return_value += '\n' + param_name
  15. return return_value