You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluate2.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. from statistics import mean
  2. from train import *
  3. from sklearn.metrics import mean_absolute_error, roc_auc_score, average_precision_score
  4. from baselines.graphvae.graphvae_train import graph_statistics
  5. from baselines.graphvae.util import *
  6. def getitem(graph, max_prev_node, arbitrary_node_deletion_flag):
  7. # print("*** input graph size = " + str(graph.number_of_nodes()))
  8. input_graph_size = graph.number_of_nodes()
  9. adj = nx.to_numpy_matrix(graph)
  10. if arbitrary_node_deletion_flag:
  11. adj = move_random_node_to_the_last_index(adj)
  12. adj_copy = adj.copy()
  13. len = adj_copy.shape[0]
  14. x = np.zeros((graph.number_of_nodes() - 1, max_prev_node)) # here zeros are padded for small graph
  15. x[0, :] = 1 # the first input token is all ones
  16. y = np.zeros((graph.number_of_nodes() - 1, max_prev_node)) # here zeros are padded for small graph
  17. # print("%%%%%%%%% in getitem")
  18. # print(adj_copy)
  19. column_vector = adj_copy[:, adj_copy.shape[0] - 1]
  20. # print(column_vector)
  21. # print("%%% end")
  22. incomplete_adj = adj_copy.copy()
  23. incomplete_adj = incomplete_adj[:, :incomplete_adj.shape[0] - 1]
  24. incomplete_adj = incomplete_adj[:incomplete_adj.shape[0] - 1, :]
  25. x_idx = np.random.permutation(incomplete_adj.shape[0])
  26. x_idx_prime = np.concatenate((x_idx, [adj.shape[0] - 1]), axis=0)
  27. column_vector = column_vector[np.ix_(x_idx_prime)]
  28. incomplete_adj = incomplete_adj[np.ix_(x_idx, x_idx)]
  29. #
  30. incomplete_matrix = np.asmatrix(incomplete_adj)
  31. G = nx.from_numpy_matrix(incomplete_matrix)
  32. # then do bfs in the permuted G
  33. start_idx = np.random.randint(incomplete_adj.shape[0])
  34. x_idx = np.array(bfs_seq(G, start_idx))
  35. incomplete_adj = incomplete_adj[np.ix_(x_idx, x_idx)]
  36. # print("*** graph size after BFS = " + str(incomplete_adj.shape[0]))
  37. # print("##########################################################################")
  38. graph_size_after_BFS = incomplete_adj.shape[0]
  39. if graph_size_after_BFS != input_graph_size - 1:
  40. return
  41. adj_encoded = encode_adj(incomplete_adj.copy(), max_prev_node=max_prev_node)
  42. #
  43. x_idx_prime = np.concatenate((x_idx, [adj.shape[0] - 1]), axis=0)
  44. column_vector = column_vector[np.ix_(x_idx_prime)]
  45. y[0:adj_encoded.shape[0], :] = adj_encoded
  46. x[1:adj_encoded.shape[0] + 1, :] = adj_encoded
  47. # print("*** incomplete")
  48. # print(incomplete_adj)
  49. return {'x': x, 'y': y, 'len': len - 1, 'label': column_vector}
  50. def evaluate(test_graph_list, args, arbitrary_node_deletion_flag):
  51. if 'GraphRNN_RNN' in args.note:
  52. rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
  53. hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
  54. has_output=True, output_size=args.hidden_size_rnn_output).cuda()
  55. output = GRU_plain(input_size=1, embedding_size=args.embedding_size_rnn_output,
  56. hidden_size=args.hidden_size_rnn_output, num_layers=args.num_layers, has_input=True,
  57. has_output=True, output_size=1).cuda()
  58. elif 'GraphRNN_MLP' in args.note:
  59. rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
  60. hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
  61. has_output=False).cuda()
  62. output = MLP_plain(h_size=args.hidden_size_rnn, embedding_size=args.embedding_size_output,
  63. y_size=args.max_prev_node).cuda()
  64. fname = args.model_save_path + args.fname_GraphRNN + 'lstm_' + str(args.load_epoch) + '.dat'
  65. print("*** fname : " + fname)
  66. rnn.load_state_dict(torch.load(fname))
  67. fname = args.model_save_path + args.fname_GraphRNN + 'output_' + str(args.load_epoch) + '.dat'
  68. output.load_state_dict(torch.load(fname))
  69. rnn.hidden = rnn.init_hidden(1)
  70. rnn.eval()
  71. output.eval()
  72. mae_list = []
  73. roc_score_list = []
  74. ap_score_list = []
  75. precision_list = []
  76. recall_list = []
  77. number_of_removed_data = 0
  78. for graph in test_graph_list:
  79. data = getitem(graph, args.max_prev_node, arbitrary_node_deletion_flag)
  80. if data == None:
  81. number_of_removed_data += 1
  82. continue
  83. x_unsorted = data['x']
  84. y_unsorted = data['y']
  85. y_len_unsorted = data['len']
  86. x_step = Variable(torch.ones(1, 1, args.max_prev_node)).cuda()
  87. for i in range(y_len_unsorted):
  88. h = rnn(x_step)
  89. if (i < y_len_unsorted - 1):
  90. x_step = (torch.Tensor(x_unsorted[i + 1:i + 2, :]).unsqueeze(0)).cuda()
  91. else: # the last step prediction
  92. if 'GraphRNN_RNN' in args.note:
  93. hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(2))).cuda()
  94. output.hidden = torch.cat((h.permute(1, 0, 2), hidden_null),
  95. dim=0) # num_layers, batch_size, hidden_size
  96. x_step = Variable(torch.zeros(1, 1, args.max_prev_node)).cuda()
  97. output_x_step = Variable(torch.ones(1, 1, 1)).cuda()
  98. y_pred = x_step
  99. for j in range(min(args.max_prev_node, i + 1)):
  100. output_y_pred_step = output(output_x_step)
  101. output_x_step = sample_sigmoid(output_y_pred_step, sample=True, sample_time=1)
  102. y_pred[:, :, j:j + 1] = F.sigmoid(output_y_pred_step)
  103. x_step[:, :, j:j + 1] = output_x_step
  104. output.hidden = Variable(output.hidden.data).cuda()
  105. result = x_step.detach()
  106. elif 'GraphRNN_MLP' in args.note:
  107. y_pred_step = output(h)
  108. y_pred = F.sigmoid(y_pred_step)
  109. result = sample_sigmoid(y_pred_step, sample=True, sample_time=1)
  110. # result = sample_sigmoid(y_pred_step, sample=False, thresh=0.5, sample_time=10)
  111. rnn.hidden = Variable(rnn.hidden.data).cuda()
  112. # print("*** before")
  113. # print(y_unsorted)
  114. result = result.squeeze(0).cpu().numpy()
  115. y_pred = y_pred.squeeze(0).cpu().detach().numpy()
  116. y_unsorted = np.concatenate((y_unsorted[:y_unsorted.shape[0] - 1, :], result), axis=0)
  117. probabilistic_y_unsorted = np.concatenate((y_unsorted[:y_unsorted.shape[0] - 1, :], y_pred), axis=0)
  118. # print("*** result")
  119. # print(result)
  120. # print("*** after")
  121. # print(y_unsorted)
  122. # print(probabilistic_y_unsorted)
  123. adj_decoded = decode_adj(y_unsorted)
  124. adj_decoded_prob = decode_adj(probabilistic_y_unsorted)
  125. # print(adj_decoded.shape)
  126. # print("*** decoded")
  127. # print(adj_decoded)
  128. # print(adj_decoded_prob)
  129. result = adj_decoded[:, adj_decoded.shape[0] - 1]
  130. y_pred = adj_decoded_prob[:, adj_decoded_prob.shape[0] - 1]
  131. label = np.transpose(data['label'])
  132. label = np.asarray(label)[0, :]
  133. # print("*** result")
  134. #
  135. # print(np.shape(result))
  136. # print(result)
  137. # print("*** label")
  138. # print(label)
  139. mae = mean_absolute_error(label, result)
  140. mae_list.append(mae)
  141. # label = gran_lable.numpy()
  142. # result = sample.cpu().numpy()
  143. part1 = label[result == 1]
  144. part2 = part1[part1 == 1]
  145. part3 = part1[part1 == 0]
  146. part4 = label[result == 0]
  147. part5 = part4[part4 == 1]
  148. tp = len(part2)
  149. fp = len(part3)
  150. fn = part5.sum()
  151. if tp + fp > 0:
  152. precision = tp / (tp + fp)
  153. else:
  154. precision = 0
  155. recall = tp / (tp + fn)
  156. precision_list.append(precision)
  157. recall_list.append(recall)
  158. # result = y_pred.cpu().detach().numpy()
  159. # print(y_pred)
  160. positive = result[label == 1]
  161. if len(positive) <= len(list(result[label == 0])):
  162. negative = random.sample(list(result[label == 0]), len(positive))
  163. else:
  164. negative = result[label == 0]
  165. positive = random.sample(list(result[label == 1]), len(negative))
  166. preds_all = np.hstack([positive, negative])
  167. labels_all = np.hstack([np.ones(len(positive)), np.zeros(len(positive))])
  168. if len(labels_all) > 0:
  169. roc_score = roc_auc_score(labels_all, preds_all)
  170. ap_score = average_precision_score(labels_all, preds_all)
  171. roc_score_list.append(roc_score)
  172. print("************ list length : " + str(len(precision_list)))
  173. return mean(mae_list), mean(roc_score_list), mean(precision_list), mean(recall_list), number_of_removed_data
  174. if __name__ == '__main__':
  175. # graphs_test = []
  176. # graph = nx.grid_2d_graph(2, 3)
  177. # graphs_test.append(graph)
  178. # # graph = nx.grid_2d_graph(2, 4)
  179. # graphs_test.append(graph)
  180. print("salam")
  181. args = Args()
  182. args.max_prev_node = 40
  183. graphs = create_graphs.create(args)
  184. training_on_KronEM_data = True
  185. if training_on_KronEM_data:
  186. small_graphs = []
  187. for i in range(len(graphs)):
  188. if graphs[i].number_of_nodes() == 8 or graphs[i].number_of_nodes() == 16 or graphs[
  189. i].number_of_nodes() == 32 or \
  190. graphs[i].number_of_nodes() == 64 or graphs[i].number_of_nodes() == 128 or graphs[
  191. i].number_of_nodes() == 256:
  192. small_graphs.append(graphs[i])
  193. graphs = small_graphs
  194. else:
  195. if args.graph_type == 'IMDBBINARY' or args.graph_type == 'IMDBMULTI':
  196. small_graphs = []
  197. for i in range(len(graphs)):
  198. if graphs[i].number_of_nodes() < 41:
  199. small_graphs.append(graphs[i])
  200. graphs = small_graphs
  201. elif args.graph_type == 'COLLAB':
  202. small_graphs = []
  203. for i in range(len(graphs)):
  204. if graphs[i].number_of_nodes() < 52 and graphs[i].number_of_nodes() > 41:
  205. small_graphs.append(graphs[i])
  206. graphs = small_graphs
  207. graph_statistics(graphs)
  208. # split datasets
  209. random.seed(123)
  210. shuffle(graphs)
  211. graphs_len = len(graphs)
  212. graphs_test = graphs[int(0.8 * graphs_len):]
  213. print("**** test graph size : " + str(len(graphs_test)))
  214. graphs_train = graphs[0:int(0.8 * graphs_len)]
  215. graph_statistics(graphs_test)
  216. iteration = 20
  217. mae_list = []
  218. roc_score_list = []
  219. precision_list = []
  220. recall_list = []
  221. F_Measure_list = []
  222. number_of_removed_data_list = []
  223. arbitrary_node_deletion_flag = True
  224. for i in range(iteration):
  225. print("########################################################################## " + str(i))
  226. mae, roc_score, precision, recall, number_of_removed_data = evaluate(graphs_test, args,
  227. arbitrary_node_deletion_flag)
  228. F_Measure = 2 * precision * recall / (precision + recall)
  229. mae_list.append(mae)
  230. roc_score_list.append(roc_score)
  231. precision_list.append(precision)
  232. recall_list.append(recall)
  233. F_Measure_list.append(F_Measure)
  234. number_of_removed_data_list.append(number_of_removed_data)
  235. print("Mean number of removed data : " + str(mean(number_of_removed_data_list)))
  236. print(
  237. "In Test: *** MAE - roc_score - precision - recall - F_Measure : " + str(
  238. mean(mae_list)) + " _ "
  239. + str(mean(roc_score_list)) + " _ "
  240. + str(mean(precision_list)) + " _ " + str(mean(recall_list)) + " _ "
  241. + str(mean(F_Measure_list)))