You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.py 3.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import os.path as osp
  2. import pandas as pd
  3. import torch
  4. from torch_geometric.data import Dataset, Data
  5. import numpy as np
  6. import random
  7. import os
  8. class DDInteractionDataset(Dataset):
  9. def __init__(self, root = "\\drug/data/", transform=None, pre_transform=None, pre_filter=None):
  10. super(DDInteractionDataset, self).__init__(os.path.dirname(os.path.abspath(os.path.dirname( __file__ ))) + "\\drug/data/", transform, pre_transform, pre_filter)
  11. @property
  12. def num_features(self):
  13. return self._num_features
  14. @num_features.setter
  15. def num_features(self, value):
  16. self._num_features = value
  17. @property
  18. def raw_file_names(self):
  19. return ['drug_interactions.tsv']
  20. @property
  21. def processed_file_names(self):
  22. return ['ddi_processed.pt']
  23. @property
  24. def raw_dir(self):
  25. dir = osp.join(self.root, 'DDI/DrugBank/raw')
  26. return dir
  27. @property
  28. def processed_dir(self):
  29. name = 'processed'
  30. return osp.join(self.root, 'DDI/DrugBank/' + name)
  31. def download(self):
  32. pass
  33. def find_drugBank_id(self, index):
  34. path = osp.join(self.root, 'DDI/DrugBank/raw/' + 'drug2id.tsv')
  35. drug2id_df = pd.read_csv(path, sep='\t')
  36. drugBankID = drug2id_df['DrugBank_id'][index]
  37. return drugBankID
  38. def generate_rand_fp(self):
  39. number = random.getrandbits(256)
  40. # Convert the number to binary
  41. binary_string = '{0:0256b}'.format(number)
  42. random_fp = [x for x in binary_string]
  43. random_fp = list(map(int, random_fp))
  44. return random_fp
  45. def read_node_features(self, num_nodes):
  46. drug_fp_path = osp.join(self.root, 'RDkit extracted/drug2FP.csv')
  47. drug_fp_df = pd.read_csv(drug_fp_path)
  48. node_features = list()
  49. node_ids = list()
  50. for i in range(num_nodes):
  51. drugbankid = self.find_drugBank_id(i)
  52. fp = drug_fp_df.loc[drug_fp_df['DrugBank_id'] == drugbankid]
  53. if fp.empty:
  54. fp = self.generate_rand_fp()
  55. else:
  56. fp = list(fp.to_numpy()[0,1:])
  57. node_features.append(fp)
  58. node_ids.append(drugbankid)
  59. self.num_features = len(node_features[0])
  60. return node_ids, node_features
  61. def process(self):
  62. path = osp.join(self.raw_dir, self.raw_file_names[0])
  63. ddi = pd.read_csv(path , sep='\t')
  64. edge_index = torch.tensor([ddi['drug1_idx'],ddi['drug2_idx']], dtype=torch.long)
  65. num_nodes = ddi['drug1_idx'].max() + 1
  66. node_ids, node_features = self.read_node_features(num_nodes)
  67. node_features = torch.tensor(node_features, dtype=torch.int)
  68. print("node features nrow and ncol: ",len(node_features),len(node_features[0]))
  69. # ---------------------------------------------------------------
  70. data = Data(x = node_features, edge_index = edge_index)
  71. if self.pre_filter is not None and not self.pre_filter(data):
  72. pass
  73. if self.pre_transform is not None:
  74. data = self.pre_transform(data)
  75. torch.save(data, osp.join(self.processed_dir, 'ddi_graph_dataset.pt'))
  76. def len(self):
  77. return len(self.processed_file_names)
  78. def get(self):
  79. data = torch.load(osp.join(self.processed_dir, 'ddi_graph_dataset.pt'))
  80. return data
  81. ddiDataset = DDInteractionDataset(root = "drug/data/")
  82. print(ddiDataset.get().edge_index.t())
  83. # print(ddiDataset.get().x)
  84. print(ddiDataset.num_features)