You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 3.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import os
  2. import torch
  3. import json
  4. import random
  5. import pandas as pd
  6. def calc_stat(numbers):
  7. mu = sum(numbers) / len(numbers)
  8. sigma = (sum([(x - mu) ** 2 for x in numbers]) / len(numbers)) ** 0.5
  9. return mu, sigma
  10. def conf_inv(mu, sigma, n):
  11. delta = 2.776 * sigma / (n ** 0.5) # 95%
  12. return mu - delta, mu + delta
  13. def arg_min(lst):
  14. m = float('inf')
  15. idx = 0
  16. for i, v in enumerate(lst):
  17. if v < m:
  18. m = v
  19. idx = i
  20. return m, idx
  21. def save_best_model(state_dict, model_dir: str, best_epoch: int, keep: int):
  22. save_to = os.path.join(model_dir, '{}.pkl'.format(best_epoch))
  23. torch.save(state_dict, save_to)
  24. model_files = [f for f in os.listdir(model_dir) if os.path.splitext(f)[-1] == '.pkl']
  25. epochs = [int(os.path.splitext(f)[0]) for f in model_files if str.isdigit(f[0])]
  26. outdated = sorted(epochs, reverse=True)[keep:]
  27. for n in outdated:
  28. os.remove(os.path.join(model_dir, '{}.pkl'.format(n)))
  29. def find_best_model(model_dir: str):
  30. model_files = [f for f in os.listdir(model_dir) if os.path.splitext(f)[-1] == '.pkl']
  31. epochs = [int(os.path.splitext(f)[0]) for f in model_files if str.isdigit(f[0])]
  32. best_epoch = max(epochs)
  33. return os.path.join(model_dir, '{}.pkl'.format(best_epoch))
  34. def save_args(args, save_to: str):
  35. args_dict = args.__dict__
  36. with open(save_to, 'w') as f:
  37. json.dump(args_dict, f, indent=2)
  38. def read_map(map_file, keep_str = False):
  39. d = {}
  40. print(map_file)
  41. with open(map_file, 'r') as f:
  42. f.readline()
  43. for line in f:
  44. k, v = line.rstrip().split()
  45. if keep_str:
  46. d[k] = v
  47. else:
  48. d[k] = int(v)
  49. return d
  50. def random_split_indices(n_samples, train_rate: float = None, test_rate: float = None):
  51. if train_rate is not None and (train_rate < 0 or train_rate > 1):
  52. raise ValueError("train rate should be in [0, 1], found {}".format(train_rate))
  53. elif test_rate is not None:
  54. if test_rate < 0 or test_rate > 1:
  55. raise ValueError("test rate should be in [0, 1], found {}".format(test_rate))
  56. train_rate = 1 - test_rate
  57. elif train_rate is None and test_rate is None:
  58. raise ValueError("Either train_rate or test_rate should be given.")
  59. evidence = list(range(n_samples))
  60. train_size = int(len(evidence) * train_rate)
  61. random.shuffle(evidence)
  62. train_indices = evidence[:train_size]
  63. test_indices = evidence[train_size:]
  64. return train_indices, test_indices
  65. # --------------------------------------------------------------------- Our Part:
  66. def get_index_by_name(drug_name):
  67. project_path = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..')))
  68. drugname2drugbankid_file = os.path.join(project_path, 'drug/data/drugname2drugbankid.tsv')
  69. drug_name2drugbank_id_df = pd.read_csv(drugname2drugbankid_file, sep='\s+')
  70. drug_bank_id = drug_name2drugbank_id_df[drug_name2drugbank_id_df['drug_name'] == drug_name].drug_bank_id.item()
  71. drug2id_file = os.path.join(project_path, 'drug/data/DDI/DrugBank/raw/', 'drug2id.tsv')
  72. drug2id_df = pd.read_csv(drug2id_file , sep='\t')
  73. row = drug2id_df[drug2id_df['DrugBank_id'] == drug_bank_id]
  74. if row.empty:
  75. drug_index = -1
  76. else:
  77. drug_index = row.node_index.item()
  78. return drug_index