You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

thyroid_ml_model.py 1.4KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import torch
  2. import torchvision
  3. from torch import nn
  4. class ThyroidClassificationModel(nn.Module):
  5. def __init__(self, base_model):
  6. super().__init__()
  7. self.base_model = base_model
  8. self.classifier = nn.Sequential(
  9. nn.Linear(1000, 500),
  10. nn.BatchNorm1d(500),
  11. nn.ReLU(),
  12. nn.Linear(500, 100),
  13. nn.BatchNorm1d(100),
  14. nn.ReLU(),
  15. nn.Linear(100, 2),
  16. nn.BatchNorm1d(2),
  17. nn.Softmax(dim=-1)
  18. )
  19. self._is_inception3 = type(base_model) == torchvision.models.inception.Inception3
  20. if self._is_inception3:
  21. self.classifier2 = nn.Sequential(
  22. nn.Linear(1000, 500),
  23. nn.BatchNorm1d(500),
  24. nn.ReLU(),
  25. nn.Linear(500, 100),
  26. nn.BatchNorm1d(100),
  27. nn.ReLU(),
  28. nn.Linear(100, 2),
  29. nn.BatchNorm1d(2),
  30. nn.Softmax(dim=-1)
  31. )
  32. def forward(self, x, validate=False):
  33. output = self.base_model(x.float())
  34. if self._is_inception3 and not validate:
  35. return self.classifier(output[0]), self.classifier2(output[1])
  36. return self.classifier(output)
  37. def save_model(self, path):
  38. torch.save(self.state_dict(), path)
  39. def load_model(self, path):
  40. self.load_state_dict(torch.load(path))
  41. self.eval()
  42. return self