1234567891011121314151617181920212223242526272829303132333435363738 |
- from typing import Dict, Union
-
- import torch
- import torch.nn.functional as F
- from captum import attr
-
- from . import Interpreter, CamInterpretableModel, InterpretableWrapper
-
-
- class GradCam(Interpreter):
-
- def __init__(self, model: CamInterpretableModel):
- super().__init__(model)
- self.__model_wrapper = InterpretableWrapper(model)
- self.__interpreters = [
- attr.LayerGradCam(self.__model_wrapper, conv)
- for conv in model.target_conv_layers
- ]
-
- def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
- inp_shape = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[2:]
- labels = self._get_target(labels, **inputs)
- separated_inputs = self.__model_wrapper.convert_inputs_to_args(**inputs)
- gradcams = [interpreter.attribute(
- separated_inputs.inputs,
- target=labels,
- additional_forward_args=separated_inputs.additional_inputs)
- for interpreter in self.__interpreters]
- gradcams = [
- cam if torch.is_tensor(cam) else cam[0]
- for cam in gradcams
- ]
- gradcams = torch.stack(gradcams).sum(dim=0)
- gradcams = F.interpolate(gradcams, size=inp_shape, mode='bilinear', align_corners=True)
-
- return {
- 'default': gradcams,
- } # TODO: must return and support tuple
|