You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

models.py 2.1KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import torch
  2. import torch.nn.functional as F
  3. from torch_geometric.nn import GCNConv
  4. class GCN(torch.nn.Module):
  5. def __init__(self, in_channels, out_channels):
  6. super().__init__()
  7. torch.manual_seed(1234)
  8. self.conv = GCNConv(in_channels, out_channels, add_self_loops=False)
  9. def forward(self, x, edge_index, edge_weight=None):
  10. x = F.dropout(x, p=0.5, training=self.training)
  11. x = self.conv(x, edge_index, edge_weight).relu()
  12. return x
  13. class GCN(torch.nn.Module):
  14. def __init__(self, hidden_channels):
  15. super(GCN, self).__init__()
  16. torch.manual_seed(42)
  17. # Initialize the layers
  18. self.conv1 = GCNConv(dataset.num_features, hidden_channels)
  19. self.conv2 = GCNConv(hidden_channels, hidden_channels)
  20. self.out = Linear(hidden_channels, dataset.num_classes)
  21. def forward(self, x, edge_index):
  22. # First Message Passing Layer (Transformation)
  23. x = self.conv1(x, edge_index)
  24. x = x.relu()
  25. x = F.dropout(x, p=0.5, training=self.training)
  26. # Second Message Passing Layer
  27. x = self.conv2(x, edge_index)
  28. x = x.relu()
  29. x = F.dropout(x, p=0.5, training=self.training)
  30. # Output layer
  31. x = F.softmax(self.out(x), dim=1)
  32. return x
  33. # address: https://www.kaggle.com/code/pinocookie/pytorch-simple-mlp/notebook
  34. class MLP(nn.Module):
  35. def __init__(self):
  36. super(MLP, self).__init__()
  37. self.layers = nn.Sequential(
  38. nn.Linear(784, 100),
  39. nn.ReLU(),
  40. nn.Linear(100, 10)
  41. )
  42. def forward(self, x):
  43. # convert tensor (128, 1, 28, 28) --> (128, 1*28*28)
  44. x = x.view(x.size(0), -1)
  45. x = self.layers(x)
  46. return x
  47. # model = GCN(dataset.num_features, dataset.num_classes)
  48. # model.train()
  49. # optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  50. # print("Training on CPU.")
  51. # for epoch in range(1, 6):
  52. # optimizer.zero_grad()
  53. # out = model(data.x, data.edge_index, data.edge_attr)
  54. # loss = F.cross_entropy(out, data.y)
  55. # loss.backward()
  56. # optimizer.step()
  57. # print(f"Epoch: {epoch}, Loss: {loss}")