Browse Source

padding added

master
Yassaman Ommi 4 years ago
parent
commit
cbaad92456
1 changed files with 21 additions and 1 deletions
  1. 21
    1
      GraphTransformer.py

+ 21
- 1
GraphTransformer.py View File

@@ -65,6 +65,7 @@ def Graph_load_batch(min_num_nodes=20, max_num_nodes=1000, name='ENZYMES', node_
print('Loaded')
return graphs


def feature_matrix(g):
'''
constructs the feautre matrix (N x 3) for the enzymes datasets
@@ -75,6 +76,7 @@ def feature_matrix(g):
piazche[i][v-1] = 1
return piazche


# def remove_random_node(graph, max_size=40, min_size=10):
# '''
# removes a random node from the gragh
@@ -96,7 +98,7 @@ def feature_matrix(g):
def prepare_graph_data(graph, max_size=40, min_size=10):
'''
gets a graph as an input
returns a graph with a randomly removed node adj matrix [0], its feature matrix [0], the removed node true links [2]
returns a graph with a randomly removed node adj matrix [0], its feature matrix [1], the removed node true links [2]
'''
if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:
return None
@@ -137,6 +139,7 @@ class GraphConv(nn.Module):
y = torch.matmul(y, self.weight.double())
return y


class GraphAttn(nn.Module):
def __init__(self, heads, model_dim, dropout=0.1):
super().__init__()
@@ -170,6 +173,7 @@ class GraphAttn(nn.Module):

return output


class FeedForward(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
@@ -186,3 +190,19 @@ class FeedForward(nn.Module):
output = self.fully_connected2(relu)
output = self.sigmoid(output)
return output


class Hydra(nn.Module):
def __init__(self, gcn_input, model_dim, head):
super().__init__()

self.GCN = GraphConv(input_dim=gcn_input, output_dim=model_dim).cuda()
self.GAT = GraphAttn(heads=head, model_dim=model_dim).cuda()
self.MLP = FeedForward(input_size=model_dim, hidden_size=gcn_input).cuda()

def forward(self, x, adj):
gcn_outputs = self.GCN(x, adj)
gat_output = self.GAT(gcn_outputs)
mlp_output = self.MLP(gat_output)

return output

Loading…
Cancel
Save