You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

interpretable_wrapper.py 2.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. """
  2. An adapter wrapper to adapt our models to Captum interpreters
  3. """
  4. import inspect
  5. from typing import Iterable, Dict, Tuple
  6. from dataclasses import dataclass
  7. from itertools import chain
  8. import torch
  9. from torch import nn
  10. from . import InterpretableModel
  11. @dataclass
  12. class InterpreterArgs:
  13. inputs: Tuple[torch.Tensor, ...]
  14. additional_inputs: Tuple[torch.Tensor, ...]
  15. class InterpretableWrapper(nn.Module):
  16. """
  17. An adapter wrapper to adapt our models to Captum interpreters
  18. """
  19. def __init__(self, model: InterpretableModel):
  20. super().__init__()
  21. self._model = model
  22. @property
  23. def _additional_names(self) -> Iterable[str]:
  24. to_be_interpreted_names = self._model.ordered_placeholder_names_to_be_interpreted
  25. signature = inspect.signature(self._model.forward)
  26. return [name for name in signature.parameters
  27. if name not in to_be_interpreted_names
  28. and signature.parameters[name].default is not None]
  29. def convert_inputs_to_kwargs(self, *args: torch.Tensor) -> Dict[str, torch.Tensor]:
  30. """
  31. Converts an ordered *args to **kwargs
  32. """
  33. to_be_interpreted_names = self._model.ordered_placeholder_names_to_be_interpreted
  34. additional_names = self._additional_names
  35. inputs = {}
  36. for i, name in enumerate(chain(to_be_interpreted_names, additional_names)):
  37. inputs[name] = args[i]
  38. return inputs
  39. def convert_inputs_to_args(self, **kwargs: torch.Tensor) -> InterpreterArgs:
  40. """
  41. Converts a **kwargs to ordered *args
  42. """
  43. return InterpreterArgs(
  44. tuple(kwargs[name] for name in self._model.ordered_placeholder_names_to_be_interpreted
  45. if name in kwargs),
  46. tuple(kwargs[name] for name in self._additional_names
  47. if name in kwargs)
  48. )
  49. def forward(self, *args: torch.Tensor) -> torch.Tensor:
  50. """
  51. Forwards the model
  52. """
  53. inputs = self.convert_inputs_to_kwargs(*args)
  54. return self._model.get_categorical_probabilities(**inputs)