1234567891011121314151617181920212223242526272829303132333435 |
- """
- Guided Backprop Interpretation
- """
- from typing import Dict, Union
-
- import torch
- from captum import attr
-
- from . import Interpreter, InterpretableModel, InterpretableWrapper
-
-
- class GuidedBackprop(Interpreter): # pylint: disable=too-few-public-methods
- """
- Produces class activation map
- """
-
- def __init__(self, model: InterpretableModel):
- super().__init__(model)
- self._model_wrapper = InterpretableWrapper(model)
- self._interpreter = attr.GuidedBackprop(self._model_wrapper)
-
-
- def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
- '''
- interprets given input and target class
- '''
-
- labels = self._get_target(labels, **inputs)
- separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs)
- return {
- 'default': self._interpreter.attribute(
- separated_inputs.inputs,
- target=labels,
- additional_forward_args=separated_inputs.additional_inputs)[0]
- } # TODO: must return and support tuple
|