from typing import Dict from collections import OrderedDict as OrdDict import torch import torch.nn.functional as F from .attention_interpreter import AttentionInterpreter AttentionDict = Dict[torch.Size, torch.Tensor] class AttentionInterpreterSmoothIntegrator(AttentionInterpreter): def _process_attentions(self) -> Dict[str, torch.Tensor]: sorted_attentions = { group_name: OrdDict(sorted(group_attention.items(), key=lambda pair: pair[0])) for group_name, group_attention in self._attention_groups.items() } results: Dict[str, torch.Tensor] = {group_name: None for group_name in sorted_attentions} for group_name, group_attention in sorted_attentions.items(): self._impact = 1.0 # from smallest to largest sum_ = 0 for attention in group_attention.values(): if torch.is_tensor(sum_): sum_ = F.interpolate(sum_, size=attention.shape[-2:]) sum_ += attention attention = (attention - 0.5).relu() results[group_name] = self._process_attention(results[group_name], attention) results[group_name] = torch.where(results[group_name] > 0, sum_, sum_ / (2 * len(group_attention))) interpretations = {} for group_name, group_result in results.items(): group_result = F.interpolate(group_result, self._input_size, mode='bilinear', align_corners=True) interpretations[group_name] = group_result ''' if group_result.shape[1] == 1: interpretations[group_name] = group_result else: interpretations.update({ f"{group_name}_{i}": group_result[:, i:i+1] for i in range(group_result.shape[1]) }) ''' return interpretations