123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960 |
- import os.path as osp
- import pandas as pd
- import torch
- from torch_geometric.data import Dataset, Data
- import numpy as np
-
-
-
- class DDInteractionDataset(Dataset):
- def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
- super(DDInteractionDataset, self).__init__(root, transform, pre_transform, pre_filter)
-
- @property
- def raw_file_names(self):
- return ['drug_interactions.tsv']
-
- @property
- def processed_file_names(self):
- return ['ddi_processed.pt']
-
- @property
- def raw_dir(self):
- dir = osp.join(self.root, 'raw')
- return dir
-
- @property
- def processed_dir(self):
- name = 'processed'
- return osp.join(self.root, name)
-
- def download(self):
- pass
-
- def process(self):
- path = osp.join(self.raw_dir, self.raw_file_names[0])
- ddi = pd.read_csv(path , sep='\t')
- edge_index = torch.tensor([ddi['drug1_idx'],ddi['drug2_idx']], dtype=torch.long)
-
- # ---------------------------------------------------------------
- data = Data(edge_index = edge_index)
-
- if self.pre_filter is not None and not self.pre_filter(data):
- pass
-
- if self.pre_transform is not None:
- data = self.pre_transform(data)
-
- torch.save(data, osp.join(self.processed_dir, 'ddi_graph_dataset.pt'))
-
-
- def len(self):
- return len(self.processed_file_names)
-
- def get(self):
- data = torch.load(osp.join(self.processed_dir, 'ddi_graph_dataset.pt'))
- return data
-
-
- ddiDataset = DDInteractionDataset(root = "Drug/Dataset/DDI/DrugBank/")
- print(ddiDataset.get().edge_index.t())
|