You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 3.2KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import os
  2. import torch
  3. import json
  4. import random
  5. import pandas as pd
  6. def calc_stat(numbers):
  7. mu = sum(numbers) / len(numbers)
  8. sigma = (sum([(x - mu) ** 2 for x in numbers]) / len(numbers)) ** 0.5
  9. return mu, sigma
  10. def conf_inv(mu, sigma, n):
  11. delta = 2.776 * sigma / (n ** 0.5) # 95%
  12. return mu - delta, mu + delta
  13. def arg_min(lst):
  14. m = float('inf')
  15. idx = 0
  16. for i, v in enumerate(lst):
  17. if v < m:
  18. m = v
  19. idx = i
  20. return m, idx
  21. def save_best_model(state_dict, model_dir: str, best_epoch: int, keep: int):
  22. save_to = os.path.join(model_dir, '{}.pkl'.format(best_epoch))
  23. torch.save(state_dict, save_to)
  24. model_files = [f for f in os.listdir(model_dir) if os.path.splitext(f)[-1] == '.pkl']
  25. epochs = [int(os.path.splitext(f)[0]) for f in model_files if str.isdigit(f[0])]
  26. outdated = sorted(epochs, reverse=True)[keep:]
  27. for n in outdated:
  28. os.remove(os.path.join(model_dir, '{}.pkl'.format(n)))
  29. def find_best_model(model_dir: str):
  30. model_files = [f for f in os.listdir(model_dir) if os.path.splitext(f)[-1] == '.pkl']
  31. epochs = [int(os.path.splitext(f)[0]) for f in model_files if str.isdigit(f[0])]
  32. best_epoch = max(epochs)
  33. return os.path.join(model_dir, '{}.pkl'.format(best_epoch))
  34. def save_args(args, save_to: str):
  35. args_dict = args.__dict__
  36. with open(save_to, 'w') as f:
  37. json.dump(args_dict, f, indent=2)
  38. def read_map(map_file):
  39. d = {}
  40. with open(map_file, 'r') as f:
  41. f.readline()
  42. for line in f:
  43. k, v = line.rstrip().split('\t')
  44. d[k] = int(v)
  45. return d
  46. def random_split_indices(n_samples, train_rate: float = None, test_rate: float = None):
  47. if train_rate is not None and (train_rate < 0 or train_rate > 1):
  48. raise ValueError("train rate should be in [0, 1], found {}".format(train_rate))
  49. elif test_rate is not None:
  50. if test_rate < 0 or test_rate > 1:
  51. raise ValueError("test rate should be in [0, 1], found {}".format(test_rate))
  52. train_rate = 1 - test_rate
  53. elif train_rate is None and test_rate is None:
  54. raise ValueError("Either train_rate or test_rate should be given.")
  55. evidence = list(range(n_samples))
  56. train_size = int(len(evidence) * train_rate)
  57. random.shuffle(evidence)
  58. train_indices = evidence[:train_size]
  59. test_indices = evidence[train_size:]
  60. return train_indices, test_indices
  61. # --------------------------------------------------------------------- Our Part:
  62. def get_index_by_name(drug_name):
  63. project_path = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..')))
  64. drugname2drugbankid_file = os.path.join(project_path, 'drug/data/drugname2drugbankid.tsv')
  65. drug_name2drugbank_id_df = pd.read_csv(drugname2drugbankid_file , sep='\t')
  66. drug_bank_id = drug_name2drugbank_id_df[drug_name2drugbank_id_df['drug_name'] == drug_name].drug_bank_id.item()
  67. drug2id_file = os.path.join(project_path, 'drug/data/DDI/DrugBank/raw/', 'drug2id.tsv')
  68. drug2id_df = pd.read_csv(drug2id_file , sep='\t')
  69. drug_index = drug2id_df[drug2id_df['DrugBank_id'] == drug_bank_id].node_index.item()
  70. return drug_index