| print('Loaded') | print('Loaded') | ||||
| return graphs | return graphs | ||||
| def feature_matrix(g): | def feature_matrix(g): | ||||
| ''' | ''' | ||||
| constructs the feautre matrix (N x 3) for the enzymes datasets | constructs the feautre matrix (N x 3) for the enzymes datasets | ||||
| piazche[i][v-1] = 1 | piazche[i][v-1] = 1 | ||||
| return piazche | return piazche | ||||
| # def remove_random_node(graph, max_size=40, min_size=10): | # def remove_random_node(graph, max_size=40, min_size=10): | ||||
| # ''' | # ''' | ||||
| # removes a random node from the gragh | # removes a random node from the gragh | ||||
| def prepare_graph_data(graph, max_size=40, min_size=10): | def prepare_graph_data(graph, max_size=40, min_size=10): | ||||
| ''' | ''' | ||||
| gets a graph as an input | gets a graph as an input | ||||
| returns a graph with a randomly removed node adj matrix [0], its feature matrix [0], the removed node true links [2] | |||||
| returns a graph with a randomly removed node adj matrix [0], its feature matrix [1], the removed node true links [2] | |||||
| ''' | ''' | ||||
| if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size: | if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size: | ||||
| return None | return None | ||||
| y = torch.matmul(y, self.weight.double()) | y = torch.matmul(y, self.weight.double()) | ||||
| return y | return y | ||||
| class GraphAttn(nn.Module): | class GraphAttn(nn.Module): | ||||
| def __init__(self, heads, model_dim, dropout=0.1): | def __init__(self, heads, model_dim, dropout=0.1): | ||||
| super().__init__() | super().__init__() | ||||
| return output | return output | ||||
| class FeedForward(nn.Module): | class FeedForward(nn.Module): | ||||
| def __init__(self, input_size, hidden_size): | def __init__(self, input_size, hidden_size): | ||||
| super().__init__() | super().__init__() | ||||
| output = self.fully_connected2(relu) | output = self.fully_connected2(relu) | ||||
| output = self.sigmoid(output) | output = self.sigmoid(output) | ||||
| return output | return output | ||||
| class Hydra(nn.Module): | |||||
| def __init__(self, gcn_input, model_dim, head): | |||||
| super().__init__() | |||||
| self.GCN = GraphConv(input_dim=gcn_input, output_dim=model_dim).cuda() | |||||
| self.GAT = GraphAttn(heads=head, model_dim=model_dim).cuda() | |||||
| self.MLP = FeedForward(input_size=model_dim, hidden_size=gcn_input).cuda() | |||||
| def forward(self, x, adj): | |||||
| gcn_outputs = self.GCN(x, adj) | |||||
| gat_output = self.GAT(gcn_outputs) | |||||
| mlp_output = self.MLP(gat_output) | |||||
| return output |