from typing import Dict, List import torch from torchvision import transforms from ...modules import LAP, AdaptiveLAP, ScaledSigmoid from ..lap_resnet import LapBottleneck, LapResNet, PoolFactory def pool_factory(channels, sigmoid_scale=1.0): return LAP(channels, hidden_channels=[1000], sigmoid_scale=sigmoid_scale, hidden_activation=ScaledSigmoid.get_factory(sigmoid_scale)) class ImagenetLAPResNet50(LapResNet): def __init__(self, pool_factory: PoolFactory = pool_factory, sigmoid_scale: float = 1.0): super().__init__(LapBottleneck, [3, 4, 6, 3], pool_factory=pool_factory, sigmoid_scale=sigmoid_scale, lap_positions=[4]) self.normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) def forward(self, x: torch.Tensor, y: torch.Tensor) -> Dict[str, torch.Tensor]: x = self.normalize(x) return super().forward(x, y) @property def attention_layers(self) -> Dict[str, List[LAP]]: return { '2_layer4': super().attention_layers['2_layer4'], }