123456789101112 |
- from torch.optim.lr_scheduler import StepLR
-
-
- class StepMinLR(StepLR):
- def __init__(self, optimizer, step_size, gamma=0.1, min_lr=0.0, last_epoch=-1):
- self.min_lr = min_lr
- super().__init__(optimizer, step_size, gamma, last_epoch)
-
- def get_lr(self):
- # use StepLR's formula, then clamp
- raw_lrs = super().get_lr()
- return [max(lr, self.min_lr) for lr in raw_lrs]
|