123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339 |
- from contextlib import ExitStack, contextmanager
- from typing import Callable, Dict, Generic, Iterator, List, Tuple, Type, TypeVar
- from uuid import UUID, uuid4
-
- import torch
- import torch.nn.functional as F
- from torch import nn
- import torchvision
-
- from . import modules as M
-
-
- TModule = TypeVar('TModule', bound=nn.Module)
- TType = TypeVar('TType', bound=type)
-
-
- def safe_divide(a, b):
- return a / (b + b.eq(0).type(b.type()) * 1e-9) * b.ne(0).type(b.type())
-
-
- class classdict(Dict[type, TType], Generic[TType]):
-
- def __nearest_available_resolution(self, __k: Type[TModule]) -> Type[TModule]:
- return next((r for r in __k.mro() if dict.__contains__(self, r)), None)
-
- def __contains__(self, __k: Type[TModule]) -> bool:
- return self.__nearest_available_resolution(__k) is not None
-
- def __getitem__(self, __k: Type[TModule]) -> TType:
- r = self.__nearest_available_resolution(__k)
- if r is None:
- raise KeyError(f"{__k} not found!")
- return super().__getitem__(r)
-
-
-
- class RPProvider:
- __props: Dict[Type[TModule], Type['RelProp[TModule]']] = classdict()
- __instances: Dict[TModule, 'RelProp[TModule]'] = {}
- __target_results: Dict[nn.Module, torch.Tensor] = {}
-
- @classmethod
- def register(cls, *module_types: Type[TModule]):
- def decorator(prop_cls):
- cls.__props.update({
- module_type: prop_cls
- for module_type in module_types
- })
- return prop_cls
- return decorator
-
- @classmethod
- def create(cls, module: TModule):
- cls.__instances[module] = prop = cls.__props[type(module)](module)
- return prop
-
- @classmethod
- def __hook(cls, prop: 'RelProp[TModule]', R: torch.Tensor) -> None:
- r_weight = torch.mean(R, dim=(2, 3), keepdim=True)
- r_cam = prop.X * r_weight
- r_cam = torch.sum(r_cam, dim=1, keepdim=True)
- cls.__target_results[prop.module] = r_cam
-
- @classmethod
- @contextmanager
- def capture(cls, target_layers: List[nn.Module]) -> Iterator[Dict[nn.Module, torch.Tensor]]:
- with ExitStack() as stack:
- cls.__target_results.clear()
- [stack.enter_context(prop.hook_module())
- for prop in cls.__instances.values()]
- [stack.enter_context(cls.get(target).register_hook(cls.__hook))
- for target in target_layers]
- yield cls.__target_results
-
- @classmethod
- def propable(cls, module: TModule) -> bool:
- return type(module) in cls.__props
-
- @classmethod
- def get(cls, module: TModule) -> 'RelProp[TModule]':
- return cls.__instances[module]
-
-
- @RPProvider.register(nn.Identity, nn.ReLU, nn.LeakyReLU, nn.Dropout, nn.Sigmoid)
- class RelProp(Generic[TModule]):
- Hook = Callable[['RelProp', torch.Tensor], None]
-
- def __forward_hook(self, _, input: Tuple[torch.Tensor, ...], output: torch.Tensor):
- if type(input[0]) in (list, tuple):
- self.X = []
- for i in input[0]:
- x = i.detach()
- x.requires_grad = True
- self.X.append(x)
- else:
- self.X = input[0].detach()
- self.X.requires_grad = True
- self.Y = output
-
- def __init__(self, module: TModule) -> None:
- self.module = module
- self.__hooks: Dict[UUID, RelProp.Hook] = {}
-
- @contextmanager
- def hook_module(self):
- handle = self.module.register_forward_hook(self.__forward_hook)
- try:
- yield
- finally:
- handle.remove()
-
- @contextmanager
- def register_hook(self, hook: 'RelProp.Hook'):
- uuid = uuid4()
- self.__hooks[uuid] = hook
- try:
- yield
- finally:
- self.__hooks.pop(uuid)
-
- def grad(self, Z, X, S):
- C = torch.autograd.grad(Z, X, S, retain_graph=True)
- return C
-
- def __call__(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- R = self.rel(R, alpha=alpha)
- [hook(self, R) for hook in self.__hooks.values()]
- return R
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- return R
-
-
- @RPProvider.register(nn.MaxPool2d, nn.AvgPool2d, M.Add, torchvision.transforms.transforms.Normalize)
- class RelPropSimple(RelProp[TModule]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- Z = self.module.forward(self.X)
- S = safe_divide(R, Z)
- C = self.grad(Z, self.X, S)
-
- if torch.is_tensor(self.X) == False:
- outputs = []
- outputs.append(self.X[0] * C[0])
- outputs.append(self.X[1] * C[1])
- else:
- outputs = self.X * C[0]
- return outputs
-
-
- @RPProvider.register(nn.Flatten)
- class RelPropFlatten(RelProp[TModule]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- return R.reshape_as(self.X)
-
-
- @RPProvider.register(nn.AdaptiveAvgPool2d)
- class AdaptiveAvgPool2dRelProp(RelProp[nn.AdaptiveAvgPool2d]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- px = torch.clamp(self.X, min=0)
-
- def f(x1):
- Z1 = F.adaptive_avg_pool2d(x1, self.module.output_size)
- S1 = safe_divide(R, Z1)
- C1 = x1 * self.grad(Z1, x1, S1)[0]
- return C1
-
- activator_relevances = f(px)
- out = activator_relevances
- return out
-
-
- @RPProvider.register(nn.ZeroPad2d)
- class ZeroPad2dRelProp(RelProp[nn.ZeroPad2d]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- Z = self.module.forward(self.X)
- S = safe_divide(R, Z)
- C = self.grad(Z, self.X, S)
- outputs = self.X * C[0]
- return outputs
-
-
- @RPProvider.register(M.Multiply)
- class MultiplyRelProp(RelProp[M.Multiply]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- x0 = torch.clamp(self.X[0], min=0)
- x1 = torch.clamp(self.X[1], min=0)
- x = [x0, x1]
- Z = self.module.forward(x)
- S = safe_divide(R, Z)
- C = self.grad(Z, x, S)
-
- outputs = []
- outputs.append(x[0] * C[0])
- outputs.append(x[1] * C[1])
-
- return outputs
-
-
- @RPProvider.register(M.Cat)
- class CatRelProp(RelProp[M.Cat]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- Z = self.module.forward(self.X, self.module.dim)
- S = safe_divide(R, Z)
- C = self.grad(Z, self.X, S)
-
- outputs = []
- for x, c in zip(self.X, C):
- outputs.append(x * c)
-
- return outputs
-
-
- @RPProvider.register(nn.Sequential)
- class SequentialRelProp(RelProp[nn.Sequential]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- for m in reversed(self.module):
- R = RPProvider.get(m)(R, alpha=alpha)
- return R
-
-
- @RPProvider.register(nn.BatchNorm2d)
- class BatchNorm2dRelProp(RelProp[nn.BatchNorm2d]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- X = self.X
- beta = 1 - alpha
- weight = self.module.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
- (self.module.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.module.eps).pow(0.5))
- Z = X * weight + 1e-9
- S = R / Z
- Ca = S * weight
- R = self.X * (Ca)
- return R
-
-
- @RPProvider.register(nn.Linear)
- class LinearRelProp(RelProp[nn.Linear]):
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- beta = alpha - 1
- pw = torch.clamp(self.module.weight, min=0)
- nw = torch.clamp(self.module.weight, max=0)
- px = torch.clamp(self.X, min=0)
- nx = torch.clamp(self.X, max=0)
-
- def f(w1, w2, x1, x2):
- Z1 = F.linear(x1, w1)
- Z2 = F.linear(x2, w2)
- Z = Z1 + Z2
- S = safe_divide(R, Z)
- C1 = x1 * self.grad(Z1, x1, S)[0]
- C2 = x2 * self.grad(Z2, x2, S)[0]
- return C1 + C2
-
- activator_relevances = f(pw, nw, px, nx)
- inhibitor_relevances = f(nw, pw, px, nx)
- out = alpha * activator_relevances - beta*inhibitor_relevances
- return out
-
-
- @RPProvider.register(nn.Conv2d)
- class Conv2dRelProp(RelProp[nn.Conv2d]):
-
- def gradprop2(self, DY, weight):
- Z = self.module.forward(self.X)
-
- output_padding = self.X.size()[2] - (
- (Z.size()[2] - 1) * self.module.stride[0] - 2 * self.module.padding[0] + self.module.kernel_size[0])
-
- return F.conv_transpose2d(DY, weight, stride=self.module.stride, padding=self.module.padding, output_padding=output_padding)
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- if self.X.shape[1] == 3:
- pw = torch.clamp(self.module.weight, min=0)
- nw = torch.clamp(self.module.weight, max=0)
- X = self.X
- L = self.X * 0 + \
- torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
- keepdim=True)[0]
- H = self.X * 0 + \
- torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
- keepdim=True)[0]
- Za = torch.conv2d(X, self.module.weight, bias=None, stride=self.module.stride, padding=self.module.padding) - \
- torch.conv2d(L, pw, bias=None, stride=self.module.stride, padding=self.module.padding) - \
- torch.conv2d(H, nw, bias=None, stride=self.module.stride,
- padding=self.module.padding) + 1e-9
-
- S = R / Za
- C = X * self.gradprop2(S, self.module.weight) - L * \
- self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
- R = C
- else:
- beta = alpha - 1
- pw = torch.clamp(self.module.weight, min=0)
- nw = torch.clamp(self.module.weight, max=0)
- px = torch.clamp(self.X, min=0)
- nx = torch.clamp(self.X, max=0)
-
- def f(w1, w2, x1, x2):
- Z1 = F.conv2d(x1, w1, bias=self.module.bias, stride=self.module.stride,
- padding=self.module.padding, groups=self.module.groups)
- Z2 = F.conv2d(x2, w2, bias=self.module.bias, stride=self.module.stride,
- padding=self.module.padding, groups=self.module.groups)
- Z = Z1 + Z2
- S = safe_divide(R, Z)
- C1 = x1 * self.grad(Z1, x1, S)[0]
- C2 = x2 * self.grad(Z2, x2, S)[0]
- return C1 + C2
-
- activator_relevances = f(pw, nw, px, nx)
- inhibitor_relevances = f(nw, pw, px, nx)
-
- R = alpha * activator_relevances - beta * inhibitor_relevances
- return R
-
-
- @RPProvider.register(nn.ConvTranspose2d)
- class ConvTranspose2dRelProp(RelProp[nn.ConvTranspose2d]):
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
-
- pw = torch.clamp(self.module.weight, min=0)
- px = torch.clamp(self.X, min=0)
-
- def f(w1, x1):
- Z1 = F.conv_transpose2d(x1, w1, bias=None, stride=self.module.stride,
- padding=self.module.padding, output_padding=self.module.output_padding)
- S1 = safe_divide(R, Z1)
- C1 = x1 * self.grad(Z1, x1, S1)[0]
- return C1
-
- activator_relevances = f(pw, px)
- R = activator_relevances
- return R
|