import torch import torch.nn as nn import torch.nn.functional as F import os import sys PROJ_DIR = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..'))) sys.path.insert(0, PROJ_DIR) from drug.models import GCN from drug.datasets import DDInteractionDataset from model.utils import get_FP_by_negative_index class Connector(nn.Module): def __init__(self, gpu_id=None): super(Connector, self).__init__() self.ddiDataset = DDInteractionDataset() self.gcn = GCN(self.ddiDataset.num_features, self.ddiDataset.num_features // 2, gpu_id) #Cell line features # np.load('cell_feat.npy') def forward(self, drug1_idx, drug2_idx, cell_feat): x = self.ddiDataset.get().x edge_index = self.ddiDataset.get().edge_index x = self.gcn(x, edge_index) drug1_idx = torch.flatten(drug1_idx) drug2_idx = torch.flatten(drug2_idx) drug1_feat = x[drug1_idx] drug2_feat = x[drug2_idx] for i, x in enumerate(drug1_idx): if x < 0: drug1_feat[i] = get_FP_by_negative_index(x) for i, x in enumerate(drug2_idx): if x < 0: drug2_feat[i] = get_FP_by_negative_index(x) feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) return feat class MLP(nn.Module): def __init__(self, input_size: int, hidden_size: int, gpu_id=None): super(MLP, self).__init__() self.layers = nn.Sequential( nn.Linear(input_size, hidden_size), nn.ReLU(), nn.BatchNorm1d(hidden_size), nn.Linear(hidden_size, hidden_size // 2), nn.ReLU(), nn.BatchNorm1d(hidden_size // 2), nn.Linear(hidden_size // 2, 1) ) self.connector = Connector(gpu_id) def forward(self, drug1_idx, drug2_idx, cell_feat): # prev input: self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor feat = self.connector(drug1_idx, drug2_idx, cell_feat) out = self.layers(feat) return out # other PRODeepSyn models have been deleted for now