You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

models.py 1.9KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch_geometric.nn import GCNConv
  5. # a simple base GCN model
  6. # class GCN(torch.nn.Module):
  7. # def __init__(self, in_channels, out_channels):
  8. # super().__init__()
  9. # torch.manual_seed(1234)
  10. # self.conv = GCNConv(in_channels, out_channels, add_self_loops=False)
  11. # def forward(self, x, edge_index, edge_weight=None):
  12. # x = F.dropout(x, p=0.5, training=self.training)
  13. # x = self.conv(x, edge_index, edge_weight).relu()
  14. # return x
  15. # base from this notebook: https://colab.research.google.com/drive/1LJir3T6M6Omc2Vn2GV2cDW_GV2YfI53_?usp=sharing#scrollTo=jNsToorfSgS0
  16. class GCN(torch.nn.Module):
  17. def __init__(self, num_features, hidden_channels, gpu_id=None): # num_features = dataset.num_features
  18. super(GCN, self).__init__()
  19. torch.manual_seed(42)
  20. # Initialize the layers
  21. self.conv1 = GCNConv(num_features, hidden_channels)
  22. self.conv2 = GCNConv(hidden_channels, num_features)
  23. self.gpu_id = gpu_id
  24. def forward(self, x, edge_index):
  25. # First Message Passing Layer (Transformation)
  26. x = x.to(torch.float32)
  27. if self.gpu_id is not None:
  28. x = x.cuda(self.gpu_id)
  29. edge_index = edge_index.cuda(self.gpu_id)
  30. x = self.conv1(x, edge_index)
  31. x = x.relu()
  32. x = F.dropout(x, p=0.5, training=self.training)
  33. # Second Message Passing Layer
  34. x = self.conv2(x, edge_index)
  35. x = x.relu()
  36. return x
  37. # model = GCN(dataset.num_features, dataset.num_classes)
  38. # model.train()
  39. # optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  40. # print("Training on CPU.")
  41. # for epoch in range(1, 6):
  42. # optimizer.zero_grad()
  43. # out = model(data.x, data.edge_index, data.edge_attr)
  44. # loss = F.cross_entropy(out, data.y)
  45. # loss.backward()
  46. # optimizer.step()
  47. # print(f"Epoch: {epoch}, Loss: {loss}")