Browse Source

reduce making fastsynergydataset time

main
MahsaYazdani 1 year ago
parent
commit
1f6e67e818
2 changed files with 17 additions and 10 deletions
  1. 5
    3
      predictor/model/datasets.py
  2. 12
    7
      predictor/model/utils.py

+ 5
- 3
predictor/model/datasets.py View File

@@ -4,7 +4,7 @@ import torch
import random

from torch.utils.data import Dataset
from .utils import read_map, get_index_by_name
from .utils import read_map, read_files, get_index_by_name

class FastTensorDataLoader:
"""
@@ -65,14 +65,16 @@ class FastSynergyDataset(Dataset):
self.train = train
valid_drugs = set(self.drug2id.keys())
valid_cells = set(self.cell2id.keys())
files_dict = read_files()

with open(synergy_score_file, 'r') as f:
f.readline()
for line in f:
drug1, drug2, cellname, score, fold = line.rstrip().split('\t')
if drug1 in valid_drugs and drug2 in valid_drugs and cellname in valid_cells:
if int(fold) in use_folds:
drug1_id = get_index_by_name(drug1)
drug2_id = get_index_by_name(drug2)
drug1_id = get_index_by_name(drug1,files_dict)
drug2_id = get_index_by_name(drug2,files_dict)
sample = [
# TODO: specify drug_feat
# drug1_feat + drug2_feat + cell_feat + score

+ 12
- 7
predictor/model/utils.py View File

@@ -5,6 +5,8 @@ import random
import pandas as pd
import numpy as np

from const import SYNERGY_FILE, CELL_FEAT_FILE, CELL2ID_FILE, OUTPUT_DIR, DRUGNAME_2_DRUGBANKID_FILE, DRUGN2ID_FILE

def calc_stat(numbers):
mu = sum(numbers) / len(numbers)
sigma = (sum([(x - mu) ** 2 for x in numbers]) / len(numbers)) ** 0.5
@@ -80,20 +82,23 @@ def random_split_indices(n_samples, train_rate: float = None, test_rate: float =
return train_indices, test_indices

# --------------------------------------------------------------------- Our Part:
def read_files():
drug_name2drugbank_id_df = pd.read_csv(DRUGNAME_2_DRUGBANKID_FILE, sep='\s+')
drug2id_df = pd.read_csv(DRUGN2ID_FILE , sep='\t')

return {'drug_name2drugbank_id_df': drug_name2drugbank_id_df, 'drug2id_df': drug2id_df}

def get_index_by_name(drug_name):
project_path = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..')))
def get_index_by_name(drug_name, files_dict = None):
drugname2drugbankid_file = os.path.join(project_path, 'drug/data/drugname2drugbankid.tsv')
drug_name2drugbank_id_df = pd.read_csv(drugname2drugbankid_file, sep='\s+')
if files_dict == None:
files_dict = read_files()
drug_name2drugbank_id_df = files_dict['drug_name2drugbank_id_df']
drug2id_df = files_dict['drug2id_df']

drug_bank_id = drug_name2drugbank_id_df[drug_name2drugbank_id_df['drug_name'] == drug_name].drug_bank_id.item()

negative_index = list(drug_name2drugbank_id_df['drug_name']).index(drug_name)

drug2id_file = os.path.join(project_path, 'drug/data/DDI/DrugBank/raw/', 'drug2id.tsv')
drug2id_df = pd.read_csv(drug2id_file , sep='\t')
row = drug2id_df[drug2id_df['DrugBank_id'] == drug_bank_id]
if row.empty:
drug_index = - negative_index

Loading…
Cancel
Save