You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

attention_interpreter.py 5.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. from abc import ABC, abstractmethod
  2. from collections import OrderedDict as OrdDict
  3. from typing import Dict, List, Tuple, Union
  4. import torch
  5. from torch import nn
  6. import torch.nn.functional as F
  7. from .interpreter import Interpreter
  8. from .interpretable import InterpretableModel
  9. from ..modules import LAP
  10. from ..utils.hooker import Hooker, Hook
  11. AttentionDict = Dict[torch.Size, torch.Tensor]
  12. class AttentionInterpretableModel(InterpretableModel, ABC):
  13. @property
  14. @abstractmethod
  15. def attention_layers(self) -> Dict[str, List[LAP]]:
  16. """
  17. List of attention groups
  18. """
  19. class AttentionInterpreter(Interpreter):
  20. def __init__(self, model: AttentionInterpretableModel):
  21. assert isinstance(model, AttentionInterpretableModel),\
  22. f"Expected AttentionInterpretableModel, got {type(model)}"
  23. super().__init__(model)
  24. layer_hooks = sum([
  25. [
  26. (layer.attention_layer, self._generate_forward_hook(group_name, layer))
  27. for layer in group_layers
  28. ] for group_name, group_layers in model.attention_layers.items()
  29. ], [])
  30. self._hooker = Hooker(*layer_hooks)
  31. self._attention_groups: Dict[str, AttentionDict] = {
  32. group_name: {} for group_name in model.attention_layers}
  33. self._impact = 1.
  34. self._impact_decay = .8
  35. self._input_size: torch.Size = None
  36. self._device = None
  37. def _generate_forward_hook(self, group_name: str, layer: LAP) -> Hook:
  38. def forward_hook(_: nn.Module, __: Tuple[torch.Tensor, ...], out: torch.Tensor) -> None:
  39. processed = out.clone()
  40. shape = processed.shape[2:]
  41. if shape not in self._attention_groups[group_name]:
  42. self._attention_groups[group_name][shape] = torch.zeros_like(
  43. processed)
  44. self._attention_groups[group_name][shape] += processed
  45. return forward_hook
  46. def _reset(self) -> None:
  47. self._attention_groups = {
  48. group_name: {} for group_name in self._model.attention_layers}
  49. self._impact = 1.
  50. def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor:
  51. """ Process current attention level.
  52. Args:
  53. result (torch.Tensor): Attention map till the previous attention level. (B, 1, H1, W1)
  54. attention (torch.Tensor): Current level's attention map. (B, 1, H2, W2)
  55. Returns:
  56. torch.Tensor: New attention map. (B, 1, H2, W2)
  57. """
  58. self._impact *= self._impact_decay
  59. # smallest attention layer
  60. if result is None:
  61. return attention
  62. # attention is larger than result
  63. scale = torch.ceil(
  64. torch.tensor(attention.shape[2:]) / torch.tensor(result.shape[2:])
  65. ).int().tolist()
  66. # find maximum of attention in each kernel
  67. max_map = F.max_pool2d(attention, kernel_size=scale, stride=scale)
  68. max_map = F.interpolate(max_map, attention.shape[2:], mode='nearest')
  69. is_any_active = max_map > 0
  70. # interpolate result to attention size
  71. result = F.interpolate(result, attention.shape[2:], mode='nearest')
  72. # # maximum between result and attention
  73. # impacted = torch.max(result, max_map * self._impact)
  74. impacted = torch.max(result, attention * self._impact)
  75. # when attention is zero, we use zero
  76. impacted[attention == 0] = 0
  77. # impacted = torch.where(attention > 0, impacted, 0)
  78. # when result is non-zero, we will use impacted
  79. impacted[result == 0] = 0
  80. # impacted = torch.where(result > 0, impacted, 0)
  81. # if max_map is zero, we use result, else we use impacted
  82. result[is_any_active] = impacted[is_any_active]
  83. # result = torch.where(is_any_active, impacted, result)
  84. # mask = (max_map > 0) & (result > 0)
  85. # result[mask] = impacted[mask]
  86. # result[mask & (attention == 0)] = 0
  87. return result
  88. def _process_attentions(self) -> Dict[str, torch.Tensor]:
  89. sorted_attentions = {
  90. group_name: OrdDict(sorted(group_attention.items(),
  91. key=lambda pair: pair[0]))
  92. for group_name, group_attention in self._attention_groups.items()
  93. }
  94. results: Dict[str, torch.Tensor] = {group_name: None for group_name in sorted_attentions}
  95. for group_name, group_attention in sorted_attentions.items():
  96. self._impact = 1.0
  97. # from smallest to largest
  98. for attention in group_attention.values():
  99. attention = (attention - 0.5).relu()
  100. results[group_name] = self._process_attention(results[group_name], attention)
  101. interpretations = {}
  102. for group_name, group_result in results.items():
  103. group_result = F.interpolate(group_result, self._input_size, mode='bilinear', align_corners=True)
  104. interpretations[group_name] = group_result
  105. '''
  106. if group_result.shape[1] == 1:
  107. interpretations[group_name] = group_result
  108. else:
  109. interpretations.update({
  110. f"{group_name}_{i}": group_result[:, i:i+1]
  111. for i in range(group_result.shape[1])
  112. })
  113. '''
  114. return interpretations
  115. def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
  116. self._reset()
  117. self._batch_size = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[0]
  118. self._input_size = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[2:]
  119. self._device = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].device
  120. with self._hooker:
  121. self._model(**inputs)
  122. return self._process_attentions()