You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

imagenet_attention_interpreter.py 2.4KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from typing import Callable, Type
  2. import torch
  3. from ..models.model import ModelIO
  4. from ..utils.hooker import Hook, Hooker
  5. from ..data.data_loader import DataLoader
  6. from ..data.dataloader_context import DataloaderContext
  7. from .interpreter import Interpreter
  8. from .attention_interpreter import AttentionInterpretableModel, AttentionInterpreter
  9. class ImagenetPredictionInterpreter(Interpreter):
  10. def __init__(self, model: AttentionInterpretableModel, k: int, base_interpreter: Type[AttentionInterpreter]):
  11. super().__init__(model)
  12. self.__base_interpreter = base_interpreter(model)
  13. self.__k = k
  14. self.__topk_classes: torch.Tensor = None
  15. self.__base_interpreter._hooker = Hooker(
  16. (model, self._generate_prediction_hook()),
  17. *[(attention_layer[-2], hook)
  18. for attention_layer, hook in self.__base_interpreter._hooker._layer_hook_pairs])
  19. @staticmethod
  20. def standard_factory(k: int, base_interpreter: Type[AttentionInterpreter] = AttentionInterpreter) -> Callable[[AttentionInterpretableModel], 'ImagenetPredictionInterpreter']:
  21. return lambda model: ImagenetPredictionInterpreter(model, k, base_interpreter)
  22. def _generate_prediction_hook(self) -> Hook:
  23. def hook(_, __, output: ModelIO):
  24. topk = output['categorical_probability'].detach()\
  25. .topk(self.__k, dim=-1)
  26. dataloader: DataLoader = DataloaderContext.instance.dataloader
  27. sample_names = dataloader.get_current_batch_samples_names()
  28. for name, top_class, top_prob in zip(sample_names, topk.indices, topk.values):
  29. print(f'Top classes of '
  30. f'{name}: '
  31. f'{top_class.detach().cpu().numpy().tolist()} - '
  32. f'{top_prob.cpu().numpy().tolist()}', flush=True)
  33. self.__topk_classes = topk\
  34. .indices\
  35. .flatten()\
  36. .cpu()
  37. return hook
  38. def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor:
  39. batch_indices = torch.arange(attention.shape[0]).repeat_interleave(self.__k)
  40. attention = attention[batch_indices, self.__topk_classes]
  41. attention = attention.reshape(-1, self.__k, *attention.shape[-2:])
  42. return self.__base_interpreter._process_attention(result, attention)
  43. def interpret(self, labels, **inputs):
  44. return self.__base_interpreter.interpret(labels, **inputs)