12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- """
- An abstract class to apply different interpretation algorithms on torch models.
- """
-
- from abc import abstractmethod, ABC
- from typing import Dict, List, Union
- import torch
- from torch import nn
- import numpy as np
- from . import InterpretableModel
-
-
- class Interpreter(ABC): # pylint: disable=too-few-public-methods
- """
- An abstract class to apply different interpretation algorithms on torch models.
- """
-
- def __init__(self, model: InterpretableModel):
- self._model = model
-
- def _get_all_leaf_children(self, module: nn.Module = None) -> List[nn.Module]:
- """
- Gets all leaf modules recursively
- """
- if module is None:
- module = self._model
-
- children: List[nn.Module] = []
-
- for child in module.children():
- if len(list(child.children())) == 0:
- children.append(child)
- else:
- children += self._get_all_leaf_children(child)
-
- return children
-
- @staticmethod
- def _get_one_hot_output(output: torch.Tensor,
- target: Union[int, torch.Tensor, np.ndarray] = None) -> torch.Tensor:
- batch_size = output.shape[0]
-
- if target is None:
- target = output.argmax(dim=1)
-
- one_hot_output = torch.zeros_like(output)
- one_hot_output[torch.arange(batch_size), target] = 1
- return one_hot_output
-
- def _get_target(self, target: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Union[int, torch.Tensor]:
- if target is not None:
- return target
-
- return self._model.get_categorical_probabilities(**inputs).argmax(dim=1)
-
- @abstractmethod
- def interpret(self, labels: Union[int, torch.Tensor, np.ndarray], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
- '''
- An abstract method to interpret given input and target class
- Returns a dictionary of interpretation results.
- '''
-
- def dynamic_threshold(self, x: np.ndarray) -> np.ndarray:
- """
- A function to dynamically threshold one sample.
- """
- return (x - (x.mean() + x.std())).clip(min=0)
|