Browse Source

padding added

master
Yassaman Ommi 4 years ago
parent
commit
6703f97347
1 changed files with 26 additions and 4 deletions
  1. 26
    4
      GraphTransformer.py

+ 26
- 4
GraphTransformer.py View File

choice = np.random.choice(list(relabeled_graph.nodes())) choice = np.random.choice(list(relabeled_graph.nodes()))
remaining_graph = relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes()))) remaining_graph = relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes())))
remaining_graph_adj = nx.to_numpy_matrix(remaining_graph) remaining_graph_adj = nx.to_numpy_matrix(remaining_graph)
graph_length = len(remaining_graph)
remaining_graph_adj = np.pad(remaining_graph_adj, [(0, max_size - graph_length), (0, max_size - graph_length)])
removed_node = nx.to_numpy_matrix(relabeled_graph)[choice] removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]
removed_node_row = np.asarray(removed_node)[0] removed_node_row = np.asarray(removed_node)[0]
removed_node_row = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))])
return remaining_graph_adj, feature_matrix(remaining_graph), removed_node_row return remaining_graph_adj, feature_matrix(remaining_graph), removed_node_row


"""" """"
Layers: Layers:
Graph Convolution Graph Convolution
Graph Multihead Attention Graph Multihead Attention
Feed-Forward (as a MLP)
""" """


class GraphConv(nn.Module): class GraphConv(nn.Module):


def forward(self, query, key, value): def forward(self, query, key, value):
# print(q, k, v) # print(q, k, v)
bs = query.size(0) # size of the graph
key = self.k_linear(key).view(bs, -1, self.heads, self.key_dim)
query = self.q_linear(query).view(bs, -1, self.heads, self.key_dim)
value = self.v_linear(value).view(bs, -1, self.heads, self.key_dim)
bs = query.size(0)

key = self.k_linear(key.float()).view(bs, -1, self.heads, self.key_dim)
query = self.q_linear(query.float()).view(bs, -1, self.heads, self.key_dim)
value = self.v_linear(value.float()).view(bs, -1, self.heads, self.key_dim)


key = key.transpose(1,2) key = key.transpose(1,2)
query = query.transpose(1,2) query = query.transpose(1,2)
output = output.view(bs, self.model_dim) output = output.view(bs, self.model_dim)


return output return output

class FeedForward(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.fully_connected1 = nn.Linear(self.input_size, self.hidden_size).cuda()
self.relu = nn.ReLU()
self.fully_connected2 = nn.Linear(self.hidden_size, 1).cuda()
self.sigmoid = nn.Sigmoid()

def forward(self, x):
hidden = self.fully_connected1(x.float())
relu = self.relu(hidden)
output = self.fully_connected2(relu)
output = self.sigmoid(output)
return output

Loading…
Cancel
Save