|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- """
- Guided Grad Cam Interpretation
- """
- from typing import Dict, Union
-
- import torch
- import torch.nn.functional as F
- from captum import attr
-
- from . import Interpreter, InterpretableModel, InterpretableWrapper
-
-
- class DeepLift(Interpreter): # pylint: disable=too-few-public-methods
- """
- Produces class activation map
- """
-
- def __init__(self, model: InterpretableModel):
- super().__init__(model)
- self._model_wrapper = InterpretableWrapper(model)
- self._interpreter = attr.DeepLift(self._model_wrapper)
- self.__B = None
-
- def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
- '''
- interprets given input and target class
- '''
- B = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[0]
- if self.__B is None:
- self.__B = B
-
- if B < self.__B:
- padding = self.__B - B
- inputs = {
- k: F.pad(v, (*([0] * (v.ndim * 2 - 1)), padding))
- for k, v in inputs.items()
- if torch.is_tensor(v) and v.shape[0] == B
- }
- if torch.is_tensor(labels) and labels.shape[0] == B:
- labels = F.pad(labels, (0, padding))
-
- labels = self._get_target(labels, **inputs)
- separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs)
- result = self._interpreter.attribute(
- separated_inputs.inputs,
- target=labels,
- additional_forward_args=separated_inputs.additional_inputs)[0]
- result = result[:B]
-
- return {
- 'default': result
- } # TODO: must return and support tuple
|