import os.path as osp import pandas as pd import torch from torch_geometric.data import Dataset, Data import numpy as np class DDInteractionDataset(Dataset): def __init__(self, root, transform=None, pre_transform=None, pre_filter=None): super(DDInteractionDataset, self).__init__(root, transform, pre_transform, pre_filter) @property def raw_file_names(self): return ['drug_interactions.tsv'] @property def processed_file_names(self): return ['ddi_processed.pt'] @property def raw_dir(self): dir = osp.join(self.root, 'raw') return dir @property def processed_dir(self): name = 'processed' return osp.join(self.root, name) def download(self): pass def process(self): path = osp.join(self.raw_dir, self.raw_file_names[0]) ddi = pd.read_csv(path , sep='\t') edge_index = torch.tensor([ddi['drug1_idx'],ddi['drug2_idx']], dtype=torch.long) # --------------------------------------------------------------- data = Data(edge_index = edge_index) if self.pre_filter is not None and not self.pre_filter(data): pass if self.pre_transform is not None: data = self.pre_transform(data) torch.save(data, osp.join(self.processed_dir, 'ddi_graph_dataset.pt')) def len(self): return len(self.processed_file_names) def get(self): data = torch.load(osp.join(self.processed_dir, 'ddi_graph_dataset.pt')) return data ddiDataset = DDInteractionDataset(root = "Drug/Dataset/DDI/DrugBank/") print(ddiDataset.get().edge_index.t())