| @@ -20,7 +20,7 @@ DRUGN2ID_FILE = os.path.join(DRUG_DATA_DIR, 'DDI\DrugBank\\raw\drug2id.tsv') | |||
| DRUGNAME_2_DRUGBANKID_FILE = os.path.join(DRUG_DATA_DIR, 'drugname2drugbankid.tsv') | |||
| CELL_FEAT_FILE = os.path.join(CELL_DATA_DIR, 'cell_feat.npy') | |||
| CELL2ID_FILE = os.path.join(CELL_DATA_DIR, 'cell2id.tsv') | |||
| Drug2FP_FILE = os.path.join(PROJ_DIR, 'drug/data/drug2FP_synergy.csv') | |||
| @@ -3,6 +3,7 @@ import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import os | |||
| import sys | |||
| import pandas as pd | |||
| PROJ_DIR = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..'))) | |||
| @@ -10,6 +11,7 @@ sys.path.insert(0, PROJ_DIR) | |||
| from drug.models import GCN | |||
| from drug.datasets import DDInteractionDataset | |||
| from model.utils import get_FP_by_negative_index | |||
| from const import Drug2FP_FILE | |||
| @@ -19,6 +21,7 @@ class Connector(nn.Module): | |||
| super(Connector, self).__init__() | |||
| # self.ddiDataset = DDInteractionDataset(gpu_id = gpu_id) | |||
| self.gcn = None | |||
| self.drug2FP_df = pd.read_csv(Drug2FP_FILE) | |||
| #Cell line features | |||
| # np.load('cell_feat.npy') | |||
| @@ -67,10 +70,10 @@ class Connector(nn.Module): | |||
| drug2_feat = drug2_feat.cuda(self.gpu_id) | |||
| for i, x in enumerate(drug1_idx): | |||
| if x < 0: | |||
| drug1_feat[i] = get_FP_by_negative_index(x) | |||
| drug1_feat[i] = get_FP_by_negative_index(x,self.drug2FP_df) | |||
| for i, x in enumerate(drug2_idx): | |||
| if x < 0: | |||
| drug2_feat[i] = get_FP_by_negative_index(x) | |||
| drug2_feat[i] = get_FP_by_negative_index(x,self.drug2FP_df) | |||
| feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) | |||
| return feat | |||
| @@ -5,7 +5,7 @@ import random | |||
| import pandas as pd | |||
| import numpy as np | |||
| from const import SYNERGY_FILE, CELL_FEAT_FILE, CELL2ID_FILE, OUTPUT_DIR, DRUGNAME_2_DRUGBANKID_FILE, DRUGN2ID_FILE | |||
| from const import DRUGNAME_2_DRUGBANKID_FILE, DRUGN2ID_FILE, Drug2FP_FILE | |||
| def calc_stat(numbers): | |||
| mu = sum(numbers) / len(numbers) | |||
| @@ -107,13 +107,13 @@ def get_index_by_name(drug_name, files_dict = None): | |||
| return drug_index | |||
| def get_FP_by_negative_index(index): | |||
| def get_FP_by_negative_index(index, drug2FP_df = None): | |||
| index = index.item() | |||
| project_path = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |||
| drug2FP_file = os.path.join(project_path, 'drug/data/drug2FP_synergy.csv') | |||
| drug2FP_df = pd.read_csv(drug2FP_file) | |||
| if drug2FP_df.empty: | |||
| print("load Drug2FP_FILE") | |||
| drug2FP_df = pd.read_csv(Drug2FP_FILE) | |||
| array = np.array(list(drug2FP_df.iloc[-index])[1:]) | |||
| return torch.tensor(array, dtype=torch.float32) | |||