from typing import Optional import numpy as np from tqdm import tqdm import wandb import torch import torch.nn as nn from transformers import T5TokenizerFast, T5ForConditionalGeneration from _config import load_config from _utils import print_system_info, silent_logs from _datasets import AutoLoad, generate_dataloader from _mydelta import T5Wrapper, auto_freeze, EmbeddingWrapper from _trainer import train_loop, valid_loop, BestFinder configs = load_config('./config.yaml') RANDOM_SEED = configs.shared.random_seed WANDB_PROJECT_NAME = configs.shared.project_name DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") USE_TQDM = configs.shared.use_tqdm def run_experminent(config): np.random.seed(RANDOM_SEED) # ______________________LOAD MODEL_____________________________ model = T5ForConditionalGeneration.from_pretrained(config.model_name) tokenizer = T5TokenizerFast.from_pretrained(config.model_name, model_max_length=2048) # ______________________MUTATE MODEL_____________________________ if config.peft_params is not None: peft_params = config.peft_params.to_dict() slected_tokens = torch.from_numpy( np.random.randint(0, tokenizer.vocab_size, size=(peft_params['n_tokens'],)) ) peft_class = { 't5_encoder': T5Wrapper, 'encoder_emb': EmbeddingWrapper }[peft_params.pop('kind')] delta_module = peft_class.mutate( model=model, slected_tokens=slected_tokens, **peft_params ) elif config.best_finder.save: raise NotImplementedError() freeze_notes = auto_freeze(model, config.hot_modules) # ______________________LOAD DATA_____________________________ data_loader = AutoLoad(tokenizer) dataset = data_loader.get_and_map(config.tasks[0]) train_loader, valid_loader = generate_dataloader(tokenizer, dataset['train'], dataset['valid'], config) # ______________________TRAIN_____________________________ wandb.init( name=config.wandb_name, project=WANDB_PROJECT_NAME, config=config.to_dict(), notes=freeze_notes ) optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) best_finder = BestFinder(config.best_finder.higher_better) model.to(DEVICE) epochs_range = range(config.num_epochs) if USE_TQDM: epochs_range = tqdm(epochs_range, position=1, desc="EPOCHS", leave=False) for epoch in epochs_range: epoch_results = {} epoch_results.update( train_loop( model=model, loader=train_loader, optimizer=optimizer, use_tqdm=USE_TQDM ) ) epoch_results.update( valid_loop( model=model, loader=valid_loader, use_tqdm=USE_TQDM ) ) if config.best_finder.save: if best_finder.is_better(epoch_results[config.best_finder.metric]): torch.save(delta_module.peft_state_dict(), './best.pt') wandb.log(epoch_results) wandb.finish() if __name__ == '__main__': print_system_info() silent_logs() run_configs = configs.run_configs if USE_TQDM: run_configs = tqdm(run_configs, position=0, desc="Experiment") for run_config in run_configs: run_experminent(run_config)