You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.py 4.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import os.path as osp
  2. import pandas as pd
  3. import torch
  4. from torch_geometric.data import Dataset, Data
  5. from torch_geometric.loader import DataLoader
  6. from torch_geometric.loader import NeighborLoader
  7. import numpy as np
  8. import random
  9. import os
  10. class DDInteractionDataset(Dataset):
  11. def __init__(self, root = "\\drug/data/", transform=None, pre_transform=None, pre_filter=None, gpu_id=None):
  12. self.gpu_id = gpu_id
  13. super(DDInteractionDataset, self).__init__(os.path.dirname(os.path.abspath(os.path.dirname( __file__ ))) + "/drug/data/", transform, pre_transform, pre_filter)
  14. @property
  15. def num_features(self):
  16. return self._num_features
  17. @num_features.setter
  18. def num_features(self, value):
  19. self._num_features = value
  20. @property
  21. def raw_file_names(self):
  22. return ['drug_interactions.tsv']
  23. @property
  24. def processed_file_names(self):
  25. return ['ddi_graph_dataset.pt']
  26. @property
  27. def raw_dir(self):
  28. dir = osp.join(self.root, 'DDI/DrugBank/raw')
  29. return dir
  30. @property
  31. def processed_dir(self):
  32. name = 'processed'
  33. return osp.join(self.root, 'DDI/DrugBank/' + name)
  34. def download(self):
  35. pass
  36. def find_drugBank_id(self, index):
  37. path = osp.join(self.root, 'DDI/DrugBank/raw/' + 'drug2id.tsv')
  38. drug2id_df = pd.read_csv(path, sep='\t')
  39. drugBankID = drug2id_df['DrugBank_id'][index]
  40. return drugBankID
  41. def generate_rand_fp(self):
  42. number = random.getrandbits(256)
  43. # Convert the number to binary
  44. binary_string = '{0:0256b}'.format(number)
  45. random_fp = [x for x in binary_string]
  46. random_fp = list(map(int, random_fp))
  47. return random_fp
  48. def read_node_features(self, num_nodes):
  49. drug_fp_path = osp.join(self.root, 'RDkit extracted/drug2FP.csv')
  50. drug_fp_df = pd.read_csv(drug_fp_path)
  51. node_features = list()
  52. node_ids = list()
  53. for i in range(num_nodes):
  54. drugbankid = self.find_drugBank_id(i)
  55. fp = drug_fp_df.loc[drug_fp_df['DrugBank_id'] == drugbankid]
  56. if fp.empty:
  57. fp = self.generate_rand_fp()
  58. else:
  59. fp = list(fp.to_numpy()[0,1:])
  60. node_features.append(fp)
  61. node_ids.append(drugbankid)
  62. self.num_features = len(node_features[0])
  63. return node_ids, node_features
  64. def process(self):
  65. path = osp.join(self.raw_dir, self.raw_file_names[0])
  66. ddi = pd.read_csv(path , sep='\t')
  67. edge_index = torch.tensor([ddi['drug1_idx'],ddi['drug2_idx']], dtype=torch.long)
  68. num_nodes = ddi['drug1_idx'].max() + 1
  69. node_ids, node_features = self.read_node_features(num_nodes)
  70. node_features = torch.tensor(node_features, dtype=torch.int)
  71. print("node features nrow and ncol: ",len(node_features),len(node_features[0]))
  72. # ---------------------------------------------------------------
  73. data = Data(x = node_features, edge_index = edge_index)
  74. if self.gpu_id is not None:
  75. data = data.cuda(self.gpu_id)
  76. if self.pre_filter is not None and not self.pre_filter(data):
  77. pass
  78. if self.pre_transform is not None:
  79. data = self.pre_transform(data)
  80. torch.save(data, osp.join(self.processed_dir, 'ddi_graph_dataset.pt'))
  81. def len(self):
  82. return len(self.processed_file_names)
  83. def get(self):
  84. data = torch.load(osp.join(self.processed_dir, 'ddi_graph_dataset.pt'))
  85. return data
  86. # run for checking
  87. ddiDataset = DDInteractionDataset(root = "drug/data/")
  88. print(ddiDataset.get()) # type of ddiDataset.get() is torch_geometric.data.data.Data
  89. # print(ddiDataset.get().edge_index.t())
  90. # print(ddiDataset.get().x)
  91. # print(ddiDataset.num_features)
  92. # test for data batch loading
  93. # dataloader = DataLoader(ddiDataset, batch_size=10)
  94. # for data in dataloader:
  95. # print(data)
  96. # # true batching way for the knowledge graph
  97. # data = ddiDataset.get()
  98. # print(len(data))
  99. # neighbor_loader = NeighborLoader(data,
  100. # num_neighbors=[3,2], batch_size=100,
  101. # directed=False, shuffle=True)
  102. # for batch in neighbor_loader:
  103. # print(batch.x, batch.edge_index)