123456789101112131415161718192021 |
- from typing import Callable
- from torch import nn
- import torch
-
-
- class ScaledSigmoid(nn.Sigmoid):
- """ Scaled sigmoid function.
- """
-
- def __init__(self, scale: float = 1.0):
- super().__init__()
- self._scale = scale
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return super().forward(self._scale * x)
-
- @staticmethod
- def get_factory(scale: float = 1.0) -> Callable[[], 'ScaledSigmoid']:
- """ Returns a function that creates a scaled sigmoid module.
- """
- return lambda: ScaledSigmoid(scale)
|