| @@ -61,8 +61,8 @@ def Graph_load_batch(min_num_nodes=20, max_num_nodes=1000, name='ENZYMES', node_ | |||
| graphs.append(G_sub) | |||
| if G_sub.number_of_nodes() > max_nodes: | |||
| max_nodes = G_sub.number_of_nodes() | |||
| print('Loaded') | |||
| return graphs | |||
| @@ -74,6 +74,7 @@ def feature_matrix(g): | |||
| piazche = np.zeros((len(esm), 3)) | |||
| for i, (k, v) in enumerate(esm.items()): | |||
| piazche[i][v-1] = 1 | |||
| return piazche | |||
| @@ -137,6 +138,7 @@ class GraphConv(nn.Module): | |||
| # print(y.shape) | |||
| # print(self.weight.shape) | |||
| y = torch.matmul(y, self.weight.double()) | |||
| return y | |||
| @@ -189,6 +191,7 @@ class FeedForward(nn.Module): | |||
| relu = self.relu(hidden) | |||
| output = self.fully_connected2(relu) | |||
| output = self.sigmoid(output) | |||
| return output | |||
| @@ -203,6 +206,6 @@ class Hydra(nn.Module): | |||
| def forward(self, x, adj): | |||
| gcn_outputs = self.GCN(x, adj) | |||
| gat_output = self.GAT(gcn_outputs) | |||
| mlp_output = self.MLP(gat_output) | |||
| mlp_output = self.MLP(gat_output).reshape(1,-1) | |||
| return output | |||
| return mlp_output | |||