You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lap_inception.py 1.3KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. from typing import Dict
  2. import torch
  3. from torch.nn import functional as F
  4. import torchvision
  5. from ..lap_inception import LAPInception
  6. class RSNALAPInception(LAPInception):
  7. def __init__(self, aux_weight: float, pool_factory, adaptive_pool_factory):
  8. super().__init__(aux_weight, n_classes=1, pool_factory=pool_factory, adaptive_pool_factory=adaptive_pool_factory)
  9. def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]:
  10. x = x.repeat_interleave(3, dim=1) # B 3 224 224
  11. if self.training:
  12. out, aux = torchvision.models.Inception3.forward(self, x) # B 1
  13. out, aux = out.flatten(), aux.flatten() # B
  14. else:
  15. out = torchvision.models.Inception3.forward(self, x).flatten() # B
  16. aux = None
  17. if y is not None:
  18. main_loss = F.binary_cross_entropy(out, y)
  19. if aux is not None:
  20. aux_loss = F.binary_cross_entropy(aux, y)
  21. loss = (main_loss + self.aux_weight * aux_loss) / (1 + self.aux_weight)
  22. else:
  23. loss = main_loss
  24. return {
  25. 'positive_class_probability': out,
  26. 'loss': loss
  27. }
  28. return {
  29. 'positive_class_probability': out
  30. }