You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

VAE_evaluate.py 12KB


  1. from statistics import mean
  2. from main_baselines.graphvae.Main_VAE_Args import Main_VAE_Args
  3. from train import *
  4. from sklearn.metrics import mean_absolute_error, roc_auc_score, average_precision_score
  5. from baselines.graphvae.graphvae_train import graph_statistics
  6. from baselines.graphvae.util import *
  7. from main_baselines.graphvae.graphvae_train import build_model
  8. from main_baselines.graphvae.graphvae_train import arg_parse
  9. def upper_triangular_to_symmetric_converter(A):
  10. for i in range(A.shape[0]):
  11. for j in range(A.shape[1]):
  12. A[j, i] = A[i, j]
  13. return A
  14. def vector_to_matrix_converter(input_vector, max_num_nodes):
  15. # first it converts input vector to an upper triangular matrix, then the result will be casted to a symmetric matrix
  16. tri = np.zeros((max_num_nodes, max_num_nodes))
  17. tri[np.triu_indices(max_num_nodes)] = input_vector
  18. return upper_triangular_to_symmetric_converter(tri)
  19. def getitem(graph, max_num_nodes):
  20. adj = nx.to_numpy_matrix(graph)
  21. num_nodes = adj.shape[0]
  22. adj_padded = np.zeros((max_num_nodes, max_num_nodes))
  23. adj_padded[:num_nodes, :num_nodes] = adj
  24. return {'adj': adj_padded,
  25. 'num_nodes': num_nodes,
  26. 'features': np.identity(max_num_nodes)}
  27. def test(model, input_features, adj):
  28. # cpu_numpy = adj.squeeze(0).cpu().numpy()
  29. # print("********** 0)")
  30. # print(cpu_numpy)
  31. #
  32. # graph_show(nx.from_numpy_matrix(cpu_numpy), "1")
  33. x = model.conv1(input_features, adj)
  34. x = model.act(x)
  35. x = model.bn1(x)
  36. x = model.conv2(x, adj)
  37. # x = self.bn2(x)
  38. # pool over all nodes
  39. # graph_h = self.pool_graph(x)
  40. graph_h = x.view(-1, model.max_num_nodes * model.hidden_dim)
  41. # vaemax_num_nodes
  42. h_decode, z_mu, z_lsgms = model.vae(graph_h)
  43. out = F.sigmoid(h_decode)
  44. out_tensor = out.cpu().data
  45. print("*** h_decode")
  46. print(h_decode)
  47. print("*** out_tensor")
  48. print(out_tensor)
  49. recon_adj_lower = model.recover_adj_lower(out_tensor)
  50. recon_adj_tensor = model.recover_full_adj_from_lower(recon_adj_lower)
  51. # set matching features be degree
  52. out_features = torch.sum(recon_adj_tensor, 1)
  53. adj_data = adj.cpu().data
  54. adj_features = torch.sum(adj_data, 1)
  55. S = model.edge_similarity_matrix(adj_data, recon_adj_tensor, adj_features, out_features,
  56. model.deg_feature_similarity)
  57. # initialization strategies
  58. init_corr = 1 / model.max_num_nodes
  59. init_assignment = torch.ones(model.max_num_nodes, model.max_num_nodes) * init_corr
  60. assignment = model.mpm(init_assignment, S)
  61. # matching
  62. # use negative of the assignment score since the alg finds min cost flow
  63. row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy())
  64. # order row index according to col index
  65. adj_permuted = model.permute_adj(adj_data, row_ind, col_ind)
  66. adj_vectorized = adj_permuted[torch.triu(torch.ones(model.max_num_nodes, model.max_num_nodes)) == 1].squeeze_()
  67. adj_vectorized_var = Variable(adj_vectorized).cuda()
  68. adj_recon_loss = model.adj_recon_loss(adj_vectorized_var, out[0])
  69. loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp())
  70. loss_kl /= model.max_num_nodes * model.max_num_nodes # normalize
  71. # print('kl: ', loss_kl)
  72. loss = adj_recon_loss + loss_kl
  73. return loss
  74. def VAE_test(model, input_features, adj_with_missing_node, adj_input):
  75. # 2 layer GCN
  76. model.eval()
  77. x = model.conv1(input_features, adj_with_missing_node)
  78. x = model.act(x)
  79. x = model.bn1(x)
  80. x = model.conv2(x, adj_with_missing_node)
  81. graph_h = x.view(-1, model.max_num_nodes * model.hidden_dim)
  82. # vae
  83. h_decode, z_mu, z_lsgms = model.vae(graph_h)
  84. predicted_vector = sample_sigmoid(h_decode.unsqueeze(0).unsqueeze(0), sample=False, thresh=0.5, sample_time=10)
  85. predicted_matrix = vector_to_matrix_converter(predicted_vector.cpu(), model.max_num_nodes)
  86. out = torch.sigmoid(h_decode)
  87. out_tensor = out.cpu().data
  88. recon_adj_lower = model.recover_adj_lower(out_tensor)
  89. recon_adj_tensor = model.recover_full_adj_from_lower(recon_adj_lower)
  90. # print("*** h_decode")
  91. # print(h_decode)
  92. # print("*** out")
  93. # print(out)
  94. # print("*** out_tensor")
  95. # print(out_tensor)
  96. # print("*** recon_adj_tensor")
  97. # print(recon_adj_tensor)
  98. # print("*******************************************")
  99. # set matching features be degree
  100. out_features = torch.sum(recon_adj_tensor, 1)
  101. # adj_data = adj_with_missing_node.cpu().data
  102. adj_data = torch.from_numpy(adj_input)
  103. # print(adj_with_missing_node.size())
  104. # print(adj_data.size())
  105. adj_features = torch.sum(adj_data, 1)
  106. # print("*** adj_data")
  107. # print(adj_data)
  108. # print("*** recon_adj_tensor")
  109. # print(recon_adj_tensor)
  110. # print("*** adj_features")
  111. # print(adj_features)
  112. # print("*** out_features")
  113. # print(out_features)
  114. # print("*******************************************")
  115. S = model.edge_similarity_matrix(adj_data, recon_adj_tensor, adj_features, out_features,
  116. model.deg_feature_similarity)
  117. # print(S)
  118. # initialization strategies
  119. init_corr = 1 / model.max_num_nodes
  120. init_assignment = torch.ones(model.max_num_nodes, model.max_num_nodes) * init_corr
  121. assignment = model.mpm(init_assignment, S)
  122. # matching
  123. # use negative of the assignment score since the alg finds min cost flow
  124. # print(assignment)
  125. try:
  126. row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy())
  127. except ValueError:
  128. print("Oops! That was no valid number. Try again...")
  129. return predicted_matrix[2, :], predicted_matrix[2, :]
  130. # order row index according to col index
  131. adj_permuted = model.permute_adj(adj_data, row_ind, col_ind)
  132. main_adj_permuted = model.permute_adj(torch.from_numpy(adj_input).cuda().float(), row_ind, col_ind)
  133. # print("%%%%%%%%%%%%%%%")
  134. # print(adj_permuted)
  135. missing_node_index = np.where(~adj_permuted.numpy().any(axis=1))[0][0]
  136. # print(missing_node_index)
  137. # print(main_adj_permuted)
  138. return main_adj_permuted[missing_node_index, :].numpy(), predicted_matrix[missing_node_index, :]
  139. def evaluate(model, test_graph_list):
  140. model.eval()
  141. mae_list = []
  142. roc_score_list = []
  143. ap_score_list = []
  144. precision_list = []
  145. recall_list = []
  146. for graph in test_graph_list:
  147. data = getitem(graph, model.max_num_nodes)
  148. input_features = data['features']
  149. input_features = torch.tensor(input_features, device='cuda:0').unsqueeze(0).float()
  150. adj_input = data['adj']
  151. num_nodes = data['num_nodes']
  152. # print("#######################")
  153. # print(adj_input)
  154. # remove random node from adj_input
  155. adj_with_missing_node = adj_input.copy()
  156. random_index = np.random.randint(num_nodes)
  157. adj_with_missing_node[:, random_index] = 0
  158. adj_with_missing_node[random_index, :] = 0
  159. test(model, input_features, torch.from_numpy(adj_input).cuda().float())
  160. # label, result = VAE_test(model, input_features, torch.from_numpy(adj_with_missing_node).cuda().float(),
  161. # adj_input)
  162. # sample = sample_sigmoid(preds.unsqueeze(0).unsqueeze(0), sample=False, thresh=0.5, sample_time=10)
  163. # sample = sample.squeeze()
  164. mae = mean_absolute_error(label, result)
  165. mae_list.append(mae)
  166. print("**** label")
  167. print(label)
  168. print("**** result")
  169. print(result)
  170. # label = gran_lable.numpy()
  171. # result = sample.cpu().numpy()
  172. # print("*** label")
  173. # print(label)
  174. # print("*** result")
  175. # print(result)
  176. # print(sum(result))
  177. # print(num_nodes)
  178. part1 = label[result == 1]
  179. part2 = part1[part1 == 1]
  180. part3 = part1[part1 == 0]
  181. part4 = label[result == 0]
  182. part5 = part4[part4 == 1]
  183. tp = len(part2)
  184. fp = len(part3)
  185. fn = part5.sum()
  186. if tp + fp > 0:
  187. precision = tp / (tp + fp)
  188. else:
  189. precision = 0
  190. recall = tp / (tp + fn)
  191. # F_Measure = 2 * precision * recall / (precision + recall)
  192. precision_list.append(precision)
  193. recall_list.append(recall)
  194. # F_Measure_list.append(F_Measure)
  195. # tp_div_pos = part2.sum() / label.sum()
  196. # tp_div_pos_list.append(tp_div_pos)
  197. # result = preds.cpu().detach().numpy() ??????????????????
  198. positive = result[label == 1]
  199. if len(positive) <= len(list(result[label == 0])):
  200. negative = random.sample(list(result[label == 0]), len(positive))
  201. else:
  202. negative = result[label == 0]
  203. positive = random.sample(list(result[label == 1]), len(negative))
  204. preds_all = np.hstack([positive, negative])
  205. labels_all = np.hstack([np.ones(len(positive)), np.zeros(len(positive))])
  206. if len(labels_all) > 0:
  207. roc_score = roc_auc_score(labels_all, preds_all)
  208. ap_score = average_precision_score(labels_all, preds_all)
  209. roc_score_list.append(roc_score)
  210. ap_score_list.append(ap_score)
  211. # print("*********************************")
  212. # print(precision_list)
  213. # print(
  214. # "*** MAE - roc_score - ap_score - precision - recall - F_Measure : " + str(mean(mae_list)) + " _ "
  215. # + str(mean(roc_score_list)) + " _ " + str(mean(ap_score_list)) + " _ "
  216. # + str(mean(precision_list)) + " _ " + str(mean(recall_list)) + " _ "
  217. # + str(mean(F_Measure_list)) + " _ ")
  218. return mean(mae_list), mean(roc_score_list), mean(precision_list), mean(recall_list)
  219. if __name__ == '__main__':
  220. random.seed(123)
  221. torch.manual_seed(1234)
  222. prog_args = arg_parse()
  223. args = Main_VAE_Args()
  224. graphs = create_graphs.create(args)
  225. training_on_KronEM_data = False
  226. if training_on_KronEM_data:
  227. small_graphs = []
  228. for i in range(len(graphs)):
  229. if graphs[i].number_of_nodes() == 8 or graphs[i].number_of_nodes() == 16 or graphs[
  230. i].number_of_nodes() == 32 or \
  231. graphs[i].number_of_nodes() == 64 or graphs[i].number_of_nodes() == 128 or graphs[
  232. i].number_of_nodes() == 256:
  233. small_graphs.append(graphs[i])
  234. graphs = small_graphs
  235. else:
  236. if args.graph_type == 'IMDBBINARY' or args.graph_type == 'IMDBMULTI':
  237. small_graphs = []
  238. for i in range(len(graphs)):
  239. if graphs[i].number_of_nodes() < 13:
  240. small_graphs.append(graphs[i])
  241. graphs = small_graphs
  242. elif args.graph_type == 'COLLAB':
  243. small_graphs = []
  244. for i in range(len(graphs)):
  245. if graphs[i].number_of_nodes() < 52 and graphs[i].number_of_nodes() > 41:
  246. small_graphs.append(graphs[i])
  247. graphs = small_graphs
  248. max_num_nodes = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
  249. fname = args.model_save_path + "GraphVAE" + str(args.load_epoch) + '.dat'
  250. model = build_model(prog_args, max_num_nodes).cuda()
  251. # model.load_state_dict(torch.load(fname))
  252. graph_statistics(graphs)
  253. # split datasets
  254. random.seed(123)
  255. shuffle(graphs)
  256. graphs_len = len(graphs)
  257. graphs_test = graphs[int(0.8 * graphs_len):]
  258. graphs_train = graphs[0:int(0.8 * graphs_len)]
  259. print("**** test graph size : " + str(len(graphs_test)))
  260. graphs_train = graphs[0:int(0.8 * graphs_len)]
  261. graph_statistics(graphs_test)
  262. iteration = 1
  263. mae_list = []
  264. roc_score_list = []
  265. precision_list = []
  266. recall_list = []
  267. F_Measure_list = []
  268. arbitrary_node_deletion_flag = True
  269. for i in range(iteration):
  270. print("########################################################################## " + str(i))
  271. mae, roc_score, precision, recall = evaluate(model, graphs_train)
  272. F_Measure = 2 * precision * recall / (precision + recall)
  273. mae_list.append(mae)
  274. roc_score_list.append(roc_score)
  275. precision_list.append(precision)
  276. recall_list.append(recall)
  277. F_Measure_list.append(F_Measure)
  278. print(
  279. "In Test: *** MAE - roc_score - precision - recall - F_Measure : " + str(
  280. mean(mae_list)) + " _ "
  281. + str(mean(roc_score_list)) + " _ "
  282. + str(mean(precision_list)) + " _ " + str(mean(recall_list)) + " _ "
  283. + str(mean(F_Measure_list)))