You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 34KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760
  1. import networkx as nx
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.init as init
  6. from torch.autograd import Variable
  7. import matplotlib.pyplot as plt
  8. import torch.nn.functional as F
  9. from torch import optim
  10. from torch.optim.lr_scheduler import MultiStepLR
  11. from sklearn.decomposition import PCA
  12. import logging
  13. from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
  14. from time import gmtime, strftime
  15. from sklearn.metrics import roc_curve
  16. from sklearn.metrics import roc_auc_score
  17. from sklearn.metrics import average_precision_score
  18. from random import shuffle
  19. import pickle
  20. from tensorboard_logger import configure, log_value
  21. import scipy.misc
  22. import time as tm
  23. from tensorboard_logger import log_value
  24. from data import *
  25. def train_vae_epoch(epoch, args, rnn, output, data_loader,
  26. optimizer_rnn, optimizer_output,
  27. scheduler_rnn, scheduler_output):
  28. rnn.train()
  29. output.train()
  30. loss_sum = 0
  31. for batch_idx, data in enumerate(data_loader):
  32. rnn.zero_grad()
  33. output.zero_grad()
  34. x_unsorted = data['x'].float()
  35. y_unsorted = data['y'].float()
  36. y_len_unsorted = data['len']
  37. y_len_max = max(y_len_unsorted)
  38. x_unsorted = x_unsorted[:, 0:y_len_max, :]
  39. y_unsorted = y_unsorted[:, 0:y_len_max, :]
  40. # initialize lstm hidden state according to batch size
  41. rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))
  42. # sort input
  43. y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True)
  44. y_len = y_len.numpy().tolist()
  45. x = torch.index_select(x_unsorted, 0, sort_index)
  46. y = torch.index_select(y_unsorted, 0, sort_index)
  47. x = Variable(x).cuda()
  48. y = Variable(y).cuda()
  49. # if using ground truth to train
  50. h = rnn(x, pack=True, input_len=y_len)
  51. y_pred, z_mu, z_lsgms = output(h)
  52. y_pred = F.sigmoid(y_pred)
  53. # clean
  54. y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True)
  55. y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
  56. z_mu = pack_padded_sequence(z_mu, y_len, batch_first=True)
  57. z_mu = pad_packed_sequence(z_mu, batch_first=True)[0]
  58. z_lsgms = pack_padded_sequence(z_lsgms, y_len, batch_first=True)
  59. z_lsgms = pad_packed_sequence(z_lsgms, batch_first=True)[0]
  60. # use cross entropy loss
  61. loss_bce = binary_cross_entropy_weight(y_pred, y)
  62. loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp())
  63. loss_kl /= y.size(0) * y.size(1) * sum(y_len) # normalize
  64. loss = loss_bce + loss_kl
  65. loss.backward()
  66. # update deterministic and lstm
  67. optimizer_output.step()
  68. optimizer_rnn.step()
  69. scheduler_output.step()
  70. scheduler_rnn.step()
  71. z_mu_mean = torch.mean(z_mu.data)
  72. z_sgm_mean = torch.mean(z_lsgms.mul(0.5).exp_().data)
  73. z_mu_min = torch.min(z_mu.data)
  74. z_sgm_min = torch.min(z_lsgms.mul(0.5).exp_().data)
  75. z_mu_max = torch.max(z_mu.data)
  76. z_sgm_max = torch.max(z_lsgms.mul(0.5).exp_().data)
  77. if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics
  78. print(
  79. 'Epoch: {}/{}, train bce loss: {:.6f}, train kl loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
  80. epoch, args.epochs, loss_bce.data[0], loss_kl.data[0], args.graph_type, args.num_layers,
  81. args.hidden_size_rnn))
  82. print('z_mu_mean', z_mu_mean, 'z_mu_min', z_mu_min, 'z_mu_max', z_mu_max, 'z_sgm_mean', z_sgm_mean,
  83. 'z_sgm_min', z_sgm_min, 'z_sgm_max', z_sgm_max)
  84. # logging
  85. log_value('bce_loss_' + args.fname, loss_bce.data[0], epoch * args.batch_ratio + batch_idx)
  86. log_value('kl_loss_' + args.fname, loss_kl.data[0], epoch * args.batch_ratio + batch_idx)
  87. log_value('z_mu_mean_' + args.fname, z_mu_mean, epoch * args.batch_ratio + batch_idx)
  88. log_value('z_mu_min_' + args.fname, z_mu_min, epoch * args.batch_ratio + batch_idx)
  89. log_value('z_mu_max_' + args.fname, z_mu_max, epoch * args.batch_ratio + batch_idx)
  90. log_value('z_sgm_mean_' + args.fname, z_sgm_mean, epoch * args.batch_ratio + batch_idx)
  91. log_value('z_sgm_min_' + args.fname, z_sgm_min, epoch * args.batch_ratio + batch_idx)
  92. log_value('z_sgm_max_' + args.fname, z_sgm_max, epoch * args.batch_ratio + batch_idx)
  93. loss_sum += loss.data[0]
  94. return loss_sum / (batch_idx + 1)
  95. def test_vae_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False, sample_time=1):
  96. rnn.hidden = rnn.init_hidden(test_batch_size)
  97. rnn.eval()
  98. output.eval()
  99. # generate graphs
  100. max_num_node = int(args.max_num_node)
  101. y_pred = Variable(
  102. torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
  103. y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
  104. x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda()
  105. for i in range(max_num_node):
  106. h = rnn(x_step)
  107. y_pred_step, _, _ = output(h)
  108. y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
  109. x_step = sample_sigmoid(y_pred_step, sample=True, sample_time=sample_time)
  110. y_pred_long[:, i:i + 1, :] = x_step
  111. rnn.hidden = Variable(rnn.hidden.data).cuda()
  112. y_pred_data = y_pred.data
  113. y_pred_long_data = y_pred_long.data.long()
  114. # save graphs as pickle
  115. G_pred_list = []
  116. for i in range(test_batch_size):
  117. adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
  118. G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
  119. G_pred_list.append(G_pred)
  120. # save prediction histograms, plot histogram over each time step
  121. # if save_histogram:
  122. # save_prediction_histogram(y_pred_data.cpu().numpy(),
  123. # fname_pred=args.figure_prediction_save_path+args.fname_pred+str(epoch)+'.jpg',
  124. # max_num_node=max_num_node)
  125. return G_pred_list
  126. def test_vae_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False, sample_time=1):
  127. rnn.eval()
  128. output.eval()
  129. G_pred_list = []
  130. for batch_idx, data in enumerate(data_loader):
  131. x = data['x'].float()
  132. y = data['y'].float()
  133. y_len = data['len']
  134. test_batch_size = x.size(0)
  135. rnn.hidden = rnn.init_hidden(test_batch_size)
  136. # generate graphs
  137. max_num_node = int(args.max_num_node)
  138. y_pred = Variable(
  139. torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
  140. y_pred_long = Variable(
  141. torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
  142. x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda()
  143. for i in range(max_num_node):
  144. print('finish node', i)
  145. h = rnn(x_step)
  146. y_pred_step, _, _ = output(h)
  147. y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
  148. x_step = sample_sigmoid_supervised(y_pred_step, y[:, i:i + 1, :].cuda(), current=i, y_len=y_len,
  149. sample_time=sample_time)
  150. y_pred_long[:, i:i + 1, :] = x_step
  151. rnn.hidden = Variable(rnn.hidden.data).cuda()
  152. y_pred_data = y_pred.data
  153. y_pred_long_data = y_pred_long.data.long()
  154. # save graphs as pickle
  155. for i in range(test_batch_size):
  156. adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
  157. G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
  158. G_pred_list.append(G_pred)
  159. return G_pred_list
  160. def train_mlp_epoch(epoch, args, rnn, output, data_loader,
  161. optimizer_rnn, optimizer_output,
  162. scheduler_rnn, scheduler_output):
  163. rnn.train()
  164. output.train()
  165. loss_sum = 0
  166. for batch_idx, data in enumerate(data_loader):
  167. rnn.zero_grad()
  168. output.zero_grad()
  169. x_unsorted = data['x'].float()
  170. y_unsorted = data['y'].float()
  171. y_len_unsorted = data['len']
  172. y_len_max = max(y_len_unsorted)
  173. x_unsorted = x_unsorted[:, 0:y_len_max, :]
  174. y_unsorted = y_unsorted[:, 0:y_len_max, :]
  175. # initialize lstm hidden state according to batch size
  176. rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))
  177. # sort input
  178. y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True)
  179. y_len = y_len.numpy().tolist()
  180. x = torch.index_select(x_unsorted, 0, sort_index)
  181. y = torch.index_select(y_unsorted, 0, sort_index)
  182. x = Variable(x).cuda()
  183. y = Variable(y).cuda()
  184. h = rnn(x, pack=True, input_len=y_len)
  185. y_pred = output(h)
  186. y_pred = F.sigmoid(y_pred)
  187. # clean
  188. y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True)
  189. y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
  190. # use cross entropy loss
  191. loss = binary_cross_entropy_weight(y_pred, y)
  192. loss.backward()
  193. # update deterministic and lstm
  194. optimizer_output.step()
  195. optimizer_rnn.step()
  196. scheduler_output.step()
  197. scheduler_rnn.step()
  198. if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics
  199. print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
  200. epoch, args.epochs, loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn))
  201. # logging
  202. log_value('loss_' + args.fname, loss.data[0], epoch * args.batch_ratio + batch_idx)
  203. loss_sum += loss.data[0]
  204. return loss_sum / (batch_idx + 1)
  205. def test_mlp_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False, sample_time=1):
  206. rnn.hidden = rnn.init_hidden(test_batch_size)
  207. rnn.eval()
  208. output.eval()
  209. # generate graphs
  210. max_num_node = int(args.max_num_node)
  211. y_pred = Variable(
  212. torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
  213. y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
  214. x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda()
  215. for i in range(max_num_node):
  216. h = rnn(x_step)
  217. y_pred_step = output(h)
  218. y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
  219. x_step = sample_sigmoid(y_pred_step, sample=True, sample_time=sample_time)
  220. y_pred_long[:, i:i + 1, :] = x_step
  221. rnn.hidden = Variable(rnn.hidden.data).cuda()
  222. y_pred_data = y_pred.data
  223. y_pred_long_data = y_pred_long.data.long()
  224. # save graphs as pickle
  225. G_pred_list = []
  226. for i in range(test_batch_size):
  227. adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
  228. G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
  229. G_pred_list.append(G_pred)
  230. # # save prediction histograms, plot histogram over each time step
  231. # if save_histogram:
  232. # save_prediction_histogram(y_pred_data.cpu().numpy(),
  233. # fname_pred=args.figure_prediction_save_path+args.fname_pred+str(epoch)+'.jpg',
  234. # max_num_node=max_num_node)
  235. return G_pred_list
  236. def test_mlp_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False, sample_time=1):
  237. rnn.eval()
  238. output.eval()
  239. G_pred_list = []
  240. for batch_idx, data in enumerate(data_loader):
  241. x = data['x'].float()
  242. y = data['y'].float()
  243. y_len = data['len']
  244. test_batch_size = x.size(0)
  245. rnn.hidden = rnn.init_hidden(test_batch_size)
  246. # generate graphs
  247. max_num_node = int(args.max_num_node)
  248. y_pred = Variable(
  249. torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
  250. y_pred_long = Variable(
  251. torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
  252. x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda()
  253. for i in range(max_num_node):
  254. print('finish node', i)
  255. h = rnn(x_step)
  256. y_pred_step = output(h)
  257. y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
  258. x_step = sample_sigmoid_supervised(y_pred_step, y[:, i:i + 1, :].cuda(), current=i, y_len=y_len,
  259. sample_time=sample_time)
  260. y_pred_long[:, i:i + 1, :] = x_step
  261. rnn.hidden = Variable(rnn.hidden.data).cuda()
  262. y_pred_data = y_pred.data
  263. y_pred_long_data = y_pred_long.data.long()
  264. # save graphs as pickle
  265. for i in range(test_batch_size):
  266. adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
  267. G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
  268. G_pred_list.append(G_pred)
  269. return G_pred_list
  270. def test_mlp_partial_simple_epoch(epoch, args, rnn, output, data_loader, save_histogram=False, sample_time=1):
  271. rnn.eval()
  272. output.eval()
  273. G_pred_list = []
  274. for batch_idx, data in enumerate(data_loader):
  275. x = data['x'].float()
  276. y = data['y'].float()
  277. y_len = data['len']
  278. test_batch_size = x.size(0)
  279. rnn.hidden = rnn.init_hidden(test_batch_size)
  280. # generate graphs
  281. max_num_node = int(args.max_num_node)
  282. y_pred = Variable(
  283. torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
  284. y_pred_long = Variable(
  285. torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
  286. x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda()
  287. for i in range(max_num_node):
  288. print('finish node', i)
  289. h = rnn(x_step)
  290. y_pred_step = output(h)
  291. y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
  292. x_step = sample_sigmoid_supervised_simple(y_pred_step, y[:, i:i + 1, :].cuda(), current=i, y_len=y_len,
  293. sample_time=sample_time)
  294. y_pred_long[:, i:i + 1, :] = x_step
  295. rnn.hidden = Variable(rnn.hidden.data).cuda()
  296. y_pred_data = y_pred.data
  297. y_pred_long_data = y_pred_long.data.long()
  298. # save graphs as pickle
  299. for i in range(test_batch_size):
  300. adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
  301. G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
  302. G_pred_list.append(G_pred)
  303. return G_pred_list
  304. def train_mlp_forward_epoch(epoch, args, rnn, output, data_loader):
  305. rnn.train()
  306. output.train()
  307. loss_sum = 0
  308. for batch_idx, data in enumerate(data_loader):
  309. rnn.zero_grad()
  310. output.zero_grad()
  311. x_unsorted = data['x'].float()
  312. y_unsorted = data['y'].float()
  313. y_len_unsorted = data['len']
  314. y_len_max = max(y_len_unsorted)
  315. x_unsorted = x_unsorted[:, 0:y_len_max, :]
  316. y_unsorted = y_unsorted[:, 0:y_len_max, :]
  317. # initialize lstm hidden state according to batch size
  318. rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))
  319. # sort input
  320. y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True)
  321. y_len = y_len.numpy().tolist()
  322. x = torch.index_select(x_unsorted, 0, sort_index)
  323. y = torch.index_select(y_unsorted, 0, sort_index)
  324. x = Variable(x).cuda()
  325. y = Variable(y).cuda()
  326. h = rnn(x, pack=True, input_len=y_len)
  327. y_pred = output(h)
  328. y_pred = F.sigmoid(y_pred)
  329. # clean
  330. y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True)
  331. y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
  332. # use cross entropy loss
  333. loss = 0
  334. for j in range(y.size(1)):
  335. # print('y_pred',y_pred[0,j,:],'y',y[0,j,:])
  336. end_idx = min(j + 1, y.size(2))
  337. loss += binary_cross_entropy_weight(y_pred[:, j, 0:end_idx], y[:, j, 0:end_idx]) * end_idx
  338. if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics
  339. print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
  340. epoch, args.epochs, loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn))
  341. # logging
  342. log_value('loss_' + args.fname, loss.data[0], epoch * args.batch_ratio + batch_idx)
  343. loss_sum += loss.data[0]
  344. return loss_sum / (batch_idx + 1)
  345. ## too complicated, deprecated
  346. # def test_mlp_partial_bfs_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1):
  347. # rnn.eval()
  348. # output.eval()
  349. # G_pred_list = []
  350. # for batch_idx, data in enumerate(data_loader):
  351. # x = data['x'].float()
  352. # y = data['y'].float()
  353. # y_len = data['len']
  354. # test_batch_size = x.size(0)
  355. # rnn.hidden = rnn.init_hidden(test_batch_size)
  356. # # generate graphs
  357. # max_num_node = int(args.max_num_node)
  358. # y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
  359. # y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
  360. # x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda()
  361. # for i in range(max_num_node):
  362. # # 1 back up hidden state
  363. # hidden_prev = Variable(rnn.hidden.data).cuda()
  364. # h = rnn(x_step)
  365. # y_pred_step = output(h)
  366. # y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
  367. # x_step = sample_sigmoid_supervised(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time)
  368. # y_pred_long[:, i:i + 1, :] = x_step
  369. #
  370. # rnn.hidden = Variable(rnn.hidden.data).cuda()
  371. #
  372. # print('finish node', i)
  373. # y_pred_data = y_pred.data
  374. # y_pred_long_data = y_pred_long.data.long()
  375. #
  376. # # save graphs as pickle
  377. # for i in range(test_batch_size):
  378. # adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
  379. # G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
  380. # G_pred_list.append(G_pred)
  381. # return G_pred_list
  382. def train_rnn_epoch(epoch, args, rnn, output, data_loader,
  383. optimizer_rnn, optimizer_output,
  384. scheduler_rnn, scheduler_output):
  385. rnn.train()
  386. output.train()
  387. loss_sum = 0
  388. for batch_idx, data in enumerate(data_loader):
  389. rnn.zero_grad()
  390. output.zero_grad()
  391. x_unsorted = data['x'].float()
  392. y_unsorted = data['y'].float()
  393. y_len_unsorted = data['len']
  394. y_len_max = max(y_len_unsorted)
  395. x_unsorted = x_unsorted[:, 0:y_len_max, :]
  396. y_unsorted = y_unsorted[:, 0:y_len_max, :]
  397. # initialize lstm hidden state according to batch size
  398. rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))
  399. # output.hidden = output.init_hidden(batch_size=x_unsorted.size(0)*x_unsorted.size(1))
  400. # sort input
  401. y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True)
  402. y_len = y_len.numpy().tolist()
  403. x = torch.index_select(x_unsorted, 0, sort_index)
  404. y = torch.index_select(y_unsorted, 0, sort_index)
  405. # input, output for output rnn module
  406. # a smart use of pytorch builtin function: pack variable--b1_l1,b2_l1,...,b1_l2,b2_l2,...
  407. y_reshape = pack_padded_sequence(y, y_len, batch_first=True).data
  408. # reverse y_reshape, so that their lengths are sorted, add dimension
  409. idx = [i for i in range(y_reshape.size(0) - 1, -1, -1)]
  410. idx = torch.LongTensor(idx)
  411. y_reshape = y_reshape.index_select(0, idx)
  412. y_reshape = y_reshape.view(y_reshape.size(0), y_reshape.size(1), 1)
  413. output_x = torch.cat((torch.ones(y_reshape.size(0), 1, 1), y_reshape[:, 0:-1, 0:1]), dim=1)
  414. output_y = y_reshape
  415. # batch size for output module: sum(y_len)
  416. output_y_len = []
  417. output_y_len_bin = np.bincount(np.array(y_len))
  418. for i in range(len(output_y_len_bin) - 1, 0, -1):
  419. count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i
  420. output_y_len.extend(
  421. [min(i, y.size(2))] * count_temp) # put them in output_y_len; max value should not exceed y.size(2)
  422. # pack into variable
  423. x = Variable(x).cuda()
  424. y = Variable(y).cuda()
  425. output_x = Variable(output_x).cuda()
  426. output_y = Variable(output_y).cuda()
  427. # print(output_y_len)
  428. # print('len',len(output_y_len))
  429. # print('y',y.size())
  430. # print('output_y',output_y.size())
  431. # if using ground truth to train
  432. h = rnn(x, pack=True, input_len=y_len)
  433. h = pack_padded_sequence(h, y_len, batch_first=True).data # get packed hidden vector
  434. # reverse h
  435. idx = [i for i in range(h.size(0) - 1, -1, -1)]
  436. idx = Variable(torch.LongTensor(idx)).cuda()
  437. h = h.index_select(0, idx)
  438. hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(1))).cuda()
  439. output.hidden = torch.cat((h.view(1, h.size(0), h.size(1)), hidden_null),
  440. dim=0) # num_layers, batch_size, hidden_size
  441. y_pred = output(output_x, pack=True, input_len=output_y_len)
  442. y_pred = F.sigmoid(y_pred)
  443. # clean
  444. y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True)
  445. y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
  446. output_y = pack_padded_sequence(output_y, output_y_len, batch_first=True)
  447. output_y = pad_packed_sequence(output_y, batch_first=True)[0]
  448. # use cross entropy loss
  449. loss = binary_cross_entropy_weight(y_pred, output_y)
  450. loss.backward()
  451. # update deterministic and lstm
  452. optimizer_output.step()
  453. optimizer_rnn.step()
  454. scheduler_output.step()
  455. scheduler_rnn.step()
  456. if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics
  457. print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
  458. epoch, args.epochs, loss.data, args.graph_type, args.num_layers, args.hidden_size_rnn))
  459. # logging
  460. log_value('loss_' + args.fname, loss.data, epoch * args.batch_ratio + batch_idx)
  461. feature_dim = y.size(1) * y.size(2)
  462. loss_sum += loss.data * feature_dim
  463. return loss_sum / (batch_idx + 1)
  464. def test_rnn_epoch(epoch, args, rnn, output, test_batch_size=16):
  465. rnn.hidden = rnn.init_hidden(test_batch_size)
  466. rnn.eval()
  467. output.eval()
  468. # generate graphs
  469. max_num_node = int(args.max_num_node)
  470. y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
  471. x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda()
  472. for i in range(max_num_node):
  473. h = rnn(x_step)
  474. # output.hidden = h.permute(1,0,2)
  475. hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(2))).cuda()
  476. output.hidden = torch.cat((h.permute(1, 0, 2), hidden_null),
  477. dim=0) # num_layers, batch_size, hidden_size
  478. x_step = Variable(torch.zeros(test_batch_size, 1, args.max_prev_node)).cuda()
  479. output_x_step = Variable(torch.ones(test_batch_size, 1, 1)).cuda()
  480. for j in range(min(args.max_prev_node, i + 1)):
  481. output_y_pred_step = output(output_x_step)
  482. output_x_step = sample_sigmoid(output_y_pred_step, sample=True, sample_time=1)
  483. x_step[:, :, j:j + 1] = output_x_step
  484. output.hidden = Variable(output.hidden.data).cuda()
  485. y_pred_long[:, i:i + 1, :] = x_step
  486. rnn.hidden = Variable(rnn.hidden.data).cuda()
  487. y_pred_long_data = y_pred_long.data.long()
  488. # save graphs as pickle
  489. G_pred_list = []
  490. for i in range(test_batch_size):
  491. adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
  492. G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
  493. G_pred_list.append(G_pred)
  494. return G_pred_list
  495. def train_rnn_forward_epoch(epoch, args, rnn, output, data_loader):
  496. rnn.train()
  497. output.train()
  498. loss_sum = 0
  499. for batch_idx, data in enumerate(data_loader):
  500. rnn.zero_grad()
  501. output.zero_grad()
  502. x_unsorted = data['x'].float()
  503. y_unsorted = data['y'].float()
  504. y_len_unsorted = data['len']
  505. y_len_max = max(y_len_unsorted)
  506. x_unsorted = x_unsorted[:, 0:y_len_max, :]
  507. y_unsorted = y_unsorted[:, 0:y_len_max, :]
  508. # initialize lstm hidden state according to batch size
  509. rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))
  510. # output.hidden = output.init_hidden(batch_size=x_unsorted.size(0)*x_unsorted.size(1))
  511. # sort input
  512. y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True)
  513. y_len = y_len.numpy().tolist()
  514. x = torch.index_select(x_unsorted, 0, sort_index)
  515. y = torch.index_select(y_unsorted, 0, sort_index)
  516. # input, output for output rnn module
  517. # a smart use of pytorch builtin function: pack variable--b1_l1,b2_l1,...,b1_l2,b2_l2,...
  518. y_reshape = pack_padded_sequence(y, y_len, batch_first=True).data
  519. # reverse y_reshape, so that their lengths are sorted, add dimension
  520. idx = [i for i in range(y_reshape.size(0) - 1, -1, -1)]
  521. idx = torch.LongTensor(idx)
  522. y_reshape = y_reshape.index_select(0, idx)
  523. y_reshape = y_reshape.view(y_reshape.size(0), y_reshape.size(1), 1)
  524. output_x = torch.cat((torch.ones(y_reshape.size(0), 1, 1), y_reshape[:, 0:-1, 0:1]), dim=1)
  525. output_y = y_reshape
  526. # batch size for output module: sum(y_len)
  527. output_y_len = []
  528. output_y_len_bin = np.bincount(np.array(y_len))
  529. for i in range(len(output_y_len_bin) - 1, 0, -1):
  530. count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i
  531. output_y_len.extend(
  532. [min(i, y.size(2))] * count_temp) # put them in output_y_len; max value should not exceed y.size(2)
  533. # pack into variable
  534. x = Variable(x).cuda()
  535. y = Variable(y).cuda()
  536. output_x = Variable(output_x).cuda()
  537. output_y = Variable(output_y).cuda()
  538. # print(output_y_len)
  539. # print('len',len(output_y_len))
  540. # print('y',y.size())
  541. # print('output_y',output_y.size())
  542. # if using ground truth to train
  543. h = rnn(x, pack=True, input_len=y_len)
  544. h = pack_padded_sequence(h, y_len, batch_first=True).data # get packed hidden vector
  545. # reverse h
  546. idx = [i for i in range(h.size(0) - 1, -1, -1)]
  547. idx = Variable(torch.LongTensor(idx)).cuda()
  548. h = h.index_select(0, idx)
  549. hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(1))).cuda()
  550. output.hidden = torch.cat((h.view(1, h.size(0), h.size(1)), hidden_null),
  551. dim=0) # num_layers, batch_size, hidden_size
  552. y_pred = output(output_x, pack=True, input_len=output_y_len)
  553. y_pred = F.sigmoid(y_pred)
  554. # clean
  555. y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True)
  556. y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
  557. output_y = pack_padded_sequence(output_y, output_y_len, batch_first=True)
  558. output_y = pad_packed_sequence(output_y, batch_first=True)[0]
  559. # use cross entropy loss
  560. loss = binary_cross_entropy_weight(y_pred, output_y)
  561. if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics
  562. print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
  563. epoch, args.epochs, loss.data, args.graph_type, args.num_layers, args.hidden_size_rnn))
  564. # logging
  565. log_value('loss_' + args.fname, loss.data, epoch * args.batch_ratio + batch_idx)
  566. # print(y_pred.size())
  567. feature_dim = y_pred.size(0) * y_pred.size(1)
  568. loss_sum += loss.data * feature_dim / y.size(0)
  569. return loss_sum / (batch_idx + 1)
  570. ########### train function for LSTM + VAE
  571. def train(args, dataset_train, rnn, output):
  572. # check if load existing model
  573. if args.load:
  574. fname = args.model_save_path + args.fname + 'lstm_' + str(args.load_epoch) + '.dat'
  575. rnn.load_state_dict(torch.load(fname))
  576. fname = args.model_save_path + args.fname + 'output_' + str(args.load_epoch) + '.dat'
  577. output.load_state_dict(torch.load(fname))
  578. args.lr = 0.00001
  579. epoch = args.load_epoch
  580. print('model loaded!, lr: {}'.format(args.lr))
  581. else:
  582. epoch = 1
  583. # initialize optimizer
  584. optimizer_rnn = optim.Adam(list(rnn.parameters()), lr=args.lr)
  585. optimizer_output = optim.Adam(list(output.parameters()), lr=args.lr)
  586. scheduler_rnn = MultiStepLR(optimizer_rnn, milestones=args.milestones, gamma=args.lr_rate)
  587. scheduler_output = MultiStepLR(optimizer_output, milestones=args.milestones, gamma=args.lr_rate)
  588. # start main loop
  589. time_all = np.zeros(args.epochs)
  590. while epoch <= args.epochs:
  591. time_start = tm.time()
  592. # train
  593. if 'GraphRNN_VAE' in args.note:
  594. train_vae_epoch(epoch, args, rnn, output, dataset_train,
  595. optimizer_rnn, optimizer_output,
  596. scheduler_rnn, scheduler_output)
  597. elif 'GraphRNN_MLP' in args.note:
  598. train_mlp_epoch(epoch, args, rnn, output, dataset_train,
  599. optimizer_rnn, optimizer_output,
  600. scheduler_rnn, scheduler_output)
  601. elif 'GraphRNN_RNN' in args.note:
  602. train_rnn_epoch(epoch, args, rnn, output, dataset_train,
  603. optimizer_rnn, optimizer_output,
  604. scheduler_rnn, scheduler_output)
  605. time_end = tm.time()
  606. time_all[epoch - 1] = time_end - time_start
  607. # test
  608. if epoch % args.epochs_test == 0 and epoch >= args.epochs_test_start:
  609. for sample_time in range(1, 4):
  610. G_pred = []
  611. while len(G_pred) < args.test_total_size:
  612. if 'GraphRNN_VAE' in args.note:
  613. G_pred_step = test_vae_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size,
  614. sample_time=sample_time)
  615. elif 'GraphRNN_MLP' in args.note:
  616. G_pred_step = test_mlp_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size,
  617. sample_time=sample_time)
  618. elif 'GraphRNN_RNN' in args.note:
  619. G_pred_step = test_rnn_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size)
  620. G_pred.extend(G_pred_step)
  621. # save graphs
  622. fname = args.graph_save_path + args.fname_pred + str(epoch) + '_' + str(sample_time) + '.dat'
  623. save_graph_list(G_pred, fname)
  624. if 'GraphRNN_RNN' in args.note:
  625. break
  626. print('test done, graphs saved')
  627. # save model checkpoint
  628. if args.save:
  629. if epoch % args.epochs_save == 0:
  630. fname = args.model_save_path + args.fname + 'lstm_' + str(epoch) + '.dat'
  631. torch.save(rnn.state_dict(), fname)
  632. fname = args.model_save_path + args.fname + 'output_' + str(epoch) + '.dat'
  633. torch.save(output.state_dict(), fname)
  634. epoch += 1
  635. np.save(args.timing_save_path + args.fname, time_all)
  636. ########### for graph completion task
  637. def train_graph_completion(args, dataset_test, rnn, output):
  638. fname = args.model_save_path + args.fname + 'lstm_' + str(args.load_epoch) + '.dat'
  639. rnn.load_state_dict(torch.load(fname))
  640. fname = args.model_save_path + args.fname + 'output_' + str(args.load_epoch) + '.dat'
  641. output.load_state_dict(torch.load(fname))
  642. epoch = args.load_epoch
  643. print('model loaded!, epoch: {}'.format(args.load_epoch))
  644. for sample_time in range(1, 4):
  645. if 'GraphRNN_MLP' in args.note:
  646. G_pred = test_mlp_partial_simple_epoch(epoch, args, rnn, output, dataset_test, sample_time=sample_time)
  647. if 'GraphRNN_VAE' in args.note:
  648. G_pred = test_vae_partial_epoch(epoch, args, rnn, output, dataset_test, sample_time=sample_time)
  649. # save graphs
  650. fname = args.graph_save_path + args.fname_pred + str(epoch) + '_' + str(sample_time) + 'graph_completion.dat'
  651. save_graph_list(G_pred, fname)
  652. print('graph completion done, graphs saved')
  653. ########### for NLL evaluation
  654. def train_nll(args, dataset_train, dataset_test, rnn, output, graph_validate_len, graph_test_len, max_iter=1000):
  655. fname = args.model_save_path + args.fname + 'lstm_' + str(args.load_epoch) + '.dat'
  656. rnn.load_state_dict(torch.load(fname))
  657. fname = args.model_save_path + args.fname + 'output_' + str(args.load_epoch) + '.dat'
  658. output.load_state_dict(torch.load(fname))
  659. epoch = args.load_epoch
  660. print('model loaded!, epoch: {}'.format(args.load_epoch))
  661. fname_output = args.nll_save_path + args.note + '_' + args.graph_type + '.csv'
  662. with open(fname_output, 'w+') as f:
  663. f.write(str(graph_validate_len) + ',' + str(graph_test_len) + '\n')
  664. f.write('train,test\n')
  665. for iter in range(max_iter):
  666. if 'GraphRNN_MLP' in args.note:
  667. nll_train = train_mlp_forward_epoch(epoch, args, rnn, output, dataset_train)
  668. nll_test = train_mlp_forward_epoch(epoch, args, rnn, output, dataset_test)
  669. if 'GraphRNN_RNN' in args.note:
  670. nll_train = train_rnn_forward_epoch(epoch, args, rnn, output, dataset_train)
  671. nll_test = train_rnn_forward_epoch(epoch, args, rnn, output, dataset_test)
  672. print('train', nll_train, 'test', nll_test)
  673. f.write(str(nll_train) + ',' + str(nll_test) + '\n')
  674. print('NLL evaluation done')