import torch from torch import nn class GaussianMax2d(nn.Module): def __init__(self): super().__init__() self._eps = 1e-4 self.alpha = nn.Parameter(torch.tensor(1.), requires_grad=True) def forward(self, x: torch.Tensor) -> torch.Tensor: r""" :return: exp(-alpha^2 * (x_max - x) ^ 2) * x """ # B HW N B, HW, N = x.shape x = x.transpose(1, 2).reshape(-1, HW) # BN HW x_max = (x.detach() if self.training else x)\ .max(dim=1, keepdim=True).values # BN 1 out = torch.exp(-self.alpha ** 2 * (x_max - x) ** 2) * x # BN HW out = out.reshape(B, N, HW).transpose(1, 2) # B HW N return out + self._eps