12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- from typing import Callable, Type
-
- import torch
-
- from ..models.model import ModelIO
- from ..utils.hooker import Hook, Hooker
- from ..data.data_loader import DataLoader
- from ..data.dataloader_context import DataloaderContext
- from .interpreter import Interpreter
- from .attention_interpreter import AttentionInterpretableModel, AttentionInterpreter
-
- class ImagenetPredictionInterpreter(Interpreter):
-
- def __init__(self, model: AttentionInterpretableModel, k: int, base_interpreter: Type[AttentionInterpreter]):
- super().__init__(model)
- self.__base_interpreter = base_interpreter(model)
- self.__k = k
- self.__topk_classes: torch.Tensor = None
- self.__base_interpreter._hooker = Hooker(
- (model, self._generate_prediction_hook()),
- *[(attention_layer[-2], hook)
- for attention_layer, hook in self.__base_interpreter._hooker._layer_hook_pairs])
-
- @staticmethod
- def standard_factory(k: int, base_interpreter: Type[AttentionInterpreter] = AttentionInterpreter) -> Callable[[AttentionInterpretableModel], 'ImagenetPredictionInterpreter']:
- return lambda model: ImagenetPredictionInterpreter(model, k, base_interpreter)
-
- def _generate_prediction_hook(self) -> Hook:
- def hook(_, __, output: ModelIO):
- topk = output['categorical_probability'].detach()\
- .topk(self.__k, dim=-1)
-
- dataloader: DataLoader = DataloaderContext.instance.dataloader
- sample_names = dataloader.get_current_batch_samples_names()
-
- for name, top_class, top_prob in zip(sample_names, topk.indices, topk.values):
- print(f'Top classes of '
- f'{name}: '
- f'{top_class.detach().cpu().numpy().tolist()} - '
- f'{top_prob.cpu().numpy().tolist()}', flush=True)
-
- self.__topk_classes = topk\
- .indices\
- .flatten()\
- .cpu()
- return hook
-
- def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor:
- batch_indices = torch.arange(attention.shape[0]).repeat_interleave(self.__k)
- attention = attention[batch_indices, self.__topk_classes]
- attention = attention.reshape(-1, self.__k, *attention.shape[-2:])
- return self.__base_interpreter._process_attention(result, attention)
-
- def interpret(self, labels, **inputs):
- return self.__base_interpreter.interpret(labels, **inputs)
|