You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

05_T5_custom_finetune.py 3.3KB

3 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. from typing import Optional
  2. import numpy as np
  3. from tqdm import tqdm
  4. import wandb
  5. import torch
  6. import torch.nn as nn
  7. from transformers import T5TokenizerFast, T5ForConditionalGeneration
  8. from _config import load_config
  9. from _utils import print_system_info, silent_logs
  10. from _datasets import AutoLoad, generate_dataloader
  11. from _mydelta import T5Wrapper, auto_freeze, EmbeddingWrapper
  12. from _trainer import train_loop, valid_loop
  13. configs = load_config('./config.yaml')
  14. RANDOM_SEED = configs.shared.random_seed
  15. WANDB_PROJECT_NAME = configs.shared.project_name
  16. DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  17. USE_TQDM = configs.shared.use_tqdm
  18. def run_experminent(config):
  19. np.random.seed(RANDOM_SEED)
  20. # ______________________LOAD MODEL_____________________________
  21. model = T5ForConditionalGeneration.from_pretrained(config.model_name)
  22. tokenizer = T5TokenizerFast.from_pretrained(config.model_name, model_max_length=2048)
  23. # ______________________MUTATE MODEL_____________________________
  24. if config.peft_params is not None:
  25. peft_params = config.peft_params.to_dict()
  26. slected_tokens = torch.from_numpy(
  27. np.random.randint(0, tokenizer.vocab_size, size=(peft_params['n_tokens'],))
  28. )
  29. peft_class = {
  30. 't5_encoder': T5Wrapper,
  31. 'encoder_emb': EmbeddingWrapper
  32. }[peft_params.pop('kind')]
  33. delta_module = peft_class.mutate(
  34. model=model,
  35. slected_tokens=slected_tokens,
  36. **peft_params
  37. )
  38. loaded_weights = torch.load('./best.pt')
  39. loaded_weights.pop('sadcl_learned_embedding')
  40. delta_module.load_peft_state_dict(loaded_weights)
  41. freeze_notes = auto_freeze(model, config.hot_modules)
  42. # ______________________LOAD DATA_____________________________
  43. data_loader = AutoLoad(tokenizer)
  44. dataset = data_loader.get_and_map(config.tasks[0])
  45. train_loader, valid_loader = generate_dataloader(tokenizer, dataset['train'], dataset['valid'], config)
  46. # ______________________TRAIN_____________________________
  47. wandb.init(
  48. name=config.wandb_name,
  49. project=WANDB_PROJECT_NAME,
  50. config=config.to_dict(),
  51. notes=freeze_notes
  52. )
  53. optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
  54. model.to(DEVICE)
  55. epochs_range = range(config.num_epochs)
  56. if USE_TQDM:
  57. epochs_range = tqdm(epochs_range, position=1, desc="EPOCHS", leave=False)
  58. for epoch in epochs_range:
  59. epoch_results = {}
  60. epoch_results.update(
  61. train_loop(
  62. model=model,
  63. loader=train_loader,
  64. optimizer=optimizer,
  65. use_tqdm=USE_TQDM
  66. )
  67. )
  68. epoch_results.update(
  69. valid_loop(
  70. model=model,
  71. loader=valid_loader,
  72. use_tqdm=USE_TQDM
  73. )
  74. )
  75. wandb.log(epoch_results)
  76. wandb.finish()
  77. if __name__ == '__main__':
  78. print_system_info()
  79. silent_logs()
  80. run_configs = configs.run_configs
  81. if USE_TQDM:
  82. run_configs = tqdm(run_configs, position=0, desc="Experiment")
  83. for run_config in run_configs:
  84. run_experminent(run_config)