You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tv_inception.py 8.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. from typing import Iterable, Optional, Callable, List
  2. import torch
  3. from torch import nn, Tensor
  4. import torchvision
  5. from ..interpreting.interpretable import CamInterpretableModel
  6. from ..interpreting.relcam.relprop import RPProvider, RelProp
  7. from ..interpreting.relcam import modules as M
  8. from .lap_inception import (
  9. InceptionA,
  10. InceptionC,
  11. InceptionE,
  12. BasicConv2d,
  13. InceptionAux as BaseInceptionAux,
  14. )
  15. class InceptionB(nn.Module):
  16. def __init__(
  17. self,
  18. in_channels: int,
  19. conv_block: Optional[Callable[..., nn.Module]] = None
  20. ) -> None:
  21. super(InceptionB, self).__init__()
  22. if conv_block is None:
  23. conv_block = BasicConv2d
  24. self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
  25. self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
  26. self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
  27. self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
  28. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
  29. self.cat = M.Cat()
  30. def _forward(self, x: Tensor) -> List[Tensor]:
  31. branch3x3 = self.branch3x3(x)
  32. branch3x3dbl = self.branch3x3dbl_1(x)
  33. branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
  34. branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
  35. branch_pool = self.maxpool(x)
  36. outputs = [branch3x3, branch3x3dbl, branch_pool]
  37. return outputs
  38. def forward(self, x: Tensor) -> Tensor:
  39. outputs = self._forward(x)
  40. return self.cat(outputs, 1)
  41. @RPProvider.register(InceptionB)
  42. class InceptionBRelProp(RelProp[InceptionB]):
  43. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  44. branch3x3, branch3x3dbl, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha)
  45. x1 = RPProvider.get(self.module.maxpool)(branch_pool, alpha=alpha)
  46. branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_3)(branch3x3dbl, alpha=alpha)
  47. branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha)
  48. x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha)
  49. x3 = RPProvider.get(self.module.branch3x3)(branch3x3, alpha=alpha)
  50. return x1 + x2 + x3
  51. class InceptionD(nn.Module):
  52. def __init__(
  53. self,
  54. in_channels: int,
  55. conv_block: Optional[Callable[..., nn.Module]] = None
  56. ) -> None:
  57. super(InceptionD, self).__init__()
  58. if conv_block is None:
  59. conv_block = BasicConv2d
  60. self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
  61. self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
  62. self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
  63. self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
  64. self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
  65. self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
  66. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
  67. self.cat = M.Cat()
  68. def _forward(self, x: Tensor) -> List[Tensor]:
  69. branch3x3 = self.branch3x3_1(x)
  70. branch3x3 = self.branch3x3_2(branch3x3)
  71. branch7x7x3 = self.branch7x7x3_1(x)
  72. branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
  73. branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
  74. branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
  75. branch_pool = self.maxpool(x)
  76. outputs = [branch3x3, branch7x7x3, branch_pool]
  77. return outputs
  78. def forward(self, x: Tensor) -> Tensor:
  79. outputs = self._forward(x)
  80. return self.cat(outputs, 1)
  81. @RPProvider.register(InceptionD)
  82. class InceptionDRelProp(RelProp[InceptionD]):
  83. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  84. branch3x3, branch7x7x3, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha)
  85. x1 = RPProvider.get(self.module.maxpool)(branch_pool, alpha=alpha)
  86. branch7x7x3 = RPProvider.get(self.module.branch7x7x3_4)(branch7x7x3, alpha=alpha)
  87. branch7x7x3 = RPProvider.get(self.module.branch7x7x3_3)(branch7x7x3, alpha=alpha)
  88. branch7x7x3 = RPProvider.get(self.module.branch7x7x3_2)(branch7x7x3, alpha=alpha)
  89. x2 = RPProvider.get(self.module.branch7x7x3_1)(branch7x7x3, alpha=alpha)
  90. branch3x3 = RPProvider.get(self.module.branch3x3_2)(branch3x3, alpha=alpha)
  91. x3 = RPProvider.get(self.module.branch3x3_1)(branch3x3, alpha=alpha)
  92. return x1 + x2 + x3
  93. class InceptionAux(BaseInceptionAux):
  94. def __init__(
  95. self,
  96. in_channels: int,
  97. num_classes: int,
  98. conv_block: Optional[Callable[..., nn.Module]] = None
  99. ) -> None:
  100. super().__init__(in_channels, num_classes, conv_block=conv_block)
  101. if conv_block is None:
  102. conv_block = BasicConv2d
  103. self.conv1 = conv_block(128, 768, kernel_size=5)
  104. self.conv1.stddev = 0.01 # type: ignore[assignment]
  105. class Inception3(CamInterpretableModel, torchvision.models.Inception3):
  106. def __init__(self, aux_weight: float, n_classes=1):
  107. torchvision.models.Inception3.__init__(self,
  108. transform_input=False, init_weights=False,
  109. inception_blocks = [
  110. BasicConv2d, InceptionA, InceptionB, InceptionC,
  111. InceptionD, InceptionE, InceptionAux
  112. ])
  113. self.fc = nn.Sequential(
  114. nn.Linear(2048, n_classes),
  115. nn.Sigmoid()
  116. )
  117. self.AuxLogits.fc = nn.Sequential(
  118. nn.Linear(768, n_classes),
  119. nn.Sigmoid()
  120. )
  121. self.aux_weight = aux_weight
  122. @property
  123. def target_conv_layers(self) -> List[nn.Module]:
  124. return [
  125. self.Mixed_7c.branch1x1,
  126. self.Mixed_7c.branch3x3_2a, self.Mixed_7c.branch3x3_2b,
  127. self.Mixed_7c.branch3x3dbl_3a, self.Mixed_7c.branch3x3dbl_3b,
  128. self.Mixed_7c.branch_pool
  129. ]
  130. @property
  131. def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]:
  132. return ['x']
  133. def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor:
  134. p = self.forward(*inputs, **kwargs)['positive_class_probability']
  135. return torch.stack([1 - p, p], dim=1)
  136. @RPProvider.register(Inception3)
  137. class Inception3Mo4RelProp(RelProp[Inception3]):
  138. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  139. if RPProvider.get(self.module.fc).Y.shape[1] == 1:
  140. R = R[:, -1:]
  141. R = RPProvider.get(self.module.fc)(R, alpha=alpha) # B 2048
  142. R = R.reshape_as(RPProvider.get(self.module.dropout).Y) # B 2048 1 1
  143. R = RPProvider.get(self.module.dropout)(R, alpha=alpha) # B 2048 1 1
  144. R = RPProvider.get(self.module.avgpool)(R, alpha=alpha) # B 2048 8 8
  145. R = RPProvider.get(self.module.Mixed_7c)(R, alpha=alpha) # B 2048 8 8
  146. R = RPProvider.get(self.module.Mixed_7b)(R, alpha=alpha) # B 1280 8 8
  147. R = RPProvider.get(self.module.Mixed_7a)(R, alpha=alpha) # B 768 17 17
  148. R = RPProvider.get(self.module.Mixed_6e)(R, alpha=alpha) # B 768 17 17
  149. R = RPProvider.get(self.module.Mixed_6d)(R, alpha=alpha) # B 768 17 17
  150. R = RPProvider.get(self.module.Mixed_6c)(R, alpha=alpha) # B 768 17 17
  151. R = RPProvider.get(self.module.Mixed_6b)(R, alpha=alpha) # B 768 17 17
  152. R = RPProvider.get(self.module.Mixed_6a)(R, alpha=alpha) # B 288 35 35
  153. R = RPProvider.get(self.module.Mixed_5d)(R, alpha=alpha) # B 288 35 35
  154. R = RPProvider.get(self.module.Mixed_5c)(R, alpha=alpha) # B 256 35 35
  155. R = RPProvider.get(self.module.Mixed_5b)(R, alpha=alpha) # B 192 35 35
  156. R = RPProvider.get(self.module.maxpool2)(R, alpha=alpha) # B 192 71 71
  157. R = RPProvider.get(self.module.Conv2d_4a_3x3)(R, alpha=alpha) # B 80 73 73
  158. R = RPProvider.get(self.module.Conv2d_3b_1x1)(R, alpha=alpha) # B 64 73 73
  159. R = RPProvider.get(self.module.maxpool1)(R, alpha=alpha) # B 64 147 147
  160. R = RPProvider.get(self.module.Conv2d_2b_3x3)(R, alpha=alpha) # B 32 147 147
  161. R = RPProvider.get(self.module.Conv2d_2a_3x3)(R, alpha=alpha) # B 32 149 149
  162. R = RPProvider.get(self.module.Conv2d_1a_3x3)(R, alpha=alpha) # B 3 299 299
  163. return R