123456789101112131415161718192021222324252627282930313233343536373839404142434445464748 |
- from typing import Dict
- from collections import OrderedDict as OrdDict
-
- import torch
- import torch.nn.functional as F
-
- from .attention_interpreter import AttentionInterpreter
-
-
- AttentionDict = Dict[torch.Size, torch.Tensor]
-
-
- class AttentionInterpreterSmoothIntegrator(AttentionInterpreter):
-
- def _process_attentions(self) -> Dict[str, torch.Tensor]:
- sorted_attentions = {
- group_name: OrdDict(sorted(group_attention.items(),
- key=lambda pair: pair[0]))
- for group_name, group_attention in self._attention_groups.items()
- }
-
- results: Dict[str, torch.Tensor] = {group_name: None for group_name in sorted_attentions}
- for group_name, group_attention in sorted_attentions.items():
- self._impact = 1.0
- # from smallest to largest
- sum_ = 0
- for attention in group_attention.values():
- if torch.is_tensor(sum_):
- sum_ = F.interpolate(sum_, size=attention.shape[-2:])
- sum_ += attention
- attention = (attention - 0.5).relu()
- results[group_name] = self._process_attention(results[group_name], attention)
- results[group_name] = torch.where(results[group_name] > 0, sum_, sum_ / (2 * len(group_attention)))
-
- interpretations = {}
- for group_name, group_result in results.items():
- group_result = F.interpolate(group_result, self._input_size, mode='bilinear', align_corners=True)
- interpretations[group_name] = group_result
- '''
- if group_result.shape[1] == 1:
- interpretations[group_name] = group_result
- else:
- interpretations.update({
- f"{group_name}_{i}": group_result[:, i:i+1]
- for i in range(group_result.shape[1])
- })
- '''
- return interpretations
|