|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233 |
- import networkx as nx
- import numpy as np
- import torch
- from sklearn.metrics import mean_absolute_error
- import sys
- import torch.nn.functional as F
-
- from baselines.graphvae.graphvae_train import build_model
- from baselines.graphvae.graphvae_train import arg_parse
- from baselines.graphvae.args import GraphVAE_Args
- from baselines.graphvae.graphvae_model import GraphVAE
- from baselines.graphvae.graphvae_data import GraphAdjSampler
- from baselines.graphvae.graphvae_model import graph_show
- from torch.autograd import Variable
-
- from data import bfs_seq
- from model import sample_sigmoid
- import gmatch4py as gm
-
-
- def matrix_permute(adj, args):
- adj = adj.numpy()[0]
- if args.bfs_mode_with_arbitrary_node_deleted:
- x_idx = np.random.permutation(adj.shape[0])
- adj = adj[np.ix_(x_idx, x_idx)]
- adj = np.asmatrix(adj)
- random_idx_for_delete = np.random.randint(adj.shape[0])
- deleted_node = adj[:, random_idx_for_delete].copy()
- for i in range(deleted_node.__len__()):
- if i >= random_idx_for_delete and i < deleted_node.__len__() - 1:
- deleted_node[i] = deleted_node[i + 1]
- elif i == deleted_node.__len__() - 1:
- deleted_node[i] = 0
- adj[:, random_idx_for_delete:adj.shape[0] - 1] = adj[:, random_idx_for_delete + 1:adj.shape[0]]
- adj[random_idx_for_delete:adj.shape[0] - 1, :] = adj[random_idx_for_delete + 1:adj.shape[0], :]
- adj = np.delete(adj, -1, axis=1)
- adj = np.delete(adj, -1, axis=0)
- G = nx.from_numpy_matrix(adj)
- # then do bfs in the permuted G
- degree_arr = np.sum(adj, axis=0)
- start_idx = np.argmax(degree_arr)
- # start_idx = np.random.randint(adj.shape[0])
- x_idx = np.array(bfs_seq(G, start_idx))
- adj = adj[np.ix_(x_idx, x_idx)]
- x_idx = np.insert(x_idx, x_idx.size, x_idx.size)
- deleted_node = deleted_node[np.ix_(x_idx)]
- adj = np.append(adj, deleted_node[:-1], axis=1)
- deleted_node = deleted_node.reshape(1, -1)
- adj = np.vstack([adj, deleted_node])
- #
- permuted_adj = adj
- permuted_adj = np.asmatrix(permuted_adj)
- permuted_adj = torch.from_numpy(permuted_adj)
- permuted_adj = permuted_adj.unsqueeze(0)
- else:
- start_idx = np.random.randint(adj.shape[0])
- x_idx = np.array(bfs_seq(nx.from_numpy_array(adj), start_idx))
- permuted_adj = adj[np.ix_(x_idx, x_idx)]
- permuted_adj = np.asmatrix(permuted_adj)
- permuted_adj = torch.from_numpy(permuted_adj)
- permuted_adj = permuted_adj.unsqueeze(0)
- return permuted_adj
-
-
- def get_threshold(adj_input, h_decode):
- h_decode = F.sigmoid(h_decode)
- adj_output = torch.zeros(adj_input.size()[0], model.max_num_nodes, model.max_num_nodes)
- adj_output[:, torch.triu(torch.ones(model.max_num_nodes, model.max_num_nodes)) == 1] = h_decode.data.cpu()
- adj_input_matrix = adj_input.cpu().data.squeeze().numpy()
- adj_input_triu_without_diag = adj_input_matrix - np.tril(adj_input_matrix)
- adj_output_matrix = adj_output.cpu().data.squeeze().numpy()
- adj_output_triu_without_diag = adj_output_matrix - np.tril(adj_output_matrix)
- number_of_edges = np.count_nonzero(adj_input_triu_without_diag)
- adj_output_triu_without_diag = adj_output_triu_without_diag.ravel()
- sorted_array = adj_output_triu_without_diag[np.argsort(adj_output_triu_without_diag)]
- threshold = sorted_array[-number_of_edges - 1]
- return threshold
-
-
- def test(graphs, model, args, max_num_nodes, completion=False, graph_show_mode=False):
- ged = gm.GraphEditDistance(1, 1, 1, 1) # all edit costs are equal to 1
- fname = args.model_save_path + "GraphVAE" + str(100) + '.dat'
- model.load_state_dict(torch.load(fname))
- model.eval()
- dataset = GraphAdjSampler(graphs, max_num_nodes, args.permutation_mode, args.bfs_mode,
- args.bfs_mode_with_arbitrary_node_deleted,
- features=prog_args.feature_type)
- dataset_loader = torch.utils.data.DataLoader(
- dataset,
- batch_size=prog_args.batch_size,
- num_workers=prog_args.num_workers)
- mae_list = []
- ged_list = []
- for batch_idx, data in enumerate(dataset_loader):
- input_features = data['features'].float()
- adj_input = data['adj'].float()
- if args.permutation_mode:
- adj_input = matrix_permute(adj_input, args)
- if my_args.GRAN:
- numpy_adj = adj_input.cpu().numpy()
- numpy_adj[:, :, max_num_nodes - 1] = 1
- numpy_adj[:, max_num_nodes - 1, :] = 1
- incomplete_adj = torch.tensor(numpy_adj, device='cuda:0')
- input_features = torch.tensor(input_features, device='cuda:0')
- x = model.conv1(input_features, incomplete_adj)
- x = model.bn1(x)
- x = model.act(x)
- x = model.conv2(x, incomplete_adj)
- x = model.bn2(x)
- x = model.act(x)
- print("*************************")
- network_result = model.gran(x).squeeze().cpu()
- print(network_result.size())
- x2 = network_result.unsqueeze(0)
- print(x2.size())
- x3 = x2.unsqueeze(0)
- print(x3.size())
- print(F.sigmoid(network_result))
- sample = sample_sigmoid(x3.cpu(), sample=True, thresh=0.5, sample_time=10)
- sample = sample.squeeze(0)
- return
- if my_args.completion_mode_small_parameter_size:
- numpy_adj = adj_input.cpu().numpy()
- number_of_incomplete_nodes = model.max_num_nodes - my_args.number_of_missing_nodes
- incomplete_adj = torch.tensor(
- numpy_adj[:, :number_of_incomplete_nodes, :number_of_incomplete_nodes],
- device='cuda:0')
- input_features = torch.tensor(
- input_features[:, :number_of_incomplete_nodes, :number_of_incomplete_nodes],
- device='cuda:0')
- x = model.conv1(input_features, incomplete_adj)
- x = model.bn1(x)
- x = model.act(x)
- x = model.conv2(x, incomplete_adj)
- x = model.bn2(x)
- x = x.view(-1, number_of_incomplete_nodes * model.hidden_dim)
- elif completion:
- graph_size = adj_input.size()[1]
- numpy_adj = adj_input.cpu().numpy()
- index_list = []
- for i in range(my_args.number_of_missing_nodes):
- random_index = np.random.randint(graph_size)
- while random_index in index_list:
- random_index = np.random.randint(graph_size)
- index_list.append(random_index)
- numpy_adj[:, :, random_index] = 0
- numpy_adj[:, random_index, :] = 0
- adj_input = torch.tensor(numpy_adj, device='cuda:0')
- input_features = Variable(input_features).cuda()
- adj_input = Variable(adj_input).cuda()
- x = model.conv1(input_features, adj_input)
- x = model.bn1(x)
- x = model.act(x)
- x = model.conv2(x, adj_input)
- x = model.bn2(x)
- x = x.view(-1, model.max_num_nodes * model.hidden_dim)
- h_decode, z_mu, z_lsgms = model.vae(x)
- x2 = h_decode[0].unsqueeze(0)
- x3 = x2.unsqueeze(0)
- if my_args.completion_mode_small_parameter_size: # temporary
- sample = sample_sigmoid(x3, sample=True, thresh=0.5, sample_time=10)
- sample = sample.squeeze(0)
- else:
- thresh = get_threshold(adj_input, h_decode)
- sample = sample_sigmoid(x3, sample=False, thresh=thresh, sample_time=10)
-
- if my_args.completion_mode_small_parameter_size:
- y = torch.zeros(adj_input.size()[0], model.max_num_nodes, model.max_num_nodes)
- y[:, torch.triu(torch.ones(model.max_num_nodes, model.max_num_nodes)) == 1] = \
- adj_input[:, torch.triu(torch.ones(model.max_num_nodes, model.max_num_nodes)) == 1]
- adj_vectorized_index = 0
- for i in range(my_args.number_of_missing_nodes): # iterates over columns
- for j in range(model.max_num_nodes - i):
- # print("********************************")
- # print(y[:, model.max_num_nodes - j - i - 1,
- # model.max_num_nodes - i - 1].size())
- # print(sample[:, i + j].size())
- # print(sample.size())
- y[:, model.max_num_nodes - j - i - 1,
- model.max_num_nodes - i - 1] = sample[:, adj_vectorized_index]
- adj_vectorized_index += 1
- else:
- y = torch.zeros(model.max_num_nodes, model.max_num_nodes)
- y[torch.triu(torch.ones(model.max_num_nodes, model.max_num_nodes)) == 1] = sample.data.cpu()
- colors = [(0.7509279299037631, 0.021203049355839054, 0.24561203044115132)]
- input_graph = nx.from_numpy_matrix(adj_input[0].data.cpu().numpy())
- largest_connected_component_input = max(nx.connected_component_subgraphs(input_graph), key=len)
- # mae = mean_absolute_error(np.triu(adj_input[0].data.cpu().numpy()), np.triu(y.data.cpu().numpy()))
- # print("@@@@ mae: " + str(mae))
- # mae_list.append(mae)
- output_graph = nx.from_numpy_matrix(y.data.cpu().numpy()[0])
- largest_connected_component_output = max(nx.connected_component_subgraphs(output_graph), key=len)
- result = ged.compare([input_graph, output_graph], None)
- ged_value = result[0][1]
- ged_list.append(ged_value)
- if graph_show_mode:
- # graph_show(input_graph, "input_graph", colors)
- graph_show(nx.from_numpy_matrix(incomplete_adj[0].data.cpu().numpy()), "input_graph", colors)
- graph_show(output_graph, "output_graph", colors)
- print("@@@@ ged: " + str(ged_value))
- if graph_show_mode:
- print(ged_list)
- print("**************** GED: " + str(np.mean(ged_list)))
- return np.mean(ged_list)
-
-
- def create_graph():
- graphs = []
- # for i in range(2, 6):
- # for j in range(2, 6):
- # graphs.append(nx.grid_2d_graph(i, j))
- # graphs.append(nx.grid_2d_graph(2, 3))
- graphs.append(nx.grid_2d_graph(1, 12))
- graphs.append(nx.grid_2d_graph(2, 6))
- graphs.append(nx.grid_2d_graph(3, 4))
- graphs.append(nx.grid_2d_graph(4, 3))
- graphs.append(nx.grid_2d_graph(6, 2))
- graphs.append(nx.grid_2d_graph(12, 1))
- return graphs
-
-
- if __name__ == '__main__':
- prog_args = arg_parse()
- my_args = GraphVAE_Args()
- graphs = create_graph()
- max_num_nodes = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
- model = build_model(prog_args, max_num_nodes).cuda()
- number_of_repeats = 1
- ged_list = []
- for i in range(number_of_repeats):
- print(i)
- ged_list.append(test(graphs, model, my_args, max_num_nodes, True, True))
- print(" GED = " + str(np.mean(ged_list)) + " after " + str(number_of_repeats) + " repeats")
|