You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. from typing import List, Callable, Union, Any
  2. from math import floor
  3. import torch
  4. from torch import nn
  5. from torch.nn import functional as F
  6. from ._common_types import size_2_t
  7. from . import GaussianMax2d, ActivatedConv2d, ScaledSigmoid
  8. from ..interpreting.relcam.relprop import RPProvider, RelProp
  9. class LAP(nn.Module):
  10. def __init__(self, in_channels: int, kernel_size: size_2_t = 2,
  11. stride: size_2_t = 2, padding: size_2_t = 0,
  12. hidden_channels: List[int] = [], sigmoid_scale: float = 1.0,
  13. n_attention_heads=1,
  14. hidden_activation: Union[
  15. Callable[[Any], torch.nn.Module],
  16. List[Callable[[Any], torch.nn.Module]]] = nn.ReLU,
  17. discriminative_attention=False):
  18. super().__init__()
  19. self._eps = 1e-4
  20. self._padding = padding if isinstance(padding, tuple) else (padding, padding)
  21. self._kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
  22. self._stride = stride if isinstance(stride, tuple) else (stride, stride)
  23. channels = [in_channels, *hidden_channels]
  24. if not isinstance(hidden_activation, list):
  25. hidden_activation = [hidden_activation] * len(hidden_channels)
  26. self.attention = nn.Sequential(*[
  27. ActivatedConv2d(i, o, 1, bn=True,
  28. activation=hidden_activation[ind])
  29. for ind, (i, o) in enumerate(zip(channels[:-1], channels[1:]))
  30. ],
  31. ActivatedConv2d(channels[-1], n_attention_heads, 1, bn=True,
  32. activation=ScaledSigmoid.get_factory(sigmoid_scale)))
  33. # Discriminative attention is trained on all the pixels not just topk,
  34. # but never returns gradients to the inputs
  35. # So tries to discriminate most informative pixels
  36. # without becoming biased to the first choice of topk
  37. self.discriminative_attention = None
  38. if discriminative_attention:
  39. self.discriminative_attention = nn.Sequential(*[
  40. ActivatedConv2d(i, o, 1, bn=True,
  41. activation=hidden_activation[ind])
  42. for ind, (i, o) in enumerate(zip(channels[:-1], channels[1:]))
  43. ],
  44. ActivatedConv2d(channels[-1], n_attention_heads, 1, bn=True,
  45. activation=ScaledSigmoid.get_factory(sigmoid_scale)))
  46. self.unfold = nn.Unfold(kernel_size, padding=padding, stride=stride)
  47. self.gaussian = GaussianMax2d()
  48. def calculate_scores(self, x: torch.Tensor) -> torch.Tensor:
  49. w = self.attention(x) # B N_A W H
  50. w = w.sum(1, keepdim=True) # B 1 W H
  51. w: torch.Tensor = self.unfold(w) # B W_k*H_k H'W'
  52. w = self.gaussian(w) # B W_k*H_k H'W'
  53. w = w.unsqueeze(1) # B 1 W_k*H_k H'W'
  54. w = w + self._eps
  55. return w
  56. def forward(self, x: torch.Tensor) -> torch.Tensor:
  57. # just running discriminative head too!
  58. if self.discriminative_attention:
  59. _ = self.discriminative_attention(x.detach() if self.training else x)
  60. B, C, W, H = x.shape
  61. unfolded_s = self.calculate_scores(x) # B 1 W_k*H_k H'W'
  62. unfolded_x: torch.Tensor = self.unfold(x) # B C*W_k*H_k H'W'
  63. N = unfolded_x.shape[-1]
  64. unfolded_x = unfolded_x.reshape(B, C, -1, N) # B C W_k*H_k H'W'
  65. numinator = (unfolded_x * unfolded_s).mean(dim=2) # B C H'W'
  66. denominator = unfolded_s.mean(dim=2) # B 1 H'W'
  67. out = numinator / denominator # B C H'W'
  68. Hp = self._get_Hp(H)
  69. Wp = self._get_Wp(W)
  70. out = out.reshape(B, C, Hp, Wp) # B I H' W'
  71. return out
  72. def _get_Hp(self, H: int) -> int:
  73. return int(floor((H + 2 * self._padding[0] - self._kernel_size[0]) / self._stride[0] + 1))
  74. def _get_Wp(self, W: int) -> int:
  75. return int(floor((W + 2 * self._padding[1] - self._kernel_size[1]) / self._stride[1] + 1))
  76. @property
  77. def attention_layer(self) -> nn.Module:
  78. return self.attention
  79. @property
  80. def discrimination_layer(self) -> nn.Module:
  81. return self.discriminative_attention
  82. @RPProvider.register(LAP)
  83. class LAPRelProp(RelProp[LAP]):
  84. def fold(self, w: torch.Tensor, size: torch.Size) -> torch.Tensor:
  85. return F.fold(w, size, self.module._kernel_size,
  86. padding=self.module._padding,
  87. stride=self.module._stride)
  88. def rel(self, R, alpha=1):
  89. # R # B C H' W'
  90. HpWp = R.shape[-2:]
  91. HW = self.X.shape[-2:]
  92. w = self.module.calculate_scores(self.X) # B 1 W_k*H_k H'W'
  93. w = self.fold(w.squeeze(1), HW) # B 1 H W
  94. def f(x):
  95. numinator = x * w # B C H W
  96. avg = F.adaptive_avg_pool2d(numinator, output_size=HpWp) # B C H' W'
  97. K = int(HW[0] / avg.shape[-2])
  98. sum_ = K ** 2 * avg # B C H' W'
  99. denominator = F.interpolate(sum_, size=HW, mode='nearest') # B C H W
  100. return numinator / (denominator + self.module._eps)
  101. beta = alpha - 1
  102. activator_relevances = f(self.X.clamp(min=0)) # B C H W
  103. inhibitor_relevances = f(self.X.clamp(max=0)) # B C H W
  104. R = F.interpolate(R, size=HW, mode='nearest') # B C H W
  105. R = R * (alpha * activator_relevances - beta * inhibitor_relevances) # B C H W
  106. return R