You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

deep_lift.py 1.7KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. """
  2. Guided Grad Cam Interpretation
  3. """
  4. from typing import Dict, Union
  5. import torch
  6. import torch.nn.functional as F
  7. from captum import attr
  8. from . import Interpreter, InterpretableModel, InterpretableWrapper
  9. class DeepLift(Interpreter): # pylint: disable=too-few-public-methods
  10. """
  11. Produces class activation map
  12. """
  13. def __init__(self, model: InterpretableModel):
  14. super().__init__(model)
  15. self._model_wrapper = InterpretableWrapper(model)
  16. self._interpreter = attr.DeepLift(self._model_wrapper)
  17. self.__B = None
  18. def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
  19. '''
  20. interprets given input and target class
  21. '''
  22. B = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[0]
  23. if self.__B is None:
  24. self.__B = B
  25. if B < self.__B:
  26. padding = self.__B - B
  27. inputs = {
  28. k: F.pad(v, (*([0] * (v.ndim * 2 - 1)), padding))
  29. for k, v in inputs.items()
  30. if torch.is_tensor(v) and v.shape[0] == B
  31. }
  32. if torch.is_tensor(labels) and labels.shape[0] == B:
  33. labels = F.pad(labels, (0, padding))
  34. labels = self._get_target(labels, **inputs)
  35. separated_inputs = self._model_wrapper.convert_inputs_to_args(**inputs)
  36. result = self._interpreter.attribute(
  37. separated_inputs.inputs,
  38. target=labels,
  39. additional_forward_args=separated_inputs.additional_inputs)[0]
  40. result = result[:B]
  41. return {
  42. 'default': result
  43. } # TODO: must return and support tuple