""" An Interpretable Pytorch Model """ from abc import abstractmethod, ABC from typing import Iterable, List import torch from torch import nn from ..models.model import Model class InterpretableModel(Model, ABC): """ An Interpretable Pytorch Model """ @property @abstractmethod def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]: """ Returns: Input module for interpretation """ @abstractmethod def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor: """ A method to get probabilities assigned to all the classes in the model's forward, with shape (B, C), in which B is the batch size and C is number of classes Args: *inputs: Inputs to the model **kwargs: Additional arguments Returns: Tensor of categorical probabilities """ class CamInterpretableModel(InterpretableModel, ABC): """ A model interpretable by gradcam """ @property @abstractmethod def target_conv_layers(self) -> List[nn.Module]: """ Returns: The convolutional layers to be interpreted. The result of the interpretation will be sum of the grad-cam of these layers """