You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_completion_test_without_training.py 9.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import torch
  2. import torch.nn.functional as F
  3. from torch.autograd import Variable
  4. from sklearn.metrics import mean_absolute_error
  5. import sys
  6. sys.path.append('../')
  7. from data import *
  8. from train import *
  9. from args import Args
  10. def test_mlp(x_batch, y_len_unsorted, epoch, args, rnn, output, test_batch_size=32, save_histogram=False,sample_time=1):
  11. rnn.hidden = rnn.init_hidden(1)
  12. rnn.eval()
  13. output.eval()
  14. # generate graphs
  15. max_num_node = int(args.max_num_node)
  16. y_pred = Variable(
  17. torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
  18. y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
  19. x_step = Variable(torch.ones(1, 1, args.max_prev_node)).cuda()
  20. incompleteness_ratio = 0.96
  21. for j in range(test_batch_size):
  22. incomplete_graph_size = int(int(y_len_unsorted.data[j]) * incompleteness_ratio)
  23. # print(y_len_unsorted.data[j])
  24. for i in range(y_len_unsorted.data[j]-1):
  25. h = rnn(x_step)
  26. y_pred_step = output(h)
  27. y_pred[j:j+1, i:i + 1, :] = F.sigmoid(y_pred_step)
  28. if (i<incomplete_graph_size):
  29. x_step = (x_batch[j:j+1, i+1:i+2, :]).cuda()
  30. else:
  31. x_step = sample_sigmoid(y_pred_step, sample=True, sample_time=sample_time)
  32. y_pred_long[j:j+1, i:i + 1, :] = x_step
  33. rnn.hidden = Variable(rnn.hidden.data).cuda()
  34. y_pred_long_data = y_pred_long.data.long()
  35. G_pred_list = []
  36. adj_true_list = []
  37. adj_pred_list = []
  38. for i in range(test_batch_size):
  39. adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
  40. adj_pred_list.append(adj_pred)
  41. adj_true_list.append(decode_adj(x_batch[i].cpu().numpy()))
  42. G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
  43. G_pred_list.append(G_pred)
  44. return G_pred_list, adj_true_list, adj_pred_list
  45. def save_graph(graph, name):
  46. adj_pred = decode_adj(graph.cpu().numpy())
  47. G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
  48. G = np.asarray(nx.to_numpy_matrix(G_pred))
  49. np.savetxt(name + '.txt', G, fmt='%d')
  50. def data_to_graph_converter(data):
  51. G_list = []
  52. for i in range(len(data)):
  53. x = data[i].numpy()
  54. x = x.astype(int)
  55. adj_pred = decode_adj(x)
  56. G = get_graph(adj_pred) # get a graph from zero-padded adj
  57. G_list.append(G)
  58. return G_list
  59. def get_incomplete_graph(x_batch, y_len_unsorted, incompleteness_ratio = 0.96):
  60. batch_size = len(x_batch)
  61. max_prev_node = len(x_batch[0][0])
  62. max_incomplete_num_node = int(int(max(y_len_unsorted)) * incompleteness_ratio)
  63. incomplete_graph = Variable(torch.zeros(batch_size, max_incomplete_num_node, max_prev_node))
  64. for i in range(len(y_len_unsorted)):
  65. incomplete_graph_size = int(int(y_len_unsorted.data[i])*incompleteness_ratio)
  66. incomplete_graph[i] = torch.cat((x_batch[i,:incomplete_graph_size],
  67. torch.zeros([max_incomplete_num_node - incomplete_graph_size,
  68. max_prev_node])), dim=0)
  69. return incomplete_graph
  70. if __name__ == '__main__':
  71. # All necessary arguments are defined in args.py
  72. args = Args()
  73. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
  74. print('CUDA', args.cuda)
  75. print('File name prefix', args.fname)
  76. # check if necessary directories exist
  77. if not os.path.isdir(args.model_save_path):
  78. os.makedirs(args.model_save_path)
  79. if not os.path.isdir(args.graph_save_path):
  80. os.makedirs(args.graph_save_path)
  81. if not os.path.isdir(args.figure_save_path):
  82. os.makedirs(args.figure_save_path)
  83. if not os.path.isdir(args.timing_save_path):
  84. os.makedirs(args.timing_save_path)
  85. if not os.path.isdir(args.figure_prediction_save_path):
  86. os.makedirs(args.figure_prediction_save_path)
  87. if not os.path.isdir(args.nll_save_path):
  88. os.makedirs(args.nll_save_path)
  89. graphs = create_graphs.create(args)
  90. # split datasets
  91. random.seed(123)
  92. shuffle(graphs)
  93. graphs_len = len(graphs)
  94. graphs_test = graphs[int(0.8 * graphs_len):]
  95. graphs_train = graphs[0:int(0.8 * graphs_len)]
  96. graphs_validate = graphs[0:int(0.2 * graphs_len)]
  97. graph_validate_len = 0
  98. for graph in graphs_validate:
  99. graph_validate_len += graph.number_of_nodes()
  100. graph_validate_len /= len(graphs_validate)
  101. print('graph_validate_len', graph_validate_len)
  102. graph_test_len = 0
  103. for graph in graphs_test:
  104. graph_test_len += graph.number_of_nodes()
  105. graph_test_len /= len(graphs_test)
  106. print('graph_test_len', graph_test_len)
  107. args.max_num_node = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
  108. max_num_edge = max([graphs[i].number_of_edges() for i in range(len(graphs))])
  109. min_num_edge = min([graphs[i].number_of_edges() for i in range(len(graphs))])
  110. # args.max_num_node = 2000
  111. # show graphs statistics
  112. print('total graph num: {}, training set: {}'.format(len(graphs), len(graphs_train)))
  113. print('max number node: {}'.format(args.max_num_node))
  114. print('max/min number edge: {}; {}'.format(max_num_edge, min_num_edge))
  115. print('max previous node: {}'.format(args.max_prev_node))
  116. # save ground truth graphs
  117. ## To get train and test set, after loading you need to manually slice
  118. save_graph_list(graphs, args.graph_save_path + args.fname_train + '0.dat')
  119. save_graph_list(graphs, args.graph_save_path + args.fname_test + '0.dat')
  120. print('train and test graphs saved at: ', args.graph_save_path + args.fname_test + '0.dat')
  121. if 'nobfs' in args.note:
  122. args.max_prev_node = args.max_num_node-1
  123. dataset = Graph_sequence_sampler_pytorch_nobfs_for_completion(graphs_train,
  124. max_num_node=args.max_num_node)
  125. # dataset = Graph_sequence_sampler_pytorch(graphs_train, max_prev_node=args.max_prev_node,
  126. # max_num_node=args.max_num_node)
  127. sample_strategy = torch.utils.data.sampler.WeightedRandomSampler([1.0 / len(dataset) for i in range(len(dataset))],
  128. num_samples=args.batch_size * args.batch_ratio,
  129. replacement=True)
  130. dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers,
  131. sampler=sample_strategy)
  132. rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
  133. hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
  134. has_output=False).cuda()
  135. output = MLP_plain(h_size=args.hidden_size_rnn, embedding_size=args.embedding_size_output,
  136. y_size=args.max_prev_node).cuda()
  137. name = "GraphRNN_MLP" + '_' + args.graph_type + '_' + str(args.num_layers) + '_' + str(args.hidden_size_rnn) + '_'
  138. fname = args.model_save_path + name + 'lstm_' + str(args.load_epoch) + '.dat'
  139. rnn.load_state_dict(torch.load(fname))
  140. fname = args.model_save_path + name + 'output_' + str(args.load_epoch) + '.dat'
  141. output.load_state_dict(torch.load(fname))
  142. # ******************************************************************
  143. rnn_for_graph_completion = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
  144. hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
  145. has_output=False).cuda()
  146. output_for_graph_completion = MLP_plain(h_size=args.hidden_size_rnn, embedding_size=args.embedding_size_output,
  147. y_size=args.max_prev_node).cuda()
  148. graph_completion_string = 'graph_completion_one_node_'
  149. # fname = args.model_save_path + args.fname + 'lstm_' + graph_completion_string + str(args.load_epoch) + '.dat'
  150. # rnn_for_graph_completion.load_state_dict(torch.load(fname))
  151. # fname = args.model_save_path + args.fname + 'output_' + graph_completion_string + str(args.load_epoch) + '.dat'
  152. # output_for_graph_completion.load_state_dict(torch.load(fname))
  153. # ******************************************************************
  154. args.lr = 0.00001
  155. epoch = args.load_epoch
  156. print('model loaded!, lr: {}'.format(args.lr))
  157. for batch_idx, data in enumerate(dataset_loader):
  158. if batch_idx==0:
  159. rnn.zero_grad()
  160. output.zero_grad()
  161. x_unsorted = data['x'].float()
  162. G = data_to_graph_converter(x_unsorted[:,1:,:])
  163. nx.write_gpickle(G, "main_graphs.dat")
  164. y_unsorted = data['y'].float()
  165. y_len_unsorted = data['len']
  166. # *********************************
  167. G = get_incomplete_graph(x_unsorted, y_len_unsorted)
  168. G = data_to_graph_converter(G[:, 1:, :])
  169. nx.write_gpickle(G, "incomplete_graphs.dat")
  170. # *********************************
  171. G_pred_step, adj_true_list, adj_pred_list = test_mlp(x_unsorted, y_len_unsorted, epoch, args, rnn, output)
  172. nx.write_gpickle(G_pred_step, "completed_graphs.dat")
  173. mae = np.sum(np.absolute((adj_pred_list[0].astype("float") - adj_true_list[0].astype("float"))))
  174. print("adj_true: ")
  175. print(adj_true_list[0])
  176. print("my err")
  177. print(mae)
  178. print(mean_absolute_error(adj_pred_list[0], adj_true_list[0]))
  179. print("adj_pred_list:")
  180. print(adj_pred_list[0])
  181. # *********************************
  182. # G_pred_step = test_mlp(x_unsorted, y_len_unsorted, epoch, args, rnn_for_graph_completion,
  183. # output_for_graph_completion)
  184. # nx.write_gpickle(G_pred_step, "completed_graphs_with_training.dat")