You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_loops.py 2.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import numpy as np
  2. import torch
  3. from tqdm import tqdm
  4. from _utils import prefix_dict_keys
  5. from .loss_hooks import get_hooks
  6. def train_loop(model, loader, optimizer, accelerator, use_tqdm=False, loss_hook_alpha=0.001, gradient_clipping=1.0):
  7. model.train()
  8. batch_losses = []
  9. if use_tqdm:
  10. loader = tqdm(loader, position=3, desc="Train Loop", leave=False)
  11. for row in loader:
  12. optimizer.zero_grad()
  13. out = model(**row.to(model.device))
  14. loss = out.loss
  15. for loss_hook in get_hooks():
  16. loss += loss_hook_alpha * loss_hook()
  17. batch_loss_value = loss.item()
  18. accelerator.backward(loss)
  19. if accelerator.sync_gradients:
  20. accelerator.clip_grad_norm_(model.parameters(), 1.0)
  21. optimizer.step()
  22. batch_losses.append(batch_loss_value)
  23. loss_value = np.mean(batch_losses)
  24. return prefix_dict_keys('train', {
  25. 'loss': loss_value
  26. })
  27. def _predict(model, row):
  28. if model._is_seq2seq:
  29. return model.generate(
  30. **row,
  31. max_length=50
  32. )
  33. else:
  34. return model(
  35. **row
  36. ).logits.argmax(-1)
  37. def valid_loop(model, loader_dict, compute_metrics, output_preprocess, use_tqdm=False):
  38. model.eval()
  39. return_value = {}
  40. all_means = []
  41. for key, loader in loader_dict.items():
  42. all_true = []
  43. all_pred = []
  44. if use_tqdm:
  45. loader = tqdm(loader, position=3, desc="Valid Loop", leave=False)
  46. with torch.no_grad():
  47. for row in loader:
  48. row.to(model.device)
  49. pred = _predict(model, row)
  50. all_true += row.labels.detach().cpu().tolist()
  51. all_pred += pred.detach().cpu().tolist()
  52. all_true = output_preprocess(all_true)
  53. all_pred = output_preprocess(all_pred)
  54. metrics = compute_metrics(y_true=all_true, y_pred=all_pred)
  55. all_means.append(metrics['mean'])
  56. return_value.update(prefix_dict_keys(key, metrics))
  57. return_value['valid_mean'] = np.mean(all_means)
  58. return return_value