You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

single_tag_org_inception.py 2.0KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import typing
  2. from typing import Dict, List, Iterable
  3. from collections import OrderedDict
  4. import torch
  5. from torch import nn
  6. from torch.nn import functional as F
  7. import torchvision
  8. from ..tv_inception import Inception3
  9. class CelebAORGInception(Inception3):
  10. def __init__(self, tag: str, aux_weight: float):
  11. super().__init__(aux_weight, n_classes=1)
  12. self._tag = tag
  13. @property
  14. def additional_kwargs(self) -> typing.OrderedDict[str, bool]:
  15. r""" Returns a dictionary from additional `kwargs` names to their optionality """
  16. return OrderedDict({
  17. f'{self._tag}': True,
  18. })
  19. def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]:
  20. if self.training:
  21. out, aux = torchvision.models.Inception3.forward(self, x) # B 1
  22. out, aux = out.flatten(), aux.flatten() # B
  23. else:
  24. out = torchvision.models.Inception3.forward(self, x).flatten() # B
  25. aux = None
  26. output = dict()
  27. output['positive_class_probability'] = out
  28. if f'{self._tag}' not in gts:
  29. return output
  30. gt = gts[f'{self._tag}']
  31. r""" Class weighted loss """
  32. loss = torch.mean(torch.stack(tuple(
  33. F.binary_cross_entropy(out[gt == i], gt[gt == i]) for i in gt.unique()
  34. )))
  35. output['loss'] = loss
  36. return output
  37. @property
  38. def target_conv_layers(self) -> List[nn.Module]:
  39. return [
  40. self.Mixed_7c.branch1x1,
  41. self.Mixed_7c.branch3x3_2a, self.Mixed_7c.branch3x3_2b,
  42. self.Mixed_7c.branch3x3dbl_3a, self.Mixed_7c.branch3x3dbl_3b,
  43. self.Mixed_7c.branch_pool
  44. ]
  45. @property
  46. def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]:
  47. return ['x']
  48. def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor:
  49. p = self.forward(*inputs, **kwargs)['positive_class_probability']
  50. return torch.stack([1 - p, p], dim=1)