123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760 |
- import networkx as nx
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.init as init
- from torch.autograd import Variable
- import matplotlib.pyplot as plt
- import torch.nn.functional as F
- from torch import optim
- from torch.optim.lr_scheduler import MultiStepLR
- from sklearn.decomposition import PCA
- import logging
- from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
- from time import gmtime, strftime
- from sklearn.metrics import roc_curve
- from sklearn.metrics import roc_auc_score
- from sklearn.metrics import average_precision_score
- from random import shuffle
- import pickle
- from tensorboard_logger import configure, log_value
- import scipy.misc
- import time as tm
-
- from tensorboard_logger import log_value
-
- from data import *
-
-
- def train_vae_epoch(epoch, args, rnn, output, data_loader,
- optimizer_rnn, optimizer_output,
- scheduler_rnn, scheduler_output):
- rnn.train()
- output.train()
- loss_sum = 0
- for batch_idx, data in enumerate(data_loader):
- rnn.zero_grad()
- output.zero_grad()
- x_unsorted = data['x'].float()
- y_unsorted = data['y'].float()
- y_len_unsorted = data['len']
- y_len_max = max(y_len_unsorted)
- x_unsorted = x_unsorted[:, 0:y_len_max, :]
- y_unsorted = y_unsorted[:, 0:y_len_max, :]
- # initialize lstm hidden state according to batch size
- rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))
-
- # sort input
- y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True)
- y_len = y_len.numpy().tolist()
- x = torch.index_select(x_unsorted, 0, sort_index)
- y = torch.index_select(y_unsorted, 0, sort_index)
- x = Variable(x).cuda()
- y = Variable(y).cuda()
-
- # if using ground truth to train
- h = rnn(x, pack=True, input_len=y_len)
- y_pred, z_mu, z_lsgms = output(h)
- y_pred = F.sigmoid(y_pred)
- # clean
- y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True)
- y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
- z_mu = pack_padded_sequence(z_mu, y_len, batch_first=True)
- z_mu = pad_packed_sequence(z_mu, batch_first=True)[0]
- z_lsgms = pack_padded_sequence(z_lsgms, y_len, batch_first=True)
- z_lsgms = pad_packed_sequence(z_lsgms, batch_first=True)[0]
- # use cross entropy loss
- loss_bce = binary_cross_entropy_weight(y_pred, y)
- loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp())
- loss_kl /= y.size(0) * y.size(1) * sum(y_len) # normalize
- loss = loss_bce + loss_kl
- loss.backward()
- # update deterministic and lstm
- optimizer_output.step()
- optimizer_rnn.step()
- scheduler_output.step()
- scheduler_rnn.step()
-
- z_mu_mean = torch.mean(z_mu.data)
- z_sgm_mean = torch.mean(z_lsgms.mul(0.5).exp_().data)
- z_mu_min = torch.min(z_mu.data)
- z_sgm_min = torch.min(z_lsgms.mul(0.5).exp_().data)
- z_mu_max = torch.max(z_mu.data)
- z_sgm_max = torch.max(z_lsgms.mul(0.5).exp_().data)
-
- if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics
- print(
- 'Epoch: {}/{}, train bce loss: {:.6f}, train kl loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
- epoch, args.epochs, loss_bce.data[0], loss_kl.data[0], args.graph_type, args.num_layers,
- args.hidden_size_rnn))
- print('z_mu_mean', z_mu_mean, 'z_mu_min', z_mu_min, 'z_mu_max', z_mu_max, 'z_sgm_mean', z_sgm_mean,
- 'z_sgm_min', z_sgm_min, 'z_sgm_max', z_sgm_max)
-
- # logging
- log_value('bce_loss_' + args.fname, loss_bce.data[0], epoch * args.batch_ratio + batch_idx)
- log_value('kl_loss_' + args.fname, loss_kl.data[0], epoch * args.batch_ratio + batch_idx)
- log_value('z_mu_mean_' + args.fname, z_mu_mean, epoch * args.batch_ratio + batch_idx)
- log_value('z_mu_min_' + args.fname, z_mu_min, epoch * args.batch_ratio + batch_idx)
- log_value('z_mu_max_' + args.fname, z_mu_max, epoch * args.batch_ratio + batch_idx)
- log_value('z_sgm_mean_' + args.fname, z_sgm_mean, epoch * args.batch_ratio + batch_idx)
- log_value('z_sgm_min_' + args.fname, z_sgm_min, epoch * args.batch_ratio + batch_idx)
- log_value('z_sgm_max_' + args.fname, z_sgm_max, epoch * args.batch_ratio + batch_idx)
-
- loss_sum += loss.data[0]
- return loss_sum / (batch_idx + 1)
-
-
- def test_vae_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False, sample_time=1):
- rnn.hidden = rnn.init_hidden(test_batch_size)
- rnn.eval()
- output.eval()
-
- # generate graphs
- max_num_node = int(args.max_num_node)
- y_pred = Variable(
- torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
- y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
- x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda()
- for i in range(max_num_node):
- h = rnn(x_step)
- y_pred_step, _, _ = output(h)
- y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
- x_step = sample_sigmoid(y_pred_step, sample=True, sample_time=sample_time)
- y_pred_long[:, i:i + 1, :] = x_step
- rnn.hidden = Variable(rnn.hidden.data).cuda()
- y_pred_data = y_pred.data
- y_pred_long_data = y_pred_long.data.long()
-
- # save graphs as pickle
- G_pred_list = []
- for i in range(test_batch_size):
- adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
- G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
- G_pred_list.append(G_pred)
-
- # save prediction histograms, plot histogram over each time step
- # if save_histogram:
- # save_prediction_histogram(y_pred_data.cpu().numpy(),
- # fname_pred=args.figure_prediction_save_path+args.fname_pred+str(epoch)+'.jpg',
- # max_num_node=max_num_node)
-
- return G_pred_list
-
-
- def test_vae_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False, sample_time=1):
- rnn.eval()
- output.eval()
- G_pred_list = []
- for batch_idx, data in enumerate(data_loader):
- x = data['x'].float()
- y = data['y'].float()
- y_len = data['len']
- test_batch_size = x.size(0)
- rnn.hidden = rnn.init_hidden(test_batch_size)
- # generate graphs
- max_num_node = int(args.max_num_node)
- y_pred = Variable(
- torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
- y_pred_long = Variable(
- torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
- x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda()
- for i in range(max_num_node):
- print('finish node', i)
- h = rnn(x_step)
- y_pred_step, _, _ = output(h)
- y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
- x_step = sample_sigmoid_supervised(y_pred_step, y[:, i:i + 1, :].cuda(), current=i, y_len=y_len,
- sample_time=sample_time)
-
- y_pred_long[:, i:i + 1, :] = x_step
- rnn.hidden = Variable(rnn.hidden.data).cuda()
- y_pred_data = y_pred.data
- y_pred_long_data = y_pred_long.data.long()
-
- # save graphs as pickle
- for i in range(test_batch_size):
- adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
- G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
- G_pred_list.append(G_pred)
- return G_pred_list
-
-
- def train_mlp_epoch(epoch, args, rnn, output, data_loader,
- optimizer_rnn, optimizer_output,
- scheduler_rnn, scheduler_output):
- rnn.train()
- output.train()
- loss_sum = 0
- for batch_idx, data in enumerate(data_loader):
- rnn.zero_grad()
- output.zero_grad()
- x_unsorted = data['x'].float()
- y_unsorted = data['y'].float()
- y_len_unsorted = data['len']
- y_len_max = max(y_len_unsorted)
- x_unsorted = x_unsorted[:, 0:y_len_max, :]
- y_unsorted = y_unsorted[:, 0:y_len_max, :]
- # initialize lstm hidden state according to batch size
- rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))
-
- # sort input
- y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True)
- y_len = y_len.numpy().tolist()
- x = torch.index_select(x_unsorted, 0, sort_index)
- y = torch.index_select(y_unsorted, 0, sort_index)
- x = Variable(x).cuda()
- y = Variable(y).cuda()
-
- h = rnn(x, pack=True, input_len=y_len)
- y_pred = output(h)
- y_pred = F.sigmoid(y_pred)
- # clean
- y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True)
- y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
- # use cross entropy loss
- loss = binary_cross_entropy_weight(y_pred, y)
- loss.backward()
- # update deterministic and lstm
- optimizer_output.step()
- optimizer_rnn.step()
- scheduler_output.step()
- scheduler_rnn.step()
-
- if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics
- print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
- epoch, args.epochs, loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn))
-
- # logging
- log_value('loss_' + args.fname, loss.data[0], epoch * args.batch_ratio + batch_idx)
-
- loss_sum += loss.data[0]
- return loss_sum / (batch_idx + 1)
-
-
- def test_mlp_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False, sample_time=1):
- rnn.hidden = rnn.init_hidden(test_batch_size)
- rnn.eval()
- output.eval()
-
- # generate graphs
- max_num_node = int(args.max_num_node)
- y_pred = Variable(
- torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
- y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
- x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda()
- for i in range(max_num_node):
- h = rnn(x_step)
- y_pred_step = output(h)
- y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
- x_step = sample_sigmoid(y_pred_step, sample=True, sample_time=sample_time)
- y_pred_long[:, i:i + 1, :] = x_step
- rnn.hidden = Variable(rnn.hidden.data).cuda()
- y_pred_data = y_pred.data
- y_pred_long_data = y_pred_long.data.long()
-
- # save graphs as pickle
- G_pred_list = []
- for i in range(test_batch_size):
- adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
- G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
- G_pred_list.append(G_pred)
-
- # # save prediction histograms, plot histogram over each time step
- # if save_histogram:
- # save_prediction_histogram(y_pred_data.cpu().numpy(),
- # fname_pred=args.figure_prediction_save_path+args.fname_pred+str(epoch)+'.jpg',
- # max_num_node=max_num_node)
- return G_pred_list
-
-
- def test_mlp_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False, sample_time=1):
- rnn.eval()
- output.eval()
- G_pred_list = []
- for batch_idx, data in enumerate(data_loader):
- x = data['x'].float()
- y = data['y'].float()
- y_len = data['len']
- test_batch_size = x.size(0)
- rnn.hidden = rnn.init_hidden(test_batch_size)
- # generate graphs
- max_num_node = int(args.max_num_node)
- y_pred = Variable(
- torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
- y_pred_long = Variable(
- torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
- x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda()
- for i in range(max_num_node):
- print('finish node', i)
- h = rnn(x_step)
- y_pred_step = output(h)
- y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
- x_step = sample_sigmoid_supervised(y_pred_step, y[:, i:i + 1, :].cuda(), current=i, y_len=y_len,
- sample_time=sample_time)
-
- y_pred_long[:, i:i + 1, :] = x_step
- rnn.hidden = Variable(rnn.hidden.data).cuda()
- y_pred_data = y_pred.data
- y_pred_long_data = y_pred_long.data.long()
-
- # save graphs as pickle
- for i in range(test_batch_size):
- adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
- G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
- G_pred_list.append(G_pred)
- return G_pred_list
-
-
- def test_mlp_partial_simple_epoch(epoch, args, rnn, output, data_loader, save_histogram=False, sample_time=1):
- rnn.eval()
- output.eval()
- G_pred_list = []
- for batch_idx, data in enumerate(data_loader):
- x = data['x'].float()
- y = data['y'].float()
- y_len = data['len']
- test_batch_size = x.size(0)
- rnn.hidden = rnn.init_hidden(test_batch_size)
- # generate graphs
- max_num_node = int(args.max_num_node)
- y_pred = Variable(
- torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
- y_pred_long = Variable(
- torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
- x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda()
- for i in range(max_num_node):
- print('finish node', i)
- h = rnn(x_step)
- y_pred_step = output(h)
- y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
- x_step = sample_sigmoid_supervised_simple(y_pred_step, y[:, i:i + 1, :].cuda(), current=i, y_len=y_len,
- sample_time=sample_time)
-
- y_pred_long[:, i:i + 1, :] = x_step
- rnn.hidden = Variable(rnn.hidden.data).cuda()
- y_pred_data = y_pred.data
- y_pred_long_data = y_pred_long.data.long()
-
- # save graphs as pickle
- for i in range(test_batch_size):
- adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
- G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
- G_pred_list.append(G_pred)
- return G_pred_list
-
-
- def train_mlp_forward_epoch(epoch, args, rnn, output, data_loader):
- rnn.train()
- output.train()
- loss_sum = 0
- for batch_idx, data in enumerate(data_loader):
- rnn.zero_grad()
- output.zero_grad()
- x_unsorted = data['x'].float()
- y_unsorted = data['y'].float()
- y_len_unsorted = data['len']
- y_len_max = max(y_len_unsorted)
- x_unsorted = x_unsorted[:, 0:y_len_max, :]
- y_unsorted = y_unsorted[:, 0:y_len_max, :]
- # initialize lstm hidden state according to batch size
- rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))
-
- # sort input
- y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True)
- y_len = y_len.numpy().tolist()
- x = torch.index_select(x_unsorted, 0, sort_index)
- y = torch.index_select(y_unsorted, 0, sort_index)
- x = Variable(x).cuda()
- y = Variable(y).cuda()
-
- h = rnn(x, pack=True, input_len=y_len)
- y_pred = output(h)
- y_pred = F.sigmoid(y_pred)
- # clean
- y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True)
- y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
- # use cross entropy loss
-
- loss = 0
- for j in range(y.size(1)):
- # print('y_pred',y_pred[0,j,:],'y',y[0,j,:])
- end_idx = min(j + 1, y.size(2))
- loss += binary_cross_entropy_weight(y_pred[:, j, 0:end_idx], y[:, j, 0:end_idx]) * end_idx
-
- if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics
- print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
- epoch, args.epochs, loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn))
-
- # logging
- log_value('loss_' + args.fname, loss.data[0], epoch * args.batch_ratio + batch_idx)
-
- loss_sum += loss.data[0]
- return loss_sum / (batch_idx + 1)
-
-
- ## too complicated, deprecated
- # def test_mlp_partial_bfs_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1):
- # rnn.eval()
- # output.eval()
- # G_pred_list = []
- # for batch_idx, data in enumerate(data_loader):
- # x = data['x'].float()
- # y = data['y'].float()
- # y_len = data['len']
- # test_batch_size = x.size(0)
- # rnn.hidden = rnn.init_hidden(test_batch_size)
- # # generate graphs
- # max_num_node = int(args.max_num_node)
- # y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
- # y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
- # x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda()
- # for i in range(max_num_node):
- # # 1 back up hidden state
- # hidden_prev = Variable(rnn.hidden.data).cuda()
- # h = rnn(x_step)
- # y_pred_step = output(h)
- # y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
- # x_step = sample_sigmoid_supervised(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time)
- # y_pred_long[:, i:i + 1, :] = x_step
- #
- # rnn.hidden = Variable(rnn.hidden.data).cuda()
- #
- # print('finish node', i)
- # y_pred_data = y_pred.data
- # y_pred_long_data = y_pred_long.data.long()
- #
- # # save graphs as pickle
- # for i in range(test_batch_size):
- # adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
- # G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
- # G_pred_list.append(G_pred)
- # return G_pred_list
-
-
- def train_rnn_epoch(epoch, args, rnn, output, data_loader,
- optimizer_rnn, optimizer_output,
- scheduler_rnn, scheduler_output):
- rnn.train()
- output.train()
- loss_sum = 0
- for batch_idx, data in enumerate(data_loader):
- rnn.zero_grad()
- output.zero_grad()
- x_unsorted = data['x'].float()
- y_unsorted = data['y'].float()
- y_len_unsorted = data['len']
- y_len_max = max(y_len_unsorted)
- x_unsorted = x_unsorted[:, 0:y_len_max, :]
- y_unsorted = y_unsorted[:, 0:y_len_max, :]
- # initialize lstm hidden state according to batch size
- rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))
- # output.hidden = output.init_hidden(batch_size=x_unsorted.size(0)*x_unsorted.size(1))
-
- # sort input
- y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True)
- y_len = y_len.numpy().tolist()
- x = torch.index_select(x_unsorted, 0, sort_index)
- y = torch.index_select(y_unsorted, 0, sort_index)
-
- # input, output for output rnn module
- # a smart use of pytorch builtin function: pack variable--b1_l1,b2_l1,...,b1_l2,b2_l2,...
- y_reshape = pack_padded_sequence(y, y_len, batch_first=True).data
- # reverse y_reshape, so that their lengths are sorted, add dimension
- idx = [i for i in range(y_reshape.size(0) - 1, -1, -1)]
- idx = torch.LongTensor(idx)
- y_reshape = y_reshape.index_select(0, idx)
- y_reshape = y_reshape.view(y_reshape.size(0), y_reshape.size(1), 1)
-
- output_x = torch.cat((torch.ones(y_reshape.size(0), 1, 1), y_reshape[:, 0:-1, 0:1]), dim=1)
- output_y = y_reshape
- # batch size for output module: sum(y_len)
- output_y_len = []
- output_y_len_bin = np.bincount(np.array(y_len))
- for i in range(len(output_y_len_bin) - 1, 0, -1):
- count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i
- output_y_len.extend(
- [min(i, y.size(2))] * count_temp) # put them in output_y_len; max value should not exceed y.size(2)
- # pack into variable
- x = Variable(x).cuda()
- y = Variable(y).cuda()
- output_x = Variable(output_x).cuda()
- output_y = Variable(output_y).cuda()
- # print(output_y_len)
- # print('len',len(output_y_len))
- # print('y',y.size())
- # print('output_y',output_y.size())
-
- # if using ground truth to train
- h = rnn(x, pack=True, input_len=y_len)
- h = pack_padded_sequence(h, y_len, batch_first=True).data # get packed hidden vector
- # reverse h
- idx = [i for i in range(h.size(0) - 1, -1, -1)]
- idx = Variable(torch.LongTensor(idx)).cuda()
- h = h.index_select(0, idx)
- hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(1))).cuda()
- output.hidden = torch.cat((h.view(1, h.size(0), h.size(1)), hidden_null),
- dim=0) # num_layers, batch_size, hidden_size
- y_pred = output(output_x, pack=True, input_len=output_y_len)
- y_pred = F.sigmoid(y_pred)
- # clean
- y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True)
- y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
- output_y = pack_padded_sequence(output_y, output_y_len, batch_first=True)
- output_y = pad_packed_sequence(output_y, batch_first=True)[0]
- # use cross entropy loss
- loss = binary_cross_entropy_weight(y_pred, output_y)
- loss.backward()
- # update deterministic and lstm
- optimizer_output.step()
- optimizer_rnn.step()
- scheduler_output.step()
- scheduler_rnn.step()
-
- if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics
- print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
- epoch, args.epochs, loss.data, args.graph_type, args.num_layers, args.hidden_size_rnn))
-
- # logging
- log_value('loss_' + args.fname, loss.data, epoch * args.batch_ratio + batch_idx)
- feature_dim = y.size(1) * y.size(2)
- loss_sum += loss.data * feature_dim
- return loss_sum / (batch_idx + 1)
-
-
- def test_rnn_epoch(epoch, args, rnn, output, test_batch_size=16):
- rnn.hidden = rnn.init_hidden(test_batch_size)
- rnn.eval()
- output.eval()
-
- # generate graphs
- max_num_node = int(args.max_num_node)
- y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
- x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda()
- for i in range(max_num_node):
- h = rnn(x_step)
- # output.hidden = h.permute(1,0,2)
- hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(2))).cuda()
- output.hidden = torch.cat((h.permute(1, 0, 2), hidden_null),
- dim=0) # num_layers, batch_size, hidden_size
- x_step = Variable(torch.zeros(test_batch_size, 1, args.max_prev_node)).cuda()
- output_x_step = Variable(torch.ones(test_batch_size, 1, 1)).cuda()
- for j in range(min(args.max_prev_node, i + 1)):
- output_y_pred_step = output(output_x_step)
- output_x_step = sample_sigmoid(output_y_pred_step, sample=True, sample_time=1)
- x_step[:, :, j:j + 1] = output_x_step
- output.hidden = Variable(output.hidden.data).cuda()
- y_pred_long[:, i:i + 1, :] = x_step
- rnn.hidden = Variable(rnn.hidden.data).cuda()
- y_pred_long_data = y_pred_long.data.long()
-
- # save graphs as pickle
- G_pred_list = []
- for i in range(test_batch_size):
- adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
- G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
- G_pred_list.append(G_pred)
-
- return G_pred_list
-
-
- def train_rnn_forward_epoch(epoch, args, rnn, output, data_loader):
- rnn.train()
- output.train()
- loss_sum = 0
- for batch_idx, data in enumerate(data_loader):
- rnn.zero_grad()
- output.zero_grad()
- x_unsorted = data['x'].float()
- y_unsorted = data['y'].float()
- y_len_unsorted = data['len']
- y_len_max = max(y_len_unsorted)
- x_unsorted = x_unsorted[:, 0:y_len_max, :]
- y_unsorted = y_unsorted[:, 0:y_len_max, :]
- # initialize lstm hidden state according to batch size
- rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))
- # output.hidden = output.init_hidden(batch_size=x_unsorted.size(0)*x_unsorted.size(1))
-
- # sort input
- y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True)
- y_len = y_len.numpy().tolist()
- x = torch.index_select(x_unsorted, 0, sort_index)
- y = torch.index_select(y_unsorted, 0, sort_index)
-
- # input, output for output rnn module
- # a smart use of pytorch builtin function: pack variable--b1_l1,b2_l1,...,b1_l2,b2_l2,...
- y_reshape = pack_padded_sequence(y, y_len, batch_first=True).data
- # reverse y_reshape, so that their lengths are sorted, add dimension
- idx = [i for i in range(y_reshape.size(0) - 1, -1, -1)]
- idx = torch.LongTensor(idx)
- y_reshape = y_reshape.index_select(0, idx)
- y_reshape = y_reshape.view(y_reshape.size(0), y_reshape.size(1), 1)
-
- output_x = torch.cat((torch.ones(y_reshape.size(0), 1, 1), y_reshape[:, 0:-1, 0:1]), dim=1)
- output_y = y_reshape
- # batch size for output module: sum(y_len)
- output_y_len = []
- output_y_len_bin = np.bincount(np.array(y_len))
- for i in range(len(output_y_len_bin) - 1, 0, -1):
- count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i
- output_y_len.extend(
- [min(i, y.size(2))] * count_temp) # put them in output_y_len; max value should not exceed y.size(2)
- # pack into variable
- x = Variable(x).cuda()
- y = Variable(y).cuda()
- output_x = Variable(output_x).cuda()
- output_y = Variable(output_y).cuda()
- # print(output_y_len)
- # print('len',len(output_y_len))
- # print('y',y.size())
- # print('output_y',output_y.size())
-
- # if using ground truth to train
- h = rnn(x, pack=True, input_len=y_len)
- h = pack_padded_sequence(h, y_len, batch_first=True).data # get packed hidden vector
- # reverse h
- idx = [i for i in range(h.size(0) - 1, -1, -1)]
- idx = Variable(torch.LongTensor(idx)).cuda()
- h = h.index_select(0, idx)
- hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(1))).cuda()
- output.hidden = torch.cat((h.view(1, h.size(0), h.size(1)), hidden_null),
- dim=0) # num_layers, batch_size, hidden_size
- y_pred = output(output_x, pack=True, input_len=output_y_len)
- y_pred = F.sigmoid(y_pred)
- # clean
- y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True)
- y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
- output_y = pack_padded_sequence(output_y, output_y_len, batch_first=True)
- output_y = pad_packed_sequence(output_y, batch_first=True)[0]
- # use cross entropy loss
- loss = binary_cross_entropy_weight(y_pred, output_y)
-
- if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics
- print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
- epoch, args.epochs, loss.data, args.graph_type, args.num_layers, args.hidden_size_rnn))
-
- # logging
- log_value('loss_' + args.fname, loss.data, epoch * args.batch_ratio + batch_idx)
- # print(y_pred.size())
- feature_dim = y_pred.size(0) * y_pred.size(1)
- loss_sum += loss.data * feature_dim / y.size(0)
- return loss_sum / (batch_idx + 1)
-
-
- ########### train function for LSTM + VAE
- def train(args, dataset_train, rnn, output):
- # check if load existing model
- if args.load:
- fname = args.model_save_path + args.fname + 'lstm_' + str(args.load_epoch) + '.dat'
- rnn.load_state_dict(torch.load(fname))
- fname = args.model_save_path + args.fname + 'output_' + str(args.load_epoch) + '.dat'
- output.load_state_dict(torch.load(fname))
-
- args.lr = 0.00001
- epoch = args.load_epoch
- print('model loaded!, lr: {}'.format(args.lr))
- else:
- epoch = 1
-
- # initialize optimizer
- optimizer_rnn = optim.Adam(list(rnn.parameters()), lr=args.lr)
- optimizer_output = optim.Adam(list(output.parameters()), lr=args.lr)
-
- scheduler_rnn = MultiStepLR(optimizer_rnn, milestones=args.milestones, gamma=args.lr_rate)
- scheduler_output = MultiStepLR(optimizer_output, milestones=args.milestones, gamma=args.lr_rate)
-
- # start main loop
- time_all = np.zeros(args.epochs)
- while epoch <= args.epochs:
- time_start = tm.time()
- # train
- if 'GraphRNN_VAE' in args.note:
- train_vae_epoch(epoch, args, rnn, output, dataset_train,
- optimizer_rnn, optimizer_output,
- scheduler_rnn, scheduler_output)
- elif 'GraphRNN_MLP' in args.note:
- train_mlp_epoch(epoch, args, rnn, output, dataset_train,
- optimizer_rnn, optimizer_output,
- scheduler_rnn, scheduler_output)
- elif 'GraphRNN_RNN' in args.note:
- train_rnn_epoch(epoch, args, rnn, output, dataset_train,
- optimizer_rnn, optimizer_output,
- scheduler_rnn, scheduler_output)
- time_end = tm.time()
- time_all[epoch - 1] = time_end - time_start
- # test
- if epoch % args.epochs_test == 0 and epoch >= args.epochs_test_start:
- for sample_time in range(1, 4):
- G_pred = []
- while len(G_pred) < args.test_total_size:
- if 'GraphRNN_VAE' in args.note:
- G_pred_step = test_vae_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size,
- sample_time=sample_time)
- elif 'GraphRNN_MLP' in args.note:
- G_pred_step = test_mlp_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size,
- sample_time=sample_time)
- elif 'GraphRNN_RNN' in args.note:
- G_pred_step = test_rnn_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size)
- G_pred.extend(G_pred_step)
- # save graphs
- fname = args.graph_save_path + args.fname_pred + str(epoch) + '_' + str(sample_time) + '.dat'
- save_graph_list(G_pred, fname)
- if 'GraphRNN_RNN' in args.note:
- break
- print('test done, graphs saved')
-
- # save model checkpoint
- if args.save:
- if epoch % args.epochs_save == 0:
- fname = args.model_save_path + args.fname + 'lstm_' + str(epoch) + '.dat'
- torch.save(rnn.state_dict(), fname)
- fname = args.model_save_path + args.fname + 'output_' + str(epoch) + '.dat'
- torch.save(output.state_dict(), fname)
- epoch += 1
- np.save(args.timing_save_path + args.fname, time_all)
-
-
- ########### for graph completion task
- def train_graph_completion(args, dataset_test, rnn, output):
- fname = args.model_save_path + args.fname + 'lstm_' + str(args.load_epoch) + '.dat'
- rnn.load_state_dict(torch.load(fname))
- fname = args.model_save_path + args.fname + 'output_' + str(args.load_epoch) + '.dat'
- output.load_state_dict(torch.load(fname))
-
- epoch = args.load_epoch
- print('model loaded!, epoch: {}'.format(args.load_epoch))
-
- for sample_time in range(1, 4):
- if 'GraphRNN_MLP' in args.note:
- G_pred = test_mlp_partial_simple_epoch(epoch, args, rnn, output, dataset_test, sample_time=sample_time)
- if 'GraphRNN_VAE' in args.note:
- G_pred = test_vae_partial_epoch(epoch, args, rnn, output, dataset_test, sample_time=sample_time)
- # save graphs
- fname = args.graph_save_path + args.fname_pred + str(epoch) + '_' + str(sample_time) + 'graph_completion.dat'
- save_graph_list(G_pred, fname)
- print('graph completion done, graphs saved')
-
-
- ########### for NLL evaluation
- def train_nll(args, dataset_train, dataset_test, rnn, output, graph_validate_len, graph_test_len, max_iter=1000):
- fname = args.model_save_path + args.fname + 'lstm_' + str(args.load_epoch) + '.dat'
- rnn.load_state_dict(torch.load(fname))
- fname = args.model_save_path + args.fname + 'output_' + str(args.load_epoch) + '.dat'
- output.load_state_dict(torch.load(fname))
-
- epoch = args.load_epoch
- print('model loaded!, epoch: {}'.format(args.load_epoch))
- fname_output = args.nll_save_path + args.note + '_' + args.graph_type + '.csv'
- with open(fname_output, 'w+') as f:
- f.write(str(graph_validate_len) + ',' + str(graph_test_len) + '\n')
- f.write('train,test\n')
- for iter in range(max_iter):
- if 'GraphRNN_MLP' in args.note:
- nll_train = train_mlp_forward_epoch(epoch, args, rnn, output, dataset_train)
- nll_test = train_mlp_forward_epoch(epoch, args, rnn, output, dataset_test)
- if 'GraphRNN_RNN' in args.note:
- nll_train = train_rnn_forward_epoch(epoch, args, rnn, output, dataset_train)
- nll_test = train_rnn_forward_epoch(epoch, args, rnn, output, dataset_test)
- print('train', nll_train, 'test', nll_test)
- f.write(str(nll_train) + ',' + str(nll_test) + '\n')
-
- print('NLL evaluation done')
|