from typing import Callable, Dict, Generic, List, Optional, Type, TypeVar from typing_extensions import Protocol import warnings import torch from torch import nn from ..modules import LAP, AdaptiveLAP from .tv_resnet import BasicBlock, Bottleneck, ResNet from ..interpreting.relcam.relprop import RPProvider, RelProp, RelPropSimple from ..interpreting.relcam import modules as M from ..interpreting.attention_interpreter import AttentionInterpretableModel class PoolFactory(Protocol): def __call__(self, channels: int, sigmoid_scale: float = 1.0) -> LAP: ... lap_factory: PoolFactory = \ lambda channels, sigmoid_scale=1.0: \ LAP(channels, sigmoid_scale=sigmoid_scale) adaptive_lap_factory: PoolFactory = \ lambda channels, sigmoid_scale=1.0: \ AdaptiveLAP(channels, sigmoid_scale=sigmoid_scale) class LapBasicBlock(BasicBlock): def __init__(self, pool_factory: PoolFactory, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None, sigmoid_scale: float = 1.0): super().__init__(inplanes, planes, stride=stride, downsample=downsample, groups=groups, base_width=base_width, dilation=dilation, norm_layer=norm_layer) self.pool = None if stride != 1: assert downsample is not None self.conv1 = nn.Conv2d(inplanes, planes, 3, padding=1, bias=False) self.downsample[0] = nn.Conv2d(inplanes, planes, 1, bias=False) self.pool = pool_factory(planes * 2, sigmoid_scale=sigmoid_scale) self.relu3 = nn.ReLU() self.cat = M.Cat() def forward(self, x: torch.Tensor) -> torch.Tensor: out = self.conv1(x) # B P H W out = self.bn1(out) # B P H W out = self.relu(out) # B P H W if self.downsample is not None: x = self.downsample(x) # B P H W x = self.relu2(x) if self.pool is not None: poolin = self.cat([out, x], 1) # B 2P H W poolout: torch.Tensor = self.pool(poolin) # B 2P H/S W/S out, x = poolout.chunk(2, dim=1) # B P H/S W/S out = self.conv2(out) # B P H/S W/S out = self.bn2(out) # B P H/S W/S out = self.add([out, x]) # B P H/S W/S out = self.relu3(out) # B P H/S W/S return out @RPProvider.register(LapBasicBlock) class BlockRelProp(RelProp[LapBasicBlock]): def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: out = RPProvider.get(self.module.relu3)(R, alpha=alpha) out, x = RPProvider.get(self.module.add)(out, alpha=alpha) out = RPProvider.get(self.module.bn2)(out, alpha=alpha) out = RPProvider.get(self.module.conv2)(out, alpha=alpha) if self.module.pool is not None: poolout = torch.cat([out, x], dim=1) poolin = RPProvider.get(self.module.pool)(poolout, alpha=alpha) out, x = RPProvider.get(self.module.cat)(poolin, alpha=alpha) if self.module.downsample is not None: x = RPProvider.get(self.module.relu2)(x, alpha=alpha) x = RPProvider.get(self.module.downsample)(x, alpha=alpha) out = RPProvider.get(self.module.relu)(out, alpha=alpha) out = RPProvider.get(self.module.bn1)(out, alpha=alpha) x1 = RPProvider.get(self.module.conv1)(out, alpha=alpha) return x + x1 class LapBottleneck(Bottleneck): def __init__(self, pool_factory: PoolFactory, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, sigmoid_scale: float = 1.0) -> None: super().__init__(inplanes, planes, stride=stride, downsample=downsample, groups=groups, base_width=base_width, dilation=dilation, norm_layer=norm_layer) self.pool = None if stride != 1: assert downsample is not None width = int(planes * (base_width / 64.)) * groups self.conv2 = nn.Conv2d(width, width, 3, padding=1, bias=False) self.downsample[0] = nn.Conv2d(inplanes, planes * self.expansion, 1, bias=False) self.pool = pool_factory(planes * (self.expansion + 1), sigmoid_scale=sigmoid_scale) self.cat = M.Cat() def forward(self, x: torch.Tensor) -> torch.Tensor: # x # B I H W out = self.conv1(x) # B P H W out = self.bn1(out) # B P H W out = self.relu(out) # B P H W out = self.conv2(out) # B P H W out = self.bn2(out) # B P H W out = self.relu2(out) # B P H W if self.downsample is not None: x = self.downsample(x) # B 4P H W if self.pool is not None: poolin = self.cat([out, x], 1) # B 5P H W poolout: torch.Tensor = self.pool(poolin) # B 5P H/S W/S out, x = poolout.split( # B P H/S W/S [out.shape[1], x.shape[1]], # B 4P H/S W/S dim=1) out = self.conv3(out) # B 4P H/S W/S out = self.bn3(out) # B 4P H/S W/S out = self.add([out, x]) # B P H/S W/S out = self.relu3(out) # B P H/S W/S return out @RPProvider.register(LapBottleneck) class BlockRP(RelProp[LapBottleneck]): def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor: out = RPProvider.get(self.module.relu3)(R, alpha=alpha) out, x = RPProvider.get(self.module.add)(out, alpha=alpha) out = RPProvider.get(self.module.bn3)(out, alpha=alpha) out = RPProvider.get(self.module.conv3)(out, alpha=alpha) if self.module.pool is not None: poolout = torch.cat([out, x], dim=1) poolin = RPProvider.get(self.module.pool)(poolout, alpha=alpha) out, x = RPProvider.get(self.module.cat)(poolin, alpha=alpha) if self.module.downsample is not None: x = RPProvider.get(self.module.downsample)(x, alpha=alpha) out = RPProvider.get(self.module.relu2)(out, alpha=alpha) out = RPProvider.get(self.module.bn2)(out, alpha=alpha) out = RPProvider.get(self.module.conv2)(out, alpha=alpha) out = RPProvider.get(self.module.relu)(out, alpha=alpha) out = RPProvider.get(self.module.bn1)(out, alpha=alpha) x1 = RPProvider.get(self.module.conv1)(out, alpha=alpha) return x + x1 TLapBlock = TypeVar('TLapBlock', LapBasicBlock, LapBottleneck) class BlockFactory(Generic[TLapBlock]): def __init__(self, block: Type[TLapBlock], pool_factory: PoolFactory, sigmoid_scale: float = 1.0) -> None: self.expansion = block.expansion self._block = block self._pool_factory = pool_factory self._sigmoid_scale = sigmoid_scale def __call__(self, *args, **kwargs) -> TLapBlock: return self._block(self._pool_factory, *args, **kwargs, sigmoid_scale=self._sigmoid_scale) class Stride(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return x[:, :, ::2, ::2] @RPProvider.register(Stride) class StrideRelProp(RelPropSimple[Stride]): pass class LapResNet(ResNet, AttentionInterpretableModel): def __init__(self, block: Type[TLapBlock], layers: List[int], pool_factory: PoolFactory = lap_factory, lap_positions: List[int] = [2, 3, 4], adaptive_factory: PoolFactory = None, sigmoid_scale: float = 1.0, binary: bool = False): super().__init__(BlockFactory(block, pool_factory, sigmoid_scale=sigmoid_scale), layers, binary=binary) if adaptive_factory is not None: self.avgpool = nn.Sequential( adaptive_factory(512, sigmoid_scale=sigmoid_scale), ) for i in range(2, 5): if i not in lap_positions: warnings.warn(f'Putting stride on layer {i}') getattr(self, 'layer{}'.format(i))[0].pool = Stride() @property def attention_layers(self) -> Dict[str, List[LAP]]: """ List of attention groups """ attention_groups = { '0_layer2': [self.layer2[0].pool], '1_layer3': [self.layer3[0].pool], '2_layer4': [self.layer4[0].pool], '3_avgpool': [self.avgpool[0]], '4_overall': [ self.layer2[0].pool, self.layer3[0].pool, self.layer4[0].pool, self.avgpool[0], ] } assert all( all( isinstance(layer, LAP) for layer in attention_group) for attention_group in attention_groups.values()), \ "Only LAP is supported for this interpretation method" return attention_groups def lap_resnet18(pool_factory: PoolFactory = lap_factory, adaptive_factory: PoolFactory = None, sigmoid_scale: float = 1.0, lap_positions: List[int] = [2, 3, 4], binary: bool = False) -> LapResNet: """Constructs a LAP-ResNet-18 model. """ return LapResNet(LapBasicBlock, [2, 2, 2, 2], pool_factory=pool_factory, adaptive_factory=adaptive_factory, sigmoid_scale=sigmoid_scale, lap_positions=lap_positions, binary=binary) def lap_resnet50(pool_factory: PoolFactory = lap_factory, adaptive_factory: PoolFactory = None, sigmoid_scale: float = 1.0, lap_positions: List[int] = [2, 3, 4], binary: bool = False) -> LapResNet: """Constructs a LAP-ResNet-50 model. """ return LapResNet(LapBottleneck, [3, 4, 6, 3], pool_factory=pool_factory, adaptive_factory=adaptive_factory, sigmoid_scale=sigmoid_scale, lap_positions=lap_positions, binary=binary)