You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

attention_sum_interpreter.py 1.1KB

1 year ago
12345678910111213141516171819202122232425262728293031
  1. import torch
  2. import torch.nn.functional as F
  3. from .attention_interpreter import AttentionInterpreter, AttentionInterpretableModel
  4. class AttentionSumInterpreter(AttentionInterpreter):
  5. def __init__(self, model: AttentionInterpretableModel):
  6. assert isinstance(model, AttentionInterpretableModel),\
  7. f"Expected AttentionInterpretableModel, got {type(model)}"
  8. super().__init__(model)
  9. def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor:
  10. """ Process current attention level.
  11. Args:
  12. result (torch.Tensor): Attention map till the previous attention level. (B, 1, H1, W1)
  13. attention (torch.Tensor): Current level's attention map. (B, 1, H2, W2)
  14. Returns:
  15. torch.Tensor: New attention map. (B, 1, H2, W2)
  16. """
  17. self._impact *= self._impact_decay
  18. # smallest attention layer
  19. if result is None:
  20. return attention
  21. # attention is larger than result
  22. result = F.interpolate(result, attention.shape[2:], mode='bilinear', align_corners=True)
  23. return result + attention