123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- """
- An adapter wrapper to adapt our models to Captum interpreters
- """
- import inspect
- from typing import Iterable, Dict, Tuple
- from dataclasses import dataclass
- from itertools import chain
-
- import torch
- from torch import nn
-
- from . import InterpretableModel
-
-
- @dataclass
- class InterpreterArgs:
- inputs: Tuple[torch.Tensor, ...]
- additional_inputs: Tuple[torch.Tensor, ...]
-
- class InterpretableWrapper(nn.Module):
- """
- An adapter wrapper to adapt our models to Captum interpreters
- """
-
- def __init__(self, model: InterpretableModel):
- super().__init__()
-
- self._model = model
-
- @property
- def _additional_names(self) -> Iterable[str]:
- to_be_interpreted_names = self._model.ordered_placeholder_names_to_be_interpreted
-
- signature = inspect.signature(self._model.forward)
-
- return [name for name in signature.parameters
- if name not in to_be_interpreted_names
- and signature.parameters[name].default is not None]
-
- def convert_inputs_to_kwargs(self, *args: torch.Tensor) -> Dict[str, torch.Tensor]:
- """
- Converts an ordered *args to **kwargs
- """
- to_be_interpreted_names = self._model.ordered_placeholder_names_to_be_interpreted
- additional_names = self._additional_names
-
- inputs = {}
- for i, name in enumerate(chain(to_be_interpreted_names, additional_names)):
- inputs[name] = args[i]
-
- return inputs
-
- def convert_inputs_to_args(self, **kwargs: torch.Tensor) -> InterpreterArgs:
- """
- Converts a **kwargs to ordered *args
- """
- return InterpreterArgs(
- tuple(kwargs[name] for name in self._model.ordered_placeholder_names_to_be_interpreted
- if name in kwargs),
- tuple(kwargs[name] for name in self._additional_names
- if name in kwargs)
- )
-
- def forward(self, *args: torch.Tensor) -> torch.Tensor:
- """
- Forwards the model
- """
- inputs = self.convert_inputs_to_kwargs(*args)
- return self._model.get_categorical_probabilities(**inputs)
|