You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 1.7KB

2 weeks ago
123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import torch
  2. def clean_hyperparameters(hyperparameters: dict):
  3. if hyperparameters["scheduler"] == 0:
  4. hyperparameters.pop("scheduler_type", None)
  5. hyperparameters.pop("scheduler_warmup_ratio", None)
  6. hyperparameters.pop("scheduler_warmup_steps", None)
  7. hyperparameters.pop("scheduler_step_size", None)
  8. hyperparameters.pop("scheduler_gamma", None)
  9. if hyperparameters["peft_mode"] != "lora":
  10. hyperparameters.pop("rank", None)
  11. hyperparameters.pop("alpha", None)
  12. hyperparameters.pop("drop_out", None)
  13. if hyperparameters["peft_mode"] != "adapter" and hyperparameters["peft_mode"] != "adapterbitfit":
  14. hyperparameters.pop("reduction_factor", None)
  15. if hyperparameters["dp"] == 0:
  16. hyperparameters.pop("epsilon", None)
  17. hyperparameters.pop("delta", None)
  18. hyperparameters.pop("clipping_mode", None)
  19. hyperparameters.pop("clipping_threshold", None)
  20. if hyperparameters["use_wandb"] == 0:
  21. hyperparameters.pop("wandb_project_name", None)
  22. hyperparameters.pop("use_wandb", None)
  23. if hyperparameters["two_step_training"] == 0:
  24. hyperparameters.pop("lr_two", None)
  25. hyperparameters.pop("virtual_batch_size_two", None)
  26. hyperparameters.pop("epochs_two", None)
  27. hyperparameters.pop("weight_decay_two", None)
  28. hyperparameters.pop("f", None)
  29. hyperparameters.pop("media_path", None)
  30. hyperparameters.pop("model_cache_path", None)
  31. return hyperparameters
  32. def copy_model_weights(model1, model2):
  33. model1.eval()
  34. model2.eval()
  35. params1 = model1.parameters()
  36. params2 = model2.parameters()
  37. with torch.no_grad():
  38. for param1, param2 in zip(params1, params2):
  39. param2.data.copy_(param1.data)