You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

interpretable.py 1.3KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. """
  2. An Interpretable Pytorch Model
  3. """
  4. from abc import abstractmethod, ABC
  5. from typing import Iterable, List
  6. import torch
  7. from torch import nn
  8. from ..models.model import Model
  9. class InterpretableModel(Model, ABC):
  10. """
  11. An Interpretable Pytorch Model
  12. """
  13. @property
  14. @abstractmethod
  15. def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]:
  16. """
  17. Returns:
  18. Input module for interpretation
  19. """
  20. @abstractmethod
  21. def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor:
  22. """
  23. A method to get probabilities assigned to all the classes in the model's forward,
  24. with shape (B, C), in which B is the batch size and C is number of classes
  25. Args:
  26. *inputs: Inputs to the model
  27. **kwargs: Additional arguments
  28. Returns:
  29. Tensor of categorical probabilities
  30. """
  31. class CamInterpretableModel(InterpretableModel, ABC):
  32. """
  33. A model interpretable by gradcam
  34. """
  35. @property
  36. @abstractmethod
  37. def target_conv_layers(self) -> List[nn.Module]:
  38. """
  39. Returns:
  40. The convolutional layers to be interpreted. The result of
  41. the interpretation will be sum of the grad-cam of these layers
  42. """