You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lap_resnet.py 1.4KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. from typing import Dict
  2. import torch
  3. from ...modules import LAP, AdaptiveLAP
  4. from ..lap_resnet import LapBasicBlock, LapResNet, PoolFactory
  5. def get_pool_factory(discriminative_attention=True):
  6. def pool_factory(channels, sigmoid_scale=1.0):
  7. return LAP(channels,
  8. hidden_channels=[8],
  9. sigmoid_scale=sigmoid_scale,
  10. discriminative_attention=discriminative_attention)
  11. return pool_factory
  12. def get_adaptive_pool_factory(discriminative_attention=True):
  13. def adaptive_pool_factory(channels, sigmoid_scale=1.0):
  14. return AdaptiveLAP(channels,
  15. sigmoid_scale=sigmoid_scale,
  16. discriminative_attention=discriminative_attention)
  17. return adaptive_pool_factory
  18. class RSNALAPResNet18(LapResNet):
  19. def __init__(self, pool_factory: PoolFactory = get_pool_factory(),
  20. adaptive_factory: PoolFactory = get_adaptive_pool_factory(),
  21. sigmoid_scale: float = 1.0):
  22. super().__init__(LapBasicBlock, [2, 2, 2, 2],
  23. pool_factory=pool_factory,
  24. adaptive_factory=adaptive_factory,
  25. sigmoid_scale=sigmoid_scale,
  26. binary=True)
  27. def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]:
  28. x = x.repeat_interleave(3, dim=1) # B 3 224 224
  29. return super().forward(x, y)