|
123456789101112131415161718192021222324252627282930313233343536373839 |
- from typing import Dict
-
- import torch
-
- from ...modules import LAP, AdaptiveLAP
- from ..lap_resnet import LapBasicBlock, LapResNet, PoolFactory
-
-
- def get_pool_factory(discriminative_attention=True):
- def pool_factory(channels, sigmoid_scale=1.0):
- return LAP(channels,
- hidden_channels=[8],
- sigmoid_scale=sigmoid_scale,
- discriminative_attention=discriminative_attention)
- return pool_factory
-
-
- def get_adaptive_pool_factory(discriminative_attention=True):
- def adaptive_pool_factory(channels, sigmoid_scale=1.0):
- return AdaptiveLAP(channels,
- sigmoid_scale=sigmoid_scale,
- discriminative_attention=discriminative_attention)
- return adaptive_pool_factory
-
-
- class RSNALAPResNet18(LapResNet):
-
- def __init__(self, pool_factory: PoolFactory = get_pool_factory(),
- adaptive_factory: PoolFactory = get_adaptive_pool_factory(),
- sigmoid_scale: float = 1.0):
- super().__init__(LapBasicBlock, [2, 2, 2, 2],
- pool_factory=pool_factory,
- adaptive_factory=adaptive_factory,
- sigmoid_scale=sigmoid_scale,
- binary=True)
-
- def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]:
- x = x.repeat_interleave(3, dim=1) # B 3 224 224
- return super().forward(x, y)
|