You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.py 1.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import os.path as osp
  2. import pandas as pd
  3. import torch
  4. from torch_geometric.data import Dataset, Data
  5. import numpy as np
  6. class DDInteractionDataset(Dataset):
  7. def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
  8. super(DDInteractionDataset, self).__init__(root, transform, pre_transform, pre_filter)
  9. @property
  10. def raw_file_names(self):
  11. return ['drug_interactions.tsv']
  12. @property
  13. def processed_file_names(self):
  14. return ['ddi_processed.pt']
  15. @property
  16. def raw_dir(self):
  17. dir = osp.join(self.root, 'raw')
  18. return dir
  19. @property
  20. def processed_dir(self):
  21. name = 'processed'
  22. return osp.join(self.root, name)
  23. def download(self):
  24. pass
  25. def process(self):
  26. path = osp.join(self.raw_dir, self.raw_file_names[0])
  27. ddi = pd.read_csv(path , sep='\t')
  28. edge_index = torch.tensor([ddi['drug1_idx'],ddi['drug2_idx']], dtype=torch.long)
  29. # ---------------------------------------------------------------
  30. data = Data(edge_index = edge_index)
  31. if self.pre_filter is not None and not self.pre_filter(data):
  32. pass
  33. if self.pre_transform is not None:
  34. data = self.pre_transform(data)
  35. torch.save(data, osp.join(self.processed_dir, 'ddi_graph_dataset.pt'))
  36. def len(self):
  37. return len(self.processed_file_names)
  38. def get(self):
  39. data = torch.load(osp.join(self.processed_dir, 'ddi_graph_dataset.pt'))
  40. return data
  41. ddiDataset = DDInteractionDataset(root = "Drug/Dataset/DDI/DrugBank/")
  42. print(ddiDataset.get().edge_index.t())