You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

gradcam.py 1.4KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. from typing import Dict, Union
  2. import torch
  3. import torch.nn.functional as F
  4. from captum import attr
  5. from . import Interpreter, CamInterpretableModel, InterpretableWrapper
  6. class GradCam(Interpreter):
  7. def __init__(self, model: CamInterpretableModel):
  8. super().__init__(model)
  9. self.__model_wrapper = InterpretableWrapper(model)
  10. self.__interpreters = [
  11. attr.LayerGradCam(self.__model_wrapper, conv)
  12. for conv in model.target_conv_layers
  13. ]
  14. def interpret(self, labels: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
  15. inp_shape = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape[2:]
  16. labels = self._get_target(labels, **inputs)
  17. separated_inputs = self.__model_wrapper.convert_inputs_to_args(**inputs)
  18. gradcams = [interpreter.attribute(
  19. separated_inputs.inputs,
  20. target=labels,
  21. additional_forward_args=separated_inputs.additional_inputs)
  22. for interpreter in self.__interpreters]
  23. gradcams = [
  24. cam if torch.is_tensor(cam) else cam[0]
  25. for cam in gradcams
  26. ]
  27. gradcams = torch.stack(gradcams).sum(dim=0)
  28. gradcams = F.interpolate(gradcams, size=inp_shape, mode='bilinear', align_corners=True)
  29. return {
  30. 'default': gradcams,
  31. } # TODO: must return and support tuple