You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

scaled_sigmoid.py 561B

123456789101112131415161718192021
  1. from typing import Callable
  2. from torch import nn
  3. import torch
  4. class ScaledSigmoid(nn.Sigmoid):
  5. """ Scaled sigmoid function.
  6. """
  7. def __init__(self, scale: float = 1.0):
  8. super().__init__()
  9. self._scale = scale
  10. def forward(self, x: torch.Tensor) -> torch.Tensor:
  11. return super().forward(self._scale * x)
  12. @staticmethod
  13. def get_factory(scale: float = 1.0) -> Callable[[], 'ScaledSigmoid']:
  14. """ Returns a function that creates a scaled sigmoid module.
  15. """
  16. return lambda: ScaledSigmoid(scale)