| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 | import typing
from typing import Dict, Iterable
from collections import OrderedDict
import torch
from torch.nn import functional as F
import torchvision
from ..lap_inception import LAPInception
class CelebALAPInception(LAPInception):
    def __init__(self, tag:str, aux_weight: float, pool_factory, adaptive_pool_factory):
        super().__init__(aux_weight, n_classes=1, pool_factory=pool_factory, adaptive_pool_factory=adaptive_pool_factory)
        self._tag = tag
    @property
    def additional_kwargs(self) -> typing.OrderedDict[str, bool]:
        r""" Returns a dictionary from additional `kwargs` names to their optionality """
        return OrderedDict({
            f'{self._tag}': True,
        })
    def forward(self, x: torch.Tensor, **gts: torch.Tensor) -> Dict[str, torch.Tensor]:
        # x:    B   3   224 224
        if self.training:
            out, aux = torchvision.models.Inception3.forward(self, x)  # B 1
            out, aux = out.flatten(), aux.flatten()  # B
        else:
            out = torchvision.models.Inception3.forward(self, x).flatten()  # B
            aux = None
        output = dict()
        output['positive_class_probability'] = out
        if f'{self._tag}' not in gts:
            return output
        gt = gts[f'{self._tag}']
        r""" Class weighted loss """
        loss = torch.mean(torch.stack(tuple(
            F.binary_cross_entropy(out[gt == i], gt[gt == i]) for i in gt.unique()
        )))
        output['loss'] = loss
        return output
    """ INTERPRETATION """
    @property
    def ordered_placeholder_names_to_be_interpreted(self) -> Iterable[str]:
        """
        :return: input module for interpretation
        """
        return ['x']
    
 |