123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import os
- import sys
- import pandas as pd
-
- PROJ_DIR = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..')))
-
- sys.path.insert(0, PROJ_DIR)
- from drug.models import GCN
- from drug.datasets import DDInteractionDataset
- from model.utils import get_FP_by_negative_index
- from const import Drug2FP_FILE
-
-
-
- class Connector(nn.Module):
- def __init__(self, gpu_id=None):
- self.gpu_id = gpu_id
- super(Connector, self).__init__()
- # self.ddiDataset = DDInteractionDataset(gpu_id = gpu_id)
- self.gcn = None
- self.drug2FP_df = pd.read_csv(Drug2FP_FILE)
-
- #Cell line features
- # np.load('cell_feat.npy')
-
- def forward(self, drug1_idx, drug2_idx, cell_feat, subgraph):
- if self.gcn == None:
- # print("here is for initializing the GCN. num_features: ", subgraph.num_features)
- self.gcn = GCN(subgraph.num_features, subgraph.num_features // 2)
- # print("this is subgraph in connector model forward: --------------")
- # print(subgraph)
- # graph.get().x --> DDInteractionDataset
- # subgraph = graph.get() --> Data
- x = subgraph.x
- edge_index = subgraph.edge_index
- x = self.gcn(x, edge_index)
-
- # print("node local indices:")
- node_indices = edge_index.flatten().unique()
- # print(node_indices)
-
- # print("-----------------------")
- # print("node global indices:")
- node_indices = subgraph.n_id
- if self.gpu_id is not None:
- node_indices = node_indices.cuda(self.gpu_id)
- # print(node_indices)
-
- drug1_idx = torch.flatten(drug1_idx)
- drug2_idx = torch.flatten(drug2_idx)
- #drug1_feat = x[drug1_idx]
- #drug2_feat = x[drug2_idx]
- drug1_feat = torch.empty((len(drug1_idx), len(x[0])))
- drug2_feat = torch.empty((len(drug2_idx), len(x[0])))
- for index, element in enumerate(drug1_idx):
- x_element = element
- if element >= 0:
- x_element = (node_indices == element).nonzero().squeeze()
- drug1_feat[index] = (x[x_element])
- for index, element in enumerate(drug2_idx):
- x_element = element
- if element >= 0:
- x_element = (node_indices == element).nonzero().squeeze()
- drug2_feat[index] = (x[x_element])
- if self.gpu_id is not None:
- drug1_feat = drug1_feat.cuda(self.gpu_id)
- drug2_feat = drug2_feat.cuda(self.gpu_id)
- for i, x in enumerate(drug1_idx):
- if x < 0:
- drug1_feat[i] = get_FP_by_negative_index(x,self.drug2FP_df)
- for i, x in enumerate(drug2_idx):
- if x < 0:
- drug2_feat[i] = get_FP_by_negative_index(x,self.drug2FP_df)
- feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
- return feat
-
-
- class MLP(nn.Module):
- def __init__(self, input_size: int, hidden_size: int, gpu_id=None):
- super(MLP, self).__init__()
- self.layers = nn.Sequential(
- nn.Linear(input_size, hidden_size),
- nn.ReLU(),
- nn.BatchNorm1d(hidden_size),
- nn.Linear(hidden_size, hidden_size // 2),
- nn.ReLU(),
- nn.BatchNorm1d(hidden_size // 2),
- nn.Linear(hidden_size // 2, 1)
- )
-
- self.connector = Connector(gpu_id)
-
- # prev input: self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor, subgraph: related subgraph for the batch
- def forward(self, drug1_idx, drug2_idx, cell_feat, subgraph):
- feat = self.connector(drug1_idx, drug2_idx, cell_feat, subgraph)
- out = self.layers(feat)
- return out
-
-
- # other PRODeepSyn models have been deleted for now
|