""" Guided Grad Cam Interpretation """ from typing import Dict, Union import torch import torch.nn.functional as F from captum import attr from . import Interpreter, InterpretableModel, InterpretableWrapper class DeepLift(Interpreter): # pylint: disable=too-few-public-methods """ Produces class activation map """ def __init__(self, model: InterpretableModel): super().__init__(model) self._model_wrapper = InterpretableWrapper(model) self._interpreter = attr.DeepLift(self._model_wrapper) self.__B = None def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: ''' interprets given input and target class ''' B = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[0] if self.__B is None: self.__B = B if B < self.__B: padding = self.__B - B inputs = { k: F.pad(v, (*([0] * (v.ndim * 2 - 1)), padding)) for k, v in inputs.items() if torch.is_tensor(v) and v.shape[0] == B } if torch.is_tensor(labels) and labels.shape[0] == B: labels = F.pad(labels, (0, padding)) labels = self._get_target(labels, **inputs) separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs) result = self._interpreter.attribute( separated_inputs.inputs, target=labels, additional_forward_args=separated_inputs.additional_inputs)[0] result = result[:B] return { 'default': result } # TODO: must return and support tuple