You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

adaptive_lap.py 1.3KB

1 year ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from typing import List, Type
  2. from torch import nn
  3. import torch
  4. from .lap import LAP, LAPRelProp
  5. from ..interpreting.relcam.relprop import RPProvider
  6. class AdaptiveUnfold(nn.Module):
  7. def __init__(self):
  8. super().__init__()
  9. def forward(self, x: torch.Tensor) -> torch.Tensor:
  10. return x.flatten(start_dim=1).unsqueeze(-1)
  11. class AdaptiveLAP(LAP):
  12. def __init__(self, in_channels: int, hidden_channels: List[int] = [],
  13. sigmoid_scale: float = 1.0,
  14. n_attention_heads: int = 1,
  15. hidden_activation: Type[nn.Module] = nn.ReLU,
  16. discriminative_attention: bool = False) -> None:
  17. super().__init__(in_channels,
  18. hidden_channels=hidden_channels,
  19. sigmoid_scale=sigmoid_scale,
  20. n_attention_heads=n_attention_heads,
  21. hidden_activation=hidden_activation,
  22. discriminative_attention=discriminative_attention)
  23. self.unfold = AdaptiveUnfold()
  24. def _get_Hp(self, _: int) -> int:
  25. return 1
  26. def _get_Wp(self, _: int) -> int:
  27. return 1
  28. @RPProvider.register(AdaptiveLAP)
  29. class AdaptiveLAPRelProp(LAPRelProp):
  30. def fold(self, w: torch.Tensor, size: torch.Size) -> torch.Tensor:
  31. return w.reshape(-1, 1, *size)