from typing import Dict, Union import warnings import numpy as np import torch.nn.functional as F import torch from ..interpreter import Interpreter from ..interpretable import CamInterpretableModel from .relprop import RPProvider class RelCamInterpreter(Interpreter): def __init__(self, model: CamInterpretableModel): super().__init__(model) self.__targets = model.target_conv_layers for name, module in model.named_modules(): if not RPProvider.propable(module): warnings.warn(f"Module {name} of type {type(module)} is not propable! Hope you know what you are doing!") continue RPProvider.create(module) @staticmethod def __normalize(x: torch.Tensor) -> torch.Tensor: fx = x.flatten(1) minx = fx.min(dim=1).values[:, None, None, None] maxx = fx.max(dim=1).values[:, None, None, None] return (x - minx) / (1e-6 + maxx - minx) def interpret(self, labels: Union[int, torch.Tensor, np.ndarray], **inputs: torch.Tensor) -> Dict[str, torch.Tensor]: x_shape = inputs[self._model.ordered_placeholder_names_to_be_interpreted[0]].shape with RPProvider.capture(self.__targets) as Rs: z = self._model.get_categorical_probabilities(**inputs) one_hot_y = self._get_one_hot_output(z, labels) RPProvider.get(self._model)(one_hot_y) result = list(Rs.values()) relcam = torch.stack(result).sum(dim=0)\ if len(result) > 1\ else result[0] relcam = self.__normalize(relcam) relcam = F.interpolate(relcam, size=x_shape[2:], mode='bilinear', align_corners=False) return { 'default': relcam, }