from pathlib import Path import torch import wandb from accelerate import Accelerator from tqdm import tqdm from .auto_save import AutoSave from .run_loops import train_loop, valid_loop from .best_finder import BestFinder from _datasets import generate_dataloader, generate_output_preprocess from _mydelta import auto_freeze def _extract_name(model_name, candidates): for candid in candidates: if candid in model_name: return candid return 'none' def get_project_name(config, model_name, dataset_name): name_stack = [] model_name = model_name.lower() if config.project_name_prefix is not None: name_stack.append(config.project_name_prefix) name_stack.append(_extract_name(model_name, ['t5', 'bert', 'bart'])) name_stack.append(_extract_name(model_name, ['small', 'base', 'large'])) name_stack.append(dataset_name) return '_'.join(name_stack) def get_experiment_name(config): if config.peft_params is None: return 'full' name_stack = [config.peft_params.n_tokens, config.peft_params.kind] if config.peft_params.kind == 'combine': name_stack.append(config.peft_params.n_comb_tokens) if len(config.peft_params.get('pretrained_paths', [])) > 0: name_stack.append(config.peft_params.use_pretrained_mode) if config.peft_params.use_pretrained_mode == 'softmax': name_stack.append(config.peft_params.tempreture) elif config.peft_params.kind == 'residual': name_stack.append(config.peft_params.mlp_size) if config.experiment_name_suffix is not None: name_stack.append(config.experiment_name_suffix) return '_'.join([str(x) for x in name_stack]) def auto_train(model, tokenizer, dataset, config, device): best_finder = BestFinder(config.best_finder.higher_better) project_name = get_project_name(config=config, model_name=model.name_or_path, dataset_name=dataset['name']) experiment_name = get_experiment_name(config) save_path = Path(config.base_save_path) / project_name / experiment_name saver = AutoSave( model=model, path=Path(config.base_save_path) / project_name / experiment_name ) train_loader, valid_loader_dict = generate_dataloader( tokenizer, dataset['train'], dataset['valid_dict'], train_bs=config.train_batch_size, valid_bs=config.valid_batch_size ) output_preprocess = generate_output_preprocess(tokenizer) freeze_notes = auto_freeze(model, config.hot_modules) optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay) accelerator = Accelerator(log_with="wandb") # gradient_accumulation_steps=8 model, optimizer, train_loader = accelerator.prepare( model, optimizer, train_loader ) accelerator.init_trackers( project_name=project_name, config=config.to_dict(), init_kwargs={"wandb": {"name": experiment_name, "notes": freeze_notes}} ) saver.save('first') epochs_range = range(config.num_epochs) if config.use_tqdm: epochs_range = tqdm(epochs_range, position=2, desc="EPOCHS", leave=False) for epoch in epochs_range: epoch_results = {} epoch_results.update( train_loop( model=model, loader=train_loader, optimizer=optimizer, accelerator=accelerator, use_tqdm=config.use_tqdm ) ) epoch_results.update( valid_loop( model=model, loader_dict=valid_loader_dict, use_tqdm=config.use_tqdm, compute_metrics=dataset['compute_metrics'], output_preprocess=output_preprocess ) ) accelerator.log(epoch_results) if best_finder.is_better(epoch_results[config.best_finder.metric]): saver.save('best') saver.save('last') accelerator.end_training() return str(save_path)