You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lap_resnet.py 1.7KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from collections import OrderedDict
  2. from typing import Dict, List
  3. import typing
  4. import torch
  5. from ...modules import LAP, AdaptiveLAP
  6. from ..lap_resnet import LapBasicBlock, LapResNet, PoolFactory
  7. def pool_factory(channels, sigmoid_scale=1.0):
  8. return LAP(channels,
  9. sigmoid_scale=sigmoid_scale,
  10. n_attention_heads=3,
  11. discriminative_attention=True)
  12. def adaptive_pool_factory(channels, sigmoid_scale=1.0):
  13. return AdaptiveLAP(channels,
  14. sigmoid_scale=sigmoid_scale,
  15. n_attention_heads=3,
  16. discriminative_attention=True)
  17. class CelebALAPResNet18(LapResNet):
  18. def __init__(self, tag: str,
  19. pool_factory: PoolFactory = pool_factory,
  20. adaptive_factory: PoolFactory = adaptive_pool_factory,
  21. sigmoid_scale: float = 1.0):
  22. super().__init__(LapBasicBlock, [2, 2, 2, 2],
  23. pool_factory=pool_factory,
  24. adaptive_factory=adaptive_factory,
  25. sigmoid_scale=sigmoid_scale,
  26. binary=True)
  27. self._tag = tag
  28. @property
  29. def additional_kwargs(self) -> typing.OrderedDict[str, bool]:
  30. r""" Returns a dictionary from additional `kwargs` names to their optionality """
  31. return OrderedDict({
  32. f'{self._tag}': True,
  33. })
  34. def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]:
  35. y = gts[f'{self._tag}']
  36. return super().forward(x, y)
  37. @property
  38. def attention_layers(self) -> Dict[str, List[LAP]]:
  39. res = super().attention_layers
  40. res.pop('4_overall')
  41. return res