import numpy as np import torch from tqdm import tqdm from _utils import prefix_dict_keys from .loss_hooks import get_hooks def train_loop(model, loader, optimizer, accelerator, use_tqdm=False, loss_hook_alpha=0.001, gradient_clipping=1.0): model.train() batch_losses = [] if use_tqdm: loader = tqdm(loader, position=3, desc="Train Loop", leave=False) for row in loader: optimizer.zero_grad() out = model(**row.to(model.device)) loss = out.loss for loss_hook in get_hooks(): loss += loss_hook_alpha * loss_hook() batch_loss_value = loss.item() accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() batch_losses.append(batch_loss_value) loss_value = np.mean(batch_losses) return prefix_dict_keys('train', { 'loss': loss_value }) def _predict(model, row): if model._is_seq2seq: return model.generate( **row, max_length=50 ) else: return model( **row ).logits.argmax(-1) def valid_loop(model, loader_dict, compute_metrics, output_preprocess, use_tqdm=False): model.eval() return_value = {} all_means = [] for key, loader in loader_dict.items(): all_true = [] all_pred = [] if use_tqdm: loader = tqdm(loader, position=3, desc="Valid Loop", leave=False) with torch.no_grad(): for row in loader: row.to(model.device) pred = _predict(model, row) all_true += row.labels.detach().cpu().tolist() all_pred += pred.detach().cpu().tolist() all_true = output_preprocess(all_true) all_pred = output_preprocess(all_pred) metrics = compute_metrics(y_true=all_true, y_pred=all_pred) all_means.append(metrics['mean']) return_value.update(prefix_dict_keys(key, metrics)) return_value['valid_mean'] = np.mean(all_means) return return_value