You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

relprop.py 11KB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. from contextlib import ExitStack, contextmanager
  2. from typing import Callable, Dict, Generic, Iterator, List, Tuple, Type, TypeVar
  3. from uuid import UUID, uuid4
  4. import torch
  5. import torch.nn.functional as F
  6. from torch import nn
  7. import torchvision
  8. from . import modules as M
  9. TModule = TypeVar('TModule', bound=nn.Module)
  10. TType = TypeVar('TType', bound=type)
  11. def safe_divide(a, b):
  12. return a / (b + b.eq(0).type(b.type()) * 1e-9) * b.ne(0).type(b.type())
  13. class classdict(Dict[type, TType], Generic[TType]):
  14. def __nearest_available_resolution(self, __k: Type[TModule]) -> Type[TModule]:
  15. return next((r for r in __k.mro() if dict.__contains__(self, r)), None)
  16. def __contains__(self, __k: Type[TModule]) -> bool:
  17. return self.__nearest_available_resolution(__k) is not None
  18. def __getitem__(self, __k: Type[TModule]) -> TType:
  19. r = self.__nearest_available_resolution(__k)
  20. if r is None:
  21. raise KeyError(f"{__k} not found!")
  22. return super().__getitem__(r)
  23. class RPProvider:
  24. __props: Dict[Type[TModule], Type['RelProp[TModule]']] = classdict()
  25. __instances: Dict[TModule, 'RelProp[TModule]'] = {}
  26. __target_results: Dict[nn.Module, torch.Tensor] = {}
  27. @classmethod
  28. def register(cls, *module_types: Type[TModule]):
  29. def decorator(prop_cls):
  30. cls.__props.update({
  31. module_type: prop_cls
  32. for module_type in module_types
  33. })
  34. return prop_cls
  35. return decorator
  36. @classmethod
  37. def create(cls, module: TModule):
  38. cls.__instances[module] = prop = cls.__props[type(module)](module)
  39. return prop
  40. @classmethod
  41. def __hook(cls, prop: 'RelProp[TModule]', R: torch.Tensor) -> None:
  42. r_weight = torch.mean(R, dim=(2, 3), keepdim=True)
  43. r_cam = prop.X * r_weight
  44. r_cam = torch.sum(r_cam, dim=1, keepdim=True)
  45. cls.__target_results[prop.module] = r_cam
  46. @classmethod
  47. @contextmanager
  48. def capture(cls, target_layers: List[nn.Module]) -> Iterator[Dict[nn.Module, torch.Tensor]]:
  49. with ExitStack() as stack:
  50. cls.__target_results.clear()
  51. [stack.enter_context(prop.hook_module())
  52. for prop in cls.__instances.values()]
  53. [stack.enter_context(cls.get(target).register_hook(cls.__hook))
  54. for target in target_layers]
  55. yield cls.__target_results
  56. @classmethod
  57. def propable(cls, module: TModule) -> bool:
  58. return type(module) in cls.__props
  59. @classmethod
  60. def get(cls, module: TModule) -> 'RelProp[TModule]':
  61. return cls.__instances[module]
  62. @RPProvider.register(nn.Identity, nn.ReLU, nn.LeakyReLU, nn.Dropout, nn.Sigmoid)
  63. class RelProp(Generic[TModule]):
  64. Hook = Callable[['RelProp', torch.Tensor], None]
  65. def __forward_hook(self, _, input: Tuple[torch.Tensor, ...], output: torch.Tensor):
  66. if type(input[0]) in (list, tuple):
  67. self.X = []
  68. for i in input[0]:
  69. x = i.detach()
  70. x.requires_grad = True
  71. self.X.append(x)
  72. else:
  73. self.X = input[0].detach()
  74. self.X.requires_grad = True
  75. self.Y = output
  76. def __init__(self, module: TModule) -> None:
  77. self.module = module
  78. self.__hooks: Dict[UUID, RelProp.Hook] = {}
  79. @contextmanager
  80. def hook_module(self):
  81. handle = self.module.register_forward_hook(self.__forward_hook)
  82. try:
  83. yield
  84. finally:
  85. handle.remove()
  86. @contextmanager
  87. def register_hook(self, hook: 'RelProp.Hook'):
  88. uuid = uuid4()
  89. self.__hooks[uuid] = hook
  90. try:
  91. yield
  92. finally:
  93. self.__hooks.pop(uuid)
  94. def grad(self, Z, X, S):
  95. C = torch.autograd.grad(Z, X, S, retain_graph=True)
  96. return C
  97. def __call__(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  98. R = self.rel(R, alpha=alpha)
  99. [hook(self, R) for hook in self.__hooks.values()]
  100. return R
  101. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  102. return R
  103. @RPProvider.register(nn.MaxPool2d, nn.AvgPool2d, M.Add, torchvision.transforms.transforms.Normalize)
  104. class RelPropSimple(RelProp[TModule]):
  105. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  106. Z = self.module.forward(self.X)
  107. S = safe_divide(R, Z)
  108. C = self.grad(Z, self.X, S)
  109. if torch.is_tensor(self.X) == False:
  110. outputs = []
  111. outputs.append(self.X[0] * C[0])
  112. outputs.append(self.X[1] * C[1])
  113. else:
  114. outputs = self.X * C[0]
  115. return outputs
  116. @RPProvider.register(nn.Flatten)
  117. class RelPropFlatten(RelProp[TModule]):
  118. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  119. return R.reshape_as(self.X)
  120. @RPProvider.register(nn.AdaptiveAvgPool2d)
  121. class AdaptiveAvgPool2dRelProp(RelProp[nn.AdaptiveAvgPool2d]):
  122. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  123. px = torch.clamp(self.X, min=0)
  124. def f(x1):
  125. Z1 = F.adaptive_avg_pool2d(x1, self.module.output_size)
  126. S1 = safe_divide(R, Z1)
  127. C1 = x1 * self.grad(Z1, x1, S1)[0]
  128. return C1
  129. activator_relevances = f(px)
  130. out = activator_relevances
  131. return out
  132. @RPProvider.register(nn.ZeroPad2d)
  133. class ZeroPad2dRelProp(RelProp[nn.ZeroPad2d]):
  134. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  135. Z = self.module.forward(self.X)
  136. S = safe_divide(R, Z)
  137. C = self.grad(Z, self.X, S)
  138. outputs = self.X * C[0]
  139. return outputs
  140. @RPProvider.register(M.Multiply)
  141. class MultiplyRelProp(RelProp[M.Multiply]):
  142. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  143. x0 = torch.clamp(self.X[0], min=0)
  144. x1 = torch.clamp(self.X[1], min=0)
  145. x = [x0, x1]
  146. Z = self.module.forward(x)
  147. S = safe_divide(R, Z)
  148. C = self.grad(Z, x, S)
  149. outputs = []
  150. outputs.append(x[0] * C[0])
  151. outputs.append(x[1] * C[1])
  152. return outputs
  153. @RPProvider.register(M.Cat)
  154. class CatRelProp(RelProp[M.Cat]):
  155. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  156. Z = self.module.forward(self.X, self.module.dim)
  157. S = safe_divide(R, Z)
  158. C = self.grad(Z, self.X, S)
  159. outputs = []
  160. for x, c in zip(self.X, C):
  161. outputs.append(x * c)
  162. return outputs
  163. @RPProvider.register(nn.Sequential)
  164. class SequentialRelProp(RelProp[nn.Sequential]):
  165. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  166. for m in reversed(self.module):
  167. R = RPProvider.get(m)(R, alpha=alpha)
  168. return R
  169. @RPProvider.register(nn.BatchNorm2d)
  170. class BatchNorm2dRelProp(RelProp[nn.BatchNorm2d]):
  171. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  172. X = self.X
  173. beta = 1 - alpha
  174. weight = self.module.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
  175. (self.module.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.module.eps).pow(0.5))
  176. Z = X * weight + 1e-9
  177. S = R / Z
  178. Ca = S * weight
  179. R = self.X * (Ca)
  180. return R
  181. @RPProvider.register(nn.Linear)
  182. class LinearRelProp(RelProp[nn.Linear]):
  183. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  184. beta = alpha - 1
  185. pw = torch.clamp(self.module.weight, min=0)
  186. nw = torch.clamp(self.module.weight, max=0)
  187. px = torch.clamp(self.X, min=0)
  188. nx = torch.clamp(self.X, max=0)
  189. def f(w1, w2, x1, x2):
  190. Z1 = F.linear(x1, w1)
  191. Z2 = F.linear(x2, w2)
  192. Z = Z1 + Z2
  193. S = safe_divide(R, Z)
  194. C1 = x1 * self.grad(Z1, x1, S)[0]
  195. C2 = x2 * self.grad(Z2, x2, S)[0]
  196. return C1 + C2
  197. activator_relevances = f(pw, nw, px, nx)
  198. inhibitor_relevances = f(nw, pw, px, nx)
  199. out = alpha * activator_relevances - beta*inhibitor_relevances
  200. return out
  201. @RPProvider.register(nn.Conv2d)
  202. class Conv2dRelProp(RelProp[nn.Conv2d]):
  203. def gradprop2(self, DY, weight):
  204. Z = self.module.forward(self.X)
  205. output_padding = self.X.size()[2] - (
  206. (Z.size()[2] - 1) * self.module.stride[0] - 2 * self.module.padding[0] + self.module.kernel_size[0])
  207. return F.conv_transpose2d(DY, weight, stride=self.module.stride, padding=self.module.padding, output_padding=output_padding)
  208. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  209. if self.X.shape[1] == 3:
  210. pw = torch.clamp(self.module.weight, min=0)
  211. nw = torch.clamp(self.module.weight, max=0)
  212. X = self.X
  213. L = self.X * 0 + \
  214. torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
  215. keepdim=True)[0]
  216. H = self.X * 0 + \
  217. torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
  218. keepdim=True)[0]
  219. Za = torch.conv2d(X, self.module.weight, bias=None, stride=self.module.stride, padding=self.module.padding) - \
  220. torch.conv2d(L, pw, bias=None, stride=self.module.stride, padding=self.module.padding) - \
  221. torch.conv2d(H, nw, bias=None, stride=self.module.stride,
  222. padding=self.module.padding) + 1e-9
  223. S = R / Za
  224. C = X * self.gradprop2(S, self.module.weight) - L * \
  225. self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
  226. R = C
  227. else:
  228. beta = alpha - 1
  229. pw = torch.clamp(self.module.weight, min=0)
  230. nw = torch.clamp(self.module.weight, max=0)
  231. px = torch.clamp(self.X, min=0)
  232. nx = torch.clamp(self.X, max=0)
  233. def f(w1, w2, x1, x2):
  234. Z1 = F.conv2d(x1, w1, bias=self.module.bias, stride=self.module.stride,
  235. padding=self.module.padding, groups=self.module.groups)
  236. Z2 = F.conv2d(x2, w2, bias=self.module.bias, stride=self.module.stride,
  237. padding=self.module.padding, groups=self.module.groups)
  238. Z = Z1 + Z2
  239. S = safe_divide(R, Z)
  240. C1 = x1 * self.grad(Z1, x1, S)[0]
  241. C2 = x2 * self.grad(Z2, x2, S)[0]
  242. return C1 + C2
  243. activator_relevances = f(pw, nw, px, nx)
  244. inhibitor_relevances = f(nw, pw, px, nx)
  245. R = alpha * activator_relevances - beta * inhibitor_relevances
  246. return R
  247. @RPProvider.register(nn.ConvTranspose2d)
  248. class ConvTranspose2dRelProp(RelProp[nn.ConvTranspose2d]):
  249. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  250. pw = torch.clamp(self.module.weight, min=0)
  251. px = torch.clamp(self.X, min=0)
  252. def f(w1, x1):
  253. Z1 = F.conv_transpose2d(x1, w1, bias=None, stride=self.module.stride,
  254. padding=self.module.padding, output_padding=self.module.output_padding)
  255. S1 = safe_divide(R, Z1)
  256. C1 = x1 * self.grad(Z1, x1, S1)[0]
  257. return C1
  258. activator_relevances = f(pw, px)
  259. R = activator_relevances
  260. return R