123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 |
- import os.path as osp
- import pandas as pd
- import torch
- from torch_geometric.data import Dataset, Data
- import numpy as np
- import random
-
-
-
- class DDInteractionDataset(Dataset):
- def __init__(self, root = "drug/data/", transform=None, pre_transform=None, pre_filter=None):
- super(DDInteractionDataset, self).__init__(root, transform, pre_transform, pre_filter)
-
- @property
- def num_features(self):
- return self._num_features
-
- @num_features.setter
- def num_features(self, value):
- self._num_features = value
-
- @property
- def raw_file_names(self):
- return ['drug_interactions.tsv']
-
- @property
- def processed_file_names(self):
- return ['ddi_processed.pt']
-
- @property
- def raw_dir(self):
- dir = osp.join(self.root, 'DDI/DrugBank/raw')
- return dir
-
- @property
- def processed_dir(self):
- name = 'processed'
- return osp.join(self.root, 'DDI/DrugBank/' + name)
-
- def download(self):
- pass
-
- def find_drugBank_id(self, index):
- path = osp.join(self.root, 'DDI/DrugBank/raw/' + 'drug2id.tsv')
- drug2id_df = pd.read_csv(path, sep='\t')
- drugBankID = drug2id_df['DrugBank_id'][index]
- return drugBankID
-
- def generate_rand_fp(self):
- number = random.getrandbits(256)
-
- # Convert the number to binary
- binary_string = '{0:0256b}'.format(number)
- random_fp = [x for x in binary_string]
- random_fp = list(map(int, random_fp))
- return random_fp
-
- def read_node_features(self, num_nodes):
- drug_fp_path = osp.join(self.root, 'RDkit extracted/drug2FP.csv')
- drug_fp_df = pd.read_csv(drug_fp_path)
-
- node_features = list()
- node_ids = list()
- for i in range(num_nodes):
- drugbankid = self.find_drugBank_id(i)
- fp = drug_fp_df.loc[drug_fp_df['DrugBank_id'] == drugbankid]
- if fp.empty:
- fp = self.generate_rand_fp()
- else:
- fp = list(fp.to_numpy()[0,1:])
-
- node_features.append(fp)
- node_ids.append(drugbankid)
-
- self.num_features = len(node_features[0])
-
- return node_ids, node_features
-
- def process(self):
- path = osp.join(self.raw_dir, self.raw_file_names[0])
- ddi = pd.read_csv(path , sep='\t')
- edge_index = torch.tensor([ddi['drug1_idx'],ddi['drug2_idx']], dtype=torch.long)
- num_nodes = ddi['drug1_idx'].max() + 1
- node_ids, node_features = self.read_node_features(num_nodes)
- node_features = torch.tensor(node_features, dtype=torch.int)
- print("node features nrow and ncol: ",len(node_features),len(node_features[0]))
-
- # ---------------------------------------------------------------
- data = Data(x = node_features, edge_index = edge_index)
-
- if self.pre_filter is not None and not self.pre_filter(data):
- pass
-
- if self.pre_transform is not None:
- data = self.pre_transform(data)
-
- torch.save(data, osp.join(self.processed_dir, 'ddi_graph_dataset.pt'))
-
-
- def len(self):
- return len(self.processed_file_names)
-
- def get(self):
- data = torch.load(osp.join(self.processed_dir, 'ddi_graph_dataset.pt'))
- return data
-
-
- ddiDataset = DDInteractionDataset(root = "drug/data/")
- print(ddiDataset.get().edge_index.t())
- # print(ddiDataset.get().x)
- print(ddiDataset.num_features)
|