You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

models.py 786B

1234567891011121314151617181920212223242526
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class MLP(nn.Module):
  5. def __init__(self, input_size: int, hidden_size: int):
  6. super(MLP, self).__init__()
  7. self.layers = nn.Sequential(
  8. nn.Linear(input_size, hidden_size),
  9. nn.ReLU(),
  10. nn.BatchNorm1d(hidden_size),
  11. nn.Linear(hidden_size, hidden_size // 2),
  12. nn.ReLU(),
  13. nn.BatchNorm1d(hidden_size // 2),
  14. nn.Linear(hidden_size // 2, 1)
  15. )
  16. def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor):
  17. feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
  18. out = self.layers(feat)
  19. return out
  20. # other PRODeepSyn models have been deleted for now