You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

org_inception.py 1.2KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. from typing import Dict
  2. import torch
  3. from torch.nn import functional as F
  4. import torchvision
  5. from ..tv_inception import Inception3
  6. class RSNAORGInception(Inception3):
  7. def __init__(self, aux_weight: float):
  8. super().__init__(aux_weight, n_classes=1)
  9. def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]:
  10. x = x.repeat_interleave(3, dim=1) # B 3 224 224
  11. if self.training:
  12. out, aux = torchvision.models.Inception3.forward(self, x) # B 1
  13. out, aux = out.flatten(), aux.flatten() # B
  14. else:
  15. out = torchvision.models.Inception3.forward(self, x).flatten() # B
  16. aux = None
  17. if y is not None:
  18. main_loss = F.binary_cross_entropy(out, y)
  19. if aux is not None:
  20. aux_loss = F.binary_cross_entropy(aux, y)
  21. loss = (main_loss + self.aux_weight * aux_loss) / (1 + self.aux_weight)
  22. else:
  23. loss = main_loss
  24. return {
  25. 'positive_class_probability': out,
  26. 'loss': loss
  27. }
  28. return {
  29. 'positive_class_probability': out
  30. }