Browse Source

transfer DDIInteractionDataset's data to gpu if available

main
MahsaYazdani 1 year ago
parent
commit
90125371ad
2 changed files with 6 additions and 7 deletions
  1. 5
    5
      drug/datasets.py
  2. 1
    2
      predictor/model/models.py

+ 5
- 5
drug/datasets.py View File

@@ -90,7 +90,7 @@ class DDInteractionDataset(Dataset):

# ---------------------------------------------------------------
data = Data(x = node_features, edge_index = edge_index)
if self.gpu_id is not None:
data = data.cuda(self.gpu_id)

@@ -110,8 +110,8 @@ class DDInteractionDataset(Dataset):
data = torch.load(osp.join(self.processed_dir, 'ddi_graph_dataset.pt'))
return data

ddiDataset = DDInteractionDataset(root = "drug/data/")
print(ddiDataset.get().edge_index.t())
# run for checking
# ddiDataset = DDInteractionDataset(root = "drug/data/")
# print(ddiDataset.get().edge_index.t())
# print(ddiDataset.get().x)
print(ddiDataset.num_features)
# print(ddiDataset.num_features)

+ 1
- 2
predictor/model/models.py View File

@@ -16,8 +16,7 @@ from model.utils import get_FP_by_negative_index
class Connector(nn.Module):
def __init__(self, gpu_id=None):
super(Connector, self).__init__()

self.ddiDataset = DDInteractionDataset(gpu_id)
self.ddiDataset = DDInteractionDataset(gpu_id = gpu_id)
self.gcn = GCN(self.ddiDataset.num_features, self.ddiDataset.num_features // 2)
#Cell line features

Loading…
Cancel
Save