You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

models.py 1.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import os
  5. import sys
  6. PROJ_DIR = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..')))
  7. sys.path.insert(0, PROJ_DIR)
  8. from drug.models import GCN
  9. from drug.datasets import DDInteractionDataset
  10. class Connector(nn.Module):
  11. def __init__(self):
  12. super(Connector, self).__init__()
  13. self.ddiDataset = DDInteractionDataset()
  14. self.gcn = GCN(self.ddiDataset.num_features, self.ddiDataset.num_features // 2)
  15. #Cell line features
  16. # np.load('cell_feat.npy')
  17. def forward(self, drug1_idx, drug2_idx, cell_feat):
  18. x = self.ddiDataset.get().x
  19. edge_index = self.ddiDataset.get().edge_index
  20. x = self.gcn(x, edge_index)
  21. drug1_feat = x[drug1_idx]
  22. drug2_feat = x[drug2_idx]
  23. feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
  24. return feat
  25. class MLP(nn.Module):
  26. def __init__(self, input_size: int, hidden_size: int):
  27. super(MLP, self).__init__()
  28. self.layers = nn.Sequential(
  29. nn.Linear(input_size, hidden_size),
  30. nn.ReLU(),
  31. nn.BatchNorm1d(hidden_size),
  32. nn.Linear(hidden_size, hidden_size // 2),
  33. nn.ReLU(),
  34. nn.BatchNorm1d(hidden_size // 2),
  35. nn.Linear(hidden_size // 2, 1)
  36. )
  37. self.connector = Connector()
  38. def forward(self, drug1_idx, drug2_idx, cell_feat): # prev input: self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor
  39. feat = self.connector(drug1_idx, drug2_idx, cell_feat)
  40. out = self.layers(feat)
  41. return out
  42. # other PRODeepSyn models have been deleted for now