123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276 |
- from typing import Callable, Dict, Generic, List, Optional, Type, TypeVar
- from typing_extensions import Protocol
- import warnings
-
- import torch
- from torch import nn
-
- from ..modules import LAP, AdaptiveLAP
- from .tv_resnet import BasicBlock, Bottleneck, ResNet
- from ..interpreting.relcam.relprop import RPProvider, RelProp, RelPropSimple
- from ..interpreting.relcam import modules as M
- from ..interpreting.attention_interpreter import AttentionInterpretableModel
-
-
- class PoolFactory(Protocol):
- def __call__(self, channels: int, sigmoid_scale: float = 1.0) -> LAP: ...
-
-
- lap_factory: PoolFactory = \
- lambda channels, sigmoid_scale=1.0: \
- LAP(channels, sigmoid_scale=sigmoid_scale)
-
-
- adaptive_lap_factory: PoolFactory = \
- lambda channels, sigmoid_scale=1.0: \
- AdaptiveLAP(channels, sigmoid_scale=sigmoid_scale)
-
-
- class LapBasicBlock(BasicBlock):
-
- def __init__(self, pool_factory: PoolFactory, inplanes, planes, stride=1, downsample=None, groups=1,
- base_width=64, dilation=1, norm_layer=None, sigmoid_scale: float = 1.0):
-
- super().__init__(inplanes, planes, stride=stride, downsample=downsample,
- groups=groups, base_width=base_width, dilation=dilation,
- norm_layer=norm_layer)
-
- self.pool = None
- if stride != 1:
- assert downsample is not None
-
- self.conv1 = nn.Conv2d(inplanes, planes, 3, padding=1, bias=False)
- self.downsample[0] = nn.Conv2d(inplanes, planes, 1, bias=False)
- self.pool = pool_factory(planes * 2, sigmoid_scale=sigmoid_scale)
- self.relu3 = nn.ReLU()
- self.cat = M.Cat()
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- out = self.conv1(x) # B P H W
- out = self.bn1(out) # B P H W
- out = self.relu(out) # B P H W
-
- if self.downsample is not None:
- x = self.downsample(x) # B P H W
- x = self.relu2(x)
-
- if self.pool is not None:
- poolin = self.cat([out, x], 1) # B 2P H W
- poolout: torch.Tensor = self.pool(poolin) # B 2P H/S W/S
- out, x = poolout.chunk(2, dim=1) # B P H/S W/S
-
- out = self.conv2(out) # B P H/S W/S
- out = self.bn2(out) # B P H/S W/S
-
- out = self.add([out, x]) # B P H/S W/S
- out = self.relu3(out) # B P H/S W/S
-
- return out
-
-
- @RPProvider.register(LapBasicBlock)
- class BlockRelProp(RelProp[LapBasicBlock]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- out = RPProvider.get(self.module.relu3)(R, alpha=alpha)
- out, x = RPProvider.get(self.module.add)(out, alpha=alpha)
-
- out = RPProvider.get(self.module.bn2)(out, alpha=alpha)
- out = RPProvider.get(self.module.conv2)(out, alpha=alpha)
-
- if self.module.pool is not None:
- poolout = torch.cat([out, x], dim=1)
- poolin = RPProvider.get(self.module.pool)(poolout, alpha=alpha)
- out, x = RPProvider.get(self.module.cat)(poolin, alpha=alpha)
-
- if self.module.downsample is not None:
- x = RPProvider.get(self.module.relu2)(x, alpha=alpha)
- x = RPProvider.get(self.module.downsample)(x, alpha=alpha)
-
- out = RPProvider.get(self.module.relu)(out, alpha=alpha)
- out = RPProvider.get(self.module.bn1)(out, alpha=alpha)
- x1 = RPProvider.get(self.module.conv1)(out, alpha=alpha)
-
- return x + x1
-
-
- class LapBottleneck(Bottleneck):
-
- def __init__(self, pool_factory: PoolFactory,
- inplanes: int,
- planes: int,
- stride: int = 1,
- downsample: Optional[nn.Module] = None,
- groups: int = 1,
- base_width: int = 64,
- dilation: int = 1,
- norm_layer: Optional[Callable[..., nn.Module]] = None,
- sigmoid_scale: float = 1.0) -> None:
- super().__init__(inplanes, planes, stride=stride, downsample=downsample,
- groups=groups, base_width=base_width, dilation=dilation,
- norm_layer=norm_layer)
-
- self.pool = None
- if stride != 1:
- assert downsample is not None
- width = int(planes * (base_width / 64.)) * groups
- self.conv2 = nn.Conv2d(width, width, 3, padding=1, bias=False)
- self.downsample[0] = nn.Conv2d(inplanes, planes * self.expansion, 1, bias=False)
- self.pool = pool_factory(planes * (self.expansion + 1), sigmoid_scale=sigmoid_scale)
- self.cat = M.Cat()
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- # x # B I H W
-
- out = self.conv1(x) # B P H W
- out = self.bn1(out) # B P H W
- out = self.relu(out) # B P H W
-
- out = self.conv2(out) # B P H W
- out = self.bn2(out) # B P H W
- out = self.relu2(out) # B P H W
-
- if self.downsample is not None:
- x = self.downsample(x) # B 4P H W
-
- if self.pool is not None:
- poolin = self.cat([out, x], 1) # B 5P H W
- poolout: torch.Tensor = self.pool(poolin) # B 5P H/S W/S
- out, x = poolout.split( # B P H/S W/S
- [out.shape[1], x.shape[1]], # B 4P H/S W/S
- dim=1)
-
- out = self.conv3(out) # B 4P H/S W/S
- out = self.bn3(out) # B 4P H/S W/S
-
- out = self.add([out, x]) # B P H/S W/S
- out = self.relu3(out) # B P H/S W/S
-
- return out
-
-
- @RPProvider.register(LapBottleneck)
- class BlockRP(RelProp[LapBottleneck]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- out = RPProvider.get(self.module.relu3)(R, alpha=alpha)
- out, x = RPProvider.get(self.module.add)(out, alpha=alpha)
-
- out = RPProvider.get(self.module.bn3)(out, alpha=alpha)
- out = RPProvider.get(self.module.conv3)(out, alpha=alpha)
-
- if self.module.pool is not None:
- poolout = torch.cat([out, x], dim=1)
- poolin = RPProvider.get(self.module.pool)(poolout, alpha=alpha)
- out, x = RPProvider.get(self.module.cat)(poolin, alpha=alpha)
-
- if self.module.downsample is not None:
- x = RPProvider.get(self.module.downsample)(x, alpha=alpha)
-
- out = RPProvider.get(self.module.relu2)(out, alpha=alpha)
- out = RPProvider.get(self.module.bn2)(out, alpha=alpha)
- out = RPProvider.get(self.module.conv2)(out, alpha=alpha)
-
- out = RPProvider.get(self.module.relu)(out, alpha=alpha)
- out = RPProvider.get(self.module.bn1)(out, alpha=alpha)
- x1 = RPProvider.get(self.module.conv1)(out, alpha=alpha)
-
- return x + x1
-
-
- TLapBlock = TypeVar('TLapBlock', LapBasicBlock, LapBottleneck)
-
-
- class BlockFactory(Generic[TLapBlock]):
-
- def __init__(self, block: Type[TLapBlock], pool_factory: PoolFactory, sigmoid_scale: float = 1.0) -> None:
- self.expansion = block.expansion
- self._block = block
- self._pool_factory = pool_factory
- self._sigmoid_scale = sigmoid_scale
-
- def __call__(self, *args, **kwargs) -> TLapBlock:
- return self._block(self._pool_factory, *args, **kwargs, sigmoid_scale=self._sigmoid_scale)
-
-
- class Stride(nn.Module):
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return x[:, :, ::2, ::2]
-
-
- @RPProvider.register(Stride)
- class StrideRelProp(RelPropSimple[Stride]):
- pass
-
-
- class LapResNet(ResNet, AttentionInterpretableModel):
-
- def __init__(self,
- block: Type[TLapBlock],
- layers: List[int],
- pool_factory: PoolFactory = lap_factory,
- lap_positions: List[int] = [2, 3, 4],
- adaptive_factory: PoolFactory = None,
- sigmoid_scale: float = 1.0,
- binary: bool = False):
- super().__init__(BlockFactory(block, pool_factory, sigmoid_scale=sigmoid_scale), layers, binary=binary)
- if adaptive_factory is not None:
- self.avgpool = nn.Sequential(
- adaptive_factory(512, sigmoid_scale=sigmoid_scale),
- )
- for i in range(2, 5):
- if i not in lap_positions:
- warnings.warn(f'Putting stride on layer {i}')
- getattr(self, 'layer{}'.format(i))[0].pool = Stride()
-
- @property
- def attention_layers(self) -> Dict[str, List[LAP]]:
- """
- List of attention groups
- """
- attention_groups = {
- '0_layer2': [self.layer2[0].pool],
- '1_layer3': [self.layer3[0].pool],
- '2_layer4': [self.layer4[0].pool],
- '3_avgpool': [self.avgpool[0]],
- '4_overall': [
- self.layer2[0].pool,
- self.layer3[0].pool,
- self.layer4[0].pool,
- self.avgpool[0],
- ]
- }
-
- assert all(
- all(
- isinstance(layer, LAP)
- for layer in attention_group)
- for attention_group in attention_groups.values()), \
- "Only LAP is supported for this interpretation method"
-
- return attention_groups
-
-
- def lap_resnet18(pool_factory: PoolFactory = lap_factory,
- adaptive_factory: PoolFactory = None,
- sigmoid_scale: float = 1.0,
- lap_positions: List[int] = [2, 3, 4],
- binary: bool = False) -> LapResNet:
- """Constructs a LAP-ResNet-18 model.
- """
- return LapResNet(LapBasicBlock, [2, 2, 2, 2], pool_factory=pool_factory,
- adaptive_factory=adaptive_factory, sigmoid_scale=sigmoid_scale,
- lap_positions=lap_positions, binary=binary)
-
-
- def lap_resnet50(pool_factory: PoolFactory = lap_factory,
- adaptive_factory: PoolFactory = None,
- sigmoid_scale: float = 1.0,
- lap_positions: List[int] = [2, 3, 4],
- binary: bool = False) -> LapResNet:
- """Constructs a LAP-ResNet-50 model.
- """
- return LapResNet(LapBottleneck, [3, 4, 6, 3], pool_factory=pool_factory,
- adaptive_factory=adaptive_factory, sigmoid_scale=sigmoid_scale,
- lap_positions=lap_positions, binary=binary)
|