You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.py 3.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import os.path as osp
  2. import pandas as pd
  3. import torch
  4. from torch_geometric.data import Dataset, Data
  5. import numpy as np
  6. import random
  7. import os
  8. class DDInteractionDataset(Dataset):
  9. def __init__(self, root = "\\drug/data/", transform=None, pre_transform=None, pre_filter=None, gpu_id=None):
  10. self.gpu_id = gpu_id
  11. super(DDInteractionDataset, self).__init__(os.path.dirname(os.path.abspath(os.path.dirname( __file__ ))) + "/drug/data/", transform, pre_transform, pre_filter)
  12. @property
  13. def num_features(self):
  14. return self._num_features
  15. @num_features.setter
  16. def num_features(self, value):
  17. self._num_features = value
  18. @property
  19. def raw_file_names(self):
  20. return ['drug_interactions.tsv']
  21. @property
  22. def processed_file_names(self):
  23. return ['ddi_processed.pt']
  24. @property
  25. def raw_dir(self):
  26. dir = osp.join(self.root, 'DDI/DrugBank/raw')
  27. return dir
  28. @property
  29. def processed_dir(self):
  30. name = 'processed'
  31. return osp.join(self.root, 'DDI/DrugBank/' + name)
  32. def download(self):
  33. pass
  34. def find_drugBank_id(self, index):
  35. path = osp.join(self.root, 'DDI/DrugBank/raw/' + 'drug2id.tsv')
  36. drug2id_df = pd.read_csv(path, sep='\t')
  37. drugBankID = drug2id_df['DrugBank_id'][index]
  38. return drugBankID
  39. def generate_rand_fp(self):
  40. number = random.getrandbits(256)
  41. # Convert the number to binary
  42. binary_string = '{0:0256b}'.format(number)
  43. random_fp = [x for x in binary_string]
  44. random_fp = list(map(int, random_fp))
  45. return random_fp
  46. def read_node_features(self, num_nodes):
  47. drug_fp_path = osp.join(self.root, 'RDkit extracted/drug2FP.csv')
  48. drug_fp_df = pd.read_csv(drug_fp_path)
  49. node_features = list()
  50. node_ids = list()
  51. for i in range(num_nodes):
  52. drugbankid = self.find_drugBank_id(i)
  53. fp = drug_fp_df.loc[drug_fp_df['DrugBank_id'] == drugbankid]
  54. if fp.empty:
  55. fp = self.generate_rand_fp()
  56. else:
  57. fp = list(fp.to_numpy()[0,1:])
  58. node_features.append(fp)
  59. node_ids.append(drugbankid)
  60. self.num_features = len(node_features[0])
  61. return node_ids, node_features
  62. def process(self):
  63. path = osp.join(self.raw_dir, self.raw_file_names[0])
  64. ddi = pd.read_csv(path , sep='\t')
  65. edge_index = torch.tensor([ddi['drug1_idx'],ddi['drug2_idx']], dtype=torch.long)
  66. num_nodes = ddi['drug1_idx'].max() + 1
  67. node_ids, node_features = self.read_node_features(num_nodes)
  68. node_features = torch.tensor(node_features, dtype=torch.int)
  69. print("node features nrow and ncol: ",len(node_features),len(node_features[0]))
  70. # ---------------------------------------------------------------
  71. data = Data(x = node_features, edge_index = edge_index)
  72. if self.gpu_id is not None:
  73. data = data.cuda(self.gpu_id)
  74. if self.pre_filter is not None and not self.pre_filter(data):
  75. pass
  76. if self.pre_transform is not None:
  77. data = self.pre_transform(data)
  78. torch.save(data, osp.join(self.processed_dir, 'ddi_graph_dataset.pt'))
  79. def len(self):
  80. return len(self.processed_file_names)
  81. def get(self):
  82. data = torch.load(osp.join(self.processed_dir, 'ddi_graph_dataset.pt'))
  83. return data
  84. # run for checking
  85. # ddiDataset = DDInteractionDataset(root = "drug/data/")
  86. # print(ddiDataset.get().edge_index.t())
  87. # print(ddiDataset.get().x)
  88. # print(ddiDataset.num_features)