12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- import typing
- from typing import Dict, List, Iterable
- from collections import OrderedDict
-
- import torch
- from torch import nn
- from torch.nn import functional as F
- import torchvision
-
- from ..tv_inception import Inception3
-
-
- class CelebAORGInception(Inception3):
-
- def __init__(self, tag: str, aux_weight: float):
- super().__init__(aux_weight, n_classes=1)
- self._tag = tag
-
- @property
- def additional_kwargs(self) -> typing.OrderedDict[str, bool]:
- r""" Returns a dictionary from additional `kwargs` names to their optionality """
- return OrderedDict({
- f'{self._tag}': True,
- })
-
- def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]:
-
- if self.training:
- out, aux = torchvision.models.Inception3.forward(self, x) # B 1
- out, aux = out.flatten(), aux.flatten() # B
- else:
- out = torchvision.models.Inception3.forward(self, x).flatten() # B
- aux = None
-
- output = dict()
- output['positive_class_probability'] = out
-
- if f'{self._tag}' not in gts:
- return output
-
- gt = gts[f'{self._tag}']
-
- r""" Class weighted loss """
- loss = torch.mean(torch.stack(tuple(
- F.binary_cross_entropy(out[gt == i], gt[gt == i]) for i in gt.unique()
- )))
-
- output['loss'] = loss
-
- return output
-
- @property
- def target_conv_layers(self) -> List[nn.Module]:
- return [
- self.Mixed_7c.branch1x1,
- self.Mixed_7c.branch3x3_2a, self.Mixed_7c.branch3x3_2b,
- self.Mixed_7c.branch3x3dbl_3a, self.Mixed_7c.branch3x3dbl_3b,
- self.Mixed_7c.branch_pool
- ]
-
- @property
- def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]:
- return ['x']
-
- def get_categorical_probabilities(self, *inputs, **kwargs) -> torch.Tensor:
- p = self.forward(*inputs, **kwargs)['positive_class_probability']
- return torch.stack([1 - p, p], dim=1)
|