from contextlib import ExitStack, contextmanager from typing import Callable, Dict, Generic, Iterator, List, Tuple, Type, TypeVar from uuid import UUID, uuid4 import torch import torch.nn.functional as F from torch import nn import torchvision from . import modules as M TModule = TypeVar('TModule', bound=nn.Module) TType = TypeVar('TType', bound=type) def safe_divide(a, b): return a / (b + b.eq(0).type(b.type()) * 1e-9) * b.ne(0).type(b.type()) class classdict(Dict[type, TType], Generic[TType]): def __nearest_available_resolution(self, __k: Type[TModule]) -> Type[TModule]: return next((r for r in __k.mro() if dict.__contains__(self, r)), None) def __contains__(self, __k: Type[TModule]) -> bool: return self.__nearest_available_resolution(__k) is not None def __getitem__(self, __k: Type[TModule]) -> TType: r = self.__nearest_available_resolution(__k) if r is None: raise KeyError(f"{__k} not found!") return super().__getitem__(r) class RPProvider: __props: Dict[Type[TModule], Type['RelProp[TModule]']] = classdict() __instances: Dict[TModule, 'RelProp[TModule]'] = {} __target_results: Dict[nn.Module, torch.Tensor] = {} @classmethod def register(cls, *module_types: Type[TModule]): def decorator(prop_cls): cls.__props.update({ module_type: prop_cls for module_type in module_types }) return prop_cls return decorator @classmethod def create(cls, module: TModule): cls.__instances[module] = prop = cls.__props[type(module)](module) return prop @classmethod def __hook(cls, prop: 'RelProp[TModule]', R: torch.Tensor) -> None: r_weight = torch.mean(R, dim=(2, 3), keepdim=True) r_cam = prop.X * r_weight r_cam = torch.sum(r_cam, dim=1, keepdim=True) cls.__target_results[prop.module] = r_cam @classmethod @contextmanager def capture(cls, target_layers: List[nn.Module]) -> Iterator[Dict[nn.Module, torch.Tensor]]: with ExitStack() as stack: cls.__target_results.clear() [stack.enter_context(prop.hook_module()) for prop in cls.__instances.values()] [stack.enter_context(cls.get(target).register_hook(cls.__hook)) for target in target_layers] yield cls.__target_results @classmethod def propable(cls, module: TModule) -> bool: return type(module) in cls.__props @classmethod def get(cls, module: TModule) -> 'RelProp[TModule]': return cls.__instances[module] @RPProvider.register(nn.Identity, nn.ReLU, nn.LeakyReLU, nn.Dropout, nn.Sigmoid) class RelProp(Generic[TModule]): Hook = Callable[['RelProp', torch.Tensor], None] def __forward_hook(self, _, input: Tuple[torch.Tensor, ...], output: torch.Tensor): if type(input[0]) in (list, tuple): self.X = [] for i in input[0]: x = i.detach() x.requires_grad = True self.X.append(x) else: self.X = input[0].detach() self.X.requires_grad = True self.Y = output def __init__(self, module: TModule) -> None: self.module = module self.__hooks: Dict[UUID, RelProp.Hook] = {} @contextmanager def hook_module(self): handle = self.module.register_forward_hook(self.__forward_hook) try: yield finally: handle.remove() @contextmanager def register_hook(self, hook: 'RelProp.Hook'): uuid = uuid4() self.__hooks[uuid] = hook try: yield finally: self.__hooks.pop(uuid) def grad(self, Z, X, S): C = torch.autograd.grad(Z, X, S, retain_graph=True) return C def __call__(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: R = self.rel(R, alpha=alpha) [hook(self, R) for hook in self.__hooks.values()] return R def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: return R @RPProvider.register(nn.MaxPool2d, nn.AvgPool2d, M.Add, torchvision.transforms.transforms.Normalize) class RelPropSimple(RelProp[TModule]): def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: Z = self.module.forward(self.X) S = safe_divide(R, Z) C = self.grad(Z, self.X, S) if torch.is_tensor(self.X) == False: outputs = [] outputs.append(self.X[0] * C[0]) outputs.append(self.X[1] * C[1]) else: outputs = self.X * C[0] return outputs @RPProvider.register(nn.Flatten) class RelPropFlatten(RelProp[TModule]): def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: return R.reshape_as(self.X) @RPProvider.register(nn.AdaptiveAvgPool2d) class AdaptiveAvgPool2dRelProp(RelProp[nn.AdaptiveAvgPool2d]): def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: px = torch.clamp(self.X, min=0) def f(x1): Z1 = F.adaptive_avg_pool2d(x1, self.module.output_size) S1 = safe_divide(R, Z1) C1 = x1 * self.grad(Z1, x1, S1)[0] return C1 activator_relevances = f(px) out = activator_relevances return out @RPProvider.register(nn.ZeroPad2d) class ZeroPad2dRelProp(RelProp[nn.ZeroPad2d]): def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: Z = self.module.forward(self.X) S = safe_divide(R, Z) C = self.grad(Z, self.X, S) outputs = self.X * C[0] return outputs @RPProvider.register(M.Multiply) class MultiplyRelProp(RelProp[M.Multiply]): def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: x0 = torch.clamp(self.X[0], min=0) x1 = torch.clamp(self.X[1], min=0) x = [x0, x1] Z = self.module.forward(x) S = safe_divide(R, Z) C = self.grad(Z, x, S) outputs = [] outputs.append(x[0] * C[0]) outputs.append(x[1] * C[1]) return outputs @RPProvider.register(M.Cat) class CatRelProp(RelProp[M.Cat]): def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: Z = self.module.forward(self.X, self.module.dim) S = safe_divide(R, Z) C = self.grad(Z, self.X, S) outputs = [] for x, c in zip(self.X, C): outputs.append(x * c) return outputs @RPProvider.register(nn.Sequential) class SequentialRelProp(RelProp[nn.Sequential]): def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: for m in reversed(self.module): R = RPProvider.get(m)(R, alpha=alpha) return R @RPProvider.register(nn.BatchNorm2d) class BatchNorm2dRelProp(RelProp[nn.BatchNorm2d]): def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: X = self.X beta = 1 - alpha weight = self.module.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / ( (self.module.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.module.eps).pow(0.5)) Z = X * weight + 1e-9 S = R / Z Ca = S * weight R = self.X * (Ca) return R @RPProvider.register(nn.Linear) class LinearRelProp(RelProp[nn.Linear]): def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: beta = alpha - 1 pw = torch.clamp(self.module.weight, min=0) nw = torch.clamp(self.module.weight, max=0) px = torch.clamp(self.X, min=0) nx = torch.clamp(self.X, max=0) def f(w1, w2, x1, x2): Z1 = F.linear(x1, w1) Z2 = F.linear(x2, w2) Z = Z1 + Z2 S = safe_divide(R, Z) C1 = x1 * self.grad(Z1, x1, S)[0] C2 = x2 * self.grad(Z2, x2, S)[0] return C1 + C2 activator_relevances = f(pw, nw, px, nx) inhibitor_relevances = f(nw, pw, px, nx) out = alpha * activator_relevances - beta*inhibitor_relevances return out @RPProvider.register(nn.Conv2d) class Conv2dRelProp(RelProp[nn.Conv2d]): def gradprop2(self, DY, weight): Z = self.module.forward(self.X) output_padding = self.X.size()[2] - ( (Z.size()[2] - 1) * self.module.stride[0] - 2 * self.module.padding[0] + self.module.kernel_size[0]) return F.conv_transpose2d(DY, weight, stride=self.module.stride, padding=self.module.padding, output_padding=output_padding) def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: if self.X.shape[1] == 3: pw = torch.clamp(self.module.weight, min=0) nw = torch.clamp(self.module.weight, max=0) X = self.X L = self.X * 0 + \ torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, keepdim=True)[0] H = self.X * 0 + \ torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, keepdim=True)[0] Za = torch.conv2d(X, self.module.weight, bias=None, stride=self.module.stride, padding=self.module.padding) - \ torch.conv2d(L, pw, bias=None, stride=self.module.stride, padding=self.module.padding) - \ torch.conv2d(H, nw, bias=None, stride=self.module.stride, padding=self.module.padding) + 1e-9 S = R / Za C = X * self.gradprop2(S, self.module.weight) - L * \ self.gradprop2(S, pw) - H * self.gradprop2(S, nw) R = C else: beta = alpha - 1 pw = torch.clamp(self.module.weight, min=0) nw = torch.clamp(self.module.weight, max=0) px = torch.clamp(self.X, min=0) nx = torch.clamp(self.X, max=0) def f(w1, w2, x1, x2): Z1 = F.conv2d(x1, w1, bias=self.module.bias, stride=self.module.stride, padding=self.module.padding, groups=self.module.groups) Z2 = F.conv2d(x2, w2, bias=self.module.bias, stride=self.module.stride, padding=self.module.padding, groups=self.module.groups) Z = Z1 + Z2 S = safe_divide(R, Z) C1 = x1 * self.grad(Z1, x1, S)[0] C2 = x2 * self.grad(Z2, x2, S)[0] return C1 + C2 activator_relevances = f(pw, nw, px, nx) inhibitor_relevances = f(nw, pw, px, nx) R = alpha * activator_relevances - beta * inhibitor_relevances return R @RPProvider.register(nn.ConvTranspose2d) class ConvTranspose2dRelProp(RelProp[nn.ConvTranspose2d]): def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: pw = torch.clamp(self.module.weight, min=0) px = torch.clamp(self.X, min=0) def f(w1, x1): Z1 = F.conv_transpose2d(x1, w1, bias=None, stride=self.module.stride, padding=self.module.padding, output_padding=self.module.output_padding) S1 = safe_divide(R, Z1) C1 = x1 * self.grad(Z1, x1, S1)[0] return C1 activator_relevances = f(pw, px) R = activator_relevances return R