from typing import Callable, Type import torch from ..models.model import ModelIO from ..utils.hooker import Hook, Hooker from ..data.data_loader import DataLoader from ..data.dataloader_context import DataloaderContext from .interpreter import Interpreter from .attention_interpreter import AttentionInterpretableModel, AttentionInterpreter class ImagenetPredictionInterpreter(Interpreter): def __init__(self, model: AttentionInterpretableModel, k: int, base_interpreter: Type[AttentionInterpreter]): super().__init__(model) self.__base_interpreter = base_interpreter(model) self.__k = k self.__topk_classes: torch.Tensor = None self.__base_interpreter._hooker = Hooker( (model, self._generate_prediction_hook()), *[(attention_layer[-2], hook) for attention_layer, hook in self.__base_interpreter._hooker._layer_hook_pairs]) @staticmethod def standard_factory(k: int, base_interpreter: Type[AttentionInterpreter] = AttentionInterpreter) -> Callable[[AttentionInterpretableModel], 'ImagenetPredictionInterpreter']: return lambda model: ImagenetPredictionInterpreter(model, k, base_interpreter) def _generate_prediction_hook(self) -> Hook: def hook(_, __, output: ModelIO): topk = output['categorical_probability'].detach()\ .topk(self.__k, dim=-1) dataloader: DataLoader = DataloaderContext.instance.dataloader sample_names = dataloader.get_current_batch_samples_names() for name, top_class, top_prob in zip(sample_names, topk.indices, topk.values): print(f'Top classes of ' f'{name}: ' f'{top_class.detach().cpu().numpy().tolist()} - ' f'{top_prob.cpu().numpy().tolist()}', flush=True) self.__topk_classes = topk\ .indices\ .flatten()\ .cpu() return hook def _process_attention(self, result: torch.Tensor, attention: torch.Tensor) -> torch.Tensor: batch_indices = torch.arange(attention.shape[0]).repeat_interleave(self.__k) attention = attention[batch_indices, self.__topk_classes] attention = attention.reshape(-1, self.__k, *attention.shape[-2:]) return self.__base_interpreter._process_attention(result, attention) def interpret(self, labels, **inputs): return self.__base_interpreter.interpret(labels, **inputs)