123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- from abc import ABC, abstractmethod
- from collections import OrderedDict as OrdDict
- from typing import Dict, List, Tuple, Union
-
- import torch
- from torch import nn
- import torch.nn.functional as F
- from .interpreter import Interpreter
- from .interpretable import InterpretableModel
- from ..modules import LAP
- from ..utils.hooker import Hooker, Hook
-
-
- AttentionDict = Dict[torch.Size, torch.Tensor]
-
-
- class AttentionInterpretableModel(InterpretableModel, ABC):
-
- @property
- @abstractmethod
- def attention_layers(self) -> Dict[str, List[LAP]]:
- """
- List of attention groups
- """
-
-
- class AttentionInterpreter(Interpreter):
-
- def __init__(self, model: AttentionInterpretableModel):
- assert isinstance(model, AttentionInterpretableModel),\
- f"Expected AttentionInterpretableModel, got {type(model)}"
- super().__init__(model)
- layer_hooks = sum([
- [
- (layer.attention_layer, self._generate_forward_hook(group_name, layer))
- for layer in group_layers
- ] for group_name, group_layers in model.attention_layers.items()
- ], [])
- self._hooker = Hooker(*layer_hooks)
- self._attention_groups: Dict[str, AttentionDict] = {
- group_name: {} for group_name in model.attention_layers}
- self._impact = 1.
- self._impact_decay = .8
- self._input_size: torch.Size = None
- self._device = None
-
- def _generate_forward_hook(self, group_name: str, layer: LAP) -> Hook:
- def forward_hook(_: nn.Module, __: Tuple[torch.Tensor, ...], out: torch.Tensor) -> None:
- processed = out.clone()
- shape = processed.shape[2:]
- if shape not in self._attention_groups[group_name]:
- self._attention_groups[group_name][shape] = torch.zeros_like(
- processed)
- self._attention_groups[group_name][shape] += processed
-
- return forward_hook
-
- def _reset(self) -> None:
- self._attention_groups = {
- group_name: {} for group_name in self._model.attention_layers}
- self._impact = 1.
-
- def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor:
- """ Process current attention level.
-
- Args:
- result (torch.Tensor): Attention map till the previous attention level. (B, 1, H1, W1)
- attention (torch.Tensor): Current level's attention map. (B, 1, H2, W2)
-
- Returns:
- torch.Tensor: New attention map. (B, 1, H2, W2)
- """
- self._impact *= self._impact_decay
-
- # smallest attention layer
- if result is None:
- return attention
-
- # attention is larger than result
- scale = torch.ceil(
- torch.tensor(attention.shape[2:]) / torch.tensor(result.shape[2:])
- ).int().tolist()
-
- # find maximum of attention in each kernel
- max_map = F.max_pool2d(attention, kernel_size=scale, stride=scale)
- max_map = F.interpolate(max_map, attention.shape[2:], mode='nearest')
- is_any_active = max_map > 0
-
- # interpolate result to attention size
- result = F.interpolate(result, attention.shape[2:], mode='nearest')
-
- # # maximum between result and attention
- # impacted = torch.max(result, max_map * self._impact)
- impacted = torch.max(result, attention * self._impact)
- # when attention is zero, we use zero
- impacted[attention == 0] = 0
- # impacted = torch.where(attention > 0, impacted, 0)
- # when result is non-zero, we will use impacted
- impacted[result == 0] = 0
- # impacted = torch.where(result > 0, impacted, 0)
- # if max_map is zero, we use result, else we use impacted
- result[is_any_active] = impacted[is_any_active]
- # result = torch.where(is_any_active, impacted, result)
-
- # mask = (max_map > 0) & (result > 0)
- # result[mask] = impacted[mask]
- # result[mask & (attention == 0)] = 0
- return result
-
- def _process_attentions(self) -> Dict[str, torch.Tensor]:
- sorted_attentions = {
- group_name: OrdDict(sorted(group_attention.items(),
- key=lambda pair: pair[0]))
- for group_name, group_attention in self._attention_groups.items()
- }
-
- results: Dict[str, torch.Tensor] = {group_name: None for group_name in sorted_attentions}
- for group_name, group_attention in sorted_attentions.items():
- self._impact = 1.0
- # from smallest to largest
- for attention in group_attention.values():
- attention = (attention - 0.5).relu()
- results[group_name] = self._process_attention(results[group_name], attention)
-
- interpretations = {}
- for group_name, group_result in results.items():
- group_result = F.interpolate(group_result, self._input_size, mode='bilinear', align_corners=True)
- interpretations[group_name] = group_result
- '''
- if group_result.shape[1] == 1:
- interpretations[group_name] = group_result
- else:
- interpretations.update({
- f"{group_name}_{i}": group_result[:, i:i+1]
- for i in range(group_result.shape[1])
- })
- '''
- return interpretations
-
- def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
- self._reset()
- self._batch_size = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[0]
- self._input_size = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[2:]
- self._device = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].device
- with self._hooker:
- self._model(**inputs)
-
- return self._process_attentions()
|