import torch import torch.nn as nn import torch.nn.functional as F import os import sys import pandas as pd import time PROJ_DIR = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..'))) sys.path.insert(0, PROJ_DIR) from drug.models import GCN from drug.datasets import DDInteractionDataset from model.utils import get_FP_by_negative_index, get_FP_by_negative_indices from const import Drug2FP_FILE class Connector(nn.Module): def __init__(self, gpu_id=None): self.gpu_id = gpu_id super(Connector, self).__init__() # self.ddiDataset = DDInteractionDataset(gpu_id = gpu_id) self.gcn = None self.drug2FP_df = pd.read_csv(Drug2FP_FILE) #Cell line features # np.load('cell_feat.npy') def forward(self, drug1_idx, drug2_idx, cell_feat, subgraph): if self.gcn == None: # print("here is for initializing the GCN. num_features: ", subgraph.num_features) self.gcn = GCN(subgraph.num_features, subgraph.num_features // 2) # print("this is subgraph in connector model forward: --------------") # print(subgraph) # graph.get().x --> DDInteractionDataset # subgraph = graph.get() --> Data x = subgraph.x edge_index = subgraph.edge_index x = self.gcn(x, edge_index) # print("node local indices:") node_indices = edge_index.flatten().unique() # print(node_indices) # print("-----------------------") # print("node global indices:") node_indices = subgraph.n_id if self.gpu_id is not None: node_indices = node_indices.cuda(self.gpu_id) # print(node_indices) drug1_idx = torch.flatten(drug1_idx) drug2_idx = torch.flatten(drug2_idx) #drug1_feat = x[drug1_idx] #drug2_feat = x[drug2_idx] drug1_feat = torch.empty((len(drug1_idx), len(x[0]))) drug2_feat = torch.empty((len(drug2_idx), len(x[0]))) print("x shape: ", x.size()) print("node_indices: ", node_indices.size()) start_time = time.time() # for index, element in enumerate(drug1_idx): # x_element = element # if element >= 0: # x_element = (node_indices == element).nonzero().squeeze() # drug1_feat[index] = (x[x_element]) # for index, element in enumerate(drug2_idx): # x_element = element # if element >= 0: # x_element = (node_indices == element).nonzero().squeeze() # drug2_feat[index] = (x[x_element]) mask_positive = (drug1_idx >= 0) x_elements_positive = (node_indices.unsqueeze(-1) == drug1_idx[mask_positive]).nonzero(as_tuple=True)[0] drug1_feat[mask_positive] = x[x_elements_positive] mask_negative = ~mask_positive drug1_feat[mask_negative] = get_FP_by_negative_indices(drug1_idx[mask_negative], self.drug2FP_df) mask_positive = (drug2_idx >= 0) x_elements_positive = (node_indices.unsqueeze(-1) == drug2_idx[mask_positive]).nonzero(as_tuple=True)[0] drug2_feat[mask_positive] = x[x_elements_positive] if self.gpu_id is not None: drug1_feat = drug1_feat.cuda(self.gpu_id) drug2_feat = drug2_feat.cuda(self.gpu_id) print("first: ", time.time() - start_time) start_time = time.time() for i, x in enumerate(drug1_idx): if x < 0: drug1_feat[i] = get_FP_by_negative_index(x,self.drug2FP_df) for i, x in enumerate(drug2_idx): if x < 0: drug2_feat[i] = get_FP_by_negative_index(x,self.drug2FP_df) print("second: ", time.time() - start_time) feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) return feat class MLP(nn.Module): def __init__(self, input_size: int, hidden_size: int, gpu_id=None): super(MLP, self).__init__() self.layers = nn.Sequential( nn.Linear(input_size, hidden_size), nn.ReLU(), nn.BatchNorm1d(hidden_size), nn.Linear(hidden_size, hidden_size // 2), nn.ReLU(), nn.BatchNorm1d(hidden_size // 2), nn.Linear(hidden_size // 2, 1) ) self.connector = Connector(gpu_id) # prev input: self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor, subgraph: related subgraph for the batch def forward(self, drug1_idx, drug2_idx, cell_feat, subgraph): start = time.time() feat = self.connector(drug1_idx, drug2_idx, cell_feat, subgraph) print("Connector forward time: ", time.time() - start) out = self.layers(feat) return out # other PRODeepSyn models have been deleted for now