You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 3.3KB

3 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from typing import Optional
  2. import numpy as np
  3. from tqdm import tqdm
  4. import wandb
  5. import torch
  6. import torch.nn as nn
  7. from transformers import T5TokenizerFast, T5ForConditionalGeneration
  8. import os
  9. import sys
  10. sys.path.insert(1, os.path.join(sys.path[0], '..'))
  11. from _config import load_config
  12. from _utils import print_system_info, silent_logs
  13. from _datasets import AutoLoad, generate_dataloader
  14. from _mydelta import auto_freeze, LowdimEmbeddingWrapper
  15. from _trainer import train_loop, valid_loop, BestFinder
  16. configs = load_config('./config.yaml')
  17. RANDOM_SEED = configs.shared.random_seed
  18. WANDB_PROJECT_NAME = configs.shared.project_name
  19. DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  20. USE_TQDM = configs.shared.use_tqdm
  21. def run_experminent(config):
  22. np.random.seed(RANDOM_SEED)
  23. # ______________________LOAD MODEL_____________________________
  24. model = T5ForConditionalGeneration.from_pretrained(config.model_name)
  25. tokenizer = T5TokenizerFast.from_pretrained(config.model_name, model_max_length=2048)
  26. # ______________________MUTATE MODEL_____________________________
  27. if config.peft_params is not None:
  28. peft_params = config.peft_params.to_dict()
  29. peft_class = {
  30. 'lowdim_prompt': LowdimEmbeddingWrapper
  31. }[peft_params.pop('kind')]
  32. delta_module = peft_class.mutate(
  33. model=model,
  34. **peft_params
  35. )
  36. elif config.best_finder.save:
  37. raise NotImplementedError()
  38. freeze_notes = auto_freeze(model, config.hot_modules)
  39. # ______________________LOAD DATA_____________________________
  40. data_loader = AutoLoad(tokenizer)
  41. dataset = data_loader.get_and_map(config.tasks[0])
  42. train_loader, valid_loader = generate_dataloader(tokenizer, dataset['train'], dataset['valid'], config)
  43. # ______________________TRAIN_____________________________
  44. print(delta_module)
  45. wandb.init(
  46. name=config.wandb_name,
  47. project=WANDB_PROJECT_NAME,
  48. config=config.to_dict(),
  49. notes=freeze_notes
  50. )
  51. optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
  52. best_finder = BestFinder(config.best_finder.higher_better)
  53. model.to(DEVICE)
  54. epochs_range = range(config.num_epochs)
  55. if USE_TQDM:
  56. epochs_range = tqdm(epochs_range, position=1, desc="EPOCHS", leave=False)
  57. for epoch in epochs_range:
  58. epoch_results = {}
  59. epoch_results.update(
  60. train_loop(
  61. model=model,
  62. loader=train_loader,
  63. optimizer=optimizer,
  64. use_tqdm=USE_TQDM
  65. )
  66. )
  67. epoch_results.update(
  68. valid_loop(
  69. model=model,
  70. loader=valid_loader,
  71. use_tqdm=USE_TQDM
  72. )
  73. )
  74. if config.best_finder.save:
  75. if best_finder.is_better(epoch_results[config.best_finder.metric]):
  76. torch.save(delta_module.peft_state_dict(), './best.pt')
  77. wandb.log(epoch_results)
  78. wandb.finish()
  79. if __name__ == '__main__':
  80. print_system_info()
  81. silent_logs()
  82. run_configs = configs.run_configs
  83. if USE_TQDM:
  84. run_configs = tqdm(run_configs, position=0, desc="Experiment")
  85. for run_config in run_configs:
  86. run_experminent(run_config)