from typing import Dict import torch from ...modules import LAP, AdaptiveLAP from ..lap_resnet import LapBasicBlock, LapResNet, PoolFactory def get_pool_factory(discriminative_attention=True): def pool_factory(channels, sigmoid_scale=1.0): return LAP(channels, hidden_channels=[8], sigmoid_scale=sigmoid_scale, discriminative_attention=discriminative_attention) return pool_factory def get_adaptive_pool_factory(discriminative_attention=True): def adaptive_pool_factory(channels, sigmoid_scale=1.0): return AdaptiveLAP(channels, sigmoid_scale=sigmoid_scale, discriminative_attention=discriminative_attention) return adaptive_pool_factory class RSNALAPResNet18(LapResNet): def __init__(self, pool_factory: PoolFactory = get_pool_factory(), adaptive_factory: PoolFactory = get_adaptive_pool_factory(), sigmoid_scale: float = 1.0): super().__init__(LapBasicBlock, [2, 2, 2, 2], pool_factory=pool_factory, adaptive_factory=adaptive_factory, sigmoid_scale=sigmoid_scale, binary=True) def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]: x = x.repeat_interleave(3, dim=1) # B 3 224 224 return super().forward(x, y)