You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

auto_train.py 4.1KB

3 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from pathlib import Path
  2. import torch
  3. import wandb
  4. from accelerate import Accelerator
  5. from tqdm import tqdm
  6. from .auto_save import AutoSave
  7. from .run_loops import train_loop, valid_loop
  8. from .best_finder import BestFinder
  9. from _datasets import generate_dataloader, generate_output_preprocess
  10. from _mydelta import auto_freeze
  11. def _extract_name(model_name, candidates):
  12. for candid in candidates:
  13. if candid in model_name:
  14. return candid
  15. return 'none'
  16. def get_project_name(config, model_name, dataset_name):
  17. name_stack = []
  18. model_name = model_name.lower()
  19. if config.project_name_prefix is not None:
  20. name_stack.append(config.project_name_prefix)
  21. name_stack.append(_extract_name(model_name, ['t5', 'bert', 'bart']))
  22. name_stack.append(_extract_name(model_name, ['small', 'base', 'large']))
  23. name_stack.append(dataset_name)
  24. return '_'.join(name_stack)
  25. def get_experiment_name(config):
  26. if config.peft_params is None:
  27. return 'full'
  28. name_stack = [config.peft_params.n_tokens, config.peft_params.kind]
  29. if config.peft_params.kind == 'combine':
  30. name_stack.append(config.peft_params.n_comb_tokens)
  31. if len(config.peft_params.get('pretrained_paths', [])) > 0:
  32. name_stack.append(config.peft_params.use_pretrained_mode)
  33. if config.peft_params.use_pretrained_mode == 'softmax':
  34. name_stack.append(config.peft_params.tempreture)
  35. elif config.peft_params.kind == 'residual':
  36. name_stack.append(config.peft_params.mlp_size)
  37. if config.experiment_name_suffix is not None:
  38. name_stack.append(config.experiment_name_suffix)
  39. return '_'.join([str(x) for x in name_stack])
  40. def auto_train(model, tokenizer, dataset, config, device):
  41. best_finder = BestFinder(config.best_finder.higher_better)
  42. project_name = get_project_name(config=config, model_name=model.name_or_path, dataset_name=dataset['name'])
  43. experiment_name = get_experiment_name(config)
  44. save_path = Path(config.base_save_path) / project_name / experiment_name
  45. saver = AutoSave(
  46. model=model,
  47. path=Path(config.base_save_path) / project_name / experiment_name
  48. )
  49. train_loader, valid_loader_dict = generate_dataloader(
  50. tokenizer,
  51. dataset['train'],
  52. dataset['valid_dict'],
  53. train_bs=config.train_batch_size,
  54. valid_bs=config.valid_batch_size
  55. )
  56. output_preprocess = generate_output_preprocess(tokenizer)
  57. freeze_notes = auto_freeze(model, config.hot_modules)
  58. optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
  59. accelerator = Accelerator(log_with="wandb") # gradient_accumulation_steps=8
  60. model, optimizer, train_loader = accelerator.prepare(
  61. model, optimizer, train_loader
  62. )
  63. accelerator.init_trackers(
  64. project_name=project_name,
  65. config=config.to_dict(),
  66. init_kwargs={"wandb": {"name": experiment_name, "notes": freeze_notes}}
  67. )
  68. saver.save('first')
  69. epochs_range = range(config.num_epochs)
  70. if config.use_tqdm:
  71. epochs_range = tqdm(epochs_range, position=2, desc="EPOCHS", leave=False)
  72. for epoch in epochs_range:
  73. epoch_results = {}
  74. epoch_results.update(
  75. train_loop(
  76. model=model,
  77. loader=train_loader,
  78. optimizer=optimizer,
  79. accelerator=accelerator,
  80. use_tqdm=config.use_tqdm
  81. )
  82. )
  83. epoch_results.update(
  84. valid_loop(
  85. model=model,
  86. loader_dict=valid_loader_dict,
  87. use_tqdm=config.use_tqdm,
  88. compute_metrics=dataset['compute_metrics'],
  89. output_preprocess=output_preprocess
  90. )
  91. )
  92. accelerator.log(epoch_results)
  93. if best_finder.is_better(epoch_results[config.best_finder.metric]):
  94. saver.save('best')
  95. saver.save('last')
  96. accelerator.end_training()
  97. return str(save_path)