from collections import OrderedDict from typing import Dict import typing import torch from ..tv_resnet import BasicBlock, ResNet class CelebAORGResNet18(ResNet): def __init__(self, tag: str): super().__init__(BasicBlock, [2, 2, 2, 2], binary=True) self._tag = tag @property def additional_kwargs(self) -> typing.OrderedDict[str, bool]: r""" Returns a dictionary from additional `kwargs` names to their optionality """ return OrderedDict({ f'{self._tag}': True, }) def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]: y = gts[f'{self._tag}'] return super().forward(x, y)