from abc import ABC, abstractmethod from collections import OrderedDict as OrdDict from typing import Dict, List, Tuple, Union import torch from torch import nn import torch.nn.functional as F from .interpreter import Interpreter from .interpretable import InterpretableModel from ..modules import LAP from ..utils.hooker import Hooker, Hook AttentionDict = Dict[torch.Size, torch.Tensor] class AttentionInterpretableModel(InterpretableModel, ABC): @property @abstractmethod def attention_layers(self) -> Dict[str, List[LAP]]: """ List of attention groups """ class AttentionInterpreter(Interpreter): def __init__(self, model: AttentionInterpretableModel): assert isinstance(model, AttentionInterpretableModel),\ f"Expected AttentionInterpretableModel, got {type(model)}" super().__init__(model) layer_hooks = sum([ [ (layer.attention_layer, self._generate_forward_hook(group_name, layer)) for layer in group_layers ] for group_name, group_layers in model.attention_layers.items() ], []) self._hooker = Hooker(*layer_hooks) self._attention_groups: Dict[str, AttentionDict] = { group_name: {} for group_name in model.attention_layers} self._impact = 1. self._impact_decay = .8 self._input_size: torch.Size = None self._device = None def _generate_forward_hook(self, group_name: str, layer: LAP) -> Hook: def forward_hook(_: nn.Module, __: Tuple[torch.Tensor, ...], out: torch.Tensor) -> None: processed = out.clone() shape = processed.shape[2:] if shape not in self._attention_groups[group_name]: self._attention_groups[group_name][shape] = torch.zeros_like( processed) self._attention_groups[group_name][shape] += processed return forward_hook def _reset(self) -> None: self._attention_groups = { group_name: {} for group_name in self._model.attention_layers} self._impact = 1. def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor: """ Process current attention level. Args: result (torch.Tensor): Attention map till the previous attention level. (B, 1, H1, W1) attention (torch.Tensor): Current level's attention map. (B, 1, H2, W2) Returns: torch.Tensor: New attention map. (B, 1, H2, W2) """ self._impact *= self._impact_decay # smallest attention layer if result is None: return attention # attention is larger than result scale = torch.ceil( torch.tensor(attention.shape[2:]) / torch.tensor(result.shape[2:]) ).int().tolist() # find maximum of attention in each kernel max_map = F.max_pool2d(attention, kernel_size=scale, stride=scale) max_map = F.interpolate(max_map, attention.shape[2:], mode='nearest') is_any_active = max_map > 0 # interpolate result to attention size result = F.interpolate(result, attention.shape[2:], mode='nearest') # # maximum between result and attention # impacted = torch.max(result, max_map * self._impact) impacted = torch.max(result, attention * self._impact) # when attention is zero, we use zero impacted[attention == 0] = 0 # impacted = torch.where(attention > 0, impacted, 0) # when result is non-zero, we will use impacted impacted[result == 0] = 0 # impacted = torch.where(result > 0, impacted, 0) # if max_map is zero, we use result, else we use impacted result[is_any_active] = impacted[is_any_active] # result = torch.where(is_any_active, impacted, result) # mask = (max_map > 0) & (result > 0) # result[mask] = impacted[mask] # result[mask & (attention == 0)] = 0 return result def _process_attentions(self) -> Dict[str, torch.Tensor]: sorted_attentions = { group_name: OrdDict(sorted(group_attention.items(), key=lambda pair: pair[0])) for group_name, group_attention in self._attention_groups.items() } results: Dict[str, torch.Tensor] = {group_name: None for group_name in sorted_attentions} for group_name, group_attention in sorted_attentions.items(): self._impact = 1.0 # from smallest to largest for attention in group_attention.values(): attention = (attention - 0.5).relu() results[group_name] = self._process_attention(results[group_name], attention) interpretations = {} for group_name, group_result in results.items(): group_result = F.interpolate(group_result, self._input_size, mode='bilinear', align_corners=True) interpretations[group_name] = group_result ''' if group_result.shape[1] == 1: interpretations[group_name] = group_result else: interpretations.update({ f"{group_name}_{i}": group_result[:, i:i+1] for i in range(group_result.shape[1]) }) ''' return interpretations def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: self._reset() self._batch_size = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[0] self._input_size = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[2:] self._device = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].device with self._hooker: self._model(**inputs) return self._process_attentions()