from torch.optim.lr_scheduler import StepLR class StepMinLR(StepLR): def __init__(self, optimizer, step_size, gamma=0.1, min_lr=0.0, last_epoch=-1): self.min_lr = min_lr super().__init__(optimizer, step_size, gamma, last_epoch) def get_lr(self): # use StepLR's formula, then clamp raw_lrs = super().get_lr() return [max(lr, self.min_lr) for lr in raw_lrs]