You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 407B

123456789101112
  1. from torch.optim.lr_scheduler import StepLR
  2. class StepMinLR(StepLR):
  3. def __init__(self, optimizer, step_size, gamma=0.1, min_lr=0.0, last_epoch=-1):
  4. self.min_lr = min_lr
  5. super().__init__(optimizer, step_size, gamma, last_epoch)
  6. def get_lr(self):
  7. # use StepLR's formula, then clamp
  8. raw_lrs = super().get_lr()
  9. return [max(lr, self.min_lr) for lr in raw_lrs]