You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 605B

1234567891011121314151617181920212223
  1. import numpy as np
  2. import torch
  3. import torch.nn.functional as F
  4. import torch.nn as nn
  5. from models import GCN
  6. from datasets import DDInteractionDataset
  7. if __name__ == '__main__':
  8. model = GCN(dataset.num_features, dataset.num_classes)
  9. model.train()
  10. optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  11. # training on CPU
  12. for epoch in range(1, 6):
  13. optimizer.zero_grad()
  14. out = model(data.x, data.edge_index, data.edge_attr)
  15. loss = F.cross_entropy(out, data.y)
  16. loss.backward()
  17. optimizer.step()
  18. print(f"Epoch: {epoch}, Loss: {loss}")