Browse Source

reduce epoch time by loading file outside of the function call

main
MahsaYazdani 1 year ago
parent
commit
7ae5c9d665
3 changed files with 11 additions and 8 deletions
  1. 1
    1
      predictor/const.py
  2. 5
    2
      predictor/model/models.py
  3. 5
    5
      predictor/model/utils.py

+ 1
- 1
predictor/const.py View File

@@ -20,7 +20,7 @@ DRUGN2ID_FILE = os.path.join(DRUG_DATA_DIR, 'DDI\DrugBank\\raw\drug2id.tsv')
DRUGNAME_2_DRUGBANKID_FILE = os.path.join(DRUG_DATA_DIR, 'drugname2drugbankid.tsv')
CELL_FEAT_FILE = os.path.join(CELL_DATA_DIR, 'cell_feat.npy')
CELL2ID_FILE = os.path.join(CELL_DATA_DIR, 'cell2id.tsv')
Drug2FP_FILE = os.path.join(PROJ_DIR, 'drug/data/drug2FP_synergy.csv')




+ 5
- 2
predictor/model/models.py View File

@@ -3,6 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
import os
import sys
import pandas as pd

PROJ_DIR = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..')))

@@ -10,6 +11,7 @@ sys.path.insert(0, PROJ_DIR)
from drug.models import GCN
from drug.datasets import DDInteractionDataset
from model.utils import get_FP_by_negative_index
from const import Drug2FP_FILE



@@ -19,6 +21,7 @@ class Connector(nn.Module):
super(Connector, self).__init__()
# self.ddiDataset = DDInteractionDataset(gpu_id = gpu_id)
self.gcn = None
self.drug2FP_df = pd.read_csv(Drug2FP_FILE)
#Cell line features
# np.load('cell_feat.npy')
@@ -67,10 +70,10 @@ class Connector(nn.Module):
drug2_feat = drug2_feat.cuda(self.gpu_id)
for i, x in enumerate(drug1_idx):
if x < 0:
drug1_feat[i] = get_FP_by_negative_index(x)
drug1_feat[i] = get_FP_by_negative_index(x,self.drug2FP_df)
for i, x in enumerate(drug2_idx):
if x < 0:
drug2_feat[i] = get_FP_by_negative_index(x)
drug2_feat[i] = get_FP_by_negative_index(x,self.drug2FP_df)
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
return feat


+ 5
- 5
predictor/model/utils.py View File

@@ -5,7 +5,7 @@ import random
import pandas as pd
import numpy as np

from const import SYNERGY_FILE, CELL_FEAT_FILE, CELL2ID_FILE, OUTPUT_DIR, DRUGNAME_2_DRUGBANKID_FILE, DRUGN2ID_FILE
from const import DRUGNAME_2_DRUGBANKID_FILE, DRUGN2ID_FILE, Drug2FP_FILE

def calc_stat(numbers):
mu = sum(numbers) / len(numbers)
@@ -107,13 +107,13 @@ def get_index_by_name(drug_name, files_dict = None):

return drug_index

def get_FP_by_negative_index(index):
def get_FP_by_negative_index(index, drug2FP_df = None):

index = index.item()

project_path = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
drug2FP_file = os.path.join(project_path, 'drug/data/drug2FP_synergy.csv')
drug2FP_df = pd.read_csv(drug2FP_file)
if drug2FP_df.empty:
print("load Drug2FP_FILE")
drug2FP_df = pd.read_csv(Drug2FP_FILE)

array = np.array(list(drug2FP_df.iloc[-index])[1:])
return torch.tensor(array, dtype=torch.float32)

Loading…
Cancel
Save