123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- import os
- import torch
- import json
- import random
- import pandas as pd
- import numpy as np
-
- from const import SYNERGY_FILE, CELL_FEAT_FILE, CELL2ID_FILE, OUTPUT_DIR, DRUGNAME_2_DRUGBANKID_FILE, DRUGN2ID_FILE
-
- def calc_stat(numbers):
- mu = sum(numbers) / len(numbers)
- sigma = (sum([(x - mu) ** 2 for x in numbers]) / len(numbers)) ** 0.5
- return mu, sigma
-
-
- def conf_inv(mu, sigma, n):
- delta = 2.776 * sigma / (n ** 0.5) # 95%
- return mu - delta, mu + delta
-
-
- def arg_min(lst):
- m = float('inf')
- idx = 0
- for i, v in enumerate(lst):
- if v < m:
- m = v
- idx = i
- return m, idx
-
-
- def save_best_model(state_dict, model_dir: str, best_epoch: int, keep: int):
- save_to = os.path.join(model_dir, '{}.pkl'.format(best_epoch))
- torch.save(state_dict, save_to)
- model_files = [f for f in os.listdir(model_dir) if os.path.splitext(f)[-1] == '.pkl']
- epochs = [int(os.path.splitext(f)[0]) for f in model_files if str.isdigit(f[0])]
- outdated = sorted(epochs, reverse=True)[keep:]
- for n in outdated:
- os.remove(os.path.join(model_dir, '{}.pkl'.format(n)))
-
-
- def find_best_model(model_dir: str):
- model_files = [f for f in os.listdir(model_dir) if os.path.splitext(f)[-1] == '.pkl']
- epochs = [int(os.path.splitext(f)[0]) for f in model_files if str.isdigit(f[0])]
- best_epoch = max(epochs)
- return os.path.join(model_dir, '{}.pkl'.format(best_epoch))
-
-
- def save_args(args, save_to: str):
- args_dict = args.__dict__
- with open(save_to, 'w') as f:
- json.dump(args_dict, f, indent=2)
-
-
- def read_map(map_file, keep_str = False):
- d = {}
- print(map_file)
- with open(map_file, 'r') as f:
- f.readline()
- for line in f:
- k, v = line.rstrip().split()
- if keep_str:
- d[k] = v
- else:
- d[k] = int(v)
- return d
-
-
- def random_split_indices(n_samples, train_rate: float = None, test_rate: float = None):
- if train_rate is not None and (train_rate < 0 or train_rate > 1):
- raise ValueError("train rate should be in [0, 1], found {}".format(train_rate))
- elif test_rate is not None:
- if test_rate < 0 or test_rate > 1:
- raise ValueError("test rate should be in [0, 1], found {}".format(test_rate))
- train_rate = 1 - test_rate
- elif train_rate is None and test_rate is None:
- raise ValueError("Either train_rate or test_rate should be given.")
- evidence = list(range(n_samples))
- train_size = int(len(evidence) * train_rate)
- random.shuffle(evidence)
- train_indices = evidence[:train_size]
- test_indices = evidence[train_size:]
- return train_indices, test_indices
-
- # --------------------------------------------------------------------- Our Part:
- def read_files():
- drug_name2drugbank_id_df = pd.read_csv(DRUGNAME_2_DRUGBANKID_FILE, sep='\s+')
- drug2id_df = pd.read_csv(DRUGN2ID_FILE , sep='\t')
-
- return {'drug_name2drugbank_id_df': drug_name2drugbank_id_df, 'drug2id_df': drug2id_df}
-
- def get_index_by_name(drug_name, files_dict = None):
-
- if files_dict == None:
- files_dict = read_files()
- drug_name2drugbank_id_df = files_dict['drug_name2drugbank_id_df']
- drug2id_df = files_dict['drug2id_df']
-
- drug_bank_id = drug_name2drugbank_id_df[drug_name2drugbank_id_df['drug_name'] == drug_name].drug_bank_id.item()
-
- negative_index = list(drug_name2drugbank_id_df['drug_name']).index(drug_name)
-
- row = drug2id_df[drug2id_df['DrugBank_id'] == drug_bank_id]
- if row.empty:
- drug_index = - negative_index
- else:
- drug_index = row.node_index.item()
-
- return drug_index
-
- def get_FP_by_negative_index(index):
-
- index = index.item()
-
- project_path = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
- drug2FP_file = os.path.join(project_path, 'drug/data/drug2FP_synergy.csv')
- drug2FP_df = pd.read_csv(drug2FP_file)
-
- array = np.array(list(drug2FP_df.iloc[-index])[1:])
- return torch.tensor(array, dtype=torch.float32)
|