123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556 |
- from typing import List, Optional, Callable, Any, Iterable, Dict
-
- import torch
- from torch import Tensor, nn
- import torchvision
-
- from ..modules.lap import LAP
- from ..interpreting.attention_interpreter import AttentionInterpretableModel
- from ..interpreting.interpretable import CamInterpretableModel
- from ..interpreting.relcam.relprop import RPProvider, RelProp
- from ..interpreting.relcam import modules as M
-
-
- class InceptionB(nn.Module):
-
- def __init__(
- self,
- in_channels: int,
- pool_factory,
- conv_block: Optional[Callable[..., nn.Module]] = None
- ) -> None:
- super(InceptionB, self).__init__()
- if conv_block is None:
- conv_block = BasicConv2d
- #self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
- self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=1, padding=1)
-
- self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
- self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
- #self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
- self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=1, padding=1)
-
- self.pool = pool_factory(384 + 96 + in_channels)
-
- self.cat = M.Cat()
-
- def _forward(self, x: Tensor) -> List[Tensor]:
-
- branch3x3 = self.branch3x3(x)
-
- branch3x3dbl = self.branch3x3dbl_1(x)
- branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
- branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
-
- outputs = [branch3x3, branch3x3dbl, x]
- return outputs
-
- def forward(self, x: Tensor) -> Tensor:
- outputs = self._forward(x)
- return self.pool(self.cat(outputs, 1))
-
-
- @RPProvider.register(InceptionB)
- class InceptionBRelProp(RelProp[InceptionB]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- R = RPProvider.get(self.module.pool)(R, alpha=alpha)
- branch3x3, branch3x3dbl, x1 = RPProvider.get(self.module.cat)(R, alpha=alpha)
-
- branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_3)(branch3x3dbl, alpha=alpha)
- branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha)
- x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha)
-
- x3 = RPProvider.get(self.module.branch3x3)(branch3x3, alpha=alpha)
-
- return x1 + x2 + x3
-
-
- class InceptionD(nn.Module):
-
- def __init__(
- self,
- in_channels: int,
- pool_factory,
- conv_block: Optional[Callable[..., nn.Module]] = None
- ) -> None:
- super(InceptionD, self).__init__()
- if conv_block is None:
- conv_block = BasicConv2d
-
- self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
- #self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
- self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=1, padding=1)
- #self.branch3x3_2_stride = get_pooler(pooler_cls, 320, 2, pooler_hidden_layers)
-
- self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
- self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
- self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
- #self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
- self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=1, padding=1)
- #self.branch7x7x3_4_stride = get_pooler(pooler_cls, 192, 2, pooler_hidden_layers)
-
- self.pool = pool_factory(320 + 192 + in_channels)
-
- self.cat = M.Cat()
-
- def _forward(self, x: Tensor) -> List[Tensor]:
-
- branch3x3 = self.branch3x3_1(x)
- branch3x3 = self.branch3x3_2(branch3x3)
-
- branch7x7x3 = self.branch7x7x3_1(x)
- branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
- branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
- branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
-
- outputs = [branch3x3, branch7x7x3, x]
- return outputs
-
- def forward(self, x: Tensor) -> Tensor:
- outputs = self._forward(x)
- return self.pool(self.cat(outputs, 1))
-
-
- @RPProvider.register(InceptionD)
- class InceptionDRelProp(RelProp[InceptionD]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- R = RPProvider.get(self.module.pool)(R, alpha=alpha)
- branch3x3, branch7x7x3, x1 = RPProvider.get(self.module.cat)(R, alpha=alpha)
-
- branch7x7x3 = RPProvider.get(self.module.branch7x7x3_4)(branch7x7x3, alpha=alpha)
- branch7x7x3 = RPProvider.get(self.module.branch7x7x3_3)(branch7x7x3, alpha=alpha)
- branch7x7x3 = RPProvider.get(self.module.branch7x7x3_2)(branch7x7x3, alpha=alpha)
- x2 = RPProvider.get(self.module.branch7x7x3_1)(branch7x7x3, alpha=alpha)
-
- branch3x3 = RPProvider.get(self.module.branch3x3_2)(branch3x3, alpha=alpha)
- x3 = RPProvider.get(self.module.branch3x3_1)(branch3x3, alpha=alpha)
-
- return x1 + x2 + x3
-
-
- def inception_b_maker(pool_factory):
- return (
- lambda in_channels, conv_block=None:
- InceptionB(in_channels, pool_factory, conv_block))
-
-
- def inception_d_maker(pool_factory):
- return (
- lambda in_channels, conv_block=None:
- InceptionD(in_channels, pool_factory, conv_block))
-
-
- class InceptionA(nn.Module):
-
- def __init__(
- self,
- in_channels: int,
- pool_features: int,
- conv_block: Optional[Callable[..., nn.Module]] = None,
- ) -> None:
- super(InceptionA, self).__init__()
- if conv_block is None:
- conv_block = BasicConv2d
- self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
-
- self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
- self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
-
- self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
- self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
- self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
-
- self.avg_pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
- self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
- self.cat = M.Cat()
-
- def _forward(self, x: Tensor) -> List[Tensor]:
-
- branch1x1 = self.branch1x1(x)
-
- branch5x5 = self.branch5x5_1(x)
- branch5x5 = self.branch5x5_2(branch5x5)
-
- branch3x3dbl = self.branch3x3dbl_1(x)
- branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
- branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
-
- branch_pool = self.avg_pool(x)
- branch_pool = self.branch_pool(branch_pool)
-
- outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
- return outputs
-
- def forward(self, x: Tensor) -> Tensor:
- outputs = self._forward(x)
- return self.cat(outputs, 1)
-
-
- @RPProvider.register(InceptionA)
- class InceptionARelProp(RelProp[InceptionA]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- branch1x1, branch5x5, branch3x3dbl, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha)
-
- branch_pool = RPProvider.get(self.module.branch_pool)(branch_pool, alpha=alpha)
- x1 = RPProvider.get(self.module.avg_pool)(branch_pool, alpha=alpha)
-
- branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_3)(branch3x3dbl, alpha=alpha)
- branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha)
- x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha)
-
- branch5x5 = RPProvider.get(self.module.branch5x5_2)(branch5x5, alpha=alpha)
- x3 = RPProvider.get(self.module.branch5x5_1)(branch5x5, alpha=alpha)
-
- x4 = RPProvider.get(self.module.branch1x1)(branch1x1, alpha=alpha)
-
- return x1 + x2 + x3 + x4
-
-
- class InceptionC(nn.Module):
-
- def __init__(
- self,
- in_channels: int,
- channels_7x7: int,
- conv_block: Optional[Callable[..., nn.Module]] = None
- ) -> None:
- super(InceptionC, self).__init__()
- if conv_block is None:
- conv_block = BasicConv2d
- self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
-
- c7 = channels_7x7
- self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
- self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
- self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
-
- self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
- self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
- self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
- self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
- self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
-
- self.avg_pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
- self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
-
- self.cat = M.Cat()
-
- def _forward(self, x: Tensor) -> List[Tensor]:
-
- branch1x1 = self.branch1x1(x)
-
- branch7x7 = self.branch7x7_1(x)
- branch7x7 = self.branch7x7_2(branch7x7)
- branch7x7 = self.branch7x7_3(branch7x7)
-
- branch7x7dbl = self.branch7x7dbl_1(x)
- branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
- branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
- branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
- branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
-
- branch_pool = self.avg_pool(x)
- branch_pool = self.branch_pool(branch_pool)
-
- outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
- return outputs
-
- def forward(self, x: Tensor) -> Tensor:
- outputs = self._forward(x)
- return self.cat(outputs, 1)
-
-
- @RPProvider.register(InceptionC)
- class InceptionCRelProp(RelProp[InceptionC]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- branch1x1, branch7x7, branch7x7dbl, branch_pool = RPProvider.get(self.module.cat)(R, alpha=alpha)
-
- branch_pool = RPProvider.get(self.module.branch_pool)(branch_pool, alpha=alpha)
- x1 = RPProvider.get(self.module.avg_pool)(branch_pool, alpha=alpha)
-
- branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_5)(branch7x7dbl, alpha=alpha)
- branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_4)(branch7x7dbl, alpha=alpha)
- branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_3)(branch7x7dbl, alpha=alpha)
- branch7x7dbl = RPProvider.get(self.module.branch7x7dbl_2)(branch7x7dbl, alpha=alpha)
- x2 = RPProvider.get(self.module.branch7x7dbl_1)(branch7x7dbl, alpha=alpha)
-
- branch7x7 = RPProvider.get(self.module.branch7x7_3)(branch7x7, alpha=alpha)
- branch7x7 = RPProvider.get(self.module.branch7x7_2)(branch7x7, alpha=alpha)
- x3 = RPProvider.get(self.module.branch7x7_1)(branch7x7, alpha=alpha)
-
- x4 = RPProvider.get(self.module.branch1x1)(branch1x1, alpha=alpha)
-
- return x1 + x2 + x3 + x4
-
-
- class InceptionE(nn.Module):
-
- def __init__(
- self,
- in_channels: int,
- conv_block: Optional[Callable[..., nn.Module]] = None
- ) -> None:
- super(InceptionE, self).__init__()
- if conv_block is None:
- conv_block = BasicConv2d
- self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
-
- self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
- self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
- self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
-
- self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
- self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
- self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
- self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
-
- self.avg_pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
- self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
-
- self.cat1 = M.Cat()
- self.cat2 = M.Cat()
- self.cat3 = M.Cat()
-
- def _forward(self, x: Tensor) -> List[Tensor]:
-
- branch1x1 = self.branch1x1(x)
-
- branch3x3 = self.branch3x3_1(x)
- branch3x3 = [
- self.branch3x3_2a(branch3x3),
- self.branch3x3_2b(branch3x3),
- ]
- branch3x3 = self.cat1(branch3x3, 1)
-
- branch3x3dbl = self.branch3x3dbl_1(x)
- branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
- branch3x3dbl = [
- self.branch3x3dbl_3a(branch3x3dbl),
- self.branch3x3dbl_3b(branch3x3dbl),
- ]
- branch3x3dbl = self.cat2(branch3x3dbl, 1)
-
- branch_pool = self.avg_pool(x)
- branch_pool = self.branch_pool(branch_pool)
-
- outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
- return outputs
-
- def forward(self, x: Tensor) -> Tensor:
- outputs = self._forward(x)
- return self.cat3(outputs, 1)
-
-
- @RPProvider.register(InceptionE)
- class InceptionERelProp(RelProp[InceptionE]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- branch1x1, branch3x3, branch3x3dbl, branch_pool = RPProvider.get(self.module.cat3)(R, alpha=alpha)
-
- branch_pool = RPProvider.get(self.module.branch_pool)(branch_pool, alpha=alpha)
- x1 = RPProvider.get(self.module.avg_pool)(branch_pool, alpha=alpha)
-
- branch3x3dbl_3a, branch3x3dbl_3b = RPProvider.get(self.module.cat2)(branch3x3dbl, alpha=alpha)
- branch3x3dbl_1 = RPProvider.get(self.module.branch3x3dbl_3a)(branch3x3dbl_3a, alpha=alpha)
- branch3x3dbl_2 = RPProvider.get(self.module.branch3x3dbl_3b)(branch3x3dbl_3b, alpha=alpha)
- branch3x3dbl = branch3x3dbl_1 + branch3x3dbl_2
- branch3x3dbl = RPProvider.get(self.module.branch3x3dbl_2)(branch3x3dbl, alpha=alpha)
- x2 = RPProvider.get(self.module.branch3x3dbl_1)(branch3x3dbl, alpha=alpha)
-
- branch3x3_2a, branch3x3_2b = RPProvider.get(self.module.cat1)(branch3x3, alpha=alpha)
- branch3x3_1 = RPProvider.get(self.module.branch3x3_2a)(branch3x3_2a, alpha=alpha)
- branch3x3_2 = RPProvider.get(self.module.branch3x3_2b)(branch3x3_2b, alpha=alpha)
- branch3x3 = branch3x3_1 + branch3x3_2
- x3 = RPProvider.get(self.module.branch3x3_1)(branch3x3, alpha=alpha)
-
- x4 = RPProvider.get(self.module.branch1x1)(branch1x1, alpha=alpha)
-
- return x1 + x2 + x3 + x4
-
-
- class InceptionAux(nn.Module):
-
- def __init__(
- self,
- in_channels: int,
- num_classes: int,
- conv_block: Optional[Callable[..., nn.Module]] = None
- ) -> None:
- super(InceptionAux, self).__init__()
- if conv_block is None:
- conv_block = BasicConv2d
-
- self.avgpool1 = nn.AvgPool2d(kernel_size=5, stride=3)
- self.conv0 = conv_block(in_channels, 128, kernel_size=1)
- self.conv1 = conv_block(128, 768, kernel_size=5, padding=2)
- self.conv1.stddev = 0.01 # type: ignore[assignment]
- self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1))
- self.flatten = nn.Flatten()
- self.fc = nn.Linear(768, num_classes)
- self.fc.stddev = 0.001 # type: ignore[assignment]
-
- def forward(self, x: Tensor) -> Tensor:
- # N x 768 x 17 x 17
- x = self.avgpool1(x)
- # N x 768 x 5 x 5
- x = self.conv0(x)
- # N x 128 x 5 x 5
- x = self.conv1(x)
- # N x 768 x 1 x 1
- # Adaptive average pooling
- x = self.avgpool2(x)
- # N x 768 x 1 x 1
- x = self.flatten(x)
- # N x 768
- x = self.fc(x)
- # N x 1000
- return x
-
-
- @RPProvider.register(InceptionAux)
- class InceptionAuxRelProp(RelProp[InceptionAux]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- R = RPProvider.get(self.module.fc)(R, alpha=alpha)
- R = RPProvider.get(self.module.flatten)(R, alpha=alpha)
- R = RPProvider.get(self.module.avgpool2)(R, alpha=alpha)
- R = RPProvider.get(self.module.conv1)(R, alpha=alpha)
- R = RPProvider.get(self.module.conv0)(R, alpha=alpha)
- R = RPProvider.get(self.module.avgpool1)(R, alpha=alpha)
- return R
-
-
- class BasicConv2d(nn.Module):
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- **kwargs: Any
- ) -> None:
- super(BasicConv2d, self).__init__()
- self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
- self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
- self.act = torch.nn.ReLU()
-
- def forward(self, x: Tensor) -> Tensor:
- x = self.conv(x)
- x = self.bn(x)
- return self.act(x)
-
-
- @RPProvider.register(BasicConv2d)
- class BasicConv2dRelProp(RelProp[BasicConv2d]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- R = RPProvider.get(self.module.act)(R, alpha=alpha)
- R = RPProvider.get(self.module.bn)(R, alpha=alpha)
- R = RPProvider.get(self.module.conv)(R, alpha=alpha)
- return R
-
-
- class LAPInception(AttentionInterpretableModel, CamInterpretableModel, torchvision.models.Inception3):
-
- def __init__(self, aux_weight: float, n_classes, pool_factory, adaptive_pool_factory):
- torchvision.models.Inception3.__init__(
- self,
- transform_input=False, init_weights=False,
- inception_blocks=[
- BasicConv2d, InceptionA, inception_b_maker(pool_factory),
- InceptionC, inception_d_maker(pool_factory), InceptionE, InceptionAux
- ])
-
- self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2, padding=1)
- self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3, padding=1)
- self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
- self.maxpool1 = pool_factory(64)
- self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1,)
- self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3, padding=1)
- self.maxpool2 = pool_factory(192)
-
- if adaptive_pool_factory is not None:
- self.avgpool = nn.Sequential(
- adaptive_pool_factory(2048))
-
- self.fc = nn.Sequential(
- nn.Linear(2048, 1),
- nn.Sigmoid()
- )
- self.AuxLogits.fc = nn.Sequential(
- nn.Linear(768, 1),
- nn.Sigmoid()
- )
- self.aux_weight = aux_weight
-
- @property
- def target_conv_layers(self) -> List[nn.Module]:
- return [
- self.Mixed_7c.branch1x1,
- self.Mixed_7c.branch3x3_2a, self.Mixed_7c.branch3x3_2b,
- self.Mixed_7c.branch3x3dbl_3a, self.Mixed_7c.branch3x3dbl_3b,
- self.Mixed_7c.branch_pool
- ]
-
- @property
- def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]:
- return ['x']
-
- def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor:
- p = self.forward(*inputs, **kwargs)['positive_class_probability']
- return torch.stack([1 - p, p], dim=1)
-
- @property
- def attention_layers(self) -> Dict[str, List[LAP]]:
- attention_groups = {
- '0_layer2': [self.maxpool2],
- '1_layer6': [self.Mixed_6a.pool],
- '2_layer7': [self.Mixed_7a.pool],
- '3_avgpool': [self.avgpool[0]],
- '4_all': [
- self.maxpool2,
- self.Mixed_6a.pool,
- self.Mixed_7a.pool,
- self.avgpool[0]],
- }
-
- return attention_groups
-
-
- @RPProvider.register(LAPInception)
- class Inception3Mo4RelProp(RelProp[LAPInception]):
-
- def rel(self, R: torch.Tensor, alpha: float = 1) -> torch.Tensor:
- if RPProvider.get(self.module.fc).Y.shape[1] == 1:
- R = R[:, -1:]
- R = RPProvider.get(self.module.fc)(R, alpha=alpha) # B 2048
- R = R.reshape_as(RPProvider.get(self.module.dropout).Y) # B 2048 1 1
- R = RPProvider.get(self.module.dropout)(R, alpha=alpha) # B 2048 1 1
- R = RPProvider.get(self.module.avgpool)(R, alpha=alpha) # B 2048 8 8
-
- R = RPProvider.get(self.module.Mixed_7c)(R, alpha=alpha) # B 2048 8 8
- R = RPProvider.get(self.module.Mixed_7b)(R, alpha=alpha) # B 1280 8 8
- R = RPProvider.get(self.module.Mixed_7a)(R, alpha=alpha) # B 768 17 17
-
- R = RPProvider.get(self.module.Mixed_6e)(R, alpha=alpha) # B 768 17 17
- R = RPProvider.get(self.module.Mixed_6d)(R, alpha=alpha) # B 768 17 17
- R = RPProvider.get(self.module.Mixed_6c)(R, alpha=alpha) # B 768 17 17
- R = RPProvider.get(self.module.Mixed_6b)(R, alpha=alpha) # B 768 17 17
- R = RPProvider.get(self.module.Mixed_6a)(R, alpha=alpha) # B 288 35 35
-
- R = RPProvider.get(self.module.Mixed_5d)(R, alpha=alpha) # B 288 35 35
- R = RPProvider.get(self.module.Mixed_5c)(R, alpha=alpha) # B 256 35 35
- R = RPProvider.get(self.module.Mixed_5b)(R, alpha=alpha) # B 192 35 35
-
- R = RPProvider.get(self.module.maxpool2)(R, alpha=alpha) # B 192 71 71
- R = RPProvider.get(self.module.Conv2d_4a_3x3)(R, alpha=alpha) # B 80 73 73
- R = RPProvider.get(self.module.Conv2d_3b_1x1)(R, alpha=alpha) # B 64 73 73
-
- R = RPProvider.get(self.module.maxpool1)(R, alpha=alpha) # B 64 147 147
- R = RPProvider.get(self.module.Conv2d_2b_3x3)(R, alpha=alpha) # B 32 147 147
- R = RPProvider.get(self.module.Conv2d_2a_3x3)(R, alpha=alpha) # B 32 149 149
- R = RPProvider.get(self.module.Conv2d_1a_3x3)(R, alpha=alpha) # B 3 299 299
- return R
|