You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_cont.py 2.2KB

3 months ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from tqdm import tqdm
  2. import numpy as np
  3. import torch
  4. import os
  5. import sys
  6. sys.path.insert(1, os.path.join(sys.path[0], '..'))
  7. from _datasets import AutoLoad
  8. from _trainer import auto_train
  9. from _mydelta import auto_mutate
  10. from _models import auto_model
  11. from _config import Config, load_config
  12. from _utils import print_system_info, silent_logs
  13. DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  14. def run_experminent(config, task_name):
  15. silent_logs()
  16. np.random.seed(config.random_seed)
  17. torch.manual_seed(config.random_seed)
  18. # ______________________LOAD MODEL_____________________________
  19. model, tokenizer = auto_model(config.model_name, AutoLoad.get_task_output(task_name))
  20. # ______________________MUTATE MODEL_____________________________
  21. n_prefix_token = 0
  22. if config.peft_params is not None:
  23. n_prefix_token = config.peft_params.n_tokens
  24. delta_module = auto_mutate(
  25. model=model,
  26. tokenizer=tokenizer,
  27. peft_params=config.peft_params.to_dict(),
  28. remove_dropout=config.remove_dropout
  29. )
  30. # ______________________LOAD DATA_____________________________
  31. autoload = AutoLoad(tokenizer, n_prefix_token=n_prefix_token)
  32. # ______________________TRAIN_____________________________
  33. dataset = autoload.get_and_map(task_name)
  34. return auto_train(model, tokenizer, dataset, config, device=DEVICE)
  35. if __name__ == '__main__':
  36. print_system_info()
  37. configs = load_config(sys.argv[1])
  38. run_configs = tqdm(configs.run_configs, position=0, desc="Experiment")
  39. for run_config in run_configs:
  40. tasks = tqdm(run_config.tasks, position=1, desc="Task:", leave=False)
  41. tasks_path = []
  42. for task_name in tasks:
  43. tasks.set_description(f'Task: {task_name}')
  44. torch.cuda.empty_cache()
  45. run_config.peft_params._write_mode = True
  46. orig_paths = run_config.peft_params.get('pretrained_paths', [])
  47. run_config.peft_params.pretrained_paths = list(orig_paths) + tasks_path
  48. delattr(run_config.peft_params, '_write_mode')
  49. saved_path = run_experminent(run_config, task_name)
  50. tasks_path.append(saved_path)