You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 33KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760
  1. import networkx as nx
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.init as init
  6. from torch.autograd import Variable
  7. import matplotlib.pyplot as plt
  8. import torch.nn.functional as F
  9. from torch import optim
  10. from torch.optim.lr_scheduler import MultiStepLR
  11. from sklearn.decomposition import PCA
  12. import logging
  13. from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
  14. from time import gmtime, strftime
  15. from sklearn.metrics import roc_curve
  16. from sklearn.metrics import roc_auc_score
  17. from sklearn.metrics import average_precision_score
  18. from random import shuffle
  19. import pickle
  20. from tensorboard_logger import configure, log_value
  21. import scipy.misc
  22. import time as tm
  23. from utils import *
  24. from model import *
  25. from data import *
  26. from args import Args
  27. import create_graphs
  28. def train_vae_epoch(epoch, args, rnn, output, data_loader,
  29. optimizer_rnn, optimizer_output,
  30. scheduler_rnn, scheduler_output):
  31. rnn.train()
  32. output.train()
  33. loss_sum = 0
  34. for batch_idx, data in enumerate(data_loader):
  35. rnn.zero_grad()
  36. output.zero_grad()
  37. x_unsorted = data['x'].float()
  38. y_unsorted = data['y'].float()
  39. y_len_unsorted = data['len']
  40. y_len_max = max(y_len_unsorted)
  41. x_unsorted = x_unsorted[:, 0:y_len_max, :]
  42. y_unsorted = y_unsorted[:, 0:y_len_max, :]
  43. # initialize lstm hidden state according to batch size
  44. rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))
  45. # sort input
  46. y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True)
  47. y_len = y_len.numpy().tolist()
  48. x = torch.index_select(x_unsorted,0,sort_index)
  49. y = torch.index_select(y_unsorted,0,sort_index)
  50. x = Variable(x).cuda()
  51. y = Variable(y).cuda()
  52. # if using ground truth to train
  53. h = rnn(x, pack=True, input_len=y_len)
  54. y_pred,z_mu,z_lsgms = output(h)
  55. y_pred = F.sigmoid(y_pred)
  56. # clean
  57. y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True)
  58. y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
  59. z_mu = pack_padded_sequence(z_mu, y_len, batch_first=True)
  60. z_mu = pad_packed_sequence(z_mu, batch_first=True)[0]
  61. z_lsgms = pack_padded_sequence(z_lsgms, y_len, batch_first=True)
  62. z_lsgms = pad_packed_sequence(z_lsgms, batch_first=True)[0]
  63. # use cross entropy loss
  64. loss_bce = binary_cross_entropy_weight(y_pred, y)
  65. loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp())
  66. loss_kl /= y.size(0)*y.size(1)*sum(y_len) # normalize
  67. loss = loss_bce + loss_kl
  68. loss.backward()
  69. # update deterministic and lstm
  70. optimizer_output.step()
  71. optimizer_rnn.step()
  72. scheduler_output.step()
  73. scheduler_rnn.step()
  74. z_mu_mean = torch.mean(z_mu.data)
  75. z_sgm_mean = torch.mean(z_lsgms.mul(0.5).exp_().data)
  76. z_mu_min = torch.min(z_mu.data)
  77. z_sgm_min = torch.min(z_lsgms.mul(0.5).exp_().data)
  78. z_mu_max = torch.max(z_mu.data)
  79. z_sgm_max = torch.max(z_lsgms.mul(0.5).exp_().data)
  80. if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics
  81. print('Epoch: {}/{}, train bce loss: {:.6f}, train kl loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
  82. epoch, args.epochs,loss_bce.data[0], loss_kl.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn))
  83. print('z_mu_mean', z_mu_mean, 'z_mu_min', z_mu_min, 'z_mu_max', z_mu_max, 'z_sgm_mean', z_sgm_mean, 'z_sgm_min', z_sgm_min, 'z_sgm_max', z_sgm_max)
  84. # logging
  85. log_value('bce_loss_'+args.fname, loss_bce.data[0], epoch*args.batch_ratio+batch_idx)
  86. log_value('kl_loss_' +args.fname, loss_kl.data[0], epoch*args.batch_ratio + batch_idx)
  87. log_value('z_mu_mean_'+args.fname, z_mu_mean, epoch*args.batch_ratio + batch_idx)
  88. log_value('z_mu_min_'+args.fname, z_mu_min, epoch*args.batch_ratio + batch_idx)
  89. log_value('z_mu_max_'+args.fname, z_mu_max, epoch*args.batch_ratio + batch_idx)
  90. log_value('z_sgm_mean_'+args.fname, z_sgm_mean, epoch*args.batch_ratio + batch_idx)
  91. log_value('z_sgm_min_'+args.fname, z_sgm_min, epoch*args.batch_ratio + batch_idx)
  92. log_value('z_sgm_max_'+args.fname, z_sgm_max, epoch*args.batch_ratio + batch_idx)
  93. loss_sum += loss.data[0]
  94. return loss_sum/(batch_idx+1)
  95. def test_vae_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False, sample_time = 1):
  96. rnn.hidden = rnn.init_hidden(test_batch_size)
  97. rnn.eval()
  98. output.eval()
  99. # generate graphs
  100. max_num_node = int(args.max_num_node)
  101. y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
  102. y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
  103. x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda()
  104. for i in range(max_num_node):
  105. h = rnn(x_step)
  106. y_pred_step, _, _ = output(h)
  107. y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
  108. x_step = sample_sigmoid(y_pred_step, sample=True, sample_time=sample_time)
  109. y_pred_long[:, i:i + 1, :] = x_step
  110. rnn.hidden = Variable(rnn.hidden.data).cuda()
  111. y_pred_data = y_pred.data
  112. y_pred_long_data = y_pred_long.data.long()
  113. # save graphs as pickle
  114. G_pred_list = []
  115. for i in range(test_batch_size):
  116. adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
  117. G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
  118. G_pred_list.append(G_pred)
  119. # save prediction histograms, plot histogram over each time step
  120. # if save_histogram:
  121. # save_prediction_histogram(y_pred_data.cpu().numpy(),
  122. # fname_pred=args.figure_prediction_save_path+args.fname_pred+str(epoch)+'.jpg',
  123. # max_num_node=max_num_node)
  124. return G_pred_list
  125. def test_vae_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1):
  126. rnn.eval()
  127. output.eval()
  128. G_pred_list = []
  129. for batch_idx, data in enumerate(data_loader):
  130. x = data['x'].float()
  131. y = data['y'].float()
  132. y_len = data['len']
  133. test_batch_size = x.size(0)
  134. rnn.hidden = rnn.init_hidden(test_batch_size)
  135. # generate graphs
  136. max_num_node = int(args.max_num_node)
  137. y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
  138. y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
  139. x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda()
  140. for i in range(max_num_node):
  141. print('finish node',i)
  142. h = rnn(x_step)
  143. y_pred_step, _, _ = output(h)
  144. y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
  145. x_step = sample_sigmoid_supervised(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time)
  146. y_pred_long[:, i:i + 1, :] = x_step
  147. rnn.hidden = Variable(rnn.hidden.data).cuda()
  148. y_pred_data = y_pred.data
  149. y_pred_long_data = y_pred_long.data.long()
  150. # save graphs as pickle
  151. for i in range(test_batch_size):
  152. adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
  153. G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
  154. G_pred_list.append(G_pred)
  155. return G_pred_list
  156. def train_mlp_epoch(epoch, args, rnn, output, data_loader,
  157. optimizer_rnn, optimizer_output,
  158. scheduler_rnn, scheduler_output):
  159. rnn.train()
  160. output.train()
  161. loss_sum = 0
  162. for batch_idx, data in enumerate(data_loader):
  163. rnn.zero_grad()
  164. output.zero_grad()
  165. x_unsorted = data['x'].float()
  166. y_unsorted = data['y'].float()
  167. y_len_unsorted = data['len']
  168. y_len_max = max(y_len_unsorted)
  169. x_unsorted = x_unsorted[:, 0:y_len_max, :]
  170. y_unsorted = y_unsorted[:, 0:y_len_max, :]
  171. # initialize lstm hidden state according to batch size
  172. rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))
  173. # sort input
  174. y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True)
  175. y_len = y_len.numpy().tolist()
  176. x = torch.index_select(x_unsorted,0,sort_index)
  177. y = torch.index_select(y_unsorted,0,sort_index)
  178. x = Variable(x).cuda()
  179. y = Variable(y).cuda()
  180. h = rnn(x, pack=True, input_len=y_len)
  181. y_pred = output(h)
  182. y_pred = F.sigmoid(y_pred)
  183. # clean
  184. y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True)
  185. y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
  186. # use cross entropy loss
  187. loss = binary_cross_entropy_weight(y_pred, y)
  188. loss.backward()
  189. # update deterministic and lstm
  190. optimizer_output.step()
  191. optimizer_rnn.step()
  192. scheduler_output.step()
  193. scheduler_rnn.step()
  194. if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics
  195. print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
  196. epoch, args.epochs,loss.data, args.graph_type, args.num_layers, args.hidden_size_rnn))
  197. # logging
  198. log_value('loss_'+args.fname, loss.data, epoch*args.batch_ratio+batch_idx)
  199. loss_sum += loss.data
  200. return loss_sum/(batch_idx+1)
  201. def test_mlp_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False,sample_time=1):
  202. rnn.hidden = rnn.init_hidden(test_batch_size)
  203. rnn.eval()
  204. output.eval()
  205. # generate graphs
  206. max_num_node = int(args.max_num_node)
  207. y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
  208. y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
  209. x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda()
  210. for i in range(max_num_node):
  211. h = rnn(x_step)
  212. y_pred_step = output(h)
  213. y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
  214. x_step = sample_sigmoid(y_pred_step, sample=True, sample_time=sample_time)
  215. y_pred_long[:, i:i + 1, :] = x_step
  216. rnn.hidden = Variable(rnn.hidden.data).cuda()
  217. y_pred_data = y_pred.data
  218. y_pred_long_data = y_pred_long.data.long()
  219. # save graphs as pickle
  220. G_pred_list = []
  221. for i in range(test_batch_size):
  222. adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
  223. G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
  224. G_pred_list.append(G_pred)
  225. # # save prediction histograms, plot histogram over each time step
  226. # if save_histogram:
  227. # save_prediction_histogram(y_pred_data.cpu().numpy(),
  228. # fname_pred=args.figure_prediction_save_path+args.fname_pred+str(epoch)+'.jpg',
  229. # max_num_node=max_num_node)
  230. return G_pred_list
  231. def test_mlp_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1):
  232. rnn.eval()
  233. output.eval()
  234. G_pred_list = []
  235. for batch_idx, data in enumerate(data_loader):
  236. x = data['x'].float()
  237. y = data['y'].float()
  238. y_len = data['len']
  239. test_batch_size = x.size(0)
  240. rnn.hidden = rnn.init_hidden(test_batch_size)
  241. # generate graphs
  242. max_num_node = int(args.max_num_node)
  243. y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
  244. y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
  245. x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda()
  246. for i in range(max_num_node):
  247. print('finish node',i)
  248. h = rnn(x_step)
  249. y_pred_step = output(h)
  250. y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
  251. x_step = sample_sigmoid_supervised(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time)
  252. y_pred_long[:, i:i + 1, :] = x_step
  253. rnn.hidden = Variable(rnn.hidden.data).cuda()
  254. y_pred_data = y_pred.data
  255. y_pred_long_data = y_pred_long.data.long()
  256. # save graphs as pickle
  257. for i in range(test_batch_size):
  258. adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
  259. G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
  260. G_pred_list.append(G_pred)
  261. return G_pred_list
  262. def test_mlp_partial_simple_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1):
  263. rnn.eval()
  264. output.eval()
  265. G_pred_list = []
  266. for batch_idx, data in enumerate(data_loader):
  267. x = data['x'].float()
  268. y = data['y'].float()
  269. y_len = data['len']
  270. test_batch_size = x.size(0)
  271. rnn.hidden = rnn.init_hidden(test_batch_size)
  272. # generate graphs
  273. max_num_node = int(args.max_num_node)
  274. y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
  275. y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
  276. x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda()
  277. for i in range(max_num_node):
  278. print('finish node',i)
  279. h = rnn(x_step)
  280. y_pred_step = output(h)
  281. y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
  282. x_step = sample_sigmoid_supervised_simple(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time)
  283. y_pred_long[:, i:i + 1, :] = x_step
  284. rnn.hidden = Variable(rnn.hidden.data).cuda()
  285. y_pred_data = y_pred.data
  286. y_pred_long_data = y_pred_long.data.long()
  287. # save graphs as pickle
  288. for i in range(test_batch_size):
  289. adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
  290. G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
  291. G_pred_list.append(G_pred)
  292. return G_pred_list
  293. def train_mlp_forward_epoch(epoch, args, rnn, output, data_loader):
  294. rnn.train()
  295. output.train()
  296. loss_sum = 0
  297. for batch_idx, data in enumerate(data_loader):
  298. rnn.zero_grad()
  299. output.zero_grad()
  300. x_unsorted = data['x'].float()
  301. y_unsorted = data['y'].float()
  302. y_len_unsorted = data['len']
  303. y_len_max = max(y_len_unsorted)
  304. x_unsorted = x_unsorted[:, 0:y_len_max, :]
  305. y_unsorted = y_unsorted[:, 0:y_len_max, :]
  306. # initialize lstm hidden state according to batch size
  307. rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))
  308. # sort input
  309. y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True)
  310. y_len = y_len.numpy().tolist()
  311. x = torch.index_select(x_unsorted,0,sort_index)
  312. y = torch.index_select(y_unsorted,0,sort_index)
  313. x = Variable(x).cuda()
  314. y = Variable(y).cuda()
  315. h = rnn(x, pack=True, input_len=y_len)
  316. y_pred = output(h)
  317. y_pred = F.sigmoid(y_pred)
  318. # clean
  319. y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True)
  320. y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
  321. # use cross entropy loss
  322. loss = 0
  323. for j in range(y.size(1)):
  324. # print('y_pred',y_pred[0,j,:],'y',y[0,j,:])
  325. end_idx = min(j+1,y.size(2))
  326. loss += binary_cross_entropy_weight(y_pred[:,j,0:end_idx], y[:,j,0:end_idx])*end_idx
  327. if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics
  328. print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
  329. epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn))
  330. # logging
  331. log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx)
  332. loss_sum += loss.data[0]
  333. return loss_sum/(batch_idx+1)
  334. ## too complicated, deprecated
  335. # def test_mlp_partial_bfs_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1):
  336. # rnn.eval()
  337. # output.eval()
  338. # G_pred_list = []
  339. # for batch_idx, data in enumerate(data_loader):
  340. # x = data['x'].float()
  341. # y = data['y'].float()
  342. # y_len = data['len']
  343. # test_batch_size = x.size(0)
  344. # rnn.hidden = rnn.init_hidden(test_batch_size)
  345. # # generate graphs
  346. # max_num_node = int(args.max_num_node)
  347. # y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
  348. # y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
  349. # x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda()
  350. # for i in range(max_num_node):
  351. # # 1 back up hidden state
  352. # hidden_prev = Variable(rnn.hidden.data).cuda()
  353. # h = rnn(x_step)
  354. # y_pred_step = output(h)
  355. # y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
  356. # x_step = sample_sigmoid_supervised(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time)
  357. # y_pred_long[:, i:i + 1, :] = x_step
  358. #
  359. # rnn.hidden = Variable(rnn.hidden.data).cuda()
  360. #
  361. # print('finish node', i)
  362. # y_pred_data = y_pred.data
  363. # y_pred_long_data = y_pred_long.data.long()
  364. #
  365. # # save graphs as pickle
  366. # for i in range(test_batch_size):
  367. # adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
  368. # G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
  369. # G_pred_list.append(G_pred)
  370. # return G_pred_list
  371. def train_rnn_epoch(epoch, args, rnn, output, data_loader,
  372. optimizer_rnn, optimizer_output,
  373. scheduler_rnn, scheduler_output):
  374. rnn.train()
  375. output.train()
  376. loss_sum = 0
  377. for batch_idx, data in enumerate(data_loader):
  378. rnn.zero_grad()
  379. output.zero_grad()
  380. x_unsorted = data['x'].float()
  381. y_unsorted = data['y'].float()
  382. y_len_unsorted = data['len']
  383. y_len_max = max(y_len_unsorted)
  384. x_unsorted = x_unsorted[:, 0:y_len_max, :]
  385. y_unsorted = y_unsorted[:, 0:y_len_max, :]
  386. # initialize lstm hidden state according to batch size
  387. rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))
  388. # output.hidden = output.init_hidden(batch_size=x_unsorted.size(0)*x_unsorted.size(1))
  389. # sort input
  390. y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True)
  391. y_len = y_len.numpy().tolist()
  392. x = torch.index_select(x_unsorted,0,sort_index)
  393. y = torch.index_select(y_unsorted,0,sort_index)
  394. # input, output for output rnn module
  395. # a smart use of pytorch builtin function: pack variable--b1_l1,b2_l1,...,b1_l2,b2_l2,...
  396. y_reshape = pack_padded_sequence(y,y_len,batch_first=True).data
  397. # reverse y_reshape, so that their lengths are sorted, add dimension
  398. idx = [i for i in range(y_reshape.size(0)-1, -1, -1)]
  399. idx = torch.LongTensor(idx)
  400. y_reshape = y_reshape.index_select(0, idx)
  401. y_reshape = y_reshape.view(y_reshape.size(0),y_reshape.size(1),1)
  402. output_x = torch.cat((torch.ones(y_reshape.size(0),1,1),y_reshape[:,0:-1,0:1]),dim=1)
  403. output_y = y_reshape
  404. # batch size for output module: sum(y_len)
  405. output_y_len = []
  406. output_y_len_bin = np.bincount(np.array(y_len))
  407. for i in range(len(output_y_len_bin)-1,0,-1):
  408. count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i
  409. output_y_len.extend([min(i,y.size(2))]*count_temp) # put them in output_y_len; max value should not exceed y.size(2)
  410. # pack into variable
  411. x = Variable(x).cuda()
  412. y = Variable(y).cuda()
  413. output_x = Variable(output_x).cuda()
  414. output_y = Variable(output_y).cuda()
  415. # print(output_y_len)
  416. # print('len',len(output_y_len))
  417. # print('y',y.size())
  418. # print('output_y',output_y.size())
  419. # if using ground truth to train
  420. h = rnn(x, pack=True, input_len=y_len)
  421. h = pack_padded_sequence(h,y_len,batch_first=True).data # get packed hidden vector
  422. # reverse h
  423. idx = [i for i in range(h.size(0) - 1, -1, -1)]
  424. idx = Variable(torch.LongTensor(idx)).cuda()
  425. h = h.index_select(0, idx)
  426. hidden_null = Variable(torch.zeros(args.num_layers-1, h.size(0), h.size(1))).cuda()
  427. output.hidden = torch.cat((h.view(1,h.size(0),h.size(1)),hidden_null),dim=0) # num_layers, batch_size, hidden_size
  428. y_pred = output(output_x, pack=True, input_len=output_y_len)
  429. y_pred = F.sigmoid(y_pred)
  430. # clean
  431. y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True)
  432. y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
  433. output_y = pack_padded_sequence(output_y,output_y_len,batch_first=True)
  434. output_y = pad_packed_sequence(output_y,batch_first=True)[0]
  435. # use cross entropy loss
  436. loss = binary_cross_entropy_weight(y_pred, output_y)
  437. loss.backward()
  438. # update deterministic and lstm
  439. optimizer_output.step()
  440. optimizer_rnn.step()
  441. scheduler_output.step()
  442. scheduler_rnn.step()
  443. if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics
  444. print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
  445. epoch, args.epochs,loss.data, args.graph_type, args.num_layers, args.hidden_size_rnn))
  446. # logging
  447. log_value('loss_'+args.fname, loss.data, epoch*args.batch_ratio+batch_idx)
  448. feature_dim = y.size(1)*y.size(2)
  449. loss_sum += loss.data*feature_dim
  450. return loss_sum/(batch_idx+1)
  451. def test_rnn_epoch(epoch, args, rnn, output, test_batch_size=16):
  452. rnn.hidden = rnn.init_hidden(test_batch_size)
  453. rnn.eval()
  454. output.eval()
  455. # generate graphs
  456. max_num_node = int(args.max_num_node)
  457. y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
  458. x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda()
  459. for i in range(max_num_node):
  460. h = rnn(x_step)
  461. # output.hidden = h.permute(1,0,2)
  462. hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(2))).cuda()
  463. output.hidden = torch.cat((h.permute(1,0,2), hidden_null),
  464. dim=0) # num_layers, batch_size, hidden_size
  465. x_step = Variable(torch.zeros(test_batch_size,1,args.max_prev_node)).cuda()
  466. output_x_step = Variable(torch.ones(test_batch_size,1,1)).cuda()
  467. for j in range(min(args.max_prev_node,i+1)):
  468. output_y_pred_step = output(output_x_step)
  469. output_x_step = sample_sigmoid(output_y_pred_step, sample=True, sample_time=1)
  470. x_step[:,:,j:j+1] = output_x_step
  471. output.hidden = Variable(output.hidden.data).cuda()
  472. y_pred_long[:, i:i + 1, :] = x_step
  473. rnn.hidden = Variable(rnn.hidden.data).cuda()
  474. y_pred_long_data = y_pred_long.data.long()
  475. # save graphs as pickle
  476. G_pred_list = []
  477. for i in range(test_batch_size):
  478. adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
  479. G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
  480. G_pred_list.append(G_pred)
  481. return G_pred_list
  482. def train_rnn_forward_epoch(epoch, args, rnn, output, data_loader):
  483. rnn.train()
  484. output.train()
  485. loss_sum = 0
  486. for batch_idx, data in enumerate(data_loader):
  487. rnn.zero_grad()
  488. output.zero_grad()
  489. x_unsorted = data['x'].float()
  490. y_unsorted = data['y'].float()
  491. y_len_unsorted = data['len']
  492. y_len_max = max(y_len_unsorted)
  493. x_unsorted = x_unsorted[:, 0:y_len_max, :]
  494. y_unsorted = y_unsorted[:, 0:y_len_max, :]
  495. # initialize lstm hidden state according to batch size
  496. rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))
  497. # output.hidden = output.init_hidden(batch_size=x_unsorted.size(0)*x_unsorted.size(1))
  498. # sort input
  499. y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True)
  500. y_len = y_len.numpy().tolist()
  501. x = torch.index_select(x_unsorted,0,sort_index)
  502. y = torch.index_select(y_unsorted,0,sort_index)
  503. # input, output for output rnn module
  504. # a smart use of pytorch builtin function: pack variable--b1_l1,b2_l1,...,b1_l2,b2_l2,...
  505. y_reshape = pack_padded_sequence(y,y_len,batch_first=True).data
  506. # reverse y_reshape, so that their lengths are sorted, add dimension
  507. idx = [i for i in range(y_reshape.size(0)-1, -1, -1)]
  508. idx = torch.LongTensor(idx)
  509. y_reshape = y_reshape.index_select(0, idx)
  510. y_reshape = y_reshape.view(y_reshape.size(0),y_reshape.size(1),1)
  511. output_x = torch.cat((torch.ones(y_reshape.size(0),1,1),y_reshape[:,0:-1,0:1]),dim=1)
  512. output_y = y_reshape
  513. # batch size for output module: sum(y_len)
  514. output_y_len = []
  515. output_y_len_bin = np.bincount(np.array(y_len))
  516. for i in range(len(output_y_len_bin)-1,0,-1):
  517. count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i
  518. output_y_len.extend([min(i,y.size(2))]*count_temp) # put them in output_y_len; max value should not exceed y.size(2)
  519. # pack into variable
  520. x = Variable(x).cuda()
  521. y = Variable(y).cuda()
  522. output_x = Variable(output_x).cuda()
  523. output_y = Variable(output_y).cuda()
  524. # print(output_y_len)
  525. # print('len',len(output_y_len))
  526. # print('y',y.size())
  527. # print('output_y',output_y.size())
  528. # if using ground truth to train
  529. h = rnn(x, pack=True, input_len=y_len)
  530. h = pack_padded_sequence(h,y_len,batch_first=True).data # get packed hidden vector
  531. # reverse h
  532. idx = [i for i in range(h.size(0) - 1, -1, -1)]
  533. idx = Variable(torch.LongTensor(idx)).cuda()
  534. h = h.index_select(0, idx)
  535. hidden_null = Variable(torch.zeros(args.num_layers-1, h.size(0), h.size(1))).cuda()
  536. output.hidden = torch.cat((h.view(1,h.size(0),h.size(1)),hidden_null),dim=0) # num_layers, batch_size, hidden_size
  537. y_pred = output(output_x, pack=True, input_len=output_y_len)
  538. y_pred = F.sigmoid(y_pred)
  539. # clean
  540. y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True)
  541. y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
  542. output_y = pack_padded_sequence(output_y,output_y_len,batch_first=True)
  543. output_y = pad_packed_sequence(output_y,batch_first=True)[0]
  544. # use cross entropy loss
  545. loss = binary_cross_entropy_weight(y_pred, output_y)
  546. if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics
  547. print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
  548. epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn))
  549. # logging
  550. log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx)
  551. # print(y_pred.size())
  552. feature_dim = y_pred.size(0)*y_pred.size(1)
  553. loss_sum += loss.data[0]*feature_dim/y.size(0)
  554. return loss_sum/(batch_idx+1)
  555. ########### train function for LSTM + VAE
  556. def train(args, dataset_train, rnn, output):
  557. # check if load existing model
  558. if args.load:
  559. fname = args.model_save_path + args.fname + 'lstm_' + str(args.load_epoch) + '.dat'
  560. rnn.load_state_dict(torch.load(fname))
  561. fname = args.model_save_path + args.fname + 'output_' + str(args.load_epoch) + '.dat'
  562. output.load_state_dict(torch.load(fname))
  563. args.lr = 0.00001
  564. epoch = args.load_epoch
  565. print('model loaded!, lr: {}'.format(args.lr))
  566. else:
  567. epoch = 1
  568. # initialize optimizer
  569. optimizer_rnn = optim.Adam(list(rnn.parameters()), lr=args.lr)
  570. optimizer_output = optim.Adam(list(output.parameters()), lr=args.lr)
  571. scheduler_rnn = MultiStepLR(optimizer_rnn, milestones=args.milestones, gamma=args.lr_rate)
  572. scheduler_output = MultiStepLR(optimizer_output, milestones=args.milestones, gamma=args.lr_rate)
  573. # start main loop
  574. time_all = np.zeros(args.epochs)
  575. while epoch<=args.epochs:
  576. time_start = tm.time()
  577. # train
  578. if 'GraphRNN_VAE' in args.note:
  579. train_vae_epoch(epoch, args, rnn, output, dataset_train,
  580. optimizer_rnn, optimizer_output,
  581. scheduler_rnn, scheduler_output)
  582. elif 'GraphRNN_MLP' in args.note:
  583. train_mlp_epoch(epoch, args, rnn, output, dataset_train,
  584. optimizer_rnn, optimizer_output,
  585. scheduler_rnn, scheduler_output)
  586. elif 'GraphRNN_RNN' in args.note:
  587. train_rnn_epoch(epoch, args, rnn, output, dataset_train,
  588. optimizer_rnn, optimizer_output,
  589. scheduler_rnn, scheduler_output)
  590. time_end = tm.time()
  591. time_all[epoch - 1] = time_end - time_start
  592. # test
  593. if epoch % args.epochs_test == 0 and epoch>=args.epochs_test_start:
  594. for sample_time in range(1,4):
  595. G_pred = []
  596. while len(G_pred)<args.test_total_size:
  597. if 'GraphRNN_VAE' in args.note:
  598. G_pred_step = test_vae_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size,sample_time=sample_time)
  599. elif 'GraphRNN_MLP' in args.note:
  600. G_pred_step = test_mlp_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size,sample_time=sample_time)
  601. elif 'GraphRNN_RNN' in args.note:
  602. G_pred_step = test_rnn_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size)
  603. G_pred.extend(G_pred_step)
  604. # save graphs
  605. fname = args.graph_save_path + args.fname_pred + str(epoch) +'_'+str(sample_time) + '.dat'
  606. save_graph_list(G_pred, fname)
  607. if 'GraphRNN_RNN' in args.note:
  608. break
  609. print('test done, graphs saved')
  610. # save model checkpoint
  611. if args.save:
  612. if epoch % args.epochs_save == 0:
  613. fname = args.model_save_path + args.fname + 'lstm_' + str(epoch) + '.dat'
  614. torch.save(rnn.state_dict(), fname)
  615. fname = args.model_save_path + args.fname + 'output_' + str(epoch) + '.dat'
  616. torch.save(output.state_dict(), fname)
  617. epoch += 1
  618. np.save(args.timing_save_path+args.fname,time_all)
  619. ########### for graph completion task
  620. def train_graph_completion(args, dataset_test, rnn, output):
  621. fname = args.model_save_path + args.fname + 'lstm_' + str(args.load_epoch) + '.dat'
  622. rnn.load_state_dict(torch.load(fname))
  623. fname = args.model_save_path + args.fname + 'output_' + str(args.load_epoch) + '.dat'
  624. output.load_state_dict(torch.load(fname))
  625. epoch = args.load_epoch
  626. print('model loaded!, epoch: {}'.format(args.load_epoch))
  627. for sample_time in range(1,4):
  628. if 'GraphRNN_MLP' in args.note:
  629. G_pred = test_mlp_partial_simple_epoch(epoch, args, rnn, output, dataset_test,sample_time=sample_time)
  630. if 'GraphRNN_VAE' in args.note:
  631. G_pred = test_vae_partial_epoch(epoch, args, rnn, output, dataset_test,sample_time=sample_time)
  632. # save graphs
  633. fname = args.graph_save_path + args.fname_pred + str(epoch) +'_'+str(sample_time) + 'graph_completion.dat'
  634. save_graph_list(G_pred, fname)
  635. print('graph completion done, graphs saved')
  636. ########### for NLL evaluation
  637. def train_nll(args, dataset_train, dataset_test, rnn, output,graph_validate_len,graph_test_len, max_iter = 1000):
  638. fname = args.model_save_path + args.fname + 'lstm_' + str(args.load_epoch) + '.dat'
  639. rnn.load_state_dict(torch.load(fname))
  640. fname = args.model_save_path + args.fname + 'output_' + str(args.load_epoch) + '.dat'
  641. output.load_state_dict(torch.load(fname))
  642. epoch = args.load_epoch
  643. print('model loaded!, epoch: {}'.format(args.load_epoch))
  644. fname_output = args.nll_save_path + args.note + '_' + args.graph_type + '.csv'
  645. with open(fname_output, 'w+') as f:
  646. f.write(str(graph_validate_len)+','+str(graph_test_len)+'\n')
  647. f.write('train,test\n')
  648. for iter in range(max_iter):
  649. if 'GraphRNN_MLP' in args.note:
  650. nll_train = train_mlp_forward_epoch(epoch, args, rnn, output, dataset_train)
  651. nll_test = train_mlp_forward_epoch(epoch, args, rnn, output, dataset_test)
  652. if 'GraphRNN_RNN' in args.note:
  653. nll_train = train_rnn_forward_epoch(epoch, args, rnn, output, dataset_train)
  654. nll_test = train_rnn_forward_epoch(epoch, args, rnn, output, dataset_test)
  655. print('train',nll_train,'test',nll_test)
  656. f.write(str(nll_train)+','+str(nll_test)+'\n')
  657. print('NLL evaluation done')