You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lap_inception.py 21KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. from typing import List, Optional, Callable, Any, Iterable, Dict
  2. import torch
  3. from torch import Tensor, nn
  4. import torchvision
  5. from ..modules.lap import LAP
  6. from ..interpreting.attention_interpreter import AttentionInterpretableModel
  7. from ..interpreting.interpretable import CamInterpretableModel
  8. from ..interpreting.relcam.relprop import RPProvider, RelProp
  9. from ..interpreting.relcam import modules as M
  10. class InceptionB(nn.Module):
  11. def __init__(
  12. self,
  13. in_channels: int,
  14. pool_factory,
  15. conv_block: Optional[Callable[..., nn.Module]] = None
  16. ) -> None:
  17. super(InceptionB, self).__init__()
  18. if conv_block is None:
  19. conv_block = BasicConv2d
  20. #self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
  21. self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=1, padding=1)
  22. self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
  23. self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
  24. #self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
  25. self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=1, padding=1)
  26. self.pool = pool_factory(384 + 96 + in_channels)
  27. self.cat = M.Cat()
  28. def _forward(self, x: Tensor) -> List[Tensor]:
  29. branch3x3 = self.branch3x3(x)
  30. branch3x3dbl = self.branch3x3dbl_1(x)
  31. branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
  32. branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
  33. outputs = [branch3x3, branch3x3dbl, x]
  34. return outputs
  35. def forward(self, x: Tensor) -> Tensor:
  36. outputs = self._forward(x)
  37. return self.pool(self.cat(outputs, 1))
  38. @RPProvider.register(InceptionB)
  39. class InceptionBRelProp(RelProp[InceptionB]):
  40. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  41. R = RPProvider.get(self.module.pool)(R, alpha=alpha)
  42. branch3x3, branch3x3dbl, x1 = RPProvider.get(self.module.cat)(R, alpha=alpha)
  43. branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_3)(branch3x3dbl, alpha=alpha)
  44. branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha)
  45. x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha)
  46. x3 = RPProvider.get(self.module.branch3x3)(branch3x3, alpha=alpha)
  47. return x1 + x2 + x3
  48. class InceptionD(nn.Module):
  49. def __init__(
  50. self,
  51. in_channels: int,
  52. pool_factory,
  53. conv_block: Optional[Callable[..., nn.Module]] = None
  54. ) -> None:
  55. super(InceptionD, self).__init__()
  56. if conv_block is None:
  57. conv_block = BasicConv2d
  58. self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
  59. #self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
  60. self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=1, padding=1)
  61. #self.branch3x3_2_stride = get_pooler(pooler_cls, 320, 2, pooler_hidden_layers)
  62. self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
  63. self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
  64. self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
  65. #self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
  66. self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=1, padding=1)
  67. #self.branch7x7x3_4_stride = get_pooler(pooler_cls, 192, 2, pooler_hidden_layers)
  68. self.pool = pool_factory(320 + 192 + in_channels)
  69. self.cat = M.Cat()
  70. def _forward(self, x: Tensor) -> List[Tensor]:
  71. branch3x3 = self.branch3x3_1(x)
  72. branch3x3 = self.branch3x3_2(branch3x3)
  73. branch7x7x3 = self.branch7x7x3_1(x)
  74. branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
  75. branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
  76. branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
  77. outputs = [branch3x3, branch7x7x3, x]
  78. return outputs
  79. def forward(self, x: Tensor) -> Tensor:
  80. outputs = self._forward(x)
  81. return self.pool(self.cat(outputs, 1))
  82. @RPProvider.register(InceptionD)
  83. class InceptionDRelProp(RelProp[InceptionD]):
  84. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  85. R = RPProvider.get(self.module.pool)(R, alpha=alpha)
  86. branch3x3, branch7x7x3, x1 = RPProvider.get(self.module.cat)(R, alpha=alpha)
  87. branch7x7x3 = RPProvider.get(self.module.branch7x7x3_4)(branch7x7x3, alpha=alpha)
  88. branch7x7x3 = RPProvider.get(self.module.branch7x7x3_3)(branch7x7x3, alpha=alpha)
  89. branch7x7x3 = RPProvider.get(self.module.branch7x7x3_2)(branch7x7x3, alpha=alpha)
  90. x2 = RPProvider.get(self.module.branch7x7x3_1)(branch7x7x3, alpha=alpha)
  91. branch3x3 = RPProvider.get(self.module.branch3x3_2)(branch3x3, alpha=alpha)
  92. x3 = RPProvider.get(self.module.branch3x3_1)(branch3x3, alpha=alpha)
  93. return x1 + x2 + x3
  94. def inception_b_maker(pool_factory):
  95. return (
  96. lambda in_channels, conv_block=None:
  97. InceptionB(in_channels, pool_factory, conv_block))
  98. def inception_d_maker(pool_factory):
  99. return (
  100. lambda in_channels, conv_block=None:
  101. InceptionD(in_channels, pool_factory, conv_block))
  102. class InceptionA(nn.Module):
  103. def __init__(
  104. self,
  105. in_channels: int,
  106. pool_features: int,
  107. conv_block: Optional[Callable[..., nn.Module]] = None,
  108. ) -> None:
  109. super(InceptionA, self).__init__()
  110. if conv_block is None:
  111. conv_block = BasicConv2d
  112. self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
  113. self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
  114. self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
  115. self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
  116. self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
  117. self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
  118. self.avg_pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
  119. self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
  120. self.cat = M.Cat()
  121. def _forward(self, x: Tensor) -> List[Tensor]:
  122. branch1x1 = self.branch1x1(x)
  123. branch5x5 = self.branch5x5_1(x)
  124. branch5x5 = self.branch5x5_2(branch5x5)
  125. branch3x3dbl = self.branch3x3dbl_1(x)
  126. branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
  127. branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
  128. branch_pool = self.avg_pool(x)
  129. branch_pool = self.branch_pool(branch_pool)
  130. outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
  131. return outputs
  132. def forward(self, x: Tensor) -> Tensor:
  133. outputs = self._forward(x)
  134. return self.cat(outputs, 1)
  135. @RPProvider.register(InceptionA)
  136. class InceptionARelProp(RelProp[InceptionA]):
  137. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  138. branch1x1, branch5x5, branch3x3dbl, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha)
  139. branch_pool = RPProvider.get(self.module.branch_pool)(branch_pool, alpha=alpha)
  140. x1 = RPProvider.get(self.module.avg_pool)(branch_pool, alpha=alpha)
  141. branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_3)(branch3x3dbl, alpha=alpha)
  142. branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha)
  143. x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha)
  144. branch5x5 = RPProvider.get(self.module.branch5x5_2)(branch5x5, alpha=alpha)
  145. x3 = RPProvider.get(self.module.branch5x5_1)(branch5x5, alpha=alpha)
  146. x4 = RPProvider.get(self.module.branch1x1)(branch1x1, alpha=alpha)
  147. return x1 + x2 + x3 + x4
  148. class InceptionC(nn.Module):
  149. def __init__(
  150. self,
  151. in_channels: int,
  152. channels_7x7: int,
  153. conv_block: Optional[Callable[..., nn.Module]] = None
  154. ) -> None:
  155. super(InceptionC, self).__init__()
  156. if conv_block is None:
  157. conv_block = BasicConv2d
  158. self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
  159. c7 = channels_7x7
  160. self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
  161. self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
  162. self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
  163. self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
  164. self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
  165. self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
  166. self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
  167. self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
  168. self.avg_pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
  169. self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
  170. self.cat = M.Cat()
  171. def _forward(self, x: Tensor) -> List[Tensor]:
  172. branch1x1 = self.branch1x1(x)
  173. branch7x7 = self.branch7x7_1(x)
  174. branch7x7 = self.branch7x7_2(branch7x7)
  175. branch7x7 = self.branch7x7_3(branch7x7)
  176. branch7x7dbl = self.branch7x7dbl_1(x)
  177. branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
  178. branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
  179. branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
  180. branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
  181. branch_pool = self.avg_pool(x)
  182. branch_pool = self.branch_pool(branch_pool)
  183. outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
  184. return outputs
  185. def forward(self, x: Tensor) -> Tensor:
  186. outputs = self._forward(x)
  187. return self.cat(outputs, 1)
  188. @RPProvider.register(InceptionC)
  189. class InceptionCRelProp(RelProp[InceptionC]):
  190. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  191. branch1x1, branch7x7, branch7x7dbl, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha)
  192. branch_pool = RPProvider.get(self.module.branch_pool)(branch_pool, alpha=alpha)
  193. x1 = RPProvider.get(self.module.avg_pool)(branch_pool, alpha=alpha)
  194. branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_5)(branch7x7dbl, alpha=alpha)
  195. branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_4)(branch7x7dbl, alpha=alpha)
  196. branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_3)(branch7x7dbl, alpha=alpha)
  197. branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_2)(branch7x7dbl, alpha=alpha)
  198. x2 = RPProvider.get(self.module.branch7x7dbl_1)(branch7x7dbl, alpha=alpha)
  199. branch7x7 = RPProvider.get(self.module.branch7x7_3)(branch7x7, alpha=alpha)
  200. branch7x7 = RPProvider.get(self.module.branch7x7_2)(branch7x7, alpha=alpha)
  201. x3 = RPProvider.get(self.module.branch7x7_1)(branch7x7, alpha=alpha)
  202. x4 = RPProvider.get(self.module.branch1x1)(branch1x1, alpha=alpha)
  203. return x1 + x2 + x3 + x4
  204. class InceptionE(nn.Module):
  205. def __init__(
  206. self,
  207. in_channels: int,
  208. conv_block: Optional[Callable[..., nn.Module]] = None
  209. ) -> None:
  210. super(InceptionE, self).__init__()
  211. if conv_block is None:
  212. conv_block = BasicConv2d
  213. self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
  214. self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
  215. self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
  216. self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
  217. self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
  218. self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
  219. self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
  220. self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
  221. self.avg_pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
  222. self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
  223. self.cat1 = M.Cat()
  224. self.cat2 = M.Cat()
  225. self.cat3 = M.Cat()
  226. def _forward(self, x: Tensor) -> List[Tensor]:
  227. branch1x1 = self.branch1x1(x)
  228. branch3x3 = self.branch3x3_1(x)
  229. branch3x3 = [
  230. self.branch3x3_2a(branch3x3),
  231. self.branch3x3_2b(branch3x3),
  232. ]
  233. branch3x3 = self.cat1(branch3x3, 1)
  234. branch3x3dbl = self.branch3x3dbl_1(x)
  235. branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
  236. branch3x3dbl = [
  237. self.branch3x3dbl_3a(branch3x3dbl),
  238. self.branch3x3dbl_3b(branch3x3dbl),
  239. ]
  240. branch3x3dbl = self.cat2(branch3x3dbl, 1)
  241. branch_pool = self.avg_pool(x)
  242. branch_pool = self.branch_pool(branch_pool)
  243. outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
  244. return outputs
  245. def forward(self, x: Tensor) -> Tensor:
  246. outputs = self._forward(x)
  247. return self.cat3(outputs, 1)
  248. @RPProvider.register(InceptionE)
  249. class InceptionERelProp(RelProp[InceptionE]):
  250. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  251. branch1x1, branch3x3, branch3x3dbl, branch_pool = RPProvider.get(self.module.cat3)(R, alpha=alpha)
  252. branch_pool = RPProvider.get(self.module.branch_pool)(branch_pool, alpha=alpha)
  253. x1 = RPProvider.get(self.module.avg_pool)(branch_pool, alpha=alpha)
  254. branch3x3dbl_3a, branch3x3dbl_3b = RPProvider.get(self.module.cat2)(branch3x3dbl, alpha=alpha)
  255. branch3x3dbl_1 = RPProvider.get(self.module.branch3x3dbl_3a)(branch3x3dbl_3a, alpha=alpha)
  256. branch3x3dbl_2 = RPProvider.get(self.module.branch3x3dbl_3b)(branch3x3dbl_3b, alpha=alpha)
  257. branch3x3dbl = branch3x3dbl_1 + branch3x3dbl_2
  258. branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha)
  259. x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha)
  260. branch3x3_2a, branch3x3_2b = RPProvider.get(self.module.cat1)(branch3x3, alpha=alpha)
  261. branch3x3_1 = RPProvider.get(self.module.branch3x3_2a)(branch3x3_2a, alpha=alpha)
  262. branch3x3_2 = RPProvider.get(self.module.branch3x3_2b)(branch3x3_2b, alpha=alpha)
  263. branch3x3 = branch3x3_1 + branch3x3_2
  264. x3 = RPProvider.get(self.module.branch3x3_1)(branch3x3, alpha=alpha)
  265. x4 = RPProvider.get(self.module.branch1x1)(branch1x1, alpha=alpha)
  266. return x1 + x2 + x3 + x4
  267. class InceptionAux(nn.Module):
  268. def __init__(
  269. self,
  270. in_channels: int,
  271. num_classes: int,
  272. conv_block: Optional[Callable[..., nn.Module]] = None
  273. ) -> None:
  274. super(InceptionAux, self).__init__()
  275. if conv_block is None:
  276. conv_block = BasicConv2d
  277. self.avgpool1 = nn.AvgPool2d(kernel_size=5, stride=3)
  278. self.conv0 = conv_block(in_channels, 128, kernel_size=1)
  279. self.conv1 = conv_block(128, 768, kernel_size=5, padding=2)
  280. self.conv1.stddev = 0.01 # type: ignore[assignment]
  281. self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1))
  282. self.flatten = nn.Flatten()
  283. self.fc = nn.Linear(768, num_classes)
  284. self.fc.stddev = 0.001 # type: ignore[assignment]
  285. def forward(self, x: Tensor) -> Tensor:
  286. # N x 768 x 17 x 17
  287. x = self.avgpool1(x)
  288. # N x 768 x 5 x 5
  289. x = self.conv0(x)
  290. # N x 128 x 5 x 5
  291. x = self.conv1(x)
  292. # N x 768 x 1 x 1
  293. # Adaptive average pooling
  294. x = self.avgpool2(x)
  295. # N x 768 x 1 x 1
  296. x = self.flatten(x)
  297. # N x 768
  298. x = self.fc(x)
  299. # N x 1000
  300. return x
  301. @RPProvider.register(InceptionAux)
  302. class InceptionAuxRelProp(RelProp[InceptionAux]):
  303. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  304. R = RPProvider.get(self.module.fc)(R, alpha=alpha)
  305. R = RPProvider.get(self.module.flatten)(R, alpha=alpha)
  306. R = RPProvider.get(self.module.avgpool2)(R, alpha=alpha)
  307. R = RPProvider.get(self.module.conv1)(R, alpha=alpha)
  308. R = RPProvider.get(self.module.conv0)(R, alpha=alpha)
  309. R = RPProvider.get(self.module.avgpool1)(R, alpha=alpha)
  310. return R
  311. class BasicConv2d(nn.Module):
  312. def __init__(
  313. self,
  314. in_channels: int,
  315. out_channels: int,
  316. **kwargs: Any
  317. ) -> None:
  318. super(BasicConv2d, self).__init__()
  319. self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
  320. self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
  321. self.act = torch.nn.ReLU()
  322. def forward(self, x: Tensor) -> Tensor:
  323. x = self.conv(x)
  324. x = self.bn(x)
  325. return self.act(x)
  326. @RPProvider.register(BasicConv2d)
  327. class BasicConv2dRelProp(RelProp[BasicConv2d]):
  328. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  329. R = RPProvider.get(self.module.act)(R, alpha=alpha)
  330. R = RPProvider.get(self.module.bn)(R, alpha=alpha)
  331. R = RPProvider.get(self.module.conv)(R, alpha=alpha)
  332. return R
  333. class LAPInception(AttentionInterpretableModel, CamInterpretableModel, torchvision.models.Inception3):
  334. def __init__(self, aux_weight: float, n_classes, pool_factory, adaptive_pool_factory):
  335. torchvision.models.Inception3.__init__(
  336. self,
  337. transform_input=False, init_weights=False,
  338. inception_blocks=[
  339. BasicConv2d, InceptionA, inception_b_maker(pool_factory),
  340. InceptionC, inception_d_maker(pool_factory), InceptionE, InceptionAux
  341. ])
  342. self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2, padding=1)
  343. self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3, padding=1)
  344. self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
  345. self.maxpool1 = pool_factory(64)
  346. self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1,)
  347. self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3, padding=1)
  348. self.maxpool2 = pool_factory(192)
  349. if adaptive_pool_factory is not None:
  350. self.avgpool = nn.Sequential(
  351. adaptive_pool_factory(2048))
  352. self.fc = nn.Sequential(
  353. nn.Linear(2048, 1),
  354. nn.Sigmoid()
  355. )
  356. self.AuxLogits.fc = nn.Sequential(
  357. nn.Linear(768, 1),
  358. nn.Sigmoid()
  359. )
  360. self.aux_weight = aux_weight
  361. @property
  362. def target_conv_layers(self) -> List[nn.Module]:
  363. return [
  364. self.Mixed_7c.branch1x1,
  365. self.Mixed_7c.branch3x3_2a, self.Mixed_7c.branch3x3_2b,
  366. self.Mixed_7c.branch3x3dbl_3a, self.Mixed_7c.branch3x3dbl_3b,
  367. self.Mixed_7c.branch_pool
  368. ]
  369. @property
  370. def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]:
  371. return ['x']
  372. def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor:
  373. p = self.forward(*inputs, **kwargs)['positive_class_probability']
  374. return torch.stack([1 - p, p], dim=1)
  375. @property
  376. def attention_layers(self) -> Dict[str, List[LAP]]:
  377. attention_groups = {
  378. '0_layer2': [self.maxpool2],
  379. '1_layer6': [self.Mixed_6a.pool],
  380. '2_layer7': [self.Mixed_7a.pool],
  381. '3_avgpool': [self.avgpool[0]],
  382. '4_all': [
  383. self.maxpool2,
  384. self.Mixed_6a.pool,
  385. self.Mixed_7a.pool,
  386. self.avgpool[0]],
  387. }
  388. return attention_groups
  389. @RPProvider.register(LAPInception)
  390. class Inception3Mo4RelProp(RelProp[LAPInception]):
  391. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  392. if RPProvider.get(self.module.fc).Y.shape[1] == 1:
  393. R = R[:, -1:]
  394. R = RPProvider.get(self.module.fc)(R, alpha=alpha) # B 2048
  395. R = R.reshape_as(RPProvider.get(self.module.dropout).Y) # B 2048 1 1
  396. R = RPProvider.get(self.module.dropout)(R, alpha=alpha) # B 2048 1 1
  397. R = RPProvider.get(self.module.avgpool)(R, alpha=alpha) # B 2048 8 8
  398. R = RPProvider.get(self.module.Mixed_7c)(R, alpha=alpha) # B 2048 8 8
  399. R = RPProvider.get(self.module.Mixed_7b)(R, alpha=alpha) # B 1280 8 8
  400. R = RPProvider.get(self.module.Mixed_7a)(R, alpha=alpha) # B 768 17 17
  401. R = RPProvider.get(self.module.Mixed_6e)(R, alpha=alpha) # B 768 17 17
  402. R = RPProvider.get(self.module.Mixed_6d)(R, alpha=alpha) # B 768 17 17
  403. R = RPProvider.get(self.module.Mixed_6c)(R, alpha=alpha) # B 768 17 17
  404. R = RPProvider.get(self.module.Mixed_6b)(R, alpha=alpha) # B 768 17 17
  405. R = RPProvider.get(self.module.Mixed_6a)(R, alpha=alpha) # B 288 35 35
  406. R = RPProvider.get(self.module.Mixed_5d)(R, alpha=alpha) # B 288 35 35
  407. R = RPProvider.get(self.module.Mixed_5c)(R, alpha=alpha) # B 256 35 35
  408. R = RPProvider.get(self.module.Mixed_5b)(R, alpha=alpha) # B 192 35 35
  409. R = RPProvider.get(self.module.maxpool2)(R, alpha=alpha) # B 192 71 71
  410. R = RPProvider.get(self.module.Conv2d_4a_3x3)(R, alpha=alpha) # B 80 73 73
  411. R = RPProvider.get(self.module.Conv2d_3b_1x1)(R, alpha=alpha) # B 64 73 73
  412. R = RPProvider.get(self.module.maxpool1)(R, alpha=alpha) # B 64 147 147
  413. R = RPProvider.get(self.module.Conv2d_2b_3x3)(R, alpha=alpha) # B 32 147 147
  414. R = RPProvider.get(self.module.Conv2d_2a_3x3)(R, alpha=alpha) # B 32 149 149
  415. R = RPProvider.get(self.module.Conv2d_1a_3x3)(R, alpha=alpha) # B 3 299 299
  416. return R