import networkx as nx import numpy as np import torch from sklearn.metrics import mean_absolute_error import sys import torch.nn.functional as F from baselines.graphvae.graphvae_train import build_model from baselines.graphvae.graphvae_train import arg_parse from baselines.graphvae.args import GraphVAE_Args from baselines.graphvae.graphvae_model import GraphVAE from baselines.graphvae.graphvae_data import GraphAdjSampler from baselines.graphvae.graphvae_model import graph_show from torch.autograd import Variable from data import bfs_seq from model import sample_sigmoid import gmatch4py as gm def matrix_permute(adj): adj = adj.numpy()[0] # x_idx = np.random.permutation(adj.shape[0]) # permuted_adj = adj[np.ix_(x_idx, x_idx)] start_idx = np.random.randint(adj.shape[0]) x_idx = np.array(bfs_seq(nx.from_numpy_array(adj), start_idx)) permuted_adj = adj[np.ix_(x_idx, x_idx)] permuted_adj = np.asmatrix(permuted_adj) permuted_adj = torch.from_numpy(permuted_adj) permuted_adj = permuted_adj.unsqueeze(0) return permuted_adj def get_threshold(adj_input, h_decode): h_decode = F.sigmoid(h_decode) adj_output = torch.zeros(model.max_num_nodes, model.max_num_nodes) adj_output[torch.triu(torch.ones(model.max_num_nodes, model.max_num_nodes)) == 1] = h_decode.data.cpu() adj_input_matrix = adj_input.cpu().data.squeeze().numpy() adj_input_triu_without_diag = adj_input_matrix - np.tril(adj_input_matrix) adj_output_matrix = adj_output.cpu().data.squeeze().numpy() adj_output_triu_without_diag = adj_output_matrix - np.tril(adj_output_matrix) number_of_edges = np.count_nonzero(adj_input_triu_without_diag) adj_output_triu_without_diag = adj_output_triu_without_diag.ravel() sorted_array = adj_output_triu_without_diag[np.argsort(adj_output_triu_without_diag)] threshold = sorted_array[-number_of_edges-1] return threshold def test(graphs, model, args, max_num_nodes, completion=False, graph_show_mode=False): ged = gm.GraphEditDistance(1, 1, 1, 1) # all edit costs are equal to 1 fname = args.model_save_path + "GraphVAE" + str(100) + '.dat' model.load_state_dict(torch.load(fname)) model.eval() dataset = GraphAdjSampler(graphs, max_num_nodes,args.permutation_mode, args.bfs_mode, features=prog_args.feature_type) dataset_loader = torch.utils.data.DataLoader( dataset, batch_size=prog_args.batch_size, num_workers=prog_args.num_workers) mae_list = [] ged_list = [] for batch_idx, data in enumerate(dataset_loader): features = data['features'].float() adj_input = data['adj'].float() if args.permutation_mode: # print("**** Before") # print(adj_input) adj_input = matrix_permute(adj_input) # print("**** After") # print(adj_input) if completion: graph_size = adj_input.size()[1] numpy_adj = adj_input.cpu().numpy()[0] index_list = [] for i in range(my_args.number_of_missing_nodes): random_index = np.random.randint(graph_size) while random_index in index_list: random_index = np.random.randint(graph_size) index_list.append(random_index) numpy_adj[:, random_index] = 0 numpy_adj[random_index, :] = 0 adj_input = torch.tensor(numpy_adj, device='cuda:0').unsqueeze(0) features = Variable(features).cuda() adj_input = Variable(adj_input).cuda() x = model.conv1(features, adj_input) x = model.bn1(x) x = model.act(x) x = model.conv2(x, adj_input) x = model.bn2(x) x = x.view(-1, model.max_num_nodes * model.hidden_dim) h_decode, z_mu, z_lsgms = model.vae(x) x2 = h_decode[0].unsqueeze(0) x3 = x2.unsqueeze(0) thresh = get_threshold(adj_input, h_decode) sample = sample_sigmoid(x3, sample=False, thresh=thresh, sample_time=10) y = torch.zeros(model.max_num_nodes, model.max_num_nodes) y[torch.triu(torch.ones(model.max_num_nodes, model.max_num_nodes)) == 1] = sample.data.cpu() colors = [(0.7509279299037631, 0.021203049355839054, 0.24561203044115132)] input_graph = nx.from_numpy_matrix(adj_input[0].data.cpu().numpy()) largest_connected_component_input = max(nx.connected_component_subgraphs(input_graph), key=len) mae = mean_absolute_error(np.triu(adj_input[0].data.cpu().numpy()), np.triu(y.data.cpu().numpy())) # print("@@@@ mae: " + str(mae)) mae_list.append(mae) output_graph = nx.from_numpy_matrix(y.data.cpu().numpy()) largest_connected_component_output = max(nx.connected_component_subgraphs(output_graph), key=len) result = ged.compare([input_graph, output_graph], None) ged_value = result[0][1] ged_list.append(ged_value) if graph_show_mode: graph_show(input_graph, "input_graph", colors) graph_show(output_graph, "output_graph", colors) print("@@@@ ged: " + str(ged_value)) # break # print(result) if graph_show_mode: print(ged_list) print("**************** GED: " + str(np.mean(ged_list))) return np.mean(ged_list) def create_graph(): graphs = [] # for i in range(2, 6): # for j in range(2, 6): # graphs.append(nx.grid_2d_graph(i, j)) # graphs.append(nx.grid_2d_graph(2, 3)) graphs.append(nx.grid_2d_graph(1, 12)) graphs.append(nx.grid_2d_graph(2, 6)) graphs.append(nx.grid_2d_graph(3, 4)) graphs.append(nx.grid_2d_graph(4, 3)) graphs.append(nx.grid_2d_graph(6, 2)) graphs.append(nx.grid_2d_graph(12, 1)) return graphs if __name__ == '__main__': prog_args = arg_parse() my_args = GraphVAE_Args() graphs = create_graph() max_num_nodes = max([graphs[i].number_of_nodes() for i in range(len(graphs))]) model = build_model(prog_args, max_num_nodes).cuda() number_of_repeats = 100 ged_list = [] for i in range(number_of_repeats): ged_list.append(test(graphs, model, my_args, max_num_nodes, True, False)) print(" GED = " + str(np.mean(ged_list)) + " after " + str(number_of_repeats) + " repeats")