You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_single.py 1.5KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import numpy as np
  2. import torch
  3. import os
  4. import sys
  5. sys.path.insert(1, os.path.join(sys.path[0], '..'))
  6. from _utils import silent_logs, sp_decode
  7. from _datasets import AutoLoad
  8. from _trainer import auto_train
  9. from _mydelta import auto_mutate
  10. from _models import auto_model
  11. from _config import Config
  12. DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  13. def run_experminent(config, task_name):
  14. silent_logs()
  15. np.random.seed(config.random_seed)
  16. torch.manual_seed(config.random_seed)
  17. # ______________________LOAD MODEL_____________________________
  18. model, tokenizer = auto_model(config.model_name, AutoLoad.get_task_output(task_name))
  19. # ______________________MUTATE MODEL_____________________________
  20. n_prefix_token = 0
  21. if config.peft_params is not None:
  22. n_prefix_token = config.peft_params.n_tokens
  23. delta_module = auto_mutate(
  24. model=model,
  25. tokenizer=tokenizer,
  26. peft_params=config.peft_params.to_dict(),
  27. remove_dropout=config.remove_dropout
  28. )
  29. # ______________________LOAD DATA_____________________________
  30. autoload = AutoLoad(tokenizer, n_prefix_token=n_prefix_token)
  31. # ______________________TRAIN_____________________________
  32. dataset = autoload.get_and_map(task_name)
  33. auto_train(model, tokenizer, dataset, config, device=DEVICE)
  34. if __name__ == '__main__':
  35. config_json = sp_decode(sys.argv[1])
  36. config = Config(config_json, '')
  37. task_name = sp_decode(sys.argv[2])
  38. run_experminent(config, task_name)