You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphvae_t.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import networkx as nx
  2. import numpy as np
  3. import torch
  4. from sklearn.metrics import mean_absolute_error
  5. import sys
  6. import torch.nn.functional as F
  7. from baselines.graphvae.graphvae_train import build_model
  8. from baselines.graphvae.graphvae_train import arg_parse
  9. from baselines.graphvae.args import GraphVAE_Args
  10. from baselines.graphvae.graphvae_model import GraphVAE
  11. from baselines.graphvae.graphvae_data import GraphAdjSampler
  12. from baselines.graphvae.graphvae_model import graph_show
  13. from torch.autograd import Variable
  14. from data import bfs_seq
  15. from model import sample_sigmoid
  16. import gmatch4py as gm
  17. def matrix_permute(adj, args):
  18. adj = adj.numpy()[0]
  19. if args.bfs_mode_with_arbitrary_node_deleted:
  20. x_idx = np.random.permutation(adj.shape[0])
  21. adj = adj[np.ix_(x_idx, x_idx)]
  22. adj = np.asmatrix(adj)
  23. random_idx_for_delete = np.random.randint(adj.shape[0])
  24. deleted_node = adj[:, random_idx_for_delete].copy()
  25. for i in range(deleted_node.__len__()):
  26. if i >= random_idx_for_delete and i < deleted_node.__len__() - 1:
  27. deleted_node[i] = deleted_node[i + 1]
  28. elif i == deleted_node.__len__() - 1:
  29. deleted_node[i] = 0
  30. adj[:, random_idx_for_delete:adj.shape[0] - 1] = adj[:, random_idx_for_delete + 1:adj.shape[0]]
  31. adj[random_idx_for_delete:adj.shape[0] - 1, :] = adj[random_idx_for_delete + 1:adj.shape[0], :]
  32. adj = np.delete(adj, -1, axis=1)
  33. adj = np.delete(adj, -1, axis=0)
  34. G = nx.from_numpy_matrix(adj)
  35. # then do bfs in the permuted G
  36. degree_arr = np.sum(adj, axis=0)
  37. start_idx = np.argmax(degree_arr)
  38. # start_idx = np.random.randint(adj.shape[0])
  39. x_idx = np.array(bfs_seq(G, start_idx))
  40. adj = adj[np.ix_(x_idx, x_idx)]
  41. x_idx = np.insert(x_idx, x_idx.size, x_idx.size)
  42. deleted_node = deleted_node[np.ix_(x_idx)]
  43. adj = np.append(adj, deleted_node[:-1], axis=1)
  44. deleted_node = deleted_node.reshape(1, -1)
  45. adj = np.vstack([adj, deleted_node])
  46. #
  47. permuted_adj = adj
  48. permuted_adj = np.asmatrix(permuted_adj)
  49. permuted_adj = torch.from_numpy(permuted_adj)
  50. permuted_adj = permuted_adj.unsqueeze(0)
  51. else:
  52. start_idx = np.random.randint(adj.shape[0])
  53. x_idx = np.array(bfs_seq(nx.from_numpy_array(adj), start_idx))
  54. permuted_adj = adj[np.ix_(x_idx, x_idx)]
  55. permuted_adj = np.asmatrix(permuted_adj)
  56. permuted_adj = torch.from_numpy(permuted_adj)
  57. permuted_adj = permuted_adj.unsqueeze(0)
  58. return permuted_adj
  59. def get_threshold(adj_input, h_decode):
  60. h_decode = F.sigmoid(h_decode)
  61. adj_output = torch.zeros(adj_input.size()[0], model.max_num_nodes, model.max_num_nodes)
  62. adj_output[:, torch.triu(torch.ones(model.max_num_nodes, model.max_num_nodes)) == 1] = h_decode.data.cpu()
  63. adj_input_matrix = adj_input.cpu().data.squeeze().numpy()
  64. adj_input_triu_without_diag = adj_input_matrix - np.tril(adj_input_matrix)
  65. adj_output_matrix = adj_output.cpu().data.squeeze().numpy()
  66. adj_output_triu_without_diag = adj_output_matrix - np.tril(adj_output_matrix)
  67. number_of_edges = np.count_nonzero(adj_input_triu_without_diag)
  68. adj_output_triu_without_diag = adj_output_triu_without_diag.ravel()
  69. sorted_array = adj_output_triu_without_diag[np.argsort(adj_output_triu_without_diag)]
  70. threshold = sorted_array[-number_of_edges - 1]
  71. return threshold
  72. def test(graphs, model, args, max_num_nodes, completion=False, graph_show_mode=False):
  73. ged = gm.GraphEditDistance(1, 1, 1, 1) # all edit costs are equal to 1
  74. fname = args.model_save_path + "GraphVAE" + str(100) + '.dat'
  75. model.load_state_dict(torch.load(fname))
  76. model.eval()
  77. dataset = GraphAdjSampler(graphs, max_num_nodes, args.permutation_mode, args.bfs_mode,
  78. args.bfs_mode_with_arbitrary_node_deleted,
  79. features=prog_args.feature_type)
  80. dataset_loader = torch.utils.data.DataLoader(
  81. dataset,
  82. batch_size=prog_args.batch_size,
  83. num_workers=prog_args.num_workers)
  84. mae_list = []
  85. ged_list = []
  86. for batch_idx, data in enumerate(dataset_loader):
  87. input_features = data['features'].float()
  88. adj_input = data['adj'].float()
  89. if args.permutation_mode:
  90. adj_input = matrix_permute(adj_input, args)
  91. if my_args.GRAN:
  92. numpy_adj = adj_input.cpu().numpy()
  93. numpy_adj[:, :, max_num_nodes - 1] = 1
  94. numpy_adj[:, max_num_nodes - 1, :] = 1
  95. incomplete_adj = torch.tensor(numpy_adj, device='cuda:0')
  96. input_features = torch.tensor(input_features, device='cuda:0')
  97. x = model.conv1(input_features, incomplete_adj)
  98. x = model.bn1(x)
  99. x = model.act(x)
  100. x = model.conv2(x, incomplete_adj)
  101. x = model.bn2(x)
  102. x = model.act(x)
  103. print("*************************")
  104. network_result = model.gran(x).squeeze().cpu()
  105. print(network_result.size())
  106. x2 = network_result.unsqueeze(0)
  107. print(x2.size())
  108. x3 = x2.unsqueeze(0)
  109. print(x3.size())
  110. print(F.sigmoid(network_result))
  111. sample = sample_sigmoid(x3.cpu(), sample=True, thresh=0.5, sample_time=10)
  112. sample = sample.squeeze(0)
  113. return
  114. if my_args.completion_mode_small_parameter_size:
  115. numpy_adj = adj_input.cpu().numpy()
  116. number_of_incomplete_nodes = model.max_num_nodes - my_args.number_of_missing_nodes
  117. incomplete_adj = torch.tensor(
  118. numpy_adj[:, :number_of_incomplete_nodes, :number_of_incomplete_nodes],
  119. device='cuda:0')
  120. input_features = torch.tensor(
  121. input_features[:, :number_of_incomplete_nodes, :number_of_incomplete_nodes],
  122. device='cuda:0')
  123. x = model.conv1(input_features, incomplete_adj)
  124. x = model.bn1(x)
  125. x = model.act(x)
  126. x = model.conv2(x, incomplete_adj)
  127. x = model.bn2(x)
  128. x = x.view(-1, number_of_incomplete_nodes * model.hidden_dim)
  129. elif completion:
  130. graph_size = adj_input.size()[1]
  131. numpy_adj = adj_input.cpu().numpy()
  132. index_list = []
  133. for i in range(my_args.number_of_missing_nodes):
  134. random_index = np.random.randint(graph_size)
  135. while random_index in index_list:
  136. random_index = np.random.randint(graph_size)
  137. index_list.append(random_index)
  138. numpy_adj[:, :, random_index] = 0
  139. numpy_adj[:, random_index, :] = 0
  140. adj_input = torch.tensor(numpy_adj, device='cuda:0')
  141. input_features = Variable(input_features).cuda()
  142. adj_input = Variable(adj_input).cuda()
  143. x = model.conv1(input_features, adj_input)
  144. x = model.bn1(x)
  145. x = model.act(x)
  146. x = model.conv2(x, adj_input)
  147. x = model.bn2(x)
  148. x = x.view(-1, model.max_num_nodes * model.hidden_dim)
  149. h_decode, z_mu, z_lsgms = model.vae(x)
  150. x2 = h_decode[0].unsqueeze(0)
  151. x3 = x2.unsqueeze(0)
  152. if my_args.completion_mode_small_parameter_size: # temporary
  153. sample = sample_sigmoid(x3, sample=True, thresh=0.5, sample_time=10)
  154. sample = sample.squeeze(0)
  155. else:
  156. thresh = get_threshold(adj_input, h_decode)
  157. sample = sample_sigmoid(x3, sample=False, thresh=thresh, sample_time=10)
  158. if my_args.completion_mode_small_parameter_size:
  159. y = torch.zeros(adj_input.size()[0], model.max_num_nodes, model.max_num_nodes)
  160. y[:, torch.triu(torch.ones(model.max_num_nodes, model.max_num_nodes)) == 1] = \
  161. adj_input[:, torch.triu(torch.ones(model.max_num_nodes, model.max_num_nodes)) == 1]
  162. adj_vectorized_index = 0
  163. for i in range(my_args.number_of_missing_nodes): # iterates over columns
  164. for j in range(model.max_num_nodes - i):
  165. # print("********************************")
  166. # print(y[:, model.max_num_nodes - j - i - 1,
  167. # model.max_num_nodes - i - 1].size())
  168. # print(sample[:, i + j].size())
  169. # print(sample.size())
  170. y[:, model.max_num_nodes - j - i - 1,
  171. model.max_num_nodes - i - 1] = sample[:, adj_vectorized_index]
  172. adj_vectorized_index += 1
  173. else:
  174. y = torch.zeros(model.max_num_nodes, model.max_num_nodes)
  175. y[torch.triu(torch.ones(model.max_num_nodes, model.max_num_nodes)) == 1] = sample.data.cpu()
  176. colors = [(0.7509279299037631, 0.021203049355839054, 0.24561203044115132)]
  177. input_graph = nx.from_numpy_matrix(adj_input[0].data.cpu().numpy())
  178. largest_connected_component_input = max(nx.connected_component_subgraphs(input_graph), key=len)
  179. # mae = mean_absolute_error(np.triu(adj_input[0].data.cpu().numpy()), np.triu(y.data.cpu().numpy()))
  180. # print("@@@@ mae: " + str(mae))
  181. # mae_list.append(mae)
  182. output_graph = nx.from_numpy_matrix(y.data.cpu().numpy()[0])
  183. largest_connected_component_output = max(nx.connected_component_subgraphs(output_graph), key=len)
  184. result = ged.compare([input_graph, output_graph], None)
  185. ged_value = result[0][1]
  186. ged_list.append(ged_value)
  187. if graph_show_mode:
  188. # graph_show(input_graph, "input_graph", colors)
  189. graph_show(nx.from_numpy_matrix(incomplete_adj[0].data.cpu().numpy()), "input_graph", colors)
  190. graph_show(output_graph, "output_graph", colors)
  191. print("@@@@ ged: " + str(ged_value))
  192. if graph_show_mode:
  193. print(ged_list)
  194. print("**************** GED: " + str(np.mean(ged_list)))
  195. return np.mean(ged_list)
  196. def create_graph():
  197. graphs = []
  198. # for i in range(2, 6):
  199. # for j in range(2, 6):
  200. # graphs.append(nx.grid_2d_graph(i, j))
  201. # graphs.append(nx.grid_2d_graph(2, 3))
  202. graphs.append(nx.grid_2d_graph(1, 12))
  203. graphs.append(nx.grid_2d_graph(2, 6))
  204. graphs.append(nx.grid_2d_graph(3, 4))
  205. graphs.append(nx.grid_2d_graph(4, 3))
  206. graphs.append(nx.grid_2d_graph(6, 2))
  207. graphs.append(nx.grid_2d_graph(12, 1))
  208. return graphs
  209. if __name__ == '__main__':
  210. prog_args = arg_parse()
  211. my_args = GraphVAE_Args()
  212. graphs = create_graph()
  213. max_num_nodes = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
  214. model = build_model(prog_args, max_num_nodes).cuda()
  215. number_of_repeats = 1
  216. ged_list = []
  217. for i in range(number_of_repeats):
  218. print(i)
  219. ged_list.append(test(graphs, model, my_args, max_num_nodes, True, True))
  220. print(" GED = " + str(np.mean(ged_list)) + " after " + str(number_of_repeats) + " repeats")