You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

auto_save.py 1.3KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import torch
  2. import json
  3. from pathlib import Path
  4. CONFIG_FILE_NAME = 'config.json'
  5. class AutoSave:
  6. def __init__(self, model, path):
  7. self.path = Path(path)
  8. self.path.mkdir(exist_ok=True, parents=True)
  9. self.model_name = model.name_or_path
  10. if hasattr(model, '_delta_module'):
  11. self.delta_module = model._delta_module
  12. else:
  13. self.model = model
  14. self._save_config()
  15. def _save_config(self):
  16. config = {
  17. 'model_name': self.model_name,
  18. }
  19. if self.has_delta:
  20. config['peft_config'] = self.delta_module.peft_config()
  21. with open(self.path / CONFIG_FILE_NAME, 'w') as f:
  22. json.dump(config, f)
  23. @property
  24. def has_delta(self):
  25. return hasattr(self, 'delta_module')
  26. def save(self, name):
  27. if self.has_delta:
  28. state_dict = self.delta_module.peft_state_dict()
  29. else:
  30. state_dict = self.model.state_dict()
  31. torch.save(state_dict, self.path / f'{name}.pt')
  32. def load(self, name):
  33. with open(self.path / CONFIG_FILE_NAME, 'r') as f:
  34. config = json.load(f)
  35. state_dict = torch.load(self.path / f'{name}.pt')
  36. self.delta_module.load_peft(config=config['peft_config'], state_dict=state_dict)