1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- from typing import Dict, Union
- import warnings
-
- import numpy as np
- import torch.nn.functional as F
- import torch
- from ..interpreter import Interpreter
- from ..interpretable import CamInterpretableModel
- from .relprop import RPProvider
-
-
- class RelCamInterpreter(Interpreter):
-
- def __init__(self, model: CamInterpretableModel):
- super().__init__(model)
- self.__targets = model.target_conv_layers
- for name, module in model.named_modules():
- if not RPProvider.propable(module):
- warnings.warn(f"Module {name} of type {type(module)} is not propable! Hope you know what you are doing!")
- continue
- RPProvider.create(module)
-
- @staticmethod
- def __normalize(x: torch.Tensor) -> torch.Tensor:
- fx = x.flatten(1)
- minx = fx.min(dim=1).values[:, None, None, None]
- maxx = fx.max(dim=1).values[:, None, None, None]
- return (x - minx) / (1e-6 + maxx - minx)
-
- def interpret(self, labels: Union[int, torch.Tensor, np.ndarray], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
- x_shape = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape
- with RPProvider.capture(self.__targets) as Rs:
- z = self._model.get_categorical_probabilities(**inputs)
- one_hot_y = self._get_one_hot_output(z, labels)
- RPProvider.get(self._model)(one_hot_y)
- result = list(Rs.values())
- relcam = torch.stack(result).sum(dim=0)\
- if len(result) > 1\
- else result[0]
- relcam = self.__normalize(relcam)
- relcam = F.interpolate(relcam, size=x_shape[2:], mode='bilinear', align_corners=False)
- return {
- 'default': relcam,
- }
|