@@ -3,3 +3,6 @@ Drug/Dataset/DDI/DrugBank/drugbank_all_full_database.xml.zip | |||
Drug/Dataset/DDI/SNAP Stanford/ChCh-Miner_durgbank-chem-chem.tsv.gz | |||
Drug/Dataset/DDI/DrugBank/raw/Drugbank_drug_interactions.tsv | |||
Drug/Dataset/DDI/SNAP Stanford/ChCh-Miner_durgbank-chem-chem.tsv | |||
Cell/data/DTI/SNAP Stanford/ChG-Miner_miner-chem-gene.tsv | |||
Cell/data/DTI/SNAP Stanford/ChG-Miner_miner-chem-gene.tsv.gz | |||
Drug/Dataset/Smiles/drugbank_all_structure_links.csv.zip |
@@ -15,16 +15,16 @@ class GCN(torch.nn.Module): | |||
return x | |||
model = GCN(dataset.num_features, dataset.num_classes) | |||
model.train() | |||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | |||
# model = GCN(dataset.num_features, dataset.num_classes) | |||
# model.train() | |||
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | |||
print("Training on CPU.") | |||
# print("Training on CPU.") | |||
for epoch in range(1, 6): | |||
optimizer.zero_grad() | |||
out = model(data.x, data.edge_index, data.edge_attr) | |||
loss = F.cross_entropy(out, data.y) | |||
loss.backward() | |||
optimizer.step() | |||
print(f"Epoch: {epoch}, Loss: {loss}") | |||
# for epoch in range(1, 6): | |||
# optimizer.zero_grad() | |||
# out = model(data.x, data.edge_index, data.edge_attr) | |||
# loss = F.cross_entropy(out, data.y) | |||
# loss.backward() | |||
# optimizer.step() | |||
# print(f"Epoch: {epoch}, Loss: {loss}") |
@@ -0,0 +1,23 @@ | |||
import numpy as np | |||
import torch | |||
import torch.nn.functional as F | |||
import torch.nn as nn | |||
from models import GCN | |||
from datasets import DDInteractionDataset | |||
if __name__ == '__main__': | |||
model = GCN(dataset.num_features, dataset.num_classes) | |||
model.train() | |||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | |||
# training on CPU | |||
for epoch in range(1, 6): | |||
optimizer.zero_grad() | |||
out = model(data.x, data.edge_index, data.edge_attr) | |||
loss = F.cross_entropy(out, data.y) | |||
loss.backward() | |||
optimizer.step() | |||
print(f"Epoch: {epoch}, Loss: {loss}") |