|
|
@@ -4,12 +4,13 @@ import torch |
|
|
|
from torch_geometric.data import Dataset, Data |
|
|
|
import numpy as np |
|
|
|
import random |
|
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DDInteractionDataset(Dataset): |
|
|
|
def __init__(self, root = "drug/data/", transform=None, pre_transform=None, pre_filter=None): |
|
|
|
super(DDInteractionDataset, self).__init__(root, transform, pre_transform, pre_filter) |
|
|
|
def __init__(self, root = "\\drug/data/", transform=None, pre_transform=None, pre_filter=None): |
|
|
|
super(DDInteractionDataset, self).__init__(os.path.dirname(os.path.abspath(os.path.dirname( __file__ ))) + "\\drug/data/", transform, pre_transform, pre_filter) |
|
|
|
|
|
|
|
@property |
|
|
|
def num_features(self): |