from typing import Callable from torch import nn import torch class ScaledSigmoid(nn.Sigmoid): """ Scaled sigmoid function. """ def __init__(self, scale: float = 1.0): super().__init__() self._scale = scale def forward(self, x: torch.Tensor) -> torch.Tensor: return super().forward(self._scale * x) @staticmethod def get_factory(scale: float = 1.0) -> Callable[[], 'ScaledSigmoid']: """ Returns a function that creates a scaled sigmoid module. """ return lambda: ScaledSigmoid(scale)