You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

models.py 879B

123456789101112131415161718192021222324252627282930
  1. import torch
  2. import torch.nn.functional as F
  3. from torch_geometric.nn import GCNConv
  4. class GCN(torch.nn.Module):
  5. def __init__(self, in_channels, out_channels):
  6. super().__init__()
  7. torch.manual_seed(1234)
  8. self.conv = GCNConv(in_channels, out_channels, add_self_loops=False)
  9. def forward(self, x, edge_index, edge_weight=None):
  10. x = F.dropout(x, p=0.5, training=self.training)
  11. x = self.conv(x, edge_index, edge_weight).relu()
  12. return x
  13. model = GCN(dataset.num_features, dataset.num_classes)
  14. model.train()
  15. optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  16. print("Training on CPU.")
  17. for epoch in range(1, 6):
  18. optimizer.zero_grad()
  19. out = model(data.x, data.edge_index, data.edge_attr)
  20. loss = F.cross_entropy(out, data.y)
  21. loss.backward()
  22. optimizer.step()
  23. print(f"Epoch: {epoch}, Loss: {loss}")