Browse Source

my vasvas will kill me

master
Yassaman Ommi 4 years ago
parent
commit
bd6b5651ad
1 changed files with 33 additions and 17 deletions
  1. 33
    17
      GraphTransformer.py

+ 33
- 17
GraphTransformer.py View File

@@ -75,23 +75,38 @@ def feature_matrix(g):
piazche[i][v-1] = 1
return piazche

def remove_random_node(graph, max_size=40, min_size=10):
'''
removes a random node from the gragh
returns the remaining graph matrix and the removed node links
'''
if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:
return None
relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph)
choice = np.random.choice(list(relabeled_graph.nodes()))
remaining_graph = nx.to_numpy_matrix(relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes()))))
removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]
graph_length = len(remaining_graph)
# source_graph = np.pad(remaining_graph, [(0, max_size - graph_length), (0, max_size - graph_length)])
# target_graph = np.copy(source_graph)
removed_node_row = np.asarray(removed_node)[0]
# target_graph[graph_length] = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))])
return remaining_graph, removed_node_row
# def remove_random_node(graph, max_size=40, min_size=10):
# '''
# removes a random node from the gragh
# returns the remaining graph matrix and the removed node links
# '''
# if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:
# return None
# relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph)
# choice = np.random.choice(list(relabeled_graph.nodes()))
# remaining_graph = nx.to_numpy_matrix(relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes()))))
# removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]
# graph_length = len(remaining_graph)
# # source_graph = np.pad(remaining_graph, [(0, max_size - graph_length), (0, max_size - graph_length)])
# # target_graph = np.copy(source_graph)
# removed_node_row = np.asarray(removed_node)[0]
# # target_graph[graph_length] = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))])
# return remaining_graph, removed_node_row

def prepare_graph_data(graph, max_size=40, min_size=10):
'''
gets a graph as an input
returns a graph with a randomly removed node adj matrix [0], its feature matrix [0], the removed node true links [2]
'''
if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:
return None
relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph)
choice = np.random.choice(list(relabeled_graph.nodes()))
remaining_graph = relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes())))
remaining_graph_adj = nx.to_numpy_matrix(remaining_graph)
removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]
removed_node_row = np.asarray(removed_node)[0]
return remaining_graph_adj, feature_matrix(remaining_graph), removed_node_row

""""
Layers:
@@ -146,5 +161,6 @@ class GraphAttn(nn.Module):
scores = attention(query, key, value, self.key_dim)
concat = scores.transpose(1,2).contiguous().view(bs, -1, self.model_dim)
output = self.out(concat)
output = output.view(bs, self.model_dim)

return output

Loading…
Cancel
Save