@@ -20,7 +20,7 @@ DRUGN2ID_FILE = os.path.join(DRUG_DATA_DIR, 'DDI\DrugBank\\raw\drug2id.tsv') | |||
DRUGNAME_2_DRUGBANKID_FILE = os.path.join(DRUG_DATA_DIR, 'drugname2drugbankid.tsv') | |||
CELL_FEAT_FILE = os.path.join(CELL_DATA_DIR, 'cell_feat.npy') | |||
CELL2ID_FILE = os.path.join(CELL_DATA_DIR, 'cell2id.tsv') | |||
Drug2FP_FILE = os.path.join(PROJ_DIR, 'drug/data/drug2FP_synergy.csv') | |||
@@ -3,6 +3,7 @@ import torch.nn as nn | |||
import torch.nn.functional as F | |||
import os | |||
import sys | |||
import pandas as pd | |||
PROJ_DIR = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..'))) | |||
@@ -10,6 +11,7 @@ sys.path.insert(0, PROJ_DIR) | |||
from drug.models import GCN | |||
from drug.datasets import DDInteractionDataset | |||
from model.utils import get_FP_by_negative_index | |||
from const import Drug2FP_FILE | |||
@@ -19,6 +21,7 @@ class Connector(nn.Module): | |||
super(Connector, self).__init__() | |||
# self.ddiDataset = DDInteractionDataset(gpu_id = gpu_id) | |||
self.gcn = None | |||
self.drug2FP_df = pd.read_csv(Drug2FP_FILE) | |||
#Cell line features | |||
# np.load('cell_feat.npy') | |||
@@ -67,10 +70,10 @@ class Connector(nn.Module): | |||
drug2_feat = drug2_feat.cuda(self.gpu_id) | |||
for i, x in enumerate(drug1_idx): | |||
if x < 0: | |||
drug1_feat[i] = get_FP_by_negative_index(x) | |||
drug1_feat[i] = get_FP_by_negative_index(x,self.drug2FP_df) | |||
for i, x in enumerate(drug2_idx): | |||
if x < 0: | |||
drug2_feat[i] = get_FP_by_negative_index(x) | |||
drug2_feat[i] = get_FP_by_negative_index(x,self.drug2FP_df) | |||
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) | |||
return feat | |||
@@ -5,7 +5,7 @@ import random | |||
import pandas as pd | |||
import numpy as np | |||
from const import SYNERGY_FILE, CELL_FEAT_FILE, CELL2ID_FILE, OUTPUT_DIR, DRUGNAME_2_DRUGBANKID_FILE, DRUGN2ID_FILE | |||
from const import DRUGNAME_2_DRUGBANKID_FILE, DRUGN2ID_FILE, Drug2FP_FILE | |||
def calc_stat(numbers): | |||
mu = sum(numbers) / len(numbers) | |||
@@ -107,13 +107,13 @@ def get_index_by_name(drug_name, files_dict = None): | |||
return drug_index | |||
def get_FP_by_negative_index(index): | |||
def get_FP_by_negative_index(index, drug2FP_df = None): | |||
index = index.item() | |||
project_path = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |||
drug2FP_file = os.path.join(project_path, 'drug/data/drug2FP_synergy.csv') | |||
drug2FP_df = pd.read_csv(drug2FP_file) | |||
if drug2FP_df.empty: | |||
print("load Drug2FP_FILE") | |||
drug2FP_df = pd.read_csv(Drug2FP_FILE) | |||
array = np.array(list(drug2FP_df.iloc[-index])[1:]) | |||
return torch.tensor(array, dtype=torch.float32) |