from typing import List, Type from torch import nn import torch from .lap import LAP, LAPRelProp from ..interpreting.relcam.relprop import RPProvider class AdaptiveUnfold(nn.Module): def __init__(self): super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: return x.flatten(start_dim=1).unsqueeze(-1) class AdaptiveLAP(LAP): def __init__(self, in_channels: int, hidden_channels: List[int] = [], sigmoid_scale: float = 1.0, n_attention_heads: int = 1, hidden_activation: Type[nn.Module] = nn.ReLU, discriminative_attention: bool = False) -> None: super().__init__(in_channels, hidden_channels=hidden_channels, sigmoid_scale=sigmoid_scale, n_attention_heads=n_attention_heads, hidden_activation=hidden_activation, discriminative_attention=discriminative_attention) self.unfold = AdaptiveUnfold() def _get_Hp(self, _: int) -> int: return 1 def _get_Wp(self, _: int) -> int: return 1 @RPProvider.register(AdaptiveLAP) class AdaptiveLAPRelProp(LAPRelProp): def fold(self, w: torch.Tensor, size: torch.Size) -> torch.Tensor: return w.reshape(-1, 1, *size)