import os.path as osp import pandas as pd import torch from torch_geometric.data import Dataset, Data import numpy as np import random class DDInteractionDataset(Dataset): def __init__(self, root, transform=None, pre_transform=None, pre_filter=None): super(DDInteractionDataset, self).__init__(root, transform, pre_transform, pre_filter) @property def num_features(self): return self._num_features @num_features.setter def num_features(self, value): self._num_features = value @property def raw_file_names(self): return ['drug_interactions.tsv'] @property def processed_file_names(self): return ['ddi_processed.pt'] @property def raw_dir(self): dir = osp.join(self.root, 'DDI/DrugBank/raw') return dir @property def processed_dir(self): name = 'processed' return osp.join(self.root, 'DDI/DrugBank/' + name) def download(self): pass def find_drugBank_id(self, index): path = osp.join(self.root, 'DDI/DrugBank/raw/' + 'drug2id.tsv') drug2id_df = pd.read_csv(path, sep='\t') drugBankID = drug2id_df['DrugBank_id'][index] return drugBankID def generate_rand_fp(self): number = random.getrandbits(256) # Convert the number to binary binary_string = format(number, '0b') random_fp = [x for x in binary_string] random_fp = list(map(int, random_fp)) return random_fp def read_node_features(self, num_nodes): drug_fp_path = osp.join(self.root, 'RDkit extracted/drug2FP.csv') drug_fp_df = pd.read_csv(drug_fp_path) node_features = list() for i in range(num_nodes): drugbankid = self.find_drugBank_id(i) fp = drug_fp_df.loc[drug_fp_df['DrugBank_id'] == drugbankid] if fp.empty: fp = self.generate_rand_fp() else: fp = list(fp.to_numpy()[0,1:]) node_features.append(fp) self.num_features = len(node_features[0]) return node_features def process(self): path = osp.join(self.raw_dir, self.raw_file_names[0]) ddi = pd.read_csv(path , sep='\t') edge_index = torch.tensor([ddi['drug1_idx'],ddi['drug2_idx']], dtype=torch.long) num_nodes = ddi['drug1_idx'].max() + 1 node_features = self.read_node_features(num_nodes) print("node features nrow and ncol: ",len(node_features),len(node_features[0])) # --------------------------------------------------------------- data = Data(x = node_features, edge_index = edge_index) if self.pre_filter is not None and not self.pre_filter(data): pass if self.pre_transform is not None: data = self.pre_transform(data) torch.save(data, osp.join(self.processed_dir, 'ddi_graph_dataset.pt')) def len(self): return len(self.processed_file_names) def get(self): data = torch.load(osp.join(self.processed_dir, 'ddi_graph_dataset.pt')) return data ddiDataset = DDInteractionDataset(root = "Drug/Dataset/") print(ddiDataset.get().edge_index.t()) # print(ddiDataset.get().x) print(ddiDataset.num_features)