123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218 |
- from typing import Iterable, Optional, Callable, List
-
- import torch
- from torch import nn, Tensor
- import torchvision
-
- from ..interpreting.interpretable import CamInterpretableModel
- from ..interpreting.relcam.relprop import RPProvider, RelProp
- from ..interpreting.relcam import modules as M
- from .lap_inception import (
- InceptionA,
- InceptionC,
- InceptionE,
- BasicConv2d,
- InceptionAux as BaseInceptionAux,
- )
-
-
- class InceptionB(nn.Module):
-
- def __init__(
- self,
- in_channels: int,
- conv_block: Optional[Callable[..., nn.Module]] = None
- ) -> None:
- super(InceptionB, self).__init__()
- if conv_block is None:
- conv_block = BasicConv2d
- self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
-
- self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
- self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
- self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
-
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
- self.cat = M.Cat()
-
- def _forward(self, x: Tensor) -> List[Tensor]:
- branch3x3 = self.branch3x3(x)
-
- branch3x3dbl = self.branch3x3dbl_1(x)
- branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
- branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
-
- branch_pool = self.maxpool(x)
-
- outputs = [branch3x3, branch3x3dbl, branch_pool]
- return outputs
-
- def forward(self, x: Tensor) -> Tensor:
- outputs = self._forward(x)
- return self.cat(outputs, 1)
-
-
- @RPProvider.register(InceptionB)
- class InceptionBRelProp(RelProp[InceptionB]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- branch3x3, branch3x3dbl, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha)
-
- x1 = RPProvider.get(self.module.maxpool)(branch_pool, alpha=alpha)
-
- branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_3)(branch3x3dbl, alpha=alpha)
- branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha)
- x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha)
-
- x3 = RPProvider.get(self.module.branch3x3)(branch3x3, alpha=alpha)
-
- return x1 + x2 + x3
-
-
- class InceptionD(nn.Module):
-
- def __init__(
- self,
- in_channels: int,
- conv_block: Optional[Callable[..., nn.Module]] = None
- ) -> None:
- super(InceptionD, self).__init__()
- if conv_block is None:
- conv_block = BasicConv2d
- self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
- self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
-
- self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
- self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
- self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
- self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
-
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
- self.cat = M.Cat()
-
- def _forward(self, x: Tensor) -> List[Tensor]:
- branch3x3 = self.branch3x3_1(x)
- branch3x3 = self.branch3x3_2(branch3x3)
-
- branch7x7x3 = self.branch7x7x3_1(x)
- branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
- branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
- branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
-
- branch_pool = self.maxpool(x)
- outputs = [branch3x3, branch7x7x3, branch_pool]
- return outputs
-
- def forward(self, x: Tensor) -> Tensor:
- outputs = self._forward(x)
- return self.cat(outputs, 1)
-
-
- @RPProvider.register(InceptionD)
- class InceptionDRelProp(RelProp[InceptionD]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- branch3x3, branch7x7x3, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha)
-
- x1 = RPProvider.get(self.module.maxpool)(branch_pool, alpha=alpha)
-
- branch7x7x3 = RPProvider.get(self.module.branch7x7x3_4)(branch7x7x3, alpha=alpha)
- branch7x7x3 = RPProvider.get(self.module.branch7x7x3_3)(branch7x7x3, alpha=alpha)
- branch7x7x3 = RPProvider.get(self.module.branch7x7x3_2)(branch7x7x3, alpha=alpha)
- x2 = RPProvider.get(self.module.branch7x7x3_1)(branch7x7x3, alpha=alpha)
-
- branch3x3 = RPProvider.get(self.module.branch3x3_2)(branch3x3, alpha=alpha)
- x3 = RPProvider.get(self.module.branch3x3_1)(branch3x3, alpha=alpha)
-
- return x1 + x2 + x3
-
-
- class InceptionAux(BaseInceptionAux):
-
- def __init__(
- self,
- in_channels: int,
- num_classes: int,
- conv_block: Optional[Callable[..., nn.Module]] = None
- ) -> None:
- super().__init__(in_channels, num_classes, conv_block=conv_block)
- if conv_block is None:
- conv_block = BasicConv2d
- self.conv1 = conv_block(128, 768, kernel_size=5)
- self.conv1.stddev = 0.01 # type: ignore[assignment]
-
-
- class Inception3(CamInterpretableModel, torchvision.models.Inception3):
-
- def __init__(self, aux_weight: float, n_classes=1):
- torchvision.models.Inception3.__init__(self,
- transform_input=False, init_weights=False,
- inception_blocks = [
- BasicConv2d, InceptionA, InceptionB, InceptionC,
- InceptionD, InceptionE, InceptionAux
- ])
-
- self.fc = nn.Sequential(
- nn.Linear(2048, n_classes),
- nn.Sigmoid()
- )
-
- self.AuxLogits.fc = nn.Sequential(
- nn.Linear(768, n_classes),
- nn.Sigmoid()
- )
- self.aux_weight = aux_weight
-
- @property
- def target_conv_layers(self) -> List[nn.Module]:
- return [
- self.Mixed_7c.branch1x1,
- self.Mixed_7c.branch3x3_2a, self.Mixed_7c.branch3x3_2b,
- self.Mixed_7c.branch3x3dbl_3a, self.Mixed_7c.branch3x3dbl_3b,
- self.Mixed_7c.branch_pool
- ]
-
- @property
- def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]:
- return ['x']
-
- def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor:
- p = self.forward(*inputs, **kwargs)['positive_class_probability']
- return torch.stack([1 - p, p], dim=1)
-
-
-
- @RPProvider.register(Inception3)
- class Inception3Mo4RelProp(RelProp[Inception3]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- if RPProvider.get(self.module.fc).Y.shape[1] == 1:
- R = R[:, -1:]
- R = RPProvider.get(self.module.fc)(R, alpha=alpha) # B 2048
- R = R.reshape_as(RPProvider.get(self.module.dropout).Y) # B 2048 1 1
- R = RPProvider.get(self.module.dropout)(R, alpha=alpha) # B 2048 1 1
- R = RPProvider.get(self.module.avgpool)(R, alpha=alpha) # B 2048 8 8
-
- R = RPProvider.get(self.module.Mixed_7c)(R, alpha=alpha) # B 2048 8 8
- R = RPProvider.get(self.module.Mixed_7b)(R, alpha=alpha) # B 1280 8 8
- R = RPProvider.get(self.module.Mixed_7a)(R, alpha=alpha) # B 768 17 17
-
- R = RPProvider.get(self.module.Mixed_6e)(R, alpha=alpha) # B 768 17 17
- R = RPProvider.get(self.module.Mixed_6d)(R, alpha=alpha) # B 768 17 17
- R = RPProvider.get(self.module.Mixed_6c)(R, alpha=alpha) # B 768 17 17
- R = RPProvider.get(self.module.Mixed_6b)(R, alpha=alpha) # B 768 17 17
- R = RPProvider.get(self.module.Mixed_6a)(R, alpha=alpha) # B 288 35 35
-
- R = RPProvider.get(self.module.Mixed_5d)(R, alpha=alpha) # B 288 35 35
- R = RPProvider.get(self.module.Mixed_5c)(R, alpha=alpha) # B 256 35 35
- R = RPProvider.get(self.module.Mixed_5b)(R, alpha=alpha) # B 192 35 35
-
- R = RPProvider.get(self.module.maxpool2)(R, alpha=alpha) # B 192 71 71
- R = RPProvider.get(self.module.Conv2d_4a_3x3)(R, alpha=alpha) # B 80 73 73
- R = RPProvider.get(self.module.Conv2d_3b_1x1)(R, alpha=alpha) # B 64 73 73
-
- R = RPProvider.get(self.module.maxpool1)(R, alpha=alpha) # B 64 147 147
- R = RPProvider.get(self.module.Conv2d_2b_3x3)(R, alpha=alpha) # B 32 147 147
- R = RPProvider.get(self.module.Conv2d_2a_3x3)(R, alpha=alpha) # B 32 149 149
- R = RPProvider.get(self.module.Conv2d_1a_3x3)(R, alpha=alpha) # B 3 299 299
- return R
|