""" Guided Backprop Interpretation """ from typing import Dict, Union import torch from captum import attr from . import Interpreter, InterpretableModel, InterpretableWrapper class GuidedBackprop(Interpreter): # pylint: disable=too-few-public-methods """ Produces class activation map """ def __init__(self, model: InterpretableModel): super().__init__(model) self._model_wrapper = InterpretableWrapper(model) self._interpreter = attr.GuidedBackprop(self._model_wrapper) def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: ''' interprets given input and target class ''' labels = self._get_target(labels, **inputs) separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs) return { 'default': self._interpreter.attribute( separated_inputs.inputs, target=labels, additional_forward_args=separated_inputs.additional_inputs)[0] } # TODO: must return and support tuple