from typing import Optional import numpy as np from tqdm import tqdm import wandb import torch import torch.nn as nn from transformers import T5TokenizerFast, T5ForConditionalGeneration import os import sys sys.path.insert(1, os.path.join(sys.path[0], '..')) from _config import load_config from _utils import print_system_info, silent_logs from _datasets import AutoLoad, generate_dataloader from _mydelta import auto_freeze, LowdimEmbeddingWrapper from _trainer import train_loop, valid_loop, BestFinder configs = load_config('./config.yaml') RANDOM_SEED = configs.shared.random_seed WANDB_PROJECT_NAME = configs.shared.project_name DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") USE_TQDM = configs.shared.use_tqdm def run_experminent(config): np.random.seed(RANDOM_SEED) # ______________________LOAD MODEL_____________________________ model = T5ForConditionalGeneration.from_pretrained(config.model_name) tokenizer = T5TokenizerFast.from_pretrained(config.model_name, model_max_length=2048) # ______________________MUTATE MODEL_____________________________ if config.peft_params is not None: peft_params = config.peft_params.to_dict() peft_class = { 'lowdim_prompt': LowdimEmbeddingWrapper }[peft_params.pop('kind')] delta_module = peft_class.mutate( model=model, **peft_params ) elif config.best_finder.save: raise NotImplementedError() freeze_notes = auto_freeze(model, config.hot_modules) # ______________________LOAD DATA_____________________________ data_loader = AutoLoad(tokenizer) dataset = data_loader.get_and_map(config.tasks[0]) train_loader, valid_loader = generate_dataloader(tokenizer, dataset['train'], dataset['valid'], config) # ______________________TRAIN_____________________________ print(delta_module) wandb.init( name=config.wandb_name, project=WANDB_PROJECT_NAME, config=config.to_dict(), notes=freeze_notes ) optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) best_finder = BestFinder(config.best_finder.higher_better) model.to(DEVICE) epochs_range = range(config.num_epochs) if USE_TQDM: epochs_range = tqdm(epochs_range, position=1, desc="EPOCHS", leave=False) for epoch in epochs_range: epoch_results = {} epoch_results.update( train_loop( model=model, loader=train_loader, optimizer=optimizer, use_tqdm=USE_TQDM ) ) epoch_results.update( valid_loop( model=model, loader=valid_loader, use_tqdm=USE_TQDM ) ) if config.best_finder.save: if best_finder.is_better(epoch_results[config.best_finder.metric]): torch.save(delta_module.peft_state_dict(), './best.pt') wandb.log(epoch_results) wandb.finish() if __name__ == '__main__': print_system_info() silent_logs() run_configs = configs.run_configs if USE_TQDM: run_configs = tqdm(run_configs, position=0, desc="Experiment") for run_config in run_configs: run_experminent(run_config)