12345678910111213141516171819202122232425 |
- from collections import OrderedDict
- from typing import Dict
- import typing
-
- import torch
-
- from ..tv_resnet import BasicBlock, ResNet
-
-
- class CelebAORGResNet18(ResNet):
-
- def __init__(self, tag: str):
- super().__init__(BasicBlock, [2, 2, 2, 2], binary=True)
- self._tag = tag
-
- @property
- def additional_kwargs(self) -> typing.OrderedDict[str, bool]:
- r""" Returns a dictionary from additional `kwargs` names to their optionality """
- return OrderedDict({
- f'{self._tag}': True,
- })
-
- def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]:
- y = gts[f'{self._tag}']
- return super().forward(x, y)
|