12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- import typing
- from typing import Dict, Iterable
- from collections import OrderedDict
-
- import torch
- from torch.nn import functional as F
- import torchvision
-
- from ..lap_inception import LAPInception
-
-
- class CelebALAPInception(LAPInception):
-
- def __init__(self, tag:str, aux_weight: float, pool_factory, adaptive_pool_factory):
- super().__init__(aux_weight, n_classes=1, pool_factory=pool_factory, adaptive_pool_factory=adaptive_pool_factory)
- self._tag = tag
-
- @property
- def additional_kwargs(self) -> typing.OrderedDict[str, bool]:
- r""" Returns a dictionary from additional `kwargs` names to their optionality """
- return OrderedDict({
- f'{self._tag}': True,
- })
-
- def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]:
- # x: B 3 224 224
-
- if self.training:
- out, aux = torchvision.models.Inception3.forward(self, x) # B 1
- out, aux = out.flatten(), aux.flatten() # B
- else:
- out = torchvision.models.Inception3.forward(self, x).flatten() # B
- aux = None
-
- output = dict()
- output['positive_class_probability'] = out
-
- if f'{self._tag}' not in gts:
- return output
-
- gt = gts[f'{self._tag}']
-
- r""" Class weighted loss """
- loss = torch.mean(torch.stack(tuple(
- F.binary_cross_entropy(out[gt == i], gt[gt == i]) for i in gt.unique()
- )))
-
- output['loss'] = loss
-
- return output
-
- """ INTERPRETATION """
-
- @property
- def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]:
- """
- :return: input module for interpretation
- """
- return ['x']
-
-
|