|
123456789101112131415161718192021222324252627282930313233343536373839404142434445 |
- """
- Guided Grad Cam Interpretation
- """
- from typing import Dict, Union
-
- import torch
- from captum import attr
-
- from . import Interpreter, CamInterpretableModel, InterpretableWrapper
-
-
- class GuidedGradCam(Interpreter):
- """ Produces class activation map """
-
- def __init__(self, model: CamInterpretableModel):
- super().__init__(model)
- self._model_wrapper = InterpretableWrapper(model)
- self._interpreters = [
- attr.GuidedGradCam(self._model_wrapper, conv)
- for conv in model.target_conv_layers
- ]
-
- def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
- """ Interprets given input and target class
-
- Args:
- y (Union[int, torch.Tensor]): target class
- **inputs (torch.Tensor): model inputs
-
- Returns:
- Dict[str, torch.Tensor]: Interpretation results
- """
-
- labels = self._get_target(labels, **inputs)
- separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs)
- gradcams = [interpreter.attribute(
- separated_inputs.inputs,
- target=labels,
- additional_forward_args=separated_inputs.additional_inputs)[0]
- for interpreter in self._interpreters]
- gradcams = torch.stack(gradcams).sum(dim=0)
-
- return {
- 'default': gradcams,
- } # TODO: must return and support tuple
|