You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

interpreter.py 1.7KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from typing import Dict, Union
  2. import warnings
  3. import numpy as np
  4. import torch.nn.functional as F
  5. import torch
  6. from ..interpreter import Interpreter
  7. from ..interpretable import CamInterpretableModel
  8. from .relprop import RPProvider
  9. class RelCamInterpreter(Interpreter):
  10. def __init__(self, model: CamInterpretableModel):
  11. super().__init__(model)
  12. self.__targets = model.target_conv_layers
  13. for name, module in model.named_modules():
  14. if not RPProvider.propable(module):
  15. warnings.warn(f"Module {name} of type {type(module)} is not propable! Hope you know what you are doing!")
  16. continue
  17. RPProvider.create(module)
  18. @staticmethod
  19. def __normalize(x: torch.Tensor) -> torch.Tensor:
  20. fx = x.flatten(1)
  21. minx = fx.min(dim=1).values[:, None, None, None]
  22. maxx = fx.max(dim=1).values[:, None, None, None]
  23. return (x - minx) / (1e-6 + maxx - minx)
  24. def interpret(self, labels: Union[int, torch.Tensor, np.ndarray], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
  25. x_shape = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape
  26. with RPProvider.capture(self.__targets) as Rs:
  27. z = self._model.get_categorical_probabilities(**inputs)
  28. one_hot_y = self._get_one_hot_output(z, labels)
  29. RPProvider.get(self._model)(one_hot_y)
  30. result = list(Rs.values())
  31. relcam = torch.stack(result).sum(dim=0)\
  32. if len(result) > 1\
  33. else result[0]
  34. relcam = self.__normalize(relcam)
  35. relcam = F.interpolate(relcam, size=x_shape[2:], mode='bilinear', align_corners=False)
  36. return {
  37. 'default': relcam,
  38. }