Browse Source

ready to write the train function :D

master
Yassaman Ommi 4 years ago
parent
commit
07555746bd
1 changed files with 6 additions and 3 deletions
  1. 6
    3
      GraphTransformer.py

+ 6
- 3
GraphTransformer.py View File

@@ -61,8 +61,8 @@ def Graph_load_batch(min_num_nodes=20, max_num_nodes=1000, name='ENZYMES', node_
graphs.append(G_sub)
if G_sub.number_of_nodes() > max_nodes:
max_nodes = G_sub.number_of_nodes()

print('Loaded')

return graphs


@@ -74,6 +74,7 @@ def feature_matrix(g):
piazche = np.zeros((len(esm), 3))
for i, (k, v) in enumerate(esm.items()):
piazche[i][v-1] = 1

return piazche


@@ -137,6 +138,7 @@ class GraphConv(nn.Module):
# print(y.shape)
# print(self.weight.shape)
y = torch.matmul(y, self.weight.double())

return y


@@ -189,6 +191,7 @@ class FeedForward(nn.Module):
relu = self.relu(hidden)
output = self.fully_connected2(relu)
output = self.sigmoid(output)

return output


@@ -203,6 +206,6 @@ class Hydra(nn.Module):
def forward(self, x, adj):
gcn_outputs = self.GCN(x, adj)
gat_output = self.GAT(gcn_outputs)
mlp_output = self.MLP(gat_output)
mlp_output = self.MLP(gat_output).reshape(1,-1)

return output
return mlp_output

Loading…
Cancel
Save