You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

org_resnet.py 693B

12345678910111213141516171819202122232425
  1. from collections import OrderedDict
  2. from typing import Dict
  3. import typing
  4. import torch
  5. from ..tv_resnet import BasicBlock, ResNet
  6. class CelebAORGResNet18(ResNet):
  7. def __init__(self, tag: str):
  8. super().__init__(BasicBlock, [2, 2, 2, 2], binary=True)
  9. self._tag = tag
  10. @property
  11. def additional_kwargs(self) -> typing.OrderedDict[str, bool]:
  12. r""" Returns a dictionary from additional `kwargs` names to their optionality """
  13. return OrderedDict({
  14. f'{self._tag}': True,
  15. })
  16. def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]:
  17. y = gts[f'{self._tag}']
  18. return super().forward(x, y)