Browse Source

ready to write the train function :D

master
Yassaman Ommi 4 years ago
parent
commit
07555746bd
1 changed files with 6 additions and 3 deletions
  1. 6
    3
      GraphTransformer.py

+ 6
- 3
GraphTransformer.py View File

graphs.append(G_sub) graphs.append(G_sub)
if G_sub.number_of_nodes() > max_nodes: if G_sub.number_of_nodes() > max_nodes:
max_nodes = G_sub.number_of_nodes() max_nodes = G_sub.number_of_nodes()

print('Loaded') print('Loaded')

return graphs return graphs




piazche = np.zeros((len(esm), 3)) piazche = np.zeros((len(esm), 3))
for i, (k, v) in enumerate(esm.items()): for i, (k, v) in enumerate(esm.items()):
piazche[i][v-1] = 1 piazche[i][v-1] = 1

return piazche return piazche




# print(y.shape) # print(y.shape)
# print(self.weight.shape) # print(self.weight.shape)
y = torch.matmul(y, self.weight.double()) y = torch.matmul(y, self.weight.double())

return y return y




relu = self.relu(hidden) relu = self.relu(hidden)
output = self.fully_connected2(relu) output = self.fully_connected2(relu)
output = self.sigmoid(output) output = self.sigmoid(output)

return output return output




def forward(self, x, adj): def forward(self, x, adj):
gcn_outputs = self.GCN(x, adj) gcn_outputs = self.GCN(x, adj)
gat_output = self.GAT(gcn_outputs) gat_output = self.GAT(gcn_outputs)
mlp_output = self.MLP(gat_output)
mlp_output = self.MLP(gat_output).reshape(1,-1)


return output
return mlp_output

Loading…
Cancel
Save