1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253 |
- """
- An Interpretable Pytorch Model
- """
- from abc import abstractmethod, ABC
- from typing import Iterable, List
-
- import torch
- from torch import nn
-
- from ..models.model import Model
-
-
- class InterpretableModel(Model, ABC):
- """
- An Interpretable Pytorch Model
- """
-
- @property
- @abstractmethod
- def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]:
- """
- Returns:
- Input module for interpretation
- """
-
- @abstractmethod
- def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor:
- """
- A method to get probabilities assigned to all the classes in the model's forward,
- with shape (B, C), in which B is the batch size and C is number of classes
-
- Args:
- *inputs: Inputs to the model
- **kwargs: Additional arguments
-
- Returns:
- Tensor of categorical probabilities
- """
-
-
- class CamInterpretableModel(InterpretableModel, ABC):
- """
- A model interpretable by gradcam
- """
-
- @property
- @abstractmethod
- def target_conv_layers(self) -> List[nn.Module]:
- """
- Returns:
- The convolutional layers to be interpreted. The result of
- the interpretation will be sum of the grad-cam of these layers
- """
|