You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 742B

123456789101112131415161718192021222324252627
  1. import numpy as np
  2. import torch
  3. import torch.nn.functional as F
  4. import torch.nn as nn
  5. from models import GCN
  6. from datasets import DDInteractionDataset
  7. if __name__ == '__main__':
  8. ddiDataset = DDInteractionDataset
  9. model = GCN(ddiDataset.num_features, ddiDataset.num_features // 2)
  10. model.train()
  11. optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  12. # training on CPU
  13. n_epochs = 6
  14. for epoch in range(1, n_epochs):
  15. optimizer.zero_grad()
  16. out = model(ddiDataset.get().x, ddiDataset.get().edge_index)
  17. # TODO: MSELoss of the synergy scores
  18. loss = F.cross_entropy(out, data.y)
  19. loss.backward()
  20. optimizer.step()
  21. print(f"Epoch: {epoch}, Loss: {loss}")