You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

guided_gradcam.py 1.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. """
  2. Guided Grad Cam Interpretation
  3. """
  4. from typing import Dict, Union
  5. import torch
  6. from captum import attr
  7. from . import Interpreter, CamInterpretableModel, InterpretableWrapper
  8. class GuidedGradCam(Interpreter):
  9. """ Produces class activation map """
  10. def __init__(self, model: CamInterpretableModel):
  11. super().__init__(model)
  12. self._model_wrapper = InterpretableWrapper(model)
  13. self._interpreters = [
  14. attr.GuidedGradCam(self._model_wrapper, conv)
  15. for conv in model.target_conv_layers
  16. ]
  17. def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
  18. """ Interprets given input and target class
  19. Args:
  20. y (Union[int, torch.Tensor]): target class
  21. **inputs (torch.Tensor): model inputs
  22. Returns:
  23. Dict[str, torch.Tensor]: Interpretation results
  24. """
  25. labels = self._get_target(labels, **inputs)
  26. separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs)
  27. gradcams = [interpreter.attribute(
  28. separated_inputs.inputs,
  29. target=labels,
  30. additional_forward_args=separated_inputs.additional_inputs)[0]
  31. for interpreter in self._interpreters]
  32. gradcams = torch.stack(gradcams).sum(dim=0)
  33. return {
  34. 'default': gradcams,
  35. } # TODO: must return and support tuple