You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_completion_with_training.py 30KB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675
  1. import torch
  2. import torch.nn.functional as F
  3. from torch.autograd import Variable
  4. import sys
  5. sys.path.append('../')
  6. from train import *
  7. from args import Args
  8. from torch import optim
  9. from torch.optim.lr_scheduler import MultiStepLR
  10. from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
  11. from torch.nn.init import *
  12. import matplotlib.pyplot as plt
  13. import networkx as nx
  14. from tensorboard_logger import configure, log_value
  15. from baselines.graphvae.graphvae_model import GraphVAE
  16. import itertools
  17. device = torch.device("cuda:0")
  18. def plot_grad_flow(named_parameters):
  19. ave_grads = []
  20. layers = []
  21. for n, p in named_parameters:
  22. if(p.requires_grad) and ("bias" not in n):
  23. layers.append(n)
  24. print(p.grad)
  25. print("**********************************")
  26. print(p.grad.abs())
  27. print("**********************************")
  28. print(p.grad.abs().mean())
  29. print("**********************************")
  30. ave_grads.append(p.grad.abs().mean())
  31. plt.plot(ave_grads, alpha=0.3, color="b")
  32. plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k" )
  33. plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
  34. plt.xlim(xmin=0, xmax=len(ave_grads))
  35. plt.xlabel("Layers")
  36. plt.ylabel("average gradient")
  37. plt.title("Gradient flow")
  38. plt.grid(True)
  39. plt.show()
  40. def graph_show(G, title):
  41. pos = nx.spring_layout(G, scale=2)
  42. nx.draw(G, pos, font_size=8)
  43. fig = plt.gcf()
  44. fig.canvas.set_window_title(title)
  45. plt.show()
  46. def decode_adj(adj_output):
  47. '''
  48. recover to adj from adj_output
  49. note: here adj_output have shape (n-1)*m
  50. '''
  51. max_prev_node = adj_output.shape[1]
  52. adj = np.zeros((adj_output.shape[0], adj_output.shape[0]))
  53. for i in range(adj_output.shape[0]):
  54. input_start = max(0, i - max_prev_node + 1)
  55. input_end = i + 1
  56. output_start = max_prev_node + max(0, i - max_prev_node + 1) - (i + 1)
  57. output_end = max_prev_node
  58. adj[i, input_start:input_end] = adj_output[i,::-1][output_start:output_end] # reverse order
  59. adj_full = np.zeros((adj_output.shape[0]+1, adj_output.shape[0]+1))
  60. n = adj_full.shape[0]
  61. adj_full[1:n, 0:n-1] = np.tril(adj, 0)
  62. adj_full = adj_full + adj_full.T
  63. return adj_full
  64. def tensor_to_graph_converter(tensor, title):
  65. adj_matrix = decode_adj(tensor.numpy())
  66. G = nx.from_numpy_matrix(adj_matrix)
  67. graph_show(G, title)
  68. return
  69. def get_indices(number_of_missing_nodes, row_size):
  70. # print("$$$$$$$$")
  71. # print(row_size)
  72. row_index = []
  73. column_index = []
  74. start_index = 0
  75. for i in range(number_of_missing_nodes):
  76. for j in range(row_size):
  77. row_index.append(i)
  78. column_index.append(j + start_index)
  79. start_index+=1
  80. return row_index, column_index
  81. def GraphRNN_loss2(true_adj_batch, predicted_adj_batch, y_pred, y, batch_size, number_of_missing_nodes):
  82. row_list = list(range(number_of_missing_nodes))
  83. permutation_indices = list(itertools.permutations(row_list))
  84. loss_sum = 0
  85. for data_index in range(batch_size):
  86. row_size = len(true_adj_batch[data_index])-number_of_missing_nodes
  87. row_index, column_index = get_indices(number_of_missing_nodes, row_size)
  88. # print("true_adj_batch: ")
  89. # print(true_adj_batch[data_index])
  90. # print("y: ")
  91. # print(y[data_index])
  92. # print("predicted_adj_batch: ")
  93. # print(predicted_adj_batch[data_index])
  94. # print("y_pred:")
  95. # print(y_pred[data_index]+0)
  96. # print("selected_y_pred:")
  97. # print(y_pred[data_index][row_index, column_index])
  98. # print("selected_y:")
  99. # print(y[data_index][row_index, column_index])
  100. loss_list = []
  101. for row_order in permutation_indices:
  102. permuted_y_pred = y_pred[data_index][list(row_order)]
  103. loss_list.append(F.binary_cross_entropy(permuted_y_pred[row_index, column_index],
  104. y[data_index][row_index, column_index]))
  105. loss_sum+=min(loss_list)
  106. loss_sum/=batch_size
  107. return loss_sum
  108. def GraphRNN_loss(true_adj_batch, predicted_adj_batch, batch_size, number_of_missing_nodes):
  109. # loss_sum = 0
  110. for data_index in range(batch_size):
  111. loss_sum= F.binary_cross_entropy(predicted_adj_batch[data_index],
  112. true_adj_batch[data_index])
  113. break
  114. # loss_sum/=batch_size
  115. # loss_sum.requires_grad = True
  116. return loss_sum
  117. def matrix_completion_loss(true_adj_batch, predicted_adj_batch, batch_size, number_of_missing_nodes):
  118. # graph_show(nx.from_numpy_matrix(true_adj[0]), "true_adj")
  119. # graph_show(nx.from_numpy_matrix(predicted_adj[0]), "predicted_adj")
  120. loss_sum=F.binary_cross_entropy(torch.from_numpy(predicted_adj_batch[0]),
  121. torch.from_numpy(true_adj_batch[0]))
  122. # loss_sum = 0
  123. # for data_index in range(batch_size):
  124. # # print("########## data_len: ")
  125. # # print(data_len[data_index].item())
  126. # # print("^^^^ data_index: " + str(data_index))
  127. # adj_matrix = true_adj_batch[data_index]
  128. # predicted_adj_matrix = predicted_adj_batch[data_index]
  129. # number_of_nodes = np.shape(adj_matrix)[1]
  130. # # loss_sum+=F.binary_cross_entropy(torch.from_numpy(predicted_adj_matrix),
  131. # # torch.from_numpy(adj_matrix))
  132. # sub_matrix = adj_matrix[number_of_nodes - number_of_missing_nodes:][:,
  133. # :number_of_nodes - number_of_missing_nodes]
  134. # predicted_sub_matrix = predicted_adj_matrix[number_of_nodes - number_of_missing_nodes:][:,
  135. # :number_of_nodes - number_of_missing_nodes]
  136. # min_permutation_loss = float("inf")
  137. # row_list = list(range(number_of_missing_nodes))
  138. # permutation_indices = list(itertools.permutations(row_list))
  139. # number_of_rows = np.shape(predicted_sub_matrix)[0]
  140. # number_of_columns = np.shape(predicted_sub_matrix)[1]
  141. # matrix_size = number_of_rows*number_of_columns
  142. # # print(str(number_of_rows) + " ### " + str(number_of_columns) + " @@@ " + str(matrix_size))
  143. #
  144. # for row_order in permutation_indices:
  145. # permuted_sub_matrix = sub_matrix[list(row_order)]
  146. # # print("***** permuted_sub_matrix ")
  147. # # print(permuted_sub_matrix)
  148. # # print("***** predicted_sub_matrix ")
  149. # # print(predicted_sub_matrix)
  150. # adj_recon_loss = F.binary_cross_entropy(torch.from_numpy(predicted_sub_matrix),
  151. # torch.from_numpy(permuted_sub_matrix))
  152. # # print(adj_recon_loss)
  153. # min_permutation_loss = np.minimum(min_permutation_loss, adj_recon_loss.item())
  154. # # print(adj_recon_loss)
  155. # loss_sum += min_permutation_loss/matrix_size
  156. # # print("^^^^^ ")
  157. # # print(loss_sum)
  158. # loss_sum = torch.autograd.Variable(torch.from_numpy(np.array(loss_sum)))
  159. # # print(loss_sum)
  160. # loss_sum.requires_grad = True
  161. # print(loss_sum)
  162. # print(loss_sum/batch_size)
  163. return loss_sum/batch_size
  164. def graph_matching_loss(true_adj_batch, predicted_adj_batch, data_len, batch_size):
  165. # graph_show(nx.from_numpy_matrix(true_adj[0]), "true_adj")
  166. # graph_show(nx.from_numpy_matrix(predicted_adj[0]), "predicted_adj")
  167. loss_sum = 0
  168. for data_index in range(batch_size):
  169. features = torch.zeros(data_len[data_index])
  170. number_of_nodes = data_len[data_index].item()
  171. # print("########## data_len: ")
  172. # print(data_len[data_index].item())
  173. if data_len[data_index].item()>14:
  174. continue
  175. graphvae = GraphVAE(1, 64, 256, data_len[data_index].item())
  176. # print("^^^^ data_index: " + str(data_index))
  177. current_true_adj = true_adj_batch[data_index]
  178. current_predicted_adj = predicted_adj_batch[data_index]
  179. current_features = features
  180. S = graphvae.edge_similarity_matrix(current_true_adj, current_predicted_adj, current_features, current_features,
  181. graphvae.deg_feature_similarity)
  182. init_corr = 1 / number_of_nodes
  183. init_assignment = torch.ones(number_of_nodes, number_of_nodes) * init_corr
  184. assignment = graphvae.mpm(init_assignment, S)
  185. row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy())
  186. permuted_adj = graphvae.permute_adj(torch.from_numpy(current_true_adj).float(), row_ind, col_ind)
  187. # print("***** permuted_adj ")
  188. # print(permuted_adj)
  189. # print("***** current_predicted_adj ")
  190. # print(torch.from_numpy(current_predicted_adj).float())
  191. adj_recon_loss = graphvae.adj_recon_loss(permuted_adj, torch.from_numpy(current_predicted_adj).float())
  192. # print(adj_recon_loss)
  193. loss_sum += adj_recon_loss.data
  194. loss_sum.requires_grad = True
  195. # print(loss_sum/batch_size)
  196. # print("^^^^^ ")
  197. # print(loss_sum)
  198. return loss_sum/batch_size
  199. def make_batch_of_matrices(initial_y, initial_y_len, y_train, y_pred, batch_size):
  200. true_adj_batch = []
  201. predicted_adj_batch = []
  202. for data_index in range(batch_size):
  203. a = initial_y[data_index][:initial_y_len[data_index] - 1].cpu()
  204. b = y_train[data_index]
  205. true_adj = torch.cat((a, b), dim=0)
  206. true_adj = decode_adj(true_adj.numpy())
  207. true_adj_tensor = torch.from_numpy(true_adj + np.identity(len(true_adj[0]))).float().to(device)
  208. true_adj_batch.append(true_adj_tensor)
  209. b = y_pred[data_index].data
  210. # print("bbbbbbbbbbbbbbbbbbbbbbb :")
  211. # print(b)
  212. # print("b.data :")
  213. # print(b.data)
  214. # print("**********************************************************************************")
  215. predicted_adj = torch.cat((a, b.cpu()), dim=0)
  216. predicted_adj = decode_adj(predicted_adj.numpy())
  217. predicted_adj_tensor = torch.from_numpy(predicted_adj + np.identity(len(predicted_adj[0]))).float().to(device)
  218. predicted_adj_tensor.requires_grad = True
  219. # predicted_adj_tensor = predicted_adj_tensor + 1
  220. # print("*********** predicted_adj_tensor *******")
  221. # print(predicted_adj_tensor+0)
  222. predicted_adj_batch.append(predicted_adj_tensor+0)
  223. return true_adj_batch, predicted_adj_batch
  224. def train_mlp_epoch_for_completion(epoch, args, rnn, output, rnn_for_initialization,
  225. output_for_initialization, data_loader,
  226. optimizer_rnn, optimizer_output,
  227. scheduler_rnn, scheduler_output, number_of_missing_nodes):
  228. rnn.train()
  229. output.train()
  230. loss_sum = 0
  231. for batch_idx, data in enumerate(data_loader):
  232. rnn.zero_grad()
  233. output.zero_grad()
  234. x_unsorted = data['x'].float()
  235. y_unsorted = data['y'].float()
  236. data_len = y_len_unsorted = data['len']
  237. # data_index = 0
  238. # b = y_unsorted[data_index][:y_len_unsorted[data_index] - 1].cpu()
  239. # tensor_to_graph_converter(b, "Complete Graph")
  240. # ****
  241. x_unsorted, y_unsorted, y_len_unsorted, hidden, initial_y, initial_y_len = \
  242. data_initializer_for_training(x_unsorted, y_unsorted, y_len_unsorted,
  243. args, rnn_for_initialization, output_for_initialization,
  244. number_of_missing_nodes, False, 0.96)
  245. # ____________________________________________________________________________________________
  246. # x_unsorted1, y_unsorted1, y_len_unsorted1, hidden, initial_y1, initial_y_len1 = \
  247. # data_initializer_for_training(x_unsorted, y_unsorted, y_len_unsorted,
  248. # args, rnn_for_initialization, output_for_initialization,
  249. # number_of_missing_nodes, False, 0.96)
  250. # ____________________________________________________________________________________________
  251. # a = initial_y[data_index][:initial_y_len[data_index]-1].cpu()
  252. # tensor_to_graph_converter(a, "Incomplete Graph")
  253. # b = y_unsorted[data_index][:-1]
  254. # recons = torch.cat((a, b), dim=0)
  255. # tensor_to_graph_converter(recons, "Merged Complete Graph")
  256. # ****
  257. y_len_max = max(y_len_unsorted)
  258. x_unsorted = x_unsorted[:, 0:y_len_max, :]
  259. y_unsorted = y_unsorted[:, 0:y_len_max, :]
  260. # initialize lstm hidden state according to batch size
  261. # rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))
  262. # ****
  263. rnn.hidden = hidden
  264. # ****
  265. # sort input
  266. y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True)
  267. y_len = y_len.numpy().tolist()
  268. x = torch.index_select(x_unsorted,0,sort_index)
  269. y = torch.index_select(y_unsorted,0,sort_index)
  270. x = Variable(x).cuda()
  271. y = Variable(y).cuda()
  272. h = rnn(x, pack=True, input_len=y_len)
  273. y_pred = output(h)
  274. y_pred = F.sigmoid(y_pred)
  275. # clean
  276. y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True)
  277. y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
  278. # use cross entropy loss
  279. # print("y_unsorted: ")
  280. # print(y_unsorted[0])
  281. # print("y: ")
  282. # print(y[0].cpu())
  283. true_adj_batch, predicted_adj_batch = make_batch_of_matrices(initial_y, initial_y_len, y.cpu(),y_pred, len(y_pred))
  284. # graph_show(nx.from_numpy_matrix(true_adj_batch[data_index]), "Complete2")
  285. # import time
  286. # start = time.time()
  287. # print("@@@@@@@@@@@@@@@@@@@@@@@@@")
  288. # print(args.loss)
  289. if args.loss == "graph_matching_loss":
  290. print(graph_matching_loss)
  291. # loss = graph_matching_loss(true_adj_batch, predicted_adj_batch, data_len, len(data_len))
  292. elif args.loss == "matrix_completion_loss":
  293. print(matrix_completion_loss)
  294. # loss = matrix_completion_loss(true_adj_batch, predicted_adj_batch, len(data_len),
  295. # args.number_of_missing_nodes)
  296. elif args.loss =="GraphRNN_loss":
  297. # print("GraphRNN_loss")
  298. # print(y)
  299. # print(y_pred)
  300. # loss = F.binary_cross_entropy(torch.from_numpy(predicted_adj_batch[0]), torch.from_numpy(true_adj_batch[0]))
  301. # loss.requires_grad = True
  302. # print("true_adj_batch: ")
  303. # print(true_adj_batch[0])
  304. # print("y: ")
  305. # print(y[0])
  306. # print("predicted_adj_batch: ")
  307. # print(predicted_adj_batch[0])
  308. # print("y_pred:")
  309. # print(y_pred[0]+0)
  310. # print("**************************************************************")
  311. # loss = GraphRNN_loss(true_adj_batch, predicted_adj_batch, len(data_len), args.number_of_missing_nodes);
  312. # loss = binary_cross_entropy_weight(predicted_adj_batch[0], true_adj_batch[0])
  313. # print("**** GraphRNN_loss: ")
  314. # print(loss)
  315. # row_index = [0, 2]
  316. # col_index = [1, 3]
  317. # print("**************************")
  318. # print(y_pred[0][row_index, col_index])
  319. # print(np.shape(y_pred[0])[1])
  320. loss = GraphRNN_loss2(true_adj_batch, predicted_adj_batch, y_pred, y, len(y_pred), args.number_of_missing_nodes)
  321. # loss = binary_cross_entropy_weight(y_pred[0][row_index, col_index], y[0][row_index, col_index])
  322. # print("**** main_loss: ")
  323. # print(loss)
  324. # print("*** epoch : " + str(epoch) + "*** loss: " + str(loss.data))
  325. # break
  326. # end = time.time()
  327. # print(end - start)
  328. # print("y_pred")
  329. # print(y_pred[0])
  330. # print("y")
  331. # print(y[0])
  332. # print("***************************************************")
  333. # loss = binary_cross_entropy_weight(y_pred, y)
  334. loss.backward()
  335. # update deterministic and lstm
  336. optimizer_output.step()
  337. optimizer_rnn.step()
  338. scheduler_output.step()
  339. scheduler_rnn.step()
  340. # for p in rnn.parameters():
  341. # print(p)
  342. # print(p.grad)
  343. # if p.grad is not None:
  344. # print("________________________________________________________")
  345. # print(p.grad.data)
  346. # plot_grad_flow(rnn.named_parameters())
  347. if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics
  348. print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
  349. epoch, args.epochs,loss.data, args.graph_type, args.num_layers, args.hidden_size_rnn))
  350. log_value(args.loss + args.fname + args.training_mode + "_ "
  351. "number_of_missing_nodes_" +
  352. str(number_of_missing_nodes), loss.data, epoch * args.batch_ratio + batch_idx)
  353. loss_sum += loss.data
  354. return loss_sum/(batch_idx+1)
  355. def train_for_completion(args, dataset_train, rnn, output, rnn_for_initialization,output_for_initialization):
  356. # check if load existing model
  357. if args.load:
  358. fname = args.model_save_path + args.fname + 'lstm_' + args.graph_completion_string + str(args.load_epoch) + '.dat'
  359. rnn.load_state_dict(torch.load(fname))
  360. fname = args.model_save_path + args.fname + 'output_' + args.graph_completion_string + str(args.load_epoch) + '.dat'
  361. output.load_state_dict(torch.load(fname))
  362. args.lr = 0.00001
  363. epoch = args.load_epoch
  364. print('model loaded!, lr: {}'.format(args.lr))
  365. else:
  366. epoch = 1
  367. # initialize optimizer
  368. optimizer_rnn = optim.Adam(list(rnn.parameters()), lr=args.lr)
  369. optimizer_output = optim.Adam(list(output.parameters()), lr=args.lr)
  370. scheduler_rnn = MultiStepLR(optimizer_rnn, milestones=args.milestones, gamma=args.lr_rate)
  371. scheduler_output = MultiStepLR(optimizer_output, milestones=args.milestones, gamma=args.lr_rate)
  372. # start main loop
  373. time_all = np.zeros(args.epochs)
  374. while epoch<=args.epochs:
  375. time_start = tm.time()
  376. # train
  377. if 'GraphRNN_VAE' in args.note:
  378. train_vae_epoch(epoch, args, rnn, output, dataset_train,
  379. optimizer_rnn, optimizer_output,
  380. scheduler_rnn, scheduler_output)
  381. elif 'GraphRNN_MLP' in args.note:
  382. train_mlp_epoch_for_completion(epoch, args, rnn, output, rnn_for_initialization,
  383. output_for_initialization,dataset_train,
  384. optimizer_rnn, optimizer_output,
  385. scheduler_rnn, scheduler_output, args.number_of_missing_nodes)
  386. elif 'GraphRNN_RNN' in args.note:
  387. train_rnn_epoch(epoch, args, rnn, output, dataset_train,
  388. optimizer_rnn, optimizer_output,
  389. scheduler_rnn, scheduler_output)
  390. time_end = tm.time()
  391. time_all[epoch - 1] = time_end - time_start
  392. # save model checkpoint
  393. if args.save:
  394. if epoch % args.epochs_save == 0:
  395. fname = args.model_save_path + args.fname + 'lstm_' + args.graph_completion_string + str(epoch) + '.dat'
  396. torch.save(rnn.state_dict(), fname)
  397. fname = args.model_save_path + args.fname + 'output_' + args.graph_completion_string + str(epoch) + '.dat'
  398. torch.save(output.state_dict(), fname)
  399. epoch += 1
  400. np.save(args.timing_save_path+args.fname,time_all)
  401. def hidden_initializer(x, rnn, output):
  402. test_batch_size = len(x)
  403. max_num_node = len(x[0])
  404. rnn.hidden = rnn.init_hidden(test_batch_size)
  405. rnn.eval()
  406. output.eval()
  407. for i in range(max_num_node):
  408. h = rnn((x[:, i:i+1, :]).cuda())
  409. rnn.hidden = Variable(rnn.hidden.data).cuda()
  410. return rnn.hidden
  411. def data_initializer_for_training(x_unsorted, y_unsorted, y_len_unsorted, args, rnn, output,
  412. number_of_missing_nodes, ratio_mode, incompleteness_ratio):
  413. # print(number_of_missing_nodes)
  414. batch_size = len(x_unsorted)
  415. max_prev_node = len(x_unsorted[0][0])
  416. max_number_of_nodes = int(max(y_len_unsorted))
  417. if ratio_mode:
  418. max_incomplete_num_node = int(max_number_of_nodes*incompleteness_ratio)
  419. else:
  420. max_incomplete_num_node = max_number_of_nodes-number_of_missing_nodes
  421. initial_x = Variable(torch.zeros(batch_size, max_incomplete_num_node, max_prev_node)).cuda()
  422. initial_y = Variable(torch.zeros(batch_size, max_incomplete_num_node, max_prev_node)).cuda()
  423. initial_y_len = Variable(torch.zeros(len(y_len_unsorted))).type(torch.LongTensor)
  424. for i in range(len(y_len_unsorted)):
  425. if ratio_mode:
  426. incomplete_graph_size = int(int(y_len_unsorted.data[i]) * incompleteness_ratio)
  427. else:
  428. incomplete_graph_size = int(y_len_unsorted.data[i]) - number_of_missing_nodes
  429. initial_y_len[i] = incomplete_graph_size
  430. initial_x[i] = torch.cat((x_unsorted[i,:incomplete_graph_size],
  431. torch.zeros([max_incomplete_num_node - incomplete_graph_size, max_prev_node])), dim=0)
  432. initial_y[i] = torch.cat((y_unsorted[i,:incomplete_graph_size-1],
  433. torch.zeros([max_incomplete_num_node - incomplete_graph_size+1, max_prev_node])), dim=0)
  434. train_y_len = y_len_unsorted - initial_y_len
  435. max_train_num_node = int(max(train_y_len)) + 1
  436. temp_size_reduction = 1
  437. train_x = Variable(torch.ones(batch_size, max_train_num_node-temp_size_reduction, max_prev_node))
  438. train_y = Variable(torch.ones(batch_size, max_train_num_node-temp_size_reduction, max_prev_node))
  439. for i in range(len(train_y_len)):
  440. if ratio_mode:
  441. incomplete_graph_size = int(int(y_len_unsorted.data[i]) * incompleteness_ratio)
  442. else:
  443. incomplete_graph_size = int(y_len_unsorted.data[i]) - number_of_missing_nodes
  444. train_graph_size = int(train_y_len.data[i])
  445. train_x[i, 1-temp_size_reduction:,:] = torch.cat((x_unsorted[i,incomplete_graph_size:incomplete_graph_size + train_graph_size]
  446. , torch.zeros([max_train_num_node - train_graph_size-1, max_prev_node])), dim=0)
  447. train_y[i, :, :] = torch.cat((x_unsorted[i, incomplete_graph_size+temp_size_reduction:incomplete_graph_size + train_graph_size]
  448. , torch.zeros([max_train_num_node - train_graph_size, max_prev_node])), dim=0)
  449. hidden = hidden_initializer(initial_x, rnn, output)
  450. # just for testing
  451. # print(y_len_unsorted)
  452. # print(initial_y_len)
  453. # print(train_y_len)
  454. # print("Complete Data : 00000000000000000000000000000000000000000000000000")
  455. # print('x:')
  456. # print(x_unsorted[0].int())
  457. # print('y:')
  458. # print(y_unsorted[0].int())
  459. # print("Data for Initialization : 11111111111111111111111111111111111111111111111111")
  460. # print('x:')
  461. # print(initial_x[0].int())
  462. # print('y:')
  463. # print(initial_y[0].int())
  464. # print('Data for Training : 22222222222222222222222222222222222222222222222222')
  465. # print('x:')
  466. # print(train_x[0].int())
  467. # print('y:')
  468. # print(train_y[0].int())
  469. # print('33333333333333333333333333333333333333333333333333')
  470. # print(hidden.size())
  471. return train_x, train_y, train_y_len, hidden, initial_y, initial_y_len
  472. def save_graph(graph, name):
  473. adj_pred = decode_adj(graph.cpu().numpy())
  474. G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
  475. G = np.asarray(nx.to_numpy_matrix(G_pred))
  476. np.savetxt(name + '.txt', G, fmt='%d')
  477. def data_to_graph_converter(data):
  478. G_list = []
  479. for i in range(len(data)):
  480. x = data[i].numpy()
  481. x = x.astype(int)
  482. adj_pred = decode_adj(x)
  483. G = get_graph(adj_pred) # get a graph from zero-padded adj
  484. G_list.append(G)
  485. return G_list
  486. if __name__ == '__main__':
  487. # All necessary arguments are defined in args.py
  488. args = Args()
  489. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
  490. print('CUDA', args.cuda)
  491. print('File name prefix', args.fname)
  492. # check if necessary directories exist
  493. if not os.path.isdir(args.model_save_path):
  494. os.makedirs(args.model_save_path)
  495. if not os.path.isdir(args.graph_save_path):
  496. os.makedirs(args.graph_save_path)
  497. if not os.path.isdir(args.figure_save_path):
  498. os.makedirs(args.figure_save_path)
  499. if not os.path.isdir(args.timing_save_path):
  500. os.makedirs(args.timing_save_path)
  501. if not os.path.isdir(args.figure_prediction_save_path):
  502. os.makedirs(args.figure_prediction_save_path)
  503. if not os.path.isdir(args.nll_save_path):
  504. os.makedirs(args.nll_save_path)
  505. time = strftime("%Y-%m-%d %H:%M:%S", gmtime())
  506. # logging.basicConfig(filename='logs/train' + time + '.log', level=logging.DEBUG)
  507. if args.clean_tensorboard:
  508. if os.path.isdir("tensorboard"):
  509. shutil.rmtree("tensorboard")
  510. configure("tensorboard/run" + time, flush_secs=5)
  511. graphs = create_graphs.create(args)
  512. random.seed(123)
  513. shuffle(graphs)
  514. graphs_len = len(graphs)
  515. graphs_test = graphs[int(0.8 * graphs_len):]
  516. graphs_train = graphs[0:int(0.8 * graphs_len)]
  517. graphs_validate = graphs[0:int(0.2 * graphs_len)]
  518. graph_validate_len = 0
  519. for graph in graphs_validate:
  520. graph_validate_len += graph.number_of_nodes()
  521. graph_validate_len /= len(graphs_validate)
  522. print('graph_validate_len', graph_validate_len)
  523. graph_test_len = 0
  524. for graph in graphs_test:
  525. graph_test_len += graph.number_of_nodes()
  526. graph_test_len /= len(graphs_test)
  527. print('graph_test_len', graph_test_len)
  528. args.max_num_node = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
  529. max_num_edge = max([graphs[i].number_of_edges() for i in range(len(graphs))])
  530. min_num_edge = min([graphs[i].number_of_edges() for i in range(len(graphs))])
  531. # args.max_num_node = 2000
  532. # show graphs statistics
  533. print('total graph num: {}, training set: {}'.format(len(graphs),len(graphs_train)))
  534. print('max number node: {}'.format(args.max_num_node))
  535. print('max/min number edge: {}; {}'.format(max_num_edge,min_num_edge))
  536. print('max previous node: {}'.format(args.max_prev_node))
  537. # save ground truth graphs
  538. ## To get train and test set, after loading you need to manually slice
  539. save_graph_list(graphs, args.graph_save_path + args.fname_train + '0.dat')
  540. save_graph_list(graphs, args.graph_save_path + args.fname_test + '0.dat')
  541. print('train and test graphs saved at: ', args.graph_save_path + args.fname_test + '0.dat')
  542. if 'nobfs' in args.note:
  543. print('nobfs')
  544. args.max_prev_node = args.max_num_node - 1
  545. dataset = Graph_sequence_sampler_pytorch_nobfs_for_completion(graphs_train,
  546. max_num_node=args.max_num_node)
  547. else:
  548. dataset = Graph_sequence_sampler_pytorch(graphs_train,max_prev_node=args.max_prev_node,max_num_node=args.max_num_node)
  549. sample_strategy = torch.utils.data.sampler.WeightedRandomSampler([1.0 / len(dataset) for i in range(len(dataset))],
  550. num_samples=args.batch_size * args.batch_ratio, replacement=True)
  551. dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers,
  552. sampler=sample_strategy)
  553. rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
  554. hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
  555. has_output=False).cuda()
  556. output = MLP_plain(h_size=args.hidden_size_rnn, embedding_size=args.embedding_size_output,
  557. y_size=args.max_prev_node).cuda()
  558. rnn_for_initialization = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
  559. hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
  560. has_output=False).cuda()
  561. output_for_initialization = MLP_plain(h_size=args.hidden_size_rnn, embedding_size=args.embedding_size_output,
  562. y_size=args.max_prev_node).cuda()
  563. #
  564. fname = args.model_save_path + args.initial_fname + 'lstm_' + str(args.load_epoch) + '.dat'
  565. rnn_for_initialization.load_state_dict(torch.load(fname))
  566. rnn.load_state_dict(torch.load(fname))
  567. fname = args.model_save_path + args.initial_fname + 'output_' + str(args.load_epoch) + '.dat'
  568. output_for_initialization.load_state_dict(torch.load(fname))
  569. output.load_state_dict(torch.load(fname))
  570. # weight initialization
  571. # rnn_for_initialization_dictionary = dict(rnn_for_initialization.named_parameters())
  572. # for name, param in rnn.named_parameters():
  573. # # print(name)
  574. # print(1)
  575. # print(param)
  576. # param.data = rnn_for_initialization_dictionary[name]
  577. # print(2)
  578. # print(param)
  579. # output_for_initialization_dictionary = dict(output_for_initialization.named_parameters())
  580. # for name, param in output.named_parameters():
  581. # # print(name)
  582. # print(1)
  583. # print(param)
  584. # param.data = output_for_initialization_dictionary[name]
  585. # print(1)
  586. # print(param)
  587. # args.lr = 0.00001
  588. # epoch = args.load_epoch
  589. print('model loaded!, lr: {}'.format(args.lr))
  590. # train_for_completion(args, dataset_loader, rnn, output, rnn,
  591. # output)
  592. train_for_completion(args, dataset_loader, rnn, output, rnn_for_initialization,
  593. output_for_initialization)
  594. # for batch_idx, data in enumerate(dataset_loader):
  595. # if batch_idx==0:
  596. # rnn.zero_grad()
  597. # output.zero_grad()
  598. # x_unsorted = data['x'].float()
  599. # y_unsorted = data['y'].float()
  600. # y_len_unsorted = data['len']
  601. # train_x, train_y, train_y_len, hidden = \
  602. # data_initializer_for_training(x_unsorted, y_unsorted, y_len_unsorted,
  603. # args, rnn, output)