| graphs.append(G_sub) | graphs.append(G_sub) | ||||
| if G_sub.number_of_nodes() > max_nodes: | if G_sub.number_of_nodes() > max_nodes: | ||||
| max_nodes = G_sub.number_of_nodes() | max_nodes = G_sub.number_of_nodes() | ||||
| print('Loaded') | print('Loaded') | ||||
| return graphs | return graphs | ||||
| piazche = np.zeros((len(esm), 3)) | piazche = np.zeros((len(esm), 3)) | ||||
| for i, (k, v) in enumerate(esm.items()): | for i, (k, v) in enumerate(esm.items()): | ||||
| piazche[i][v-1] = 1 | piazche[i][v-1] = 1 | ||||
| return piazche | return piazche | ||||
| # print(y.shape) | # print(y.shape) | ||||
| # print(self.weight.shape) | # print(self.weight.shape) | ||||
| y = torch.matmul(y, self.weight.double()) | y = torch.matmul(y, self.weight.double()) | ||||
| return y | return y | ||||
| relu = self.relu(hidden) | relu = self.relu(hidden) | ||||
| output = self.fully_connected2(relu) | output = self.fully_connected2(relu) | ||||
| output = self.sigmoid(output) | output = self.sigmoid(output) | ||||
| return output | return output | ||||
| def forward(self, x, adj): | def forward(self, x, adj): | ||||
| gcn_outputs = self.GCN(x, adj) | gcn_outputs = self.GCN(x, adj) | ||||
| gat_output = self.GAT(gcn_outputs) | gat_output = self.GAT(gcn_outputs) | ||||
| mlp_output = self.MLP(gat_output) | |||||
| mlp_output = self.MLP(gat_output).reshape(1,-1) | |||||
| return output | |||||
| return mlp_output |