123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- from typing import Optional
-
- import numpy as np
- from tqdm import tqdm
-
- import wandb
- import torch
- import torch.nn as nn
- from transformers import T5TokenizerFast, T5ForConditionalGeneration
-
- from _config import load_config
- from _utils import print_system_info, silent_logs
- from _datasets import AutoLoad, generate_dataloader
- from _mydelta import T5Wrapper, auto_freeze, EmbeddingWrapper
- from _trainer import train_loop, valid_loop, BestFinder
-
- configs = load_config('./config.yaml')
-
- RANDOM_SEED = configs.shared.random_seed
- WANDB_PROJECT_NAME = configs.shared.project_name
- DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- USE_TQDM = configs.shared.use_tqdm
-
- def run_experminent(config):
- np.random.seed(RANDOM_SEED)
-
- # ______________________LOAD MODEL_____________________________
-
- model = T5ForConditionalGeneration.from_pretrained(config.model_name)
- tokenizer = T5TokenizerFast.from_pretrained(config.model_name, model_max_length=2048)
- # ______________________MUTATE MODEL_____________________________
- if config.peft_params is not None:
- peft_params = config.peft_params.to_dict()
- slected_tokens = torch.from_numpy(
- np.random.randint(0, tokenizer.vocab_size, size=(peft_params['n_tokens'],))
- )
-
- peft_class = {
- 't5_encoder': T5Wrapper,
- 'encoder_emb': EmbeddingWrapper
- }[peft_params.pop('kind')]
-
- delta_module = peft_class.mutate(
- model=model,
- slected_tokens=slected_tokens,
- **peft_params
- )
- elif config.best_finder.save:
- raise NotImplementedError()
-
- freeze_notes = auto_freeze(model, config.hot_modules)
-
- # ______________________LOAD DATA_____________________________
-
- data_loader = AutoLoad(tokenizer)
- dataset = data_loader.get_and_map(config.tasks[0])
- train_loader, valid_loader = generate_dataloader(tokenizer, dataset['train'], dataset['valid'], config)
-
- # ______________________TRAIN_____________________________
-
- wandb.init(
- name=config.wandb_name,
- project=WANDB_PROJECT_NAME,
- config=config.to_dict(),
- notes=freeze_notes
- )
-
- optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
-
- best_finder = BestFinder(config.best_finder.higher_better)
-
- model.to(DEVICE)
-
- epochs_range = range(config.num_epochs)
- if USE_TQDM:
- epochs_range = tqdm(epochs_range, position=1, desc="EPOCHS", leave=False)
-
- for epoch in epochs_range:
- epoch_results = {}
-
- epoch_results.update(
- train_loop(
- model=model,
- loader=train_loader,
- optimizer=optimizer,
- use_tqdm=USE_TQDM
- )
- )
-
- epoch_results.update(
- valid_loop(
- model=model,
- loader=valid_loader,
- use_tqdm=USE_TQDM
- )
- )
-
- if config.best_finder.save:
- if best_finder.is_better(epoch_results[config.best_finder.metric]):
- torch.save(delta_module.peft_state_dict(), './best.pt')
-
- wandb.log(epoch_results)
-
- wandb.finish()
-
-
- if __name__ == '__main__':
- print_system_info()
- silent_logs()
- run_configs = configs.run_configs
- if USE_TQDM:
- run_configs = tqdm(run_configs, position=0, desc="Experiment")
-
- for run_config in run_configs:
- run_experminent(run_config)
|