You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 3.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import os
  2. import torch
  3. import json
  4. import random
  5. import pandas as pd
  6. import numpy as np
  7. from const import SYNERGY_FILE, CELL_FEAT_FILE, CELL2ID_FILE, OUTPUT_DIR, DRUGNAME_2_DRUGBANKID_FILE, DRUG2ID_FILE
  8. project_path = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
  9. drug2FP_file = os.path.join(project_path, 'drug/data/drug2FP_synergy.csv')
  10. #drug2FP_df = pd.read_csv(drug2FP_file)
  11. def calc_stat(numbers):
  12. mu = sum(numbers) / len(numbers)
  13. sigma = (sum([(x - mu) ** 2 for x in numbers]) / len(numbers)) ** 0.5
  14. return mu, sigma
  15. def conf_inv(mu, sigma, n):
  16. delta = 2.776 * sigma / (n ** 0.5) # 95%
  17. return mu - delta, mu + delta
  18. def arg_min(lst):
  19. m = float('inf')
  20. idx = 0
  21. for i, v in enumerate(lst):
  22. if v < m:
  23. m = v
  24. idx = i
  25. return m, idx
  26. def save_best_model(state_dict, model_dir: str, best_epoch: int, keep: int):
  27. save_to = os.path.join(model_dir, '{}.pkl'.format(best_epoch))
  28. torch.save(state_dict, save_to)
  29. model_files = [f for f in os.listdir(model_dir) if os.path.splitext(f)[-1] == '.pkl']
  30. epochs = [int(os.path.splitext(f)[0]) for f in model_files if str.isdigit(f[0])]
  31. outdated = sorted(epochs, reverse=True)[keep:]
  32. for n in outdated:
  33. os.remove(os.path.join(model_dir, '{}.pkl'.format(n)))
  34. def find_best_model(model_dir: str):
  35. model_files = [f for f in os.listdir(model_dir) if os.path.splitext(f)[-1] == '.pkl']
  36. epochs = [int(os.path.splitext(f)[0]) for f in model_files if str.isdigit(f[0])]
  37. best_epoch = max(epochs)
  38. return os.path.join(model_dir, '{}.pkl'.format(best_epoch))
  39. def save_args(args, save_to: str):
  40. args_dict = args.__dict__
  41. with open(save_to, 'w') as f:
  42. json.dump(args_dict, f, indent=2)
  43. def read_map(map_file, keep_str = False):
  44. d = {}
  45. print(map_file)
  46. with open(map_file, 'r') as f:
  47. f.readline()
  48. for line in f:
  49. k, v = line.rstrip().split("\t")
  50. if keep_str:
  51. d[k] = v
  52. else:
  53. d[k] = int(v)
  54. return d
  55. def random_split_indices(n_samples, train_rate: float = None, test_rate: float = None):
  56. if train_rate is not None and (train_rate < 0 or train_rate > 1):
  57. raise ValueError("train rate should be in [0, 1], found {}".format(train_rate))
  58. elif test_rate is not None:
  59. if test_rate < 0 or test_rate > 1:
  60. raise ValueError("test rate should be in [0, 1], found {}".format(test_rate))
  61. train_rate = 1 - test_rate
  62. elif train_rate is None and test_rate is None:
  63. raise ValueError("Either train_rate or test_rate should be given.")
  64. evidence = list(range(n_samples))
  65. train_size = int(len(evidence) * train_rate)
  66. random.shuffle(evidence)
  67. train_indices = evidence[:train_size]
  68. test_indices = evidence[train_size:]
  69. return train_indices, test_indices
  70. # --------------------------------------------------------------------- Our Part:
  71. def read_files():
  72. #drug_name2drugbank_id_df = pd.read_csv(DRUGNAME_2_DRUGBANKID_FILE, sep='\s+')
  73. drug2id_df = pd.read_csv(DRUG2ID_FILE , sep='\t')
  74. return {'drug2id_df': drug2id_df}
  75. def get_index_by_name(drug_name, files_dict = None):
  76. if files_dict == None:
  77. files_dict = read_files()
  78. drug2id_df = files_dict['drug2id_df']
  79. row = drug2id_df[drug2id_df['drug_name'] == drug_name]
  80. drug_index = row.id.item()
  81. return drug_index
  82. def get_FP_by_negative_index(index):
  83. index = index.item()
  84. #array = np.array(list(drug2FP_df.iloc[-index])[1:])
  85. #return torch.tensor(array, dtype=torch.float32)
  86. def get_DTI_data(dti_feat_file):
  87. df = pd.read_csv(dti_feat_file)
  88. DTI = {}
  89. for row in df.iterrows():
  90. DTI[row[1][0]] = torch.FloatTensor(row[1][1:])
  91. return DTI
  92. def get_FP_data(df_feat_file):
  93. df = pd.read_csv(df_feat_file)
  94. DTI = {}
  95. for row in df.iterrows():
  96. DTI[row[1][0]] = torch.FloatTensor(row[1][1:])
  97. return DTI