| @@ -65,6 +65,7 @@ def Graph_load_batch(min_num_nodes=20, max_num_nodes=1000, name='ENZYMES', node_ | |||
| print('Loaded') | |||
| return graphs | |||
| def feature_matrix(g): | |||
| ''' | |||
| constructs the feautre matrix (N x 3) for the enzymes datasets | |||
| @@ -75,6 +76,7 @@ def feature_matrix(g): | |||
| piazche[i][v-1] = 1 | |||
| return piazche | |||
| # def remove_random_node(graph, max_size=40, min_size=10): | |||
| # ''' | |||
| # removes a random node from the gragh | |||
| @@ -96,7 +98,7 @@ def feature_matrix(g): | |||
| def prepare_graph_data(graph, max_size=40, min_size=10): | |||
| ''' | |||
| gets a graph as an input | |||
| returns a graph with a randomly removed node adj matrix [0], its feature matrix [0], the removed node true links [2] | |||
| returns a graph with a randomly removed node adj matrix [0], its feature matrix [1], the removed node true links [2] | |||
| ''' | |||
| if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size: | |||
| return None | |||
| @@ -137,6 +139,7 @@ class GraphConv(nn.Module): | |||
| y = torch.matmul(y, self.weight.double()) | |||
| return y | |||
| class GraphAttn(nn.Module): | |||
| def __init__(self, heads, model_dim, dropout=0.1): | |||
| super().__init__() | |||
| @@ -170,6 +173,7 @@ class GraphAttn(nn.Module): | |||
| return output | |||
| class FeedForward(nn.Module): | |||
| def __init__(self, input_size, hidden_size): | |||
| super().__init__() | |||
| @@ -186,3 +190,19 @@ class FeedForward(nn.Module): | |||
| output = self.fully_connected2(relu) | |||
| output = self.sigmoid(output) | |||
| return output | |||
| class Hydra(nn.Module): | |||
| def __init__(self, gcn_input, model_dim, head): | |||
| super().__init__() | |||
| self.GCN = GraphConv(input_dim=gcn_input, output_dim=model_dim).cuda() | |||
| self.GAT = GraphAttn(heads=head, model_dim=model_dim).cuda() | |||
| self.MLP = FeedForward(input_size=model_dim, hidden_size=gcn_input).cuda() | |||
| def forward(self, x, adj): | |||
| gcn_outputs = self.GCN(x, adj) | |||
| gat_output = self.GAT(gcn_outputs) | |||
| mlp_output = self.MLP(gat_output) | |||
| return output | |||