You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

guided_backprop.py 1.1KB

1234567891011121314151617181920212223242526272829303132333435
  1. """
  2. Guided Backprop Interpretation
  3. """
  4. from typing import Dict, Union
  5. import torch
  6. from captum import attr
  7. from . import Interpreter, InterpretableModel, InterpretableWrapper
  8. class GuidedBackprop(Interpreter): # pylint: disable=too-few-public-methods
  9. """
  10. Produces class activation map
  11. """
  12. def __init__(self, model: InterpretableModel):
  13. super().__init__(model)
  14. self._model_wrapper = InterpretableWrapper(model)
  15. self._interpreter = attr.GuidedBackprop(self._model_wrapper)
  16. def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
  17. '''
  18. interprets given input and target class
  19. '''
  20. labels = self._get_target(labels, **inputs)
  21. separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs)
  22. return {
  23. 'default': self._interpreter.attribute(
  24. separated_inputs.inputs,
  25. target=labels,
  26. additional_forward_args=separated_inputs.additional_inputs)[0]
  27. } # TODO: must return and support tuple