You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

attention_interpreter_smooth_integrator.py 1.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from typing import Dict
  2. from collections import OrderedDict as OrdDict
  3. import torch
  4. import torch.nn.functional as F
  5. from .attention_interpreter import AttentionInterpreter
  6. AttentionDict = Dict[torch.Size, torch.Tensor]
  7. class AttentionInterpreterSmoothIntegrator(AttentionInterpreter):
  8. def _process_attentions(self) -> Dict[str, torch.Tensor]:
  9. sorted_attentions = {
  10. group_name: OrdDict(sorted(group_attention.items(),
  11. key=lambda pair: pair[0]))
  12. for group_name, group_attention in self._attention_groups.items()
  13. }
  14. results: Dict[str, torch.Tensor] = {group_name: None for group_name in sorted_attentions}
  15. for group_name, group_attention in sorted_attentions.items():
  16. self._impact = 1.0
  17. # from smallest to largest
  18. sum_ = 0
  19. for attention in group_attention.values():
  20. if torch.is_tensor(sum_):
  21. sum_ = F.interpolate(sum_, size=attention.shape[-2:])
  22. sum_ += attention
  23. attention = (attention - 0.5).relu()
  24. results[group_name] = self._process_attention(results[group_name], attention)
  25. results[group_name] = torch.where(results[group_name] > 0, sum_, sum_ / (2 * len(group_attention)))
  26. interpretations = {}
  27. for group_name, group_result in results.items():
  28. group_result = F.interpolate(group_result, self._input_size, mode='bilinear', align_corners=True)
  29. interpretations[group_name] = group_result
  30. '''
  31. if group_result.shape[1] == 1:
  32. interpretations[group_name] = group_result
  33. else:
  34. interpretations.update({
  35. f"{group_name}_{i}": group_result[:, i:i+1]
  36. for i in range(group_result.shape[1])
  37. })
  38. '''
  39. return interpretations