|
|
@@ -3,6 +3,7 @@ import pandas as pd |
|
|
|
import torch |
|
|
|
from torch_geometric.data import Dataset, Data |
|
|
|
import numpy as np |
|
|
|
import random |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -20,21 +21,56 @@ class DDInteractionDataset(Dataset): |
|
|
|
|
|
|
|
@property |
|
|
|
def raw_dir(self): |
|
|
|
dir = osp.join(self.root, 'raw') |
|
|
|
dir = osp.join(self.root, 'DDI/DrugBank/raw') |
|
|
|
return dir |
|
|
|
|
|
|
|
@property |
|
|
|
def processed_dir(self): |
|
|
|
name = 'processed' |
|
|
|
return osp.join(self.root, name) |
|
|
|
return osp.join(self.root, 'DDI/DrugBank/' + name) |
|
|
|
|
|
|
|
def download(self): |
|
|
|
pass |
|
|
|
|
|
|
|
def find_drugBank_id(self, index): |
|
|
|
path = osp.join(self.root, 'DDI/DrugBank/raw/' + 'drug2id.tsv') |
|
|
|
drug2id_df = pd.read_csv(path, sep='\t') |
|
|
|
drugBankID = drug2id_df['DrugBank_id'][index] |
|
|
|
return drugBankID |
|
|
|
|
|
|
|
def generate_rand_fp(self): |
|
|
|
number = random.getrandbits(256) |
|
|
|
# Convert the number to binary |
|
|
|
binary_string = format(number, '0b') |
|
|
|
random_fp = [x for x in binary_string] |
|
|
|
random_fp = list(map(int, random_fp)) |
|
|
|
return random_fp |
|
|
|
|
|
|
|
def read_node_features(self, num_nodes): |
|
|
|
drug_fp_path = osp.join(self.root, 'RDkit extracted/drug2FP.csv') |
|
|
|
drug_fp_df = pd.read_csv(drug_fp_path) |
|
|
|
|
|
|
|
node_features = list() |
|
|
|
for i in range(num_nodes): |
|
|
|
drugbankid = self.find_drugBank_id(i) |
|
|
|
fp = drug_fp_df.loc[drug_fp_df['DrugBank_id'] == drugbankid] |
|
|
|
if fp.empty: |
|
|
|
fp = self.generate_rand_fp() |
|
|
|
else: |
|
|
|
fp = list(fp.to_numpy()[0,1:]) |
|
|
|
|
|
|
|
node_features.append(fp) |
|
|
|
|
|
|
|
return node_features |
|
|
|
|
|
|
|
def process(self): |
|
|
|
path = osp.join(self.raw_dir, self.raw_file_names[0]) |
|
|
|
ddi = pd.read_csv(path , sep='\t') |
|
|
|
edge_index = torch.tensor([ddi['drug1_idx'],ddi['drug2_idx']], dtype=torch.long) |
|
|
|
num_nodes = ddi['drug1_idx'].max() + 1 |
|
|
|
node_features = self.read_node_features(num_nodes) |
|
|
|
# TODO: check the output of ncol; (it is 254 but it most be 256) |
|
|
|
print("node features nrow and ncol: ",len(node_features),len(node_features[0])) |
|
|
|
|
|
|
|
# --------------------------------------------------------------- |
|
|
|
data = Data(edge_index = edge_index) |
|
|
@@ -56,5 +92,5 @@ class DDInteractionDataset(Dataset): |
|
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
ddiDataset = DDInteractionDataset(root = "Drug/Dataset/DDI/DrugBank/") |
|
|
|
ddiDataset = DDInteractionDataset(root = "Drug/Dataset/") |
|
|
|
print(ddiDataset.get().edge_index.t()) |