12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- from typing import List, Type
-
- from torch import nn
- import torch
-
- from .lap import LAP, LAPRelProp
- from ..interpreting.relcam.relprop import RPProvider
-
-
- class AdaptiveUnfold(nn.Module):
-
- def __init__(self):
- super().__init__()
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return x.flatten(start_dim=1).unsqueeze(-1)
-
-
- class AdaptiveLAP(LAP):
-
- def __init__(self, in_channels: int, hidden_channels: List[int] = [],
- sigmoid_scale: float = 1.0,
- n_attention_heads: int = 1,
- hidden_activation: Type[nn.Module] = nn.ReLU,
- discriminative_attention: bool = False) -> None:
- super().__init__(in_channels,
- hidden_channels=hidden_channels,
- sigmoid_scale=sigmoid_scale,
- n_attention_heads=n_attention_heads,
- hidden_activation=hidden_activation,
- discriminative_attention=discriminative_attention)
-
- self.unfold = AdaptiveUnfold()
-
- def _get_Hp(self, _: int) -> int:
- return 1
-
- def _get_Wp(self, _: int) -> int:
- return 1
-
-
- @RPProvider.register(AdaptiveLAP)
- class AdaptiveLAPRelProp(LAPRelProp):
-
- def fold(self, w: torch.Tensor, size: torch.Size) -> torch.Tensor:
- return w.reshape(-1, 1, *size)
|