|
|
@@ -4,13 +4,13 @@ import torch.nn.functional as F |
|
|
|
import os |
|
|
|
import sys |
|
|
|
import pandas as pd |
|
|
|
|
|
|
|
import time |
|
|
|
PROJ_DIR = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..'))) |
|
|
|
|
|
|
|
sys.path.insert(0, PROJ_DIR) |
|
|
|
from drug.models import GCN |
|
|
|
from drug.datasets import DDInteractionDataset |
|
|
|
from model.utils import get_FP_by_negative_index |
|
|
|
from model.utils import get_FP_by_negative_index, get_FP_by_negative_indices |
|
|
|
from const import Drug2FP_FILE |
|
|
|
|
|
|
|
|
|
|
@@ -55,25 +55,48 @@ class Connector(nn.Module): |
|
|
|
#drug2_feat = x[drug2_idx] |
|
|
|
drug1_feat = torch.empty((len(drug1_idx), len(x[0]))) |
|
|
|
drug2_feat = torch.empty((len(drug2_idx), len(x[0]))) |
|
|
|
for index, element in enumerate(drug1_idx): |
|
|
|
x_element = element |
|
|
|
if element >= 0: |
|
|
|
x_element = (node_indices == element).nonzero().squeeze() |
|
|
|
drug1_feat[index] = (x[x_element]) |
|
|
|
for index, element in enumerate(drug2_idx): |
|
|
|
x_element = element |
|
|
|
if element >= 0: |
|
|
|
x_element = (node_indices == element).nonzero().squeeze() |
|
|
|
drug2_feat[index] = (x[x_element]) |
|
|
|
|
|
|
|
print("x shape: ", x.size()) |
|
|
|
print("node_indices: ", node_indices.size()) |
|
|
|
|
|
|
|
start_time = time.time() |
|
|
|
# for index, element in enumerate(drug1_idx): |
|
|
|
# x_element = element |
|
|
|
# if element >= 0: |
|
|
|
# x_element = (node_indices == element).nonzero().squeeze() |
|
|
|
# drug1_feat[index] = (x[x_element]) |
|
|
|
# for index, element in enumerate(drug2_idx): |
|
|
|
# x_element = element |
|
|
|
# if element >= 0: |
|
|
|
# x_element = (node_indices == element).nonzero().squeeze() |
|
|
|
# drug2_feat[index] = (x[x_element]) |
|
|
|
|
|
|
|
mask_positive = (drug1_idx >= 0) |
|
|
|
x_elements_positive = (node_indices.unsqueeze(-1) == drug1_idx[mask_positive]).nonzero(as_tuple=True)[0] |
|
|
|
drug1_feat[mask_positive] = x[x_elements_positive] |
|
|
|
|
|
|
|
mask_negative = ~mask_positive |
|
|
|
drug1_feat[mask_negative] = get_FP_by_negative_indices(drug1_idx[mask_negative], self.drug2FP_df) |
|
|
|
|
|
|
|
mask_positive = (drug2_idx >= 0) |
|
|
|
x_elements_positive = (node_indices.unsqueeze(-1) == drug2_idx[mask_positive]).nonzero(as_tuple=True)[0] |
|
|
|
drug2_feat[mask_positive] = x[x_elements_positive] |
|
|
|
|
|
|
|
if self.gpu_id is not None: |
|
|
|
drug1_feat = drug1_feat.cuda(self.gpu_id) |
|
|
|
drug2_feat = drug2_feat.cuda(self.gpu_id) |
|
|
|
|
|
|
|
print("first: ", time.time() - start_time) |
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
for i, x in enumerate(drug1_idx): |
|
|
|
if x < 0: |
|
|
|
drug1_feat[i] = get_FP_by_negative_index(x,self.drug2FP_df) |
|
|
|
for i, x in enumerate(drug2_idx): |
|
|
|
if x < 0: |
|
|
|
drug2_feat[i] = get_FP_by_negative_index(x,self.drug2FP_df) |
|
|
|
print("second: ", time.time() - start_time) |
|
|
|
|
|
|
|
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) |
|
|
|
return feat |
|
|
|
|
|
|
@@ -95,7 +118,9 @@ class MLP(nn.Module): |
|
|
|
|
|
|
|
# prev input: self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor, subgraph: related subgraph for the batch |
|
|
|
def forward(self, drug1_idx, drug2_idx, cell_feat, subgraph): |
|
|
|
start = time.time() |
|
|
|
feat = self.connector(drug1_idx, drug2_idx, cell_feat, subgraph) |
|
|
|
print("Connector forward time: ", time.time() - start) |
|
|
|
out = self.layers(feat) |
|
|
|
return out |
|
|
|
|