from typing import Dict, Union import torch import torch.nn.functional as F from captum import attr from . import Interpreter, CamInterpretableModel, InterpretableWrapper class GradCam(Interpreter): def __init__(self, model: CamInterpretableModel): super().__init__(model) self.__model_wrapper = InterpretableWrapper(model) self.__interpreters = [ attr.LayerGradCam(self.__model_wrapper, conv) for conv in model.target_conv_layers ] def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: inp_shape = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[2:] labels = self._get_target(labels, **inputs) separated_inputs = self.__model_wrapper.convert_inputs_to_args(**inputs) gradcams = [interpreter.attribute( separated_inputs.inputs, target=labels, additional_forward_args=separated_inputs.additional_inputs) for interpreter in self.__interpreters] gradcams = [ cam if torch.is_tensor(cam) else cam[0] for cam in gradcams ] gradcams = torch.stack(gradcams).sum(dim=0) gradcams = F.interpolate(gradcams, size=inp_shape, mode='bilinear', align_corners=True) return { 'default': gradcams, } # TODO: must return and support tuple