import torch import torch.nn.functional as F from .attention_interpreter import AttentionInterpreter, AttentionInterpretableModel class AttentionSumInterpreter(AttentionInterpreter): def __init__(self, model: AttentionInterpretableModel): assert isinstance(model, AttentionInterpretableModel),\ f"Expected AttentionInterpretableModel, got {type(model)}" super().__init__(model) def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor: """ Process current attention level. Args: result (torch.Tensor): Attention map till the previous attention level. (B, 1, H1, W1) attention (torch.Tensor): Current level's attention map. (B, 1, H2, W2) Returns: torch.Tensor: New attention map. (B, 1, H2, W2) """ self._impact *= self._impact_decay # smallest attention layer if result is None: return attention # attention is larger than result result = F.interpolate(result, attention.shape[2:], mode='bilinear', align_corners=True) return result + attention