You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lap_resnet.py 1.1KB

12345678910111213141516171819202122232425262728293031323334
  1. from typing import Dict, List
  2. import torch
  3. from torchvision import transforms
  4. from ...modules import LAP, AdaptiveLAP, ScaledSigmoid
  5. from ..lap_resnet import LapBottleneck, LapResNet, PoolFactory
  6. def pool_factory(channels, sigmoid_scale=1.0):
  7. return LAP(channels,
  8. hidden_channels=[1000],
  9. sigmoid_scale=sigmoid_scale,
  10. hidden_activation=ScaledSigmoid.get_factory(sigmoid_scale))
  11. class ImagenetLAPResNet50(LapResNet):
  12. def __init__(self, pool_factory: PoolFactory = pool_factory, sigmoid_scale: float = 1.0):
  13. super().__init__(LapBottleneck, [3, 4, 6, 3],
  14. pool_factory=pool_factory,
  15. sigmoid_scale=sigmoid_scale,
  16. lap_positions=[4])
  17. self.normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
  18. def forward(self, x: torch.Tensor, y: torch.Tensor) -> Dict[str, torch.Tensor]:
  19. x = self.normalize(x)
  20. return super().forward(x, y)
  21. @property
  22. def attention_layers(self) -> Dict[str, List[LAP]]:
  23. return {
  24. '2_layer4': super().attention_layers['2_layer4'],
  25. }