You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

gaussianmax.py 801B

123456789101112131415161718192021222324
  1. import torch
  2. from torch import nn
  3. class GaussianMax2d(nn.Module):
  4. def __init__(self):
  5. super().__init__()
  6. self._eps = 1e-4
  7. self.alpha = nn.Parameter(torch.tensor(1.), requires_grad=True)
  8. def forward(self, x: torch.Tensor) -> torch.Tensor:
  9. r"""
  10. :return: exp(-alpha^2 * (x_max - x) ^ 2) * x
  11. """
  12. # B HW N
  13. B, HW, N = x.shape
  14. x = x.transpose(1, 2).reshape(-1, HW) # BN HW
  15. x_max = (x.detach() if self.training else x)\
  16. .max(dim=1, keepdim=True).values # BN 1
  17. out = torch.exp(-self.alpha ** 2 * (x_max - x) ** 2) * x # BN HW
  18. out = out.reshape(B, N, HW).transpose(1, 2) # B HW N
  19. return out + self._eps