""" An abstract class to apply different interpretation algorithms on torch models. """ from abc import abstractmethod, ABC from typing import Dict, List, Union import torch from torch import nn import numpy as np from . import InterpretableModel class Interpreter(ABC): # pylint: disable=too-few-public-methods """ An abstract class to apply different interpretation algorithms on torch models. """ def __init__(self, model: InterpretableModel): self._model = model def _get_all_leaf_children(self, module: nn.Module = None) -> List[nn.Module]: """ Gets all leaf modules recursively """ if module is None: module = self._model children: List[nn.Module] = [] for child in module.children(): if len(list(child.children())) == 0: children.append(child) else: children += self._get_all_leaf_children(child) return children @staticmethod def _get_one_hot_output(output: torch.Tensor, target: Union[int, torch.Tensor, np.ndarray] = None) -> torch.Tensor: batch_size = output.shape[0] if target is None: target = output.argmax(dim=1) one_hot_output = torch.zeros_like(output) one_hot_output[torch.arange(batch_size), target] = 1 return one_hot_output def _get_target(self, target: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Union[int, torch.Tensor]: if target is not None: return target return self._model.get_categorical_probabilities(**inputs).argmax(dim=1) @abstractmethod def interpret(self, labels: Union[int, torch.Tensor, np.ndarray], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: ''' An abstract method to interpret given input and target class Returns a dictionary of interpretation results. ''' def dynamic_threshold(self, x: np.ndarray) -> np.ndarray: """ A function to dynamically threshold one sample. """ return (x - (x.mean() + x.std())).clip(min=0)