from collections import OrderedDict from typing import Dict, List import typing import torch from ...modules import LAP, AdaptiveLAP from ..lap_resnet import LapBasicBlock, LapResNet, PoolFactory def pool_factory(channels, sigmoid_scale=1.0): return LAP(channels, sigmoid_scale=sigmoid_scale, n_attention_heads=3, discriminative_attention=True) def adaptive_pool_factory(channels, sigmoid_scale=1.0): return AdaptiveLAP(channels, sigmoid_scale=sigmoid_scale, n_attention_heads=3, discriminative_attention=True) class CelebALAPResNet18(LapResNet): def __init__(self, tag: str, pool_factory: PoolFactory = pool_factory, adaptive_factory: PoolFactory = adaptive_pool_factory, sigmoid_scale: float = 1.0): super().__init__(LapBasicBlock, [2, 2, 2, 2], pool_factory=pool_factory, adaptive_factory=adaptive_factory, sigmoid_scale=sigmoid_scale, binary=True) self._tag = tag @property def additional_kwargs(self) -> typing.OrderedDict[str, bool]: r""" Returns a dictionary from additional `kwargs` names to their optionality """ return OrderedDict({ f'{self._tag}': True, }) def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]: y = gts[f'{self._tag}'] return super().forward(x, y) @property def attention_layers(self) -> Dict[str, List[LAP]]: res = super().attention_layers res.pop('4_overall') return res