|
12345678910111213141516171819202122232425262728293031 |
- import torch
- import torch.nn.functional as F
- from .attention_interpreter import AttentionInterpreter, AttentionInterpretableModel
-
- class AttentionSumInterpreter(AttentionInterpreter):
-
- def __init__(self, model: AttentionInterpretableModel):
- assert isinstance(model, AttentionInterpretableModel),\
- f"Expected AttentionInterpretableModel, got {type(model)}"
- super().__init__(model)
-
- def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor:
- """ Process current attention level.
-
- Args:
- result (torch.Tensor): Attention map till the previous attention level. (B, 1, H1, W1)
- attention (torch.Tensor): Current level's attention map. (B, 1, H2, W2)
-
- Returns:
- torch.Tensor: New attention map. (B, 1, H2, W2)
- """
- self._impact *= self._impact_decay
-
- # smallest attention layer
- if result is None:
- return attention
-
- # attention is larger than result
-
- result = F.interpolate(result, attention.shape[2:], mode='bilinear', align_corners=True)
- return result + attention
|