123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import os
- import sys
-
-
- PROJ_DIR = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..')))
-
- sys.path.insert(0, PROJ_DIR)
- from drug.models import GCN
- from drug.datasets import DDInteractionDataset
-
-
- class Connector(nn.Module):
- def __init__(self):
- super(Connector, self).__init__()
-
- self.ddiDataset = DDInteractionDataset()
- self.gcn = GCN(self.ddiDataset.num_features, self.ddiDataset.num_features // 2)
-
- #Cell line features
- # np.load('cell_feat.npy')
-
- def forward(self, drug1_idx, drug2_idx, cell_feat):
- x = self.ddiDataset.get().x
- edge_index = self.ddiDataset.get().edge_index
- x = self.gcn(x, edge_index)
- drug1_feat = x[drug1_idx]
- drug2_feat = x[drug2_idx]
- feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
-
- return feat
-
-
- class MLP(nn.Module):
- def __init__(self, input_size: int, hidden_size: int):
- super(MLP, self).__init__()
- self.layers = nn.Sequential(
- nn.Linear(input_size, hidden_size),
- nn.ReLU(),
- nn.BatchNorm1d(hidden_size),
- nn.Linear(hidden_size, hidden_size // 2),
- nn.ReLU(),
- nn.BatchNorm1d(hidden_size // 2),
- nn.Linear(hidden_size // 2, 1)
- )
-
- self.connector = Connector()
-
- def forward(self, drug1_idx, drug2_idx, cell_feat): # prev input: self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor
- feat = self.connector(drug1_idx, drug2_idx, cell_feat)
- out = self.layers(feat)
- return out
-
-
- # other PRODeepSyn models have been deleted for now
|