1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253 |
- from collections import OrderedDict
- from typing import Dict, List
- import typing
-
- import torch
-
- from ...modules import LAP, AdaptiveLAP
- from ..lap_resnet import LapBasicBlock, LapResNet, PoolFactory
-
-
- def pool_factory(channels, sigmoid_scale=1.0):
- return LAP(channels,
- sigmoid_scale=sigmoid_scale,
- n_attention_heads=3,
- discriminative_attention=True)
-
-
- def adaptive_pool_factory(channels, sigmoid_scale=1.0):
- return AdaptiveLAP(channels,
- sigmoid_scale=sigmoid_scale,
- n_attention_heads=3,
- discriminative_attention=True)
-
-
- class CelebALAPResNet18(LapResNet):
-
- def __init__(self, tag: str,
- pool_factory: PoolFactory = pool_factory,
- adaptive_factory: PoolFactory = adaptive_pool_factory,
- sigmoid_scale: float = 1.0):
- super().__init__(LapBasicBlock, [2, 2, 2, 2],
- pool_factory=pool_factory,
- adaptive_factory=adaptive_factory,
- sigmoid_scale=sigmoid_scale,
- binary=True)
- self._tag = tag
-
- @property
- def additional_kwargs(self) -> typing.OrderedDict[str, bool]:
- r""" Returns a dictionary from additional `kwargs` names to their optionality """
- return OrderedDict({
- f'{self._tag}': True,
- })
-
- def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]:
- y = gts[f'{self._tag}']
- return super().forward(x, y)
-
- @property
- def attention_layers(self) -> Dict[str, List[LAP]]:
- res = super().attention_layers
- res.pop('4_overall')
- return res
|