|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- import numpy as np
- import torch
- from tqdm import tqdm
-
- from _utils import prefix_dict_keys
- from .loss_hooks import get_hooks
-
- def train_loop(model, loader, optimizer, accelerator, use_tqdm=False, loss_hook_alpha=0.001, gradient_clipping=1.0):
- model.train()
-
- batch_losses = []
-
- if use_tqdm:
- loader = tqdm(loader, position=3, desc="Train Loop", leave=False)
-
- for row in loader:
- optimizer.zero_grad()
-
- out = model(**row.to(model.device))
- loss = out.loss
-
- for loss_hook in get_hooks():
- loss += loss_hook_alpha * loss_hook()
-
- batch_loss_value = loss.item()
- accelerator.backward(loss)
- if accelerator.sync_gradients:
- accelerator.clip_grad_norm_(model.parameters(), 1.0)
- optimizer.step()
-
- batch_losses.append(batch_loss_value)
-
- loss_value = np.mean(batch_losses)
- return prefix_dict_keys('train', {
- 'loss': loss_value
- })
-
- def _predict(model, row):
- if model._is_seq2seq:
- return model.generate(
- **row,
- max_length=50
- )
- else:
- return model(
- **row
- ).logits.argmax(-1)
-
- def valid_loop(model, loader_dict, compute_metrics, output_preprocess, use_tqdm=False):
- model.eval()
-
- return_value = {}
-
- all_means = []
-
- for key, loader in loader_dict.items():
- all_true = []
- all_pred = []
- if use_tqdm:
- loader = tqdm(loader, position=3, desc="Valid Loop", leave=False)
- with torch.no_grad():
- for row in loader:
- row.to(model.device)
- pred = _predict(model, row)
-
- all_true += row.labels.detach().cpu().tolist()
- all_pred += pred.detach().cpu().tolist()
-
- all_true = output_preprocess(all_true)
- all_pred = output_preprocess(all_pred)
-
- metrics = compute_metrics(y_true=all_true, y_pred=all_pred)
- all_means.append(metrics['mean'])
- return_value.update(prefix_dict_keys(key, metrics))
-
- return_value['valid_mean'] = np.mean(all_means)
-
- return return_value
|