You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lap_inception.py 1.7KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import typing
  2. from typing import Dict, Iterable
  3. from collections import OrderedDict
  4. import torch
  5. from torch.nn import functional as F
  6. import torchvision
  7. from ..lap_inception import LAPInception
  8. class CelebALAPInception(LAPInception):
  9. def __init__(self, tag:str, aux_weight: float, pool_factory, adaptive_pool_factory):
  10. super().__init__(aux_weight, n_classes=1, pool_factory=pool_factory, adaptive_pool_factory=adaptive_pool_factory)
  11. self._tag = tag
  12. @property
  13. def additional_kwargs(self) -> typing.OrderedDict[str, bool]:
  14. r""" Returns a dictionary from additional `kwargs` names to their optionality """
  15. return OrderedDict({
  16. f'{self._tag}': True,
  17. })
  18. def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]:
  19. # x: B 3 224 224
  20. if self.training:
  21. out, aux = torchvision.models.Inception3.forward(self, x) # B 1
  22. out, aux = out.flatten(), aux.flatten() # B
  23. else:
  24. out = torchvision.models.Inception3.forward(self, x).flatten() # B
  25. aux = None
  26. output = dict()
  27. output['positive_class_probability'] = out
  28. if f'{self._tag}' not in gts:
  29. return output
  30. gt = gts[f'{self._tag}']
  31. r""" Class weighted loss """
  32. loss = torch.mean(torch.stack(tuple(
  33. F.binary_cross_entropy(out[gt == i], gt[gt == i]) for i in gt.unique()
  34. )))
  35. output['loss'] = loss
  36. return output
  37. """ INTERPRETATION """
  38. @property
  39. def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]:
  40. """
  41. :return: input module for interpretation
  42. """
  43. return ['x']