You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

models.py 6.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import pandas as pd
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from torch.nn import Linear, BatchNorm1d, ModuleList
  7. from torch_geometric.nn import GCNConv, GATConv, TransformerConv, TopKPooling
  8. from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
  9. torch.manual_seed(42)
  10. # a simple base GCN model
  11. # class GCN(torch.nn.Module):
  12. # def __init__(self, in_channels, out_channels):
  13. # super().__init__()
  14. # torch.manual_seed(1234)
  15. # self.conv = GCNConv(in_channels, out_channels, add_self_loops=False)
  16. # def forward(self, x, edge_index, edge_weight=None):
  17. # x = F.dropout(x, p=0.5, training=self.training)
  18. # x = self.conv(x, edge_index, edge_weight).relu()
  19. # return x
  20. # base from this notebook: https://colab.research.google.com/drive/1LJir3T6M6Omc2Vn2GV2cDW_GV2YfI53_?usp=sharing#scrollTo=jNsToorfSgS0
  21. class GCN(torch.nn.Module):
  22. def __init__(self, num_features, hidden_channels): # num_features = dataset.num_features
  23. super(GCN, self).__init__()
  24. torch.manual_seed(42)
  25. # Initialize the layers
  26. self.conv1 = GCNConv(num_features, hidden_channels)
  27. self.conv2 = GCNConv(hidden_channels, hidden_channels)
  28. def forward(self, x, edge_index):
  29. # First Message Passing Layer (Transformation)
  30. x = x.to(torch.float32)
  31. x = self.conv1(x, edge_index)
  32. x = x.relu()
  33. x = F.dropout(x, p=0.5, training=self.training)
  34. # Second Message Passing Layer
  35. x = self.conv2(x, edge_index)
  36. x = x.relu()
  37. return x
  38. # model = GCN(dataset.num_features, dataset.num_classes)
  39. # model.train()
  40. # optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  41. # print("Training on CPU.")
  42. # for epoch in range(1, 6):
  43. # optimizer.zero_grad()
  44. # out = model(data.x, data.edge_index, data.edge_attr)
  45. # loss = F.cross_entropy(out, data.y)
  46. # loss.backward()
  47. # optimizer.step()
  48. # print(f"Epoch: {epoch}, Loss: {loss}")
  49. class TransformerGNN(torch.nn.Module):
  50. def __init__(self, feature_size, model_params):
  51. super(TransformerGNN, self).__init__()
  52. embedding_size = model_params["model_embedding_size"] # default: 1024
  53. n_heads = model_params["model_attention_heads"] # default: 3
  54. self.n_layers = model_params["model_layers"] # default: 3
  55. dropout_rate = model_params["model_dropout_rate"] # default: 0.3
  56. top_k_ratio = model_params["model_top_k_ratio"]
  57. self.top_k_every_n = model_params["model_top_k_every_n"]
  58. dense_neurons = model_params["model_dense_neurons"]
  59. edge_dim = model_params["model_edge_dim"]
  60. self.conv_layers = ModuleList([])
  61. self.transf_layers = ModuleList([])
  62. self.pooling_layers = ModuleList([])
  63. self.bn_layers = ModuleList([])
  64. self.leakyrelu = nn.LeakyReLU(0.1)
  65. # Transformation layer
  66. self.conv1 = GATConv(feature_size,
  67. embedding_size,
  68. heads=n_heads,
  69. dropout=dropout_rate,
  70. edge_dim=edge_dim)
  71. self.transf1 = Linear(embedding_size*n_heads, embedding_size)
  72. self.bn1 = BatchNorm1d(embedding_size)
  73. # Other layers
  74. for i in range(self.n_layers):
  75. self.conv_layers.append(GATConv(embedding_size,
  76. embedding_size,
  77. heads=n_heads,
  78. dropout=dropout_rate,
  79. edge_dim=edge_dim))
  80. self.transf_layers.append(Linear(embedding_size*n_heads, embedding_size))
  81. self.bn_layers.append(BatchNorm1d(embedding_size))
  82. if i % self.top_k_every_n == 0:
  83. self.pooling_layers.append(TopKPooling(embedding_size, ratio=top_k_ratio))
  84. # test extra layer ---------------------------
  85. # self.conv2 = TransformerConv(embedding_size * n_heads,
  86. # embedding_size,
  87. # dropout=dropout_rate,
  88. # edge_dim=edge_dim,
  89. # beta=True)
  90. # self.transf2 = Linear(embedding_size, embedding_size)
  91. # self.bn2 = BatchNorm1d(embedding_size)
  92. # ---------------------------------------------
  93. # Linear layers
  94. # TODO: only linear layers should be changed. either removing them or changing the last linear layer
  95. self.linear1 = Linear(embedding_size * 2, dense_neurons)
  96. self.linear2 = Linear(dense_neurons, int(dense_neurons/2)) # dafault: 128, 64
  97. self.linear3 = Linear(int(dense_neurons/2), int(dense_neurons/2))
  98. def forward(self, x, edge_attr, edge_index, batch_index):
  99. torch.autograd.set_detect_anomaly(True)
  100. # Initial transformation
  101. x = self.conv1(x, edge_index, edge_attr)
  102. # x = torch.relu(self.transf1(x))
  103. x = self.leakyrelu(self.transf1(x))
  104. x = self.bn1(x)
  105. # Holds the intermediate graph representations
  106. global_representation = []
  107. for i in range(self.n_layers):
  108. x = self.conv_layers[i](x, edge_index, edge_attr)
  109. # x = torch.relu(self.transf_layers[i](x))
  110. x = self.leakyrelu(self.transf_layers[i](x))
  111. x = self.bn_layers[i](x)
  112. # Always aggregate last layer
  113. if i % self.top_k_every_n == 0 or i == self.n_layers:
  114. x , edge_index, edge_attr, batch_index, _, _ = self.pooling_layers[int(i/self.top_k_every_n)](
  115. x, edge_index, edge_attr, batch_index
  116. )
  117. # Add current representation
  118. global_representation.append(torch.cat([gmp(x, batch_index), gap(x, batch_index)], dim=1))
  119. # # test ------------------------------------
  120. # x = self.conv1(x, edge_index, edge_attr)
  121. # # x = torch.relu(self.transf1(x))
  122. # x = F.elu(x)
  123. # # x = self.bn1(x)
  124. # x = F.dropout(x, p=0.2, training=self.training)
  125. # x = self.conv2(x, edge_index, edge_attr)
  126. # x = F.elu(x)
  127. # # x = self.bn1(x)
  128. # x = F.dropout(x, p=0.2, training=self.training)
  129. # x = torch.relu(self.transf2(x))
  130. # # x = self.bn2(x)
  131. # x = gmp(x, batch_index)
  132. # # -----------------------------------------
  133. x = sum(global_representation)
  134. # Output block
  135. x = torch.relu(self.linear1(x))
  136. x = F.dropout(x, p=0.8, training=self.training)
  137. x = torch.relu(self.linear2(x))
  138. x = F.dropout(x, p=0.8, training=self.training)
  139. x = self.linear3(x)
  140. return x