You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_completion_test.py 8.5KB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. from sklearn.metrics import mean_absolute_error
  2. import sys
  3. sys.path.append('../')
  4. from train import *
  5. from args import Args
  6. from GraphCompletion.graph_completion_with_training import graph_show
  7. from GraphCompletion.graph_show import graph_save
  8. def test_completion(x_batch, y_len_unsorted, args, rnn, output, test_batch_size=32, sample_time=1):
  9. rnn.hidden = rnn.init_hidden(1)
  10. rnn.eval()
  11. output.eval()
  12. # generate graphs
  13. max_num_node = int(args.max_num_node)
  14. y_pred = Variable(
  15. torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
  16. y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
  17. x_step = Variable(torch.ones(1, 1, args.max_prev_node)).cuda()
  18. # number_of_missing_nodes = args.number_of_missing_nodes
  19. number_of_missing_nodes = 2
  20. for j in range(test_batch_size):
  21. incomplete_graph_size = y_len_unsorted.data[j] - number_of_missing_nodes
  22. for i in range(y_len_unsorted.data[j]-1):
  23. h = rnn(x_step)
  24. y_pred_step = output(h)
  25. y_pred[j:j+1, i:i + 1, :] = F.sigmoid(y_pred_step)
  26. if(i<incomplete_graph_size):
  27. x_step = (x_batch[j:j+1, i+1:i+2, :]).cuda()
  28. else:
  29. x_step = sample_sigmoid(y_pred_step, sample=True, sample_time=sample_time)
  30. y_pred_long[j:j+1, i:i + 1, :] = x_step
  31. rnn.hidden = Variable(rnn.hidden.data).cuda()
  32. y_pred_long_data = y_pred_long.data.long()
  33. adj_true_list = []
  34. graph_true_list = []
  35. adj_incomplete_list = []
  36. graph_incomplete_list = []
  37. adj_pred_list = []
  38. graph_pred_list = []
  39. for i in range(test_batch_size):
  40. adj_true = decode_adj(x_batch[i, 1:y_len_unsorted.data[i], :].cpu().numpy())
  41. adj_true_list.append(adj_true)
  42. graph_true_list.append(nx.from_numpy_matrix(adj_true))
  43. adj_incomplete = decode_adj(x_batch[i, 1:y_len_unsorted.data[i] - number_of_missing_nodes, :].cpu().numpy())
  44. adj_incomplete_list.append(adj_incomplete)
  45. graph_incomplete_list.append(nx.from_numpy_matrix(adj_incomplete))
  46. adj_pred = decode_adj(y_pred_long_data[i, 0:y_len_unsorted.data[i]-1, :].cpu().numpy())
  47. adj_pred_list.append(adj_pred)
  48. graph_pred_list.append(nx.from_numpy_matrix(adj_pred))
  49. return adj_true_list, graph_true_list, adj_incomplete_list, graph_incomplete_list, adj_pred_list, graph_pred_list
  50. if __name__ == '__main__':
  51. # All necessary arguments are defined in args.py
  52. args = Args()
  53. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
  54. # check if necessary directories exist
  55. if not os.path.isdir(args.model_save_path):
  56. os.makedirs(args.model_save_path)
  57. if not os.path.isdir(args.graph_save_path):
  58. os.makedirs(args.graph_save_path)
  59. if not os.path.isdir(args.figure_save_path):
  60. os.makedirs(args.figure_save_path)
  61. if not os.path.isdir(args.timing_save_path):
  62. os.makedirs(args.timing_save_path)
  63. if not os.path.isdir(args.figure_prediction_save_path):
  64. os.makedirs(args.figure_prediction_save_path)
  65. if not os.path.isdir(args.nll_save_path):
  66. os.makedirs(args.nll_save_path)
  67. graphs = create_graphs.create(args)
  68. # split datasets
  69. random.seed(123)
  70. shuffle(graphs)
  71. graphs_len = len(graphs)
  72. args.max_num_node = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
  73. max_num_edge = max([graphs[i].number_of_edges() for i in range(len(graphs))])
  74. min_num_edge = min([graphs[i].number_of_edges() for i in range(len(graphs))])
  75. # show graphs statistics
  76. print('total graph num: {}'.format(len(graphs)))
  77. print('max number node: {}'.format(args.max_num_node))
  78. print('max/min number edge: {}; {}'.format(max_num_edge, min_num_edge))
  79. print('max previous node: {}'.format(args.max_prev_node))
  80. if 'nobfs' in args.note:
  81. args.max_prev_node = args.max_num_node-1
  82. dataset = Graph_sequence_sampler_pytorch_nobfs_for_completion(graphs,
  83. max_num_node=args.max_num_node)
  84. sample_strategy = torch.utils.data.sampler.WeightedRandomSampler([1.0 / len(dataset) for i in range(len(dataset))],
  85. num_samples=args.batch_size * args.batch_ratio,
  86. replacement=True)
  87. dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers,
  88. sampler=sample_strategy)
  89. rnn_GraphRNN = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
  90. hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
  91. has_output=False).cuda()
  92. output_GraphRNN = MLP_plain(h_size=args.hidden_size_rnn, embedding_size=args.embedding_size_output,
  93. y_size=args.max_prev_node).cuda()
  94. fname = args.model_save_path + args.fname_GraphRNN + 'lstm_' + str(args.load_epoch) + '.dat'
  95. rnn_GraphRNN.load_state_dict(torch.load(fname))
  96. fname = args.model_save_path + args.fname_GraphRNN + 'output_' + str(args.load_epoch) + '.dat'
  97. output_GraphRNN.load_state_dict(torch.load(fname))
  98. # ******************************************************************
  99. rnn_GraphCompletion = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
  100. hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
  101. has_output=False).cuda()
  102. output_GraphCompletion = MLP_plain(h_size=args.hidden_size_rnn, embedding_size=args.embedding_size_output,
  103. y_size=args.max_prev_node).cuda()
  104. fname = args.model_save_path + args.fname + 'lstm_' + args.graph_completion_string + str(args.epochs) + '.dat'
  105. rnn_GraphCompletion.load_state_dict(torch.load(fname))
  106. fname = args.model_save_path + args.fname + 'output_' + args.graph_completion_string + str(args.epochs) + '.dat'
  107. output_GraphCompletion.load_state_dict(torch.load(fname))
  108. # ******************************************************************
  109. MAE_GraphRNN = []
  110. MAE_GraphCompletion = []
  111. for batch_idx, data in enumerate(dataset_loader):
  112. if True:
  113. x_unsorted = data['x'].float()
  114. y_unsorted = data['y'].float()
  115. y_len_unsorted = data['len']
  116. # *********************************
  117. adj_true_list, graph_true_list, adj_incomplete_list, graph_incomplete_list, adj_pred_list, graph_pred_list\
  118. = test_completion(x_unsorted, y_len_unsorted, args, rnn_GraphRNN, output_GraphRNN)
  119. adj_true_list, graph_true_list, adj_incomplete_list, graph_incomplete_list, adj_pred_list_completion_model\
  120. , graph_pred_list_completion_model \
  121. = test_completion(x_unsorted, y_len_unsorted, args, rnn_GraphCompletion, output_GraphCompletion)
  122. mae = np.sum(np.absolute((adj_pred_list[0].astype("float") - adj_true_list[0].astype("float"))))
  123. # print("adj_true: ")
  124. # print(adj_true_list[0])
  125. # graph_show(nx.from_numpy_matrix(adj_true_list[0]),"adj_true" )
  126. # graph_show(nx.from_numpy_matrix(adj_incomplete_list[0]), "adj_incomplete")
  127. # print("my error")
  128. # print(mae)
  129. # print(mean_absolute_error(adj_pred_list[0], adj_true_list[0]))
  130. # print("adj_pred_list:")
  131. # print(adj_pred_list[0])
  132. # graph_show(nx.from_numpy_matrix(adj_pred_list[0]), "pred_true")
  133. for i in range(len(graph_true_list)):
  134. # graph_save(graph_true_list[i], graph_incomplete_list[i],
  135. # graph_pred_list[i], graph_pred_list_completion_model[i], i, args.graph_save_path)
  136. mae = mean_absolute_error(adj_pred_list[i], adj_true_list[i])
  137. MAE_GraphRNN.append(mae)
  138. mae = mean_absolute_error(adj_pred_list_completion_model[i], adj_true_list[i])
  139. MAE_GraphCompletion.append(mae)
  140. # *********************************
  141. # print(MAE_GraphCompletion)
  142. # G_pred_step = test_mlp(x_unsorted, y_len_unsorted, epoch, args, rnn_for_graph_completion,
  143. # output_for_graph_completion)
  144. # nx.write_gpickle(G_pred_step, "completed_graphs_with_training.dat")
  145. print("MAE_GraphRNN:")
  146. print(np.mean(MAE_GraphRNN))
  147. # print(MAE_GraphRNN)
  148. print("MAE_GraphCompletion:")
  149. print(np.mean(MAE_GraphCompletion))