12345678910111213141516171819202122232425262728293031323334 |
- from typing import Dict, List
- import torch
-
- from torchvision import transforms
-
- from ...modules import LAP, AdaptiveLAP, ScaledSigmoid
- from ..lap_resnet import LapBottleneck, LapResNet, PoolFactory
-
-
- def pool_factory(channels, sigmoid_scale=1.0):
- return LAP(channels,
- hidden_channels=[1000],
- sigmoid_scale=sigmoid_scale,
- hidden_activation=ScaledSigmoid.get_factory(sigmoid_scale))
-
-
- class ImagenetLAPResNet50(LapResNet):
-
- def __init__(self, pool_factory: PoolFactory = pool_factory, sigmoid_scale: float = 1.0):
- super().__init__(LapBottleneck, [3, 4, 6, 3],
- pool_factory=pool_factory,
- sigmoid_scale=sigmoid_scale,
- lap_positions=[4])
- self.normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
-
- def forward(self, x: torch.Tensor, y: torch.Tensor) -> Dict[str, torch.Tensor]:
- x = self.normalize(x)
- return super().forward(x, y)
-
- @property
- def attention_layers(self) -> Dict[str, List[LAP]]:
- return {
- '2_layer4': super().attention_layers['2_layer4'],
- }
|