123456789101112131415161718192021222324 |
- import torch
- from torch import nn
-
-
- class GaussianMax2d(nn.Module):
-
- def __init__(self):
- super().__init__()
- self._eps = 1e-4
- self.alpha = nn.Parameter(torch.tensor(1.), requires_grad=True)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- r"""
- :return: exp(-alpha^2 * (x_max - x) ^ 2) * x
- """
- # B HW N
- B, HW, N = x.shape
- x = x.transpose(1, 2).reshape(-1, HW) # BN HW
- x_max = (x.detach() if self.training else x)\
- .max(dim=1, keepdim=True).values # BN 1
-
- out = torch.exp(-self.alpha ** 2 * (x_max - x) ** 2) * x # BN HW
- out = out.reshape(B, N, HW).transpose(1, 2) # B HW N
- return out + self._eps
|