You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphvae_t.py 6.2KB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import networkx as nx
  2. import numpy as np
  3. import torch
  4. from sklearn.metrics import mean_absolute_error
  5. import sys
  6. import torch.nn.functional as F
  7. from baselines.graphvae.graphvae_train import build_model
  8. from baselines.graphvae.graphvae_train import arg_parse
  9. from baselines.graphvae.args import GraphVAE_Args
  10. from baselines.graphvae.graphvae_model import GraphVAE
  11. from baselines.graphvae.graphvae_data import GraphAdjSampler
  12. from baselines.graphvae.graphvae_model import graph_show
  13. from torch.autograd import Variable
  14. from data import bfs_seq
  15. from model import sample_sigmoid
  16. import gmatch4py as gm
  17. def matrix_permute(adj):
  18. adj = adj.numpy()[0]
  19. # x_idx = np.random.permutation(adj.shape[0])
  20. # permuted_adj = adj[np.ix_(x_idx, x_idx)]
  21. start_idx = np.random.randint(adj.shape[0])
  22. x_idx = np.array(bfs_seq(nx.from_numpy_array(adj), start_idx))
  23. permuted_adj = adj[np.ix_(x_idx, x_idx)]
  24. permuted_adj = np.asmatrix(permuted_adj)
  25. permuted_adj = torch.from_numpy(permuted_adj)
  26. permuted_adj = permuted_adj.unsqueeze(0)
  27. return permuted_adj
  28. def get_threshold(adj_input, h_decode):
  29. h_decode = F.sigmoid(h_decode)
  30. adj_output = torch.zeros(model.max_num_nodes, model.max_num_nodes)
  31. adj_output[torch.triu(torch.ones(model.max_num_nodes, model.max_num_nodes)) == 1] = h_decode.data.cpu()
  32. adj_input_matrix = adj_input.cpu().data.squeeze().numpy()
  33. adj_input_triu_without_diag = adj_input_matrix - np.tril(adj_input_matrix)
  34. adj_output_matrix = adj_output.cpu().data.squeeze().numpy()
  35. adj_output_triu_without_diag = adj_output_matrix - np.tril(adj_output_matrix)
  36. number_of_edges = np.count_nonzero(adj_input_triu_without_diag)
  37. adj_output_triu_without_diag = adj_output_triu_without_diag.ravel()
  38. sorted_array = adj_output_triu_without_diag[np.argsort(adj_output_triu_without_diag)]
  39. threshold = sorted_array[-number_of_edges-1]
  40. return threshold
  41. def test(graphs, model, args, max_num_nodes, completion=False, graph_show_mode=False):
  42. ged = gm.GraphEditDistance(1, 1, 1, 1) # all edit costs are equal to 1
  43. fname = args.model_save_path + "GraphVAE" + str(100) + '.dat'
  44. model.load_state_dict(torch.load(fname))
  45. model.eval()
  46. dataset = GraphAdjSampler(graphs, max_num_nodes,args.permutation_mode, args.bfs_mode, features=prog_args.feature_type)
  47. dataset_loader = torch.utils.data.DataLoader(
  48. dataset,
  49. batch_size=prog_args.batch_size,
  50. num_workers=prog_args.num_workers)
  51. mae_list = []
  52. ged_list = []
  53. for batch_idx, data in enumerate(dataset_loader):
  54. features = data['features'].float()
  55. adj_input = data['adj'].float()
  56. if args.permutation_mode:
  57. # print("**** Before")
  58. # print(adj_input)
  59. adj_input = matrix_permute(adj_input)
  60. # print("**** After")
  61. # print(adj_input)
  62. if completion:
  63. graph_size = adj_input.size()[1]
  64. numpy_adj = adj_input.cpu().numpy()[0]
  65. index_list = []
  66. for i in range(my_args.number_of_missing_nodes):
  67. random_index = np.random.randint(graph_size)
  68. while random_index in index_list:
  69. random_index = np.random.randint(graph_size)
  70. index_list.append(random_index)
  71. numpy_adj[:, random_index] = 0
  72. numpy_adj[random_index, :] = 0
  73. adj_input = torch.tensor(numpy_adj, device='cuda:0').unsqueeze(0)
  74. features = Variable(features).cuda()
  75. adj_input = Variable(adj_input).cuda()
  76. x = model.conv1(features, adj_input)
  77. x = model.bn1(x)
  78. x = model.act(x)
  79. x = model.conv2(x, adj_input)
  80. x = model.bn2(x)
  81. x = x.view(-1, model.max_num_nodes * model.hidden_dim)
  82. h_decode, z_mu, z_lsgms = model.vae(x)
  83. x2 = h_decode[0].unsqueeze(0)
  84. x3 = x2.unsqueeze(0)
  85. thresh = get_threshold(adj_input, h_decode)
  86. sample = sample_sigmoid(x3, sample=False, thresh=thresh, sample_time=10)
  87. y = torch.zeros(model.max_num_nodes, model.max_num_nodes)
  88. y[torch.triu(torch.ones(model.max_num_nodes, model.max_num_nodes)) == 1] = sample.data.cpu()
  89. colors = [(0.7509279299037631, 0.021203049355839054, 0.24561203044115132)]
  90. input_graph = nx.from_numpy_matrix(adj_input[0].data.cpu().numpy())
  91. largest_connected_component_input = max(nx.connected_component_subgraphs(input_graph), key=len)
  92. mae = mean_absolute_error(np.triu(adj_input[0].data.cpu().numpy()), np.triu(y.data.cpu().numpy()))
  93. # print("@@@@ mae: " + str(mae))
  94. mae_list.append(mae)
  95. output_graph = nx.from_numpy_matrix(y.data.cpu().numpy())
  96. largest_connected_component_output = max(nx.connected_component_subgraphs(output_graph), key=len)
  97. result = ged.compare([input_graph, output_graph], None)
  98. ged_value = result[0][1]
  99. ged_list.append(ged_value)
  100. if graph_show_mode:
  101. graph_show(input_graph, "input_graph", colors)
  102. graph_show(output_graph, "output_graph", colors)
  103. print("@@@@ ged: " + str(ged_value))
  104. # break
  105. # print(result)
  106. if graph_show_mode:
  107. print(ged_list)
  108. print("**************** GED: " + str(np.mean(ged_list)))
  109. return np.mean(ged_list)
  110. def create_graph():
  111. graphs = []
  112. # for i in range(2, 6):
  113. # for j in range(2, 6):
  114. # graphs.append(nx.grid_2d_graph(i, j))
  115. # graphs.append(nx.grid_2d_graph(2, 3))
  116. graphs.append(nx.grid_2d_graph(1, 12))
  117. graphs.append(nx.grid_2d_graph(2, 6))
  118. graphs.append(nx.grid_2d_graph(3, 4))
  119. graphs.append(nx.grid_2d_graph(4, 3))
  120. graphs.append(nx.grid_2d_graph(6, 2))
  121. graphs.append(nx.grid_2d_graph(12, 1))
  122. return graphs
  123. if __name__ == '__main__':
  124. prog_args = arg_parse()
  125. my_args = GraphVAE_Args()
  126. graphs = create_graph()
  127. max_num_nodes = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
  128. model = build_model(prog_args, max_num_nodes).cuda()
  129. number_of_repeats = 100
  130. ged_list = []
  131. for i in range(number_of_repeats):
  132. ged_list.append(test(graphs, model, my_args, max_num_nodes, True, False))
  133. print(" GED = " + str(np.mean(ged_list)) + " after " + str(number_of_repeats) + " repeats")