""" An adapter wrapper to adapt our models to Captum interpreters """ import inspect from typing import Iterable, Dict, Tuple from dataclasses import dataclass from itertools import chain import torch from torch import nn from . import InterpretableModel @dataclass class InterpreterArgs: inputs: Tuple[torch.Tensor, ...] additional_inputs: Tuple[torch.Tensor, ...] class InterpretableWrapper(nn.Module): """ An adapter wrapper to adapt our models to Captum interpreters """ def __init__(self, model: InterpretableModel): super().__init__() self._model = model @property def _additional_names(self) -> Iterable[str]: to_be_interpreted_names = self._model.ordered_placeholder_names_to_be_interpreted signature = inspect.signature(self._model.forward) return [name for name in signature.parameters if name not in to_be_interpreted_names and signature.parameters[name].default is not None] def convert_inputs_to_kwargs(self, *args: torch.Tensor) -> Dict[str, torch.Tensor]: """ Converts an ordered *args to **kwargs """ to_be_interpreted_names = self._model.ordered_placeholder_names_to_be_interpreted additional_names = self._additional_names inputs = {} for i, name in enumerate(chain(to_be_interpreted_names, additional_names)): inputs[name] = args[i] return inputs def convert_inputs_to_args(self, **kwargs: torch.Tensor) -> InterpreterArgs: """ Converts a **kwargs to ordered *args """ return InterpreterArgs( tuple(kwargs[name] for name in self._model.ordered_placeholder_names_to_be_interpreted if name in kwargs), tuple(kwargs[name] for name in self._additional_names if name in kwargs) ) def forward(self, *args: torch.Tensor) -> torch.Tensor: """ Forwards the model """ inputs = self.convert_inputs_to_kwargs(*args) return self._model.get_categorical_probabilities(**inputs)