123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- from typing import List, Callable, Union, Any
- from math import floor
-
- import torch
- from torch import nn
- from torch.nn import functional as F
-
- from ._common_types import size_2_t
- from . import GaussianMax2d, ActivatedConv2d, ScaledSigmoid
- from ..interpreting.relcam.relprop import RPProvider, RelProp
-
-
- class LAP(nn.Module):
-
- def __init__(self, in_channels: int, kernel_size: size_2_t = 2,
- stride: size_2_t = 2, padding: size_2_t = 0,
- hidden_channels: List[int] = [], sigmoid_scale: float = 1.0,
- n_attention_heads=1,
- hidden_activation: Union[
- Callable[[Any], torch.nn.Module],
- List[Callable[[Any], torch.nn.Module]]] = nn.ReLU,
- discriminative_attention=False):
- super().__init__()
- self._eps = 1e-4
- self._padding = padding if isinstance(padding, tuple) else (padding, padding)
- self._kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
- self._stride = stride if isinstance(stride, tuple) else (stride, stride)
- channels = [in_channels, *hidden_channels]
- if not isinstance(hidden_activation, list):
- hidden_activation = [hidden_activation] * len(hidden_channels)
-
- self.attention = nn.Sequential(*[
- ActivatedConv2d(i, o, 1, bn=True,
- activation=hidden_activation[ind])
- for ind, (i, o) in enumerate(zip(channels[:-1], channels[1:]))
- ],
- ActivatedConv2d(channels[-1], n_attention_heads, 1, bn=True,
- activation=ScaledSigmoid.get_factory(sigmoid_scale)))
-
-
- # Discriminative attention is trained on all the pixels not just topk,
- # but never returns gradients to the inputs
- # So tries to discriminate most informative pixels
- # without becoming biased to the first choice of topk
- self.discriminative_attention = None
- if discriminative_attention:
- self.discriminative_attention = nn.Sequential(*[
- ActivatedConv2d(i, o, 1, bn=True,
- activation=hidden_activation[ind])
- for ind, (i, o) in enumerate(zip(channels[:-1], channels[1:]))
- ],
- ActivatedConv2d(channels[-1], n_attention_heads, 1, bn=True,
- activation=ScaledSigmoid.get_factory(sigmoid_scale)))
-
- self.unfold = nn.Unfold(kernel_size, padding=padding, stride=stride)
- self.gaussian = GaussianMax2d()
-
- def calculate_scores(self, x: torch.Tensor) -> torch.Tensor:
- w = self.attention(x) # B N_A W H
- w = w.sum(1, keepdim=True) # B 1 W H
-
- w: torch.Tensor = self.unfold(w) # B W_k*H_k H'W'
- w = self.gaussian(w) # B W_k*H_k H'W'
- w = w.unsqueeze(1) # B 1 W_k*H_k H'W'
- w = w + self._eps
- return w
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
-
- # just running discriminative head too!
- if self.discriminative_attention:
- _ = self.discriminative_attention(x.detach() if self.training else x)
-
- B, C, W, H = x.shape
- unfolded_s = self.calculate_scores(x) # B 1 W_k*H_k H'W'
-
- unfolded_x: torch.Tensor = self.unfold(x) # B C*W_k*H_k H'W'
- N = unfolded_x.shape[-1]
- unfolded_x = unfolded_x.reshape(B, C, -1, N) # B C W_k*H_k H'W'
-
- numinator = (unfolded_x * unfolded_s).mean(dim=2) # B C H'W'
- denominator = unfolded_s.mean(dim=2) # B 1 H'W'
- out = numinator / denominator # B C H'W'
-
- Hp = self._get_Hp(H)
- Wp = self._get_Wp(W)
-
- out = out.reshape(B, C, Hp, Wp) # B I H' W'
- return out
-
- def _get_Hp(self, H: int) -> int:
- return int(floor((H + 2 * self._padding[0] - self._kernel_size[0]) / self._stride[0] + 1))
-
- def _get_Wp(self, W: int) -> int:
- return int(floor((W + 2 * self._padding[1] - self._kernel_size[1]) / self._stride[1] + 1))
-
- @property
- def attention_layer(self) -> nn.Module:
- return self.attention
-
- @property
- def discrimination_layer(self) -> nn.Module:
- return self.discriminative_attention
-
-
- @RPProvider.register(LAP)
- class LAPRelProp(RelProp[LAP]):
-
- def fold(self, w: torch.Tensor, size: torch.Size) -> torch.Tensor:
- return F.fold(w, size, self.module._kernel_size,
- padding=self.module._padding,
- stride=self.module._stride)
-
- def rel(self, R, alpha=1):
- # R # B C H' W'
- HpWp = R.shape[-2:]
- HW = self.X.shape[-2:]
-
- w = self.module.calculate_scores(self.X) # B 1 W_k*H_k H'W'
- w = self.fold(w.squeeze(1), HW) # B 1 H W
-
- def f(x):
- numinator = x * w # B C H W
- avg = F.adaptive_avg_pool2d(numinator, output_size=HpWp) # B C H' W'
- K = int(HW[0] / avg.shape[-2])
- sum_ = K ** 2 * avg # B C H' W'
- denominator = F.interpolate(sum_, size=HW, mode='nearest') # B C H W
- return numinator / (denominator + self.module._eps)
-
- beta = alpha - 1
- activator_relevances = f(self.X.clamp(min=0)) # B C H W
- inhibitor_relevances = f(self.X.clamp(max=0)) # B C H W
-
- R = F.interpolate(R, size=HW, mode='nearest') # B C H W
- R = R * (alpha * activator_relevances - beta * inhibitor_relevances) # B C H W
- return R
|