|
|
@@ -61,8 +61,8 @@ def Graph_load_batch(min_num_nodes=20, max_num_nodes=1000, name='ENZYMES', node_ |
|
|
|
graphs.append(G_sub) |
|
|
|
if G_sub.number_of_nodes() > max_nodes: |
|
|
|
max_nodes = G_sub.number_of_nodes() |
|
|
|
|
|
|
|
print('Loaded') |
|
|
|
|
|
|
|
return graphs |
|
|
|
|
|
|
|
|
|
|
@@ -74,6 +74,7 @@ def feature_matrix(g): |
|
|
|
piazche = np.zeros((len(esm), 3)) |
|
|
|
for i, (k, v) in enumerate(esm.items()): |
|
|
|
piazche[i][v-1] = 1 |
|
|
|
|
|
|
|
return piazche |
|
|
|
|
|
|
|
|
|
|
@@ -137,6 +138,7 @@ class GraphConv(nn.Module): |
|
|
|
# print(y.shape) |
|
|
|
# print(self.weight.shape) |
|
|
|
y = torch.matmul(y, self.weight.double()) |
|
|
|
|
|
|
|
return y |
|
|
|
|
|
|
|
|
|
|
@@ -189,6 +191,7 @@ class FeedForward(nn.Module): |
|
|
|
relu = self.relu(hidden) |
|
|
|
output = self.fully_connected2(relu) |
|
|
|
output = self.sigmoid(output) |
|
|
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
@@ -203,6 +206,6 @@ class Hydra(nn.Module): |
|
|
|
def forward(self, x, adj): |
|
|
|
gcn_outputs = self.GCN(x, adj) |
|
|
|
gat_output = self.GAT(gcn_outputs) |
|
|
|
mlp_output = self.MLP(gat_output) |
|
|
|
mlp_output = self.MLP(gat_output).reshape(1,-1) |
|
|
|
|
|
|
|
return output |
|
|
|
return mlp_output |