import networkx as nx import numpy as np import torch from sklearn.metrics import mean_absolute_error import sys import torch.nn.functional as F from baselines.graphvae.graphvae_train import build_model from baselines.graphvae.graphvae_train import arg_parse from baselines.graphvae.args import GraphVAE_Args from baselines.graphvae.graphvae_model import GraphVAE from baselines.graphvae.graphvae_data import GraphAdjSampler from baselines.graphvae.graphvae_model import graph_show from torch.autograd import Variable from data import bfs_seq from model import sample_sigmoid import gmatch4py as gm def matrix_permute(adj, args): adj = adj.numpy()[0] if args.bfs_mode_with_arbitrary_node_deleted: x_idx = np.random.permutation(adj.shape[0]) adj = adj[np.ix_(x_idx, x_idx)] adj = np.asmatrix(adj) random_idx_for_delete = np.random.randint(adj.shape[0]) deleted_node = adj[:, random_idx_for_delete].copy() for i in range(deleted_node.__len__()): if i >= random_idx_for_delete and i < deleted_node.__len__() - 1: deleted_node[i] = deleted_node[i + 1] elif i == deleted_node.__len__() - 1: deleted_node[i] = 0 adj[:, random_idx_for_delete:adj.shape[0] - 1] = adj[:, random_idx_for_delete + 1:adj.shape[0]] adj[random_idx_for_delete:adj.shape[0] - 1, :] = adj[random_idx_for_delete + 1:adj.shape[0], :] adj = np.delete(adj, -1, axis=1) adj = np.delete(adj, -1, axis=0) G = nx.from_numpy_matrix(adj) # then do bfs in the permuted G degree_arr = np.sum(adj, axis=0) start_idx = np.argmax(degree_arr) # start_idx = np.random.randint(adj.shape[0]) x_idx = np.array(bfs_seq(G, start_idx)) adj = adj[np.ix_(x_idx, x_idx)] x_idx = np.insert(x_idx, x_idx.size, x_idx.size) deleted_node = deleted_node[np.ix_(x_idx)] adj = np.append(adj, deleted_node[:-1], axis=1) deleted_node = deleted_node.reshape(1, -1) adj = np.vstack([adj, deleted_node]) # permuted_adj = adj permuted_adj = np.asmatrix(permuted_adj) permuted_adj = torch.from_numpy(permuted_adj) permuted_adj = permuted_adj.unsqueeze(0) else: start_idx = np.random.randint(adj.shape[0]) x_idx = np.array(bfs_seq(nx.from_numpy_array(adj), start_idx)) permuted_adj = adj[np.ix_(x_idx, x_idx)] permuted_adj = np.asmatrix(permuted_adj) permuted_adj = torch.from_numpy(permuted_adj) permuted_adj = permuted_adj.unsqueeze(0) return permuted_adj def get_threshold(adj_input, h_decode): h_decode = F.sigmoid(h_decode) adj_output = torch.zeros(adj_input.size()[0], model.max_num_nodes, model.max_num_nodes) adj_output[:, torch.triu(torch.ones(model.max_num_nodes, model.max_num_nodes)) == 1] = h_decode.data.cpu() adj_input_matrix = adj_input.cpu().data.squeeze().numpy() adj_input_triu_without_diag = adj_input_matrix - np.tril(adj_input_matrix) adj_output_matrix = adj_output.cpu().data.squeeze().numpy() adj_output_triu_without_diag = adj_output_matrix - np.tril(adj_output_matrix) number_of_edges = np.count_nonzero(adj_input_triu_without_diag) adj_output_triu_without_diag = adj_output_triu_without_diag.ravel() sorted_array = adj_output_triu_without_diag[np.argsort(adj_output_triu_without_diag)] threshold = sorted_array[-number_of_edges - 1] return threshold def test(graphs, model, args, max_num_nodes, completion=False, graph_show_mode=False): ged = gm.GraphEditDistance(1, 1, 1, 1) # all edit costs are equal to 1 fname = args.model_save_path + "GraphVAE" + str(100) + '.dat' model.load_state_dict(torch.load(fname)) model.eval() dataset = GraphAdjSampler(graphs, max_num_nodes, args.permutation_mode, args.bfs_mode, args.bfs_mode_with_arbitrary_node_deleted, features=prog_args.feature_type) dataset_loader = torch.utils.data.DataLoader( dataset, batch_size=prog_args.batch_size, num_workers=prog_args.num_workers) mae_list = [] ged_list = [] for batch_idx, data in enumerate(dataset_loader): input_features = data['features'].float() adj_input = data['adj'].float() if args.permutation_mode: adj_input = matrix_permute(adj_input, args) if my_args.GRAN: numpy_adj = adj_input.cpu().numpy() numpy_adj[:, :, max_num_nodes - 1] = 1 numpy_adj[:, max_num_nodes - 1, :] = 1 incomplete_adj = torch.tensor(numpy_adj, device='cuda:0') input_features = torch.tensor(input_features, device='cuda:0') x = model.conv1(input_features, incomplete_adj) x = model.bn1(x) x = model.act(x) x = model.conv2(x, incomplete_adj) x = model.bn2(x) x = model.act(x) print("*************************") network_result = model.gran(x).squeeze().cpu() print(network_result.size()) x2 = network_result.unsqueeze(0) print(x2.size()) x3 = x2.unsqueeze(0) print(x3.size()) print(F.sigmoid(network_result)) sample = sample_sigmoid(x3.cpu(), sample=True, thresh=0.5, sample_time=10) sample = sample.squeeze(0) return if my_args.completion_mode_small_parameter_size: numpy_adj = adj_input.cpu().numpy() number_of_incomplete_nodes = model.max_num_nodes - my_args.number_of_missing_nodes incomplete_adj = torch.tensor( numpy_adj[:, :number_of_incomplete_nodes, :number_of_incomplete_nodes], device='cuda:0') input_features = torch.tensor( input_features[:, :number_of_incomplete_nodes, :number_of_incomplete_nodes], device='cuda:0') x = model.conv1(input_features, incomplete_adj) x = model.bn1(x) x = model.act(x) x = model.conv2(x, incomplete_adj) x = model.bn2(x) x = x.view(-1, number_of_incomplete_nodes * model.hidden_dim) elif completion: graph_size = adj_input.size()[1] numpy_adj = adj_input.cpu().numpy() index_list = [] for i in range(my_args.number_of_missing_nodes): random_index = np.random.randint(graph_size) while random_index in index_list: random_index = np.random.randint(graph_size) index_list.append(random_index) numpy_adj[:, :, random_index] = 0 numpy_adj[:, random_index, :] = 0 adj_input = torch.tensor(numpy_adj, device='cuda:0') input_features = Variable(input_features).cuda() adj_input = Variable(adj_input).cuda() x = model.conv1(input_features, adj_input) x = model.bn1(x) x = model.act(x) x = model.conv2(x, adj_input) x = model.bn2(x) x = x.view(-1, model.max_num_nodes * model.hidden_dim) h_decode, z_mu, z_lsgms = model.vae(x) x2 = h_decode[0].unsqueeze(0) x3 = x2.unsqueeze(0) if my_args.completion_mode_small_parameter_size: # temporary sample = sample_sigmoid(x3, sample=True, thresh=0.5, sample_time=10) sample = sample.squeeze(0) else: thresh = get_threshold(adj_input, h_decode) sample = sample_sigmoid(x3, sample=False, thresh=thresh, sample_time=10) if my_args.completion_mode_small_parameter_size: y = torch.zeros(adj_input.size()[0], model.max_num_nodes, model.max_num_nodes) y[:, torch.triu(torch.ones(model.max_num_nodes, model.max_num_nodes)) == 1] = \ adj_input[:, torch.triu(torch.ones(model.max_num_nodes, model.max_num_nodes)) == 1] adj_vectorized_index = 0 for i in range(my_args.number_of_missing_nodes): # iterates over columns for j in range(model.max_num_nodes - i): # print("********************************") # print(y[:, model.max_num_nodes - j - i - 1, # model.max_num_nodes - i - 1].size()) # print(sample[:, i + j].size()) # print(sample.size()) y[:, model.max_num_nodes - j - i - 1, model.max_num_nodes - i - 1] = sample[:, adj_vectorized_index] adj_vectorized_index += 1 else: y = torch.zeros(model.max_num_nodes, model.max_num_nodes) y[torch.triu(torch.ones(model.max_num_nodes, model.max_num_nodes)) == 1] = sample.data.cpu() colors = [(0.7509279299037631, 0.021203049355839054, 0.24561203044115132)] input_graph = nx.from_numpy_matrix(adj_input[0].data.cpu().numpy()) largest_connected_component_input = max(nx.connected_component_subgraphs(input_graph), key=len) # mae = mean_absolute_error(np.triu(adj_input[0].data.cpu().numpy()), np.triu(y.data.cpu().numpy())) # print("@@@@ mae: " + str(mae)) # mae_list.append(mae) output_graph = nx.from_numpy_matrix(y.data.cpu().numpy()[0]) largest_connected_component_output = max(nx.connected_component_subgraphs(output_graph), key=len) result = ged.compare([input_graph, output_graph], None) ged_value = result[0][1] ged_list.append(ged_value) if graph_show_mode: # graph_show(input_graph, "input_graph", colors) graph_show(nx.from_numpy_matrix(incomplete_adj[0].data.cpu().numpy()), "input_graph", colors) graph_show(output_graph, "output_graph", colors) print("@@@@ ged: " + str(ged_value)) if graph_show_mode: print(ged_list) print("**************** GED: " + str(np.mean(ged_list))) return np.mean(ged_list) def create_graph(): graphs = [] # for i in range(2, 6): # for j in range(2, 6): # graphs.append(nx.grid_2d_graph(i, j)) # graphs.append(nx.grid_2d_graph(2, 3)) graphs.append(nx.grid_2d_graph(1, 12)) graphs.append(nx.grid_2d_graph(2, 6)) graphs.append(nx.grid_2d_graph(3, 4)) graphs.append(nx.grid_2d_graph(4, 3)) graphs.append(nx.grid_2d_graph(6, 2)) graphs.append(nx.grid_2d_graph(12, 1)) return graphs if __name__ == '__main__': prog_args = arg_parse() my_args = GraphVAE_Args() graphs = create_graph() max_num_nodes = max([graphs[i].number_of_nodes() for i in range(len(graphs))]) model = build_model(prog_args, max_num_nodes).cuda() number_of_repeats = 1 ged_list = [] for i in range(number_of_repeats): print(i) ged_list.append(test(graphs, model, my_args, max_num_nodes, True, True)) print(" GED = " + str(np.mean(ged_list)) + " after " + str(number_of_repeats) + " repeats")