""" Guided Grad Cam Interpretation """ from typing import Dict, Union import torch from captum import attr from . import Interpreter, CamInterpretableModel, InterpretableWrapper class GuidedGradCam(Interpreter): """ Produces class activation map """ def __init__(self, model: CamInterpretableModel): super().__init__(model) self._model_wrapper = InterpretableWrapper(model) self._interpreters = [ attr.GuidedGradCam(self._model_wrapper, conv) for conv in model.target_conv_layers ] def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: """ Interprets given input and target class Args: y (Union[int, torch.Tensor]): target class **inputs (torch.Tensor): model inputs Returns: Dict[str, torch.Tensor]: Interpretation results """ labels = self._get_target(labels, **inputs) separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs) gradcams = [interpreter.attribute( separated_inputs.inputs, target=labels, additional_forward_args=separated_inputs.additional_inputs)[0] for interpreter in self._interpreters] gradcams = torch.stack(gradcams).sum(dim=0) return { 'default': gradcams, } # TODO: must return and support tuple