from typing import List, Callable, Union, Any from math import floor import torch from torch import nn from torch.nn import functional as F from ._common_types import size_2_t from . import GaussianMax2d, ActivatedConv2d, ScaledSigmoid from ..interpreting.relcam.relprop import RPProvider, RelProp class LAP(nn.Module): def __init__(self, in_channels: int, kernel_size: size_2_t = 2, stride: size_2_t = 2, padding: size_2_t = 0, hidden_channels: List[int] = [], sigmoid_scale: float = 1.0, n_attention_heads=1, hidden_activation: Union[ Callable[[Any], torch.nn.Module], List[Callable[[Any], torch.nn.Module]]] = nn.ReLU, discriminative_attention=False): super().__init__() self._eps = 1e-4 self._padding = padding if isinstance(padding, tuple) else (padding, padding) self._kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) self._stride = stride if isinstance(stride, tuple) else (stride, stride) channels = [in_channels, *hidden_channels] if not isinstance(hidden_activation, list): hidden_activation = [hidden_activation] * len(hidden_channels) self.attention = nn.Sequential(*[ ActivatedConv2d(i, o, 1, bn=True, activation=hidden_activation[ind]) for ind, (i, o) in enumerate(zip(channels[:-1], channels[1:])) ], ActivatedConv2d(channels[-1], n_attention_heads, 1, bn=True, activation=ScaledSigmoid.get_factory(sigmoid_scale))) # Discriminative attention is trained on all the pixels not just topk, # but never returns gradients to the inputs # So tries to discriminate most informative pixels # without becoming biased to the first choice of topk self.discriminative_attention = None if discriminative_attention: self.discriminative_attention = nn.Sequential(*[ ActivatedConv2d(i, o, 1, bn=True, activation=hidden_activation[ind]) for ind, (i, o) in enumerate(zip(channels[:-1], channels[1:])) ], ActivatedConv2d(channels[-1], n_attention_heads, 1, bn=True, activation=ScaledSigmoid.get_factory(sigmoid_scale))) self.unfold = nn.Unfold(kernel_size, padding=padding, stride=stride) self.gaussian = GaussianMax2d() def calculate_scores(self, x: torch.Tensor) -> torch.Tensor: w = self.attention(x) # B N_A W H w = w.sum(1, keepdim=True) # B 1 W H w: torch.Tensor = self.unfold(w) # B W_k*H_k H'W' w = self.gaussian(w) # B W_k*H_k H'W' w = w.unsqueeze(1) # B 1 W_k*H_k H'W' w = w + self._eps return w def forward(self, x: torch.Tensor) -> torch.Tensor: # just running discriminative head too! if self.discriminative_attention: _ = self.discriminative_attention(x.detach() if self.training else x) B, C, W, H = x.shape unfolded_s = self.calculate_scores(x) # B 1 W_k*H_k H'W' unfolded_x: torch.Tensor = self.unfold(x) # B C*W_k*H_k H'W' N = unfolded_x.shape[-1] unfolded_x = unfolded_x.reshape(B, C, -1, N) # B C W_k*H_k H'W' numinator = (unfolded_x * unfolded_s).mean(dim=2) # B C H'W' denominator = unfolded_s.mean(dim=2) # B 1 H'W' out = numinator / denominator # B C H'W' Hp = self._get_Hp(H) Wp = self._get_Wp(W) out = out.reshape(B, C, Hp, Wp) # B I H' W' return out def _get_Hp(self, H: int) -> int: return int(floor((H + 2 * self._padding[0] - self._kernel_size[0]) / self._stride[0] + 1)) def _get_Wp(self, W: int) -> int: return int(floor((W + 2 * self._padding[1] - self._kernel_size[1]) / self._stride[1] + 1)) @property def attention_layer(self) -> nn.Module: return self.attention @property def discrimination_layer(self) -> nn.Module: return self.discriminative_attention @RPProvider.register(LAP) class LAPRelProp(RelProp[LAP]): def fold(self, w: torch.Tensor, size: torch.Size) -> torch.Tensor: return F.fold(w, size, self.module._kernel_size, padding=self.module._padding, stride=self.module._stride) def rel(self, R, alpha=1): # R # B C H' W' HpWp = R.shape[-2:] HW = self.X.shape[-2:] w = self.module.calculate_scores(self.X) # B 1 W_k*H_k H'W' w = self.fold(w.squeeze(1), HW) # B 1 H W def f(x): numinator = x * w # B C H W avg = F.adaptive_avg_pool2d(numinator, output_size=HpWp) # B C H' W' K = int(HW[0] / avg.shape[-2]) sum_ = K ** 2 * avg # B C H' W' denominator = F.interpolate(sum_, size=HW, mode='nearest') # B C H W return numinator / (denominator + self.module._eps) beta = alpha - 1 activator_relevances = f(self.X.clamp(min=0)) # B C H W inhibitor_relevances = f(self.X.clamp(max=0)) # B C H W R = F.interpolate(R, size=HW, mode='nearest') # B C H W R = R * (alpha * activator_relevances - beta * inhibitor_relevances) # B C H W return R