You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 643B

123456789101112131415161718192021222324
  1. import numpy as np
  2. import torch
  3. import torch.nn.functional as F
  4. import torch.nn as nn
  5. from models import GCN
  6. from datasets import DDInteractionDataset
  7. if __name__ == '__main__':
  8. ddiDataset = DDInteractionDataset
  9. model = GCN(dataset.num_features, dataset.num_classes)
  10. model.train()
  11. optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  12. # training on CPU
  13. for epoch in range(1, 6):
  14. optimizer.zero_grad()
  15. out = model(data.x, data.edge_index, data.edge_attr)
  16. loss = F.cross_entropy(out, data.y)
  17. loss.backward()
  18. optimizer.step()
  19. print(f"Epoch: {epoch}, Loss: {loss}")