|
|
@@ -90,7 +90,7 @@ class DDInteractionDataset(Dataset): |
|
|
|
|
|
|
|
# --------------------------------------------------------------- |
|
|
|
data = Data(x = node_features, edge_index = edge_index) |
|
|
|
|
|
|
|
|
|
|
|
if self.gpu_id is not None: |
|
|
|
data = data.cuda(self.gpu_id) |
|
|
|
|
|
|
@@ -110,8 +110,8 @@ class DDInteractionDataset(Dataset): |
|
|
|
data = torch.load(osp.join(self.processed_dir, 'ddi_graph_dataset.pt')) |
|
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
ddiDataset = DDInteractionDataset(root = "drug/data/") |
|
|
|
print(ddiDataset.get().edge_index.t()) |
|
|
|
# run for checking |
|
|
|
# ddiDataset = DDInteractionDataset(root = "drug/data/") |
|
|
|
# print(ddiDataset.get().edge_index.t()) |
|
|
|
# print(ddiDataset.get().x) |
|
|
|
print(ddiDataset.num_features) |
|
|
|
# print(ddiDataset.num_features) |