You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lap_resnet.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. from typing import Callable, Dict, Generic, List, Optional, Type, TypeVar
  2. from typing_extensions import Protocol
  3. import warnings
  4. import torch
  5. from torch import nn
  6. from ..modules import LAP, AdaptiveLAP
  7. from .tv_resnet import BasicBlock, Bottleneck, ResNet
  8. from ..interpreting.relcam.relprop import RPProvider, RelProp, RelPropSimple
  9. from ..interpreting.relcam import modules as M
  10. from ..interpreting.attention_interpreter import AttentionInterpretableModel
  11. class PoolFactory(Protocol):
  12. def __call__(self, channels: int, sigmoid_scale: float = 1.0) -> LAP: ...
  13. lap_factory: PoolFactory = \
  14. lambda channels, sigmoid_scale=1.0: \
  15. LAP(channels, sigmoid_scale=sigmoid_scale)
  16. adaptive_lap_factory: PoolFactory = \
  17. lambda channels, sigmoid_scale=1.0: \
  18. AdaptiveLAP(channels, sigmoid_scale=sigmoid_scale)
  19. class LapBasicBlock(BasicBlock):
  20. def __init__(self, pool_factory: PoolFactory, inplanes, planes, stride=1, downsample=None, groups=1,
  21. base_width=64, dilation=1, norm_layer=None, sigmoid_scale: float = 1.0):
  22. super().__init__(inplanes, planes, stride=stride, downsample=downsample,
  23. groups=groups, base_width=base_width, dilation=dilation,
  24. norm_layer=norm_layer)
  25. self.pool = None
  26. if stride != 1:
  27. assert downsample is not None
  28. self.conv1 = nn.Conv2d(inplanes, planes, 3, padding=1, bias=False)
  29. self.downsample[0] = nn.Conv2d(inplanes, planes, 1, bias=False)
  30. self.pool = pool_factory(planes * 2, sigmoid_scale=sigmoid_scale)
  31. self.relu3 = nn.ReLU()
  32. self.cat = M.Cat()
  33. def forward(self, x: torch.Tensor) -> torch.Tensor:
  34. out = self.conv1(x) # B P H W
  35. out = self.bn1(out) # B P H W
  36. out = self.relu(out) # B P H W
  37. if self.downsample is not None:
  38. x = self.downsample(x) # B P H W
  39. x = self.relu2(x)
  40. if self.pool is not None:
  41. poolin = self.cat([out, x], 1) # B 2P H W
  42. poolout: torch.Tensor = self.pool(poolin) # B 2P H/S W/S
  43. out, x = poolout.chunk(2, dim=1) # B P H/S W/S
  44. out = self.conv2(out) # B P H/S W/S
  45. out = self.bn2(out) # B P H/S W/S
  46. out = self.add([out, x]) # B P H/S W/S
  47. out = self.relu3(out) # B P H/S W/S
  48. return out
  49. @RPProvider.register(LapBasicBlock)
  50. class BlockRelProp(RelProp[LapBasicBlock]):
  51. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  52. out = RPProvider.get(self.module.relu3)(R, alpha=alpha)
  53. out, x = RPProvider.get(self.module.add)(out, alpha=alpha)
  54. out = RPProvider.get(self.module.bn2)(out, alpha=alpha)
  55. out = RPProvider.get(self.module.conv2)(out, alpha=alpha)
  56. if self.module.pool is not None:
  57. poolout = torch.cat([out, x], dim=1)
  58. poolin = RPProvider.get(self.module.pool)(poolout, alpha=alpha)
  59. out, x = RPProvider.get(self.module.cat)(poolin, alpha=alpha)
  60. if self.module.downsample is not None:
  61. x = RPProvider.get(self.module.relu2)(x, alpha=alpha)
  62. x = RPProvider.get(self.module.downsample)(x, alpha=alpha)
  63. out = RPProvider.get(self.module.relu)(out, alpha=alpha)
  64. out = RPProvider.get(self.module.bn1)(out, alpha=alpha)
  65. x1 = RPProvider.get(self.module.conv1)(out, alpha=alpha)
  66. return x + x1
  67. class LapBottleneck(Bottleneck):
  68. def __init__(self, pool_factory: PoolFactory,
  69. inplanes: int,
  70. planes: int,
  71. stride: int = 1,
  72. downsample: Optional[nn.Module] = None,
  73. groups: int = 1,
  74. base_width: int = 64,
  75. dilation: int = 1,
  76. norm_layer: Optional[Callable[..., nn.Module]] = None,
  77. sigmoid_scale: float = 1.0) -> None:
  78. super().__init__(inplanes, planes, stride=stride, downsample=downsample,
  79. groups=groups, base_width=base_width, dilation=dilation,
  80. norm_layer=norm_layer)
  81. self.pool = None
  82. if stride != 1:
  83. assert downsample is not None
  84. width = int(planes * (base_width / 64.)) * groups
  85. self.conv2 = nn.Conv2d(width, width, 3, padding=1, bias=False)
  86. self.downsample[0] = nn.Conv2d(inplanes, planes * self.expansion, 1, bias=False)
  87. self.pool = pool_factory(planes * (self.expansion + 1), sigmoid_scale=sigmoid_scale)
  88. self.cat = M.Cat()
  89. def forward(self, x: torch.Tensor) -> torch.Tensor:
  90. # x # B I H W
  91. out = self.conv1(x) # B P H W
  92. out = self.bn1(out) # B P H W
  93. out = self.relu(out) # B P H W
  94. out = self.conv2(out) # B P H W
  95. out = self.bn2(out) # B P H W
  96. out = self.relu2(out) # B P H W
  97. if self.downsample is not None:
  98. x = self.downsample(x) # B 4P H W
  99. if self.pool is not None:
  100. poolin = self.cat([out, x], 1) # B 5P H W
  101. poolout: torch.Tensor = self.pool(poolin) # B 5P H/S W/S
  102. out, x = poolout.split( # B P H/S W/S
  103. [out.shape[1], x.shape[1]], # B 4P H/S W/S
  104. dim=1)
  105. out = self.conv3(out) # B 4P H/S W/S
  106. out = self.bn3(out) # B 4P H/S W/S
  107. out = self.add([out, x]) # B P H/S W/S
  108. out = self.relu3(out) # B P H/S W/S
  109. return out
  110. @RPProvider.register(LapBottleneck)
  111. class BlockRP(RelProp[LapBottleneck]):
  112. def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
  113. out = RPProvider.get(self.module.relu3)(R, alpha=alpha)
  114. out, x = RPProvider.get(self.module.add)(out, alpha=alpha)
  115. out = RPProvider.get(self.module.bn3)(out, alpha=alpha)
  116. out = RPProvider.get(self.module.conv3)(out, alpha=alpha)
  117. if self.module.pool is not None:
  118. poolout = torch.cat([out, x], dim=1)
  119. poolin = RPProvider.get(self.module.pool)(poolout, alpha=alpha)
  120. out, x = RPProvider.get(self.module.cat)(poolin, alpha=alpha)
  121. if self.module.downsample is not None:
  122. x = RPProvider.get(self.module.downsample)(x, alpha=alpha)
  123. out = RPProvider.get(self.module.relu2)(out, alpha=alpha)
  124. out = RPProvider.get(self.module.bn2)(out, alpha=alpha)
  125. out = RPProvider.get(self.module.conv2)(out, alpha=alpha)
  126. out = RPProvider.get(self.module.relu)(out, alpha=alpha)
  127. out = RPProvider.get(self.module.bn1)(out, alpha=alpha)
  128. x1 = RPProvider.get(self.module.conv1)(out, alpha=alpha)
  129. return x + x1
  130. TLapBlock = TypeVar('TLapBlock', LapBasicBlock, LapBottleneck)
  131. class BlockFactory(Generic[TLapBlock]):
  132. def __init__(self, block: Type[TLapBlock], pool_factory: PoolFactory, sigmoid_scale: float = 1.0) -> None:
  133. self.expansion = block.expansion
  134. self._block = block
  135. self._pool_factory = pool_factory
  136. self._sigmoid_scale = sigmoid_scale
  137. def __call__(self, *args, **kwargs) -> TLapBlock:
  138. return self._block(self._pool_factory, *args, **kwargs, sigmoid_scale=self._sigmoid_scale)
  139. class Stride(nn.Module):
  140. def forward(self, x: torch.Tensor) -> torch.Tensor:
  141. return x[:, :, ::2, ::2]
  142. @RPProvider.register(Stride)
  143. class StrideRelProp(RelPropSimple[Stride]):
  144. pass
  145. class LapResNet(ResNet, AttentionInterpretableModel):
  146. def __init__(self,
  147. block: Type[TLapBlock],
  148. layers: List[int],
  149. pool_factory: PoolFactory = lap_factory,
  150. lap_positions: List[int] = [2, 3, 4],
  151. adaptive_factory: PoolFactory = None,
  152. sigmoid_scale: float = 1.0,
  153. binary: bool = False):
  154. super().__init__(BlockFactory(block, pool_factory, sigmoid_scale=sigmoid_scale), layers, binary=binary)
  155. if adaptive_factory is not None:
  156. self.avgpool = nn.Sequential(
  157. adaptive_factory(512, sigmoid_scale=sigmoid_scale),
  158. )
  159. for i in range(2, 5):
  160. if i not in lap_positions:
  161. warnings.warn(f'Putting stride on layer {i}')
  162. getattr(self, 'layer{}'.format(i))[0].pool = Stride()
  163. @property
  164. def attention_layers(self) -> Dict[str, List[LAP]]:
  165. """
  166. List of attention groups
  167. """
  168. attention_groups = {
  169. '0_layer2': [self.layer2[0].pool],
  170. '1_layer3': [self.layer3[0].pool],
  171. '2_layer4': [self.layer4[0].pool],
  172. '3_avgpool': [self.avgpool[0]],
  173. '4_overall': [
  174. self.layer2[0].pool,
  175. self.layer3[0].pool,
  176. self.layer4[0].pool,
  177. self.avgpool[0],
  178. ]
  179. }
  180. assert all(
  181. all(
  182. isinstance(layer, LAP)
  183. for layer in attention_group)
  184. for attention_group in attention_groups.values()), \
  185. "Only LAP is supported for this interpretation method"
  186. return attention_groups
  187. def lap_resnet18(pool_factory: PoolFactory = lap_factory,
  188. adaptive_factory: PoolFactory = None,
  189. sigmoid_scale: float = 1.0,
  190. lap_positions: List[int] = [2, 3, 4],
  191. binary: bool = False) -> LapResNet:
  192. """Constructs a LAP-ResNet-18 model.
  193. """
  194. return LapResNet(LapBasicBlock, [2, 2, 2, 2], pool_factory=pool_factory,
  195. adaptive_factory=adaptive_factory, sigmoid_scale=sigmoid_scale,
  196. lap_positions=lap_positions, binary=binary)
  197. def lap_resnet50(pool_factory: PoolFactory = lap_factory,
  198. adaptive_factory: PoolFactory = None,
  199. sigmoid_scale: float = 1.0,
  200. lap_positions: List[int] = [2, 3, 4],
  201. binary: bool = False) -> LapResNet:
  202. """Constructs a LAP-ResNet-50 model.
  203. """
  204. return LapResNet(LapBottleneck, [3, 4, 6, 3], pool_factory=pool_factory,
  205. adaptive_factory=adaptive_factory, sigmoid_scale=sigmoid_scale,
  206. lap_positions=lap_positions, binary=binary)