1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
-
- import torch
- import torch.nn.functional as F
- from torch_geometric.nn import GCNConv
-
- class GCN(torch.nn.Module):
- def __init__(self, in_channels, out_channels):
- super().__init__()
- torch.manual_seed(1234)
- self.conv = GCNConv(in_channels, out_channels, add_self_loops=False)
-
- def forward(self, x, edge_index, edge_weight=None):
- x = F.dropout(x, p=0.5, training=self.training)
- x = self.conv(x, edge_index, edge_weight).relu()
- return x
-
-
- class GCN(torch.nn.Module):
- def __init__(self, hidden_channels):
- super(GCN, self).__init__()
- torch.manual_seed(42)
-
- # Initialize the layers
- self.conv1 = GCNConv(dataset.num_features, hidden_channels)
- self.conv2 = GCNConv(hidden_channels, hidden_channels)
- self.out = Linear(hidden_channels, dataset.num_classes)
-
- def forward(self, x, edge_index):
- # First Message Passing Layer (Transformation)
- x = self.conv1(x, edge_index)
- x = x.relu()
- x = F.dropout(x, p=0.5, training=self.training)
-
- # Second Message Passing Layer
- x = self.conv2(x, edge_index)
- x = x.relu()
- x = F.dropout(x, p=0.5, training=self.training)
-
- # Output layer
- x = F.softmax(self.out(x), dim=1)
- return x
-
-
-
- # address: https://www.kaggle.com/code/pinocookie/pytorch-simple-mlp/notebook
-
- class MLP(nn.Module):
- def __init__(self):
- super(MLP, self).__init__()
- self.layers = nn.Sequential(
- nn.Linear(784, 100),
- nn.ReLU(),
- nn.Linear(100, 10)
- )
-
- def forward(self, x):
- # convert tensor (128, 1, 28, 28) --> (128, 1*28*28)
- x = x.view(x.size(0), -1)
- x = self.layers(x)
- return x
-
-
-
-
- # model = GCN(dataset.num_features, dataset.num_classes)
- # model.train()
- # optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
-
- # print("Training on CPU.")
-
- # for epoch in range(1, 6):
- # optimizer.zero_grad()
- # out = model(data.x, data.edge_index, data.edge_attr)
- # loss = F.cross_entropy(out, data.y)
- # loss.backward()
- # optimizer.step()
- # print(f"Epoch: {epoch}, Loss: {loss}")
|