You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.py 3.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import os.path as osp
  2. import pandas as pd
  3. import torch
  4. from torch_geometric.data import Dataset, Data
  5. import numpy as np
  6. import random
  7. class DDInteractionDataset(Dataset):
  8. def __init__(self, root = "drug/data/", transform=None, pre_transform=None, pre_filter=None):
  9. super(DDInteractionDataset, self).__init__(root, transform, pre_transform, pre_filter)
  10. @property
  11. def num_features(self):
  12. return self._num_features
  13. @num_features.setter
  14. def num_features(self, value):
  15. self._num_features = value
  16. @property
  17. def raw_file_names(self):
  18. return ['drug_interactions.tsv']
  19. @property
  20. def processed_file_names(self):
  21. return ['ddi_processed.pt']
  22. @property
  23. def raw_dir(self):
  24. dir = osp.join(self.root, 'DDI/DrugBank/raw')
  25. return dir
  26. @property
  27. def processed_dir(self):
  28. name = 'processed'
  29. return osp.join(self.root, 'DDI/DrugBank/' + name)
  30. def download(self):
  31. pass
  32. def find_drugBank_id(self, index):
  33. path = osp.join(self.root, 'DDI/DrugBank/raw/' + 'drug2id.tsv')
  34. drug2id_df = pd.read_csv(path, sep='\t')
  35. drugBankID = drug2id_df['DrugBank_id'][index]
  36. return drugBankID
  37. def generate_rand_fp(self):
  38. number = random.getrandbits(256)
  39. # Convert the number to binary
  40. binary_string = '{0:0256b}'.format(number)
  41. random_fp = [x for x in binary_string]
  42. random_fp = list(map(int, random_fp))
  43. return random_fp
  44. def read_node_features(self, num_nodes):
  45. drug_fp_path = osp.join(self.root, 'RDkit extracted/drug2FP.csv')
  46. drug_fp_df = pd.read_csv(drug_fp_path)
  47. node_features = list()
  48. node_ids = list()
  49. for i in range(num_nodes):
  50. drugbankid = self.find_drugBank_id(i)
  51. fp = drug_fp_df.loc[drug_fp_df['DrugBank_id'] == drugbankid]
  52. if fp.empty:
  53. fp = self.generate_rand_fp()
  54. else:
  55. fp = list(fp.to_numpy()[0,1:])
  56. node_features.append(fp)
  57. node_ids.append(drugbankid)
  58. self.num_features = len(node_features[0])
  59. return node_ids, node_features
  60. def process(self):
  61. path = osp.join(self.raw_dir, self.raw_file_names[0])
  62. ddi = pd.read_csv(path , sep='\t')
  63. edge_index = torch.tensor([ddi['drug1_idx'],ddi['drug2_idx']], dtype=torch.long)
  64. num_nodes = ddi['drug1_idx'].max() + 1
  65. node_ids, node_features = self.read_node_features(num_nodes)
  66. node_features = torch.tensor(node_features, dtype=torch.int)
  67. print("node features nrow and ncol: ",len(node_features),len(node_features[0]))
  68. # ---------------------------------------------------------------
  69. data = Data(x = node_features, edge_index = edge_index)
  70. if self.pre_filter is not None and not self.pre_filter(data):
  71. pass
  72. if self.pre_transform is not None:
  73. data = self.pre_transform(data)
  74. torch.save(data, osp.join(self.processed_dir, 'ddi_graph_dataset.pt'))
  75. def len(self):
  76. return len(self.processed_file_names)
  77. def get(self):
  78. data = torch.load(osp.join(self.processed_dir, 'ddi_graph_dataset.pt'))
  79. return data
  80. ddiDataset = DDInteractionDataset(root = "drug/data/")
  81. print(ddiDataset.get().edge_index.t())
  82. # print(ddiDataset.get().x)
  83. print(ddiDataset.num_features)