You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

interpreter.py 2.1KB

1 year ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. """
  2. An abstract class to apply different interpretation algorithms on torch models.
  3. """
  4. from abc import abstractmethod, ABC
  5. from typing import Dict, List, Union
  6. import torch
  7. from torch import nn
  8. import numpy as np
  9. from . import InterpretableModel
  10. class Interpreter(ABC): # pylint: disable=too-few-public-methods
  11. """
  12. An abstract class to apply different interpretation algorithms on torch models.
  13. """
  14. def __init__(self, model: InterpretableModel):
  15. self._model = model
  16. def _get_all_leaf_children(self, module: nn.Module = None) -> List[nn.Module]:
  17. """
  18. Gets all leaf modules recursively
  19. """
  20. if module is None:
  21. module = self._model
  22. children: List[nn.Module] = []
  23. for child in module.children():
  24. if len(list(child.children())) == 0:
  25. children.append(child)
  26. else:
  27. children += self._get_all_leaf_children(child)
  28. return children
  29. @staticmethod
  30. def _get_one_hot_output(output: torch.Tensor,
  31. target: Union[int, torch.Tensor, np.ndarray] = None) -> torch.Tensor:
  32. batch_size = output.shape[0]
  33. if target is None:
  34. target = output.argmax(dim=1)
  35. one_hot_output = torch.zeros_like(output)
  36. one_hot_output[torch.arange(batch_size), target] = 1
  37. return one_hot_output
  38. def _get_target(self, target: Union[int, torch.Tensor], **inputs: torch.Tensor) -> Union[int, torch.Tensor]:
  39. if target is not None:
  40. return target
  41. return self._model.get_categorical_probabilities(**inputs).argmax(dim=1)
  42. @abstractmethod
  43. def interpret(self, labels: Union[int, torch.Tensor, np.ndarray], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
  44. '''
  45. An abstract method to interpret given input and target class
  46. Returns a dictionary of interpretation results.
  47. '''
  48. def dynamic_threshold(self, x: np.ndarray) -> np.ndarray:
  49. """
  50. A function to dynamically threshold one sample.
  51. """
  52. return (x - (x.mean() + x.std())).clip(min=0)