123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147 |
- import networkx as nx
- import numpy as np
- import torch
- from sklearn.metrics import mean_absolute_error
- import sys
- import torch.nn.functional as F
-
- from baselines.graphvae.graphvae_train import build_model
- from baselines.graphvae.graphvae_train import arg_parse
- from baselines.graphvae.args import GraphVAE_Args
- from baselines.graphvae.graphvae_model import GraphVAE
- from baselines.graphvae.graphvae_data import GraphAdjSampler
- from baselines.graphvae.graphvae_model import graph_show
- from torch.autograd import Variable
-
- from data import bfs_seq
- from model import sample_sigmoid
- import gmatch4py as gm
-
-
- def matrix_permute(adj):
- adj = adj.numpy()[0]
- # x_idx = np.random.permutation(adj.shape[0])
- # permuted_adj = adj[np.ix_(x_idx, x_idx)]
-
- start_idx = np.random.randint(adj.shape[0])
- x_idx = np.array(bfs_seq(nx.from_numpy_array(adj), start_idx))
- permuted_adj = adj[np.ix_(x_idx, x_idx)]
- permuted_adj = np.asmatrix(permuted_adj)
- permuted_adj = torch.from_numpy(permuted_adj)
- permuted_adj = permuted_adj.unsqueeze(0)
- return permuted_adj
-
-
- def get_threshold(adj_input, h_decode):
- h_decode = F.sigmoid(h_decode)
- adj_output = torch.zeros(model.max_num_nodes, model.max_num_nodes)
- adj_output[torch.triu(torch.ones(model.max_num_nodes, model.max_num_nodes)) == 1] = h_decode.data.cpu()
- adj_input_matrix = adj_input.cpu().data.squeeze().numpy()
- adj_input_triu_without_diag = adj_input_matrix - np.tril(adj_input_matrix)
- adj_output_matrix = adj_output.cpu().data.squeeze().numpy()
- adj_output_triu_without_diag = adj_output_matrix - np.tril(adj_output_matrix)
- number_of_edges = np.count_nonzero(adj_input_triu_without_diag)
- adj_output_triu_without_diag = adj_output_triu_without_diag.ravel()
- sorted_array = adj_output_triu_without_diag[np.argsort(adj_output_triu_without_diag)]
- threshold = sorted_array[-number_of_edges-1]
- return threshold
-
-
- def test(graphs, model, args, max_num_nodes, completion=False, graph_show_mode=False):
- ged = gm.GraphEditDistance(1, 1, 1, 1) # all edit costs are equal to 1
- fname = args.model_save_path + "GraphVAE" + str(100) + '.dat'
- model.load_state_dict(torch.load(fname))
- model.eval()
- dataset = GraphAdjSampler(graphs, max_num_nodes,args.permutation_mode, args.bfs_mode, features=prog_args.feature_type)
- dataset_loader = torch.utils.data.DataLoader(
- dataset,
- batch_size=prog_args.batch_size,
- num_workers=prog_args.num_workers)
- mae_list = []
- ged_list = []
- for batch_idx, data in enumerate(dataset_loader):
- features = data['features'].float()
- adj_input = data['adj'].float()
- if args.permutation_mode:
- # print("**** Before")
- # print(adj_input)
- adj_input = matrix_permute(adj_input)
- # print("**** After")
- # print(adj_input)
- if completion:
- graph_size = adj_input.size()[1]
- numpy_adj = adj_input.cpu().numpy()[0]
- index_list = []
- for i in range(my_args.number_of_missing_nodes):
- random_index = np.random.randint(graph_size)
- while random_index in index_list:
- random_index = np.random.randint(graph_size)
- index_list.append(random_index)
- numpy_adj[:, random_index] = 0
- numpy_adj[random_index, :] = 0
- adj_input = torch.tensor(numpy_adj, device='cuda:0').unsqueeze(0)
- features = Variable(features).cuda()
- adj_input = Variable(adj_input).cuda()
- x = model.conv1(features, adj_input)
- x = model.bn1(x)
- x = model.act(x)
- x = model.conv2(x, adj_input)
- x = model.bn2(x)
- x = x.view(-1, model.max_num_nodes * model.hidden_dim)
- h_decode, z_mu, z_lsgms = model.vae(x)
- x2 = h_decode[0].unsqueeze(0)
- x3 = x2.unsqueeze(0)
- thresh = get_threshold(adj_input, h_decode)
- sample = sample_sigmoid(x3, sample=False, thresh=thresh, sample_time=10)
- y = torch.zeros(model.max_num_nodes, model.max_num_nodes)
- y[torch.triu(torch.ones(model.max_num_nodes, model.max_num_nodes)) == 1] = sample.data.cpu()
- colors = [(0.7509279299037631, 0.021203049355839054, 0.24561203044115132)]
- input_graph = nx.from_numpy_matrix(adj_input[0].data.cpu().numpy())
- largest_connected_component_input = max(nx.connected_component_subgraphs(input_graph), key=len)
- mae = mean_absolute_error(np.triu(adj_input[0].data.cpu().numpy()), np.triu(y.data.cpu().numpy()))
- # print("@@@@ mae: " + str(mae))
- mae_list.append(mae)
- output_graph = nx.from_numpy_matrix(y.data.cpu().numpy())
- largest_connected_component_output = max(nx.connected_component_subgraphs(output_graph), key=len)
- result = ged.compare([input_graph, output_graph], None)
- ged_value = result[0][1]
- ged_list.append(ged_value)
- if graph_show_mode:
- graph_show(input_graph, "input_graph", colors)
- graph_show(output_graph, "output_graph", colors)
- print("@@@@ ged: " + str(ged_value))
-
- # break
- # print(result)
- if graph_show_mode:
- print(ged_list)
- print("**************** GED: " + str(np.mean(ged_list)))
- return np.mean(ged_list)
-
-
- def create_graph():
- graphs = []
- # for i in range(2, 6):
- # for j in range(2, 6):
- # graphs.append(nx.grid_2d_graph(i, j))
- # graphs.append(nx.grid_2d_graph(2, 3))
- graphs.append(nx.grid_2d_graph(1, 12))
- graphs.append(nx.grid_2d_graph(2, 6))
- graphs.append(nx.grid_2d_graph(3, 4))
- graphs.append(nx.grid_2d_graph(4, 3))
- graphs.append(nx.grid_2d_graph(6, 2))
- graphs.append(nx.grid_2d_graph(12, 1))
- return graphs
-
-
- if __name__ == '__main__':
- prog_args = arg_parse()
- my_args = GraphVAE_Args()
- graphs = create_graph()
- max_num_nodes = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
- model = build_model(prog_args, max_num_nodes).cuda()
- number_of_repeats = 100
- ged_list = []
- for i in range(number_of_repeats):
- ged_list.append(test(graphs, model, my_args, max_num_nodes, True, False))
- print(" GED = " + str(np.mean(ged_list)) + " after " + str(number_of_repeats) + " repeats")
|