You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Optim.py 958B

123456789101112131415161718192021222324252627282930
  1. '''A wrapper class for optimizer '''
  2. import numpy as np
  3. class ScheduledOptim(object):
  4. '''A simple wrapper class for learning rate scheduling'''
  5. def __init__(self, optimizer, d_model, n_warmup_steps):
  6. self.optimizer = optimizer
  7. self.d_model = d_model
  8. self.n_warmup_steps = n_warmup_steps
  9. self.n_current_steps = 0
  10. def step(self):
  11. "Step by the inner optimizer"
  12. self.optimizer.step()
  13. def zero_grad(self):
  14. "Zero out the gradients by the inner optimizer"
  15. self.optimizer.zero_grad()
  16. def update_learning_rate(self):
  17. ''' Learning rate scheduling per step '''
  18. self.n_current_steps += 1
  19. new_lr = np.power(self.d_model, -0.5) * np.min([
  20. np.power(self.n_current_steps, -0.5),
  21. np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])
  22. for param_group in self.optimizer.param_groups:
  23. param_group['lr'] = new_lr