import networkx as nx import numpy as np import torch import torch.nn as nn import torch.nn.init as init from torch.autograd import Variable import matplotlib.pyplot as plt import torch.nn.functional as F from torch import optim from torch.optim.lr_scheduler import MultiStepLR from sklearn.decomposition import PCA import logging from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence from time import gmtime, strftime from sklearn.metrics import roc_curve from sklearn.metrics import roc_auc_score from sklearn.metrics import average_precision_score from random import shuffle import pickle from tensorboard_logger import configure, log_value import scipy.misc import time as tm from utils import * from model import * from data import * from args import Args import create_graphs def train_vae_epoch(epoch, args, rnn, output, data_loader, optimizer_rnn, optimizer_output, scheduler_rnn, scheduler_output): rnn.train() output.train() loss_sum = 0 for batch_idx, data in enumerate(data_loader): rnn.zero_grad() output.zero_grad() x_unsorted = data['x'].float() y_unsorted = data['y'].float() y_len_unsorted = data['len'] y_len_max = max(y_len_unsorted) x_unsorted = x_unsorted[:, 0:y_len_max, :] y_unsorted = y_unsorted[:, 0:y_len_max, :] # initialize lstm hidden state according to batch size rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) # sort input y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) y_len = y_len.numpy().tolist() x = torch.index_select(x_unsorted,0,sort_index) y = torch.index_select(y_unsorted,0,sort_index) x = Variable(x).cuda() y = Variable(y).cuda() # if using ground truth to train h = rnn(x, pack=True, input_len=y_len) y_pred,z_mu,z_lsgms = output(h) y_pred = F.sigmoid(y_pred) # clean y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True) y_pred = pad_packed_sequence(y_pred, batch_first=True)[0] z_mu = pack_padded_sequence(z_mu, y_len, batch_first=True) z_mu = pad_packed_sequence(z_mu, batch_first=True)[0] z_lsgms = pack_padded_sequence(z_lsgms, y_len, batch_first=True) z_lsgms = pad_packed_sequence(z_lsgms, batch_first=True)[0] # use cross entropy loss loss_bce = binary_cross_entropy_weight(y_pred, y) loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp()) loss_kl /= y.size(0)*y.size(1)*sum(y_len) # normalize loss = loss_bce + loss_kl loss.backward() # update deterministic and lstm optimizer_output.step() optimizer_rnn.step() scheduler_output.step() scheduler_rnn.step() z_mu_mean = torch.mean(z_mu.data) z_sgm_mean = torch.mean(z_lsgms.mul(0.5).exp_().data) z_mu_min = torch.min(z_mu.data) z_sgm_min = torch.min(z_lsgms.mul(0.5).exp_().data) z_mu_max = torch.max(z_mu.data) z_sgm_max = torch.max(z_lsgms.mul(0.5).exp_().data) if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics print('Epoch: {}/{}, train bce loss: {:.6f}, train kl loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( epoch, args.epochs,loss_bce.data[0], loss_kl.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) print('z_mu_mean', z_mu_mean, 'z_mu_min', z_mu_min, 'z_mu_max', z_mu_max, 'z_sgm_mean', z_sgm_mean, 'z_sgm_min', z_sgm_min, 'z_sgm_max', z_sgm_max) # logging log_value('bce_loss_'+args.fname, loss_bce.data[0], epoch*args.batch_ratio+batch_idx) log_value('kl_loss_' +args.fname, loss_kl.data[0], epoch*args.batch_ratio + batch_idx) log_value('z_mu_mean_'+args.fname, z_mu_mean, epoch*args.batch_ratio + batch_idx) log_value('z_mu_min_'+args.fname, z_mu_min, epoch*args.batch_ratio + batch_idx) log_value('z_mu_max_'+args.fname, z_mu_max, epoch*args.batch_ratio + batch_idx) log_value('z_sgm_mean_'+args.fname, z_sgm_mean, epoch*args.batch_ratio + batch_idx) log_value('z_sgm_min_'+args.fname, z_sgm_min, epoch*args.batch_ratio + batch_idx) log_value('z_sgm_max_'+args.fname, z_sgm_max, epoch*args.batch_ratio + batch_idx) loss_sum += loss.data[0] return loss_sum/(batch_idx+1) def test_vae_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False, sample_time = 1): rnn.hidden = rnn.init_hidden(test_batch_size) rnn.eval() output.eval() # generate graphs max_num_node = int(args.max_num_node) y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() for i in range(max_num_node): h = rnn(x_step) y_pred_step, _, _ = output(h) y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) x_step = sample_sigmoid(y_pred_step, sample=True, sample_time=sample_time) y_pred_long[:, i:i + 1, :] = x_step rnn.hidden = Variable(rnn.hidden.data).cuda() y_pred_data = y_pred.data y_pred_long_data = y_pred_long.data.long() # save graphs as pickle G_pred_list = [] for i in range(test_batch_size): adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) G_pred = get_graph(adj_pred) # get a graph from zero-padded adj G_pred_list.append(G_pred) # save prediction histograms, plot histogram over each time step # if save_histogram: # save_prediction_histogram(y_pred_data.cpu().numpy(), # fname_pred=args.figure_prediction_save_path+args.fname_pred+str(epoch)+'.jpg', # max_num_node=max_num_node) return G_pred_list def test_vae_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1): rnn.eval() output.eval() G_pred_list = [] for batch_idx, data in enumerate(data_loader): x = data['x'].float() y = data['y'].float() y_len = data['len'] test_batch_size = x.size(0) rnn.hidden = rnn.init_hidden(test_batch_size) # generate graphs max_num_node = int(args.max_num_node) y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() for i in range(max_num_node): print('finish node',i) h = rnn(x_step) y_pred_step, _, _ = output(h) y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) x_step = sample_sigmoid_supervised(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time) y_pred_long[:, i:i + 1, :] = x_step rnn.hidden = Variable(rnn.hidden.data).cuda() y_pred_data = y_pred.data y_pred_long_data = y_pred_long.data.long() # save graphs as pickle for i in range(test_batch_size): adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) G_pred = get_graph(adj_pred) # get a graph from zero-padded adj G_pred_list.append(G_pred) return G_pred_list def train_mlp_epoch(epoch, args, rnn, output, data_loader, optimizer_rnn, optimizer_output, scheduler_rnn, scheduler_output): rnn.train() output.train() loss_sum = 0 for batch_idx, data in enumerate(data_loader): rnn.zero_grad() output.zero_grad() x_unsorted = data['x'].float() y_unsorted = data['y'].float() y_len_unsorted = data['len'] y_len_max = max(y_len_unsorted) x_unsorted = x_unsorted[:, 0:y_len_max, :] y_unsorted = y_unsorted[:, 0:y_len_max, :] # initialize lstm hidden state according to batch size rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) # sort input y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) y_len = y_len.numpy().tolist() x = torch.index_select(x_unsorted,0,sort_index) y = torch.index_select(y_unsorted,0,sort_index) x = Variable(x).cuda() y = Variable(y).cuda() h = rnn(x, pack=True, input_len=y_len) y_pred = output(h) y_pred = F.sigmoid(y_pred) # clean y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True) y_pred = pad_packed_sequence(y_pred, batch_first=True)[0] # use cross entropy loss loss = binary_cross_entropy_weight(y_pred, y) loss.backward() # update deterministic and lstm optimizer_output.step() optimizer_rnn.step() scheduler_output.step() scheduler_rnn.step() if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( epoch, args.epochs,loss.data, args.graph_type, args.num_layers, args.hidden_size_rnn)) # logging log_value('loss_'+args.fname, loss.data, epoch*args.batch_ratio+batch_idx) loss_sum += loss.data return loss_sum/(batch_idx+1) def test_mlp_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False,sample_time=1): rnn.hidden = rnn.init_hidden(test_batch_size) rnn.eval() output.eval() # generate graphs max_num_node = int(args.max_num_node) y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() for i in range(max_num_node): h = rnn(x_step) y_pred_step = output(h) y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) x_step = sample_sigmoid(y_pred_step, sample=True, sample_time=sample_time) y_pred_long[:, i:i + 1, :] = x_step rnn.hidden = Variable(rnn.hidden.data).cuda() y_pred_data = y_pred.data y_pred_long_data = y_pred_long.data.long() # save graphs as pickle G_pred_list = [] for i in range(test_batch_size): adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) G_pred = get_graph(adj_pred) # get a graph from zero-padded adj G_pred_list.append(G_pred) # # save prediction histograms, plot histogram over each time step # if save_histogram: # save_prediction_histogram(y_pred_data.cpu().numpy(), # fname_pred=args.figure_prediction_save_path+args.fname_pred+str(epoch)+'.jpg', # max_num_node=max_num_node) return G_pred_list def test_mlp_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1): rnn.eval() output.eval() G_pred_list = [] for batch_idx, data in enumerate(data_loader): x = data['x'].float() y = data['y'].float() y_len = data['len'] test_batch_size = x.size(0) rnn.hidden = rnn.init_hidden(test_batch_size) # generate graphs max_num_node = int(args.max_num_node) y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() for i in range(max_num_node): print('finish node',i) h = rnn(x_step) y_pred_step = output(h) y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) x_step = sample_sigmoid_supervised(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time) y_pred_long[:, i:i + 1, :] = x_step rnn.hidden = Variable(rnn.hidden.data).cuda() y_pred_data = y_pred.data y_pred_long_data = y_pred_long.data.long() # save graphs as pickle for i in range(test_batch_size): adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) G_pred = get_graph(adj_pred) # get a graph from zero-padded adj G_pred_list.append(G_pred) return G_pred_list def test_mlp_partial_simple_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1): rnn.eval() output.eval() G_pred_list = [] for batch_idx, data in enumerate(data_loader): x = data['x'].float() y = data['y'].float() y_len = data['len'] test_batch_size = x.size(0) rnn.hidden = rnn.init_hidden(test_batch_size) # generate graphs max_num_node = int(args.max_num_node) y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() for i in range(max_num_node): print('finish node',i) h = rnn(x_step) y_pred_step = output(h) y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) x_step = sample_sigmoid_supervised_simple(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time) y_pred_long[:, i:i + 1, :] = x_step rnn.hidden = Variable(rnn.hidden.data).cuda() y_pred_data = y_pred.data y_pred_long_data = y_pred_long.data.long() # save graphs as pickle for i in range(test_batch_size): adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) G_pred = get_graph(adj_pred) # get a graph from zero-padded adj G_pred_list.append(G_pred) return G_pred_list def train_mlp_forward_epoch(epoch, args, rnn, output, data_loader): rnn.train() output.train() loss_sum = 0 for batch_idx, data in enumerate(data_loader): rnn.zero_grad() output.zero_grad() x_unsorted = data['x'].float() y_unsorted = data['y'].float() y_len_unsorted = data['len'] y_len_max = max(y_len_unsorted) x_unsorted = x_unsorted[:, 0:y_len_max, :] y_unsorted = y_unsorted[:, 0:y_len_max, :] # initialize lstm hidden state according to batch size rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) # sort input y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) y_len = y_len.numpy().tolist() x = torch.index_select(x_unsorted,0,sort_index) y = torch.index_select(y_unsorted,0,sort_index) x = Variable(x).cuda() y = Variable(y).cuda() h = rnn(x, pack=True, input_len=y_len) y_pred = output(h) y_pred = F.sigmoid(y_pred) # clean y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True) y_pred = pad_packed_sequence(y_pred, batch_first=True)[0] # use cross entropy loss loss = 0 for j in range(y.size(1)): # print('y_pred',y_pred[0,j,:],'y',y[0,j,:]) end_idx = min(j+1,y.size(2)) loss += binary_cross_entropy_weight(y_pred[:,j,0:end_idx], y[:,j,0:end_idx])*end_idx if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) # logging log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx) loss_sum += loss.data[0] return loss_sum/(batch_idx+1) ## too complicated, deprecated # def test_mlp_partial_bfs_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1): # rnn.eval() # output.eval() # G_pred_list = [] # for batch_idx, data in enumerate(data_loader): # x = data['x'].float() # y = data['y'].float() # y_len = data['len'] # test_batch_size = x.size(0) # rnn.hidden = rnn.init_hidden(test_batch_size) # # generate graphs # max_num_node = int(args.max_num_node) # y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score # y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction # x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() # for i in range(max_num_node): # # 1 back up hidden state # hidden_prev = Variable(rnn.hidden.data).cuda() # h = rnn(x_step) # y_pred_step = output(h) # y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) # x_step = sample_sigmoid_supervised(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time) # y_pred_long[:, i:i + 1, :] = x_step # # rnn.hidden = Variable(rnn.hidden.data).cuda() # # print('finish node', i) # y_pred_data = y_pred.data # y_pred_long_data = y_pred_long.data.long() # # # save graphs as pickle # for i in range(test_batch_size): # adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) # G_pred = get_graph(adj_pred) # get a graph from zero-padded adj # G_pred_list.append(G_pred) # return G_pred_list def train_rnn_epoch(epoch, args, rnn, output, data_loader, optimizer_rnn, optimizer_output, scheduler_rnn, scheduler_output): rnn.train() output.train() loss_sum = 0 for batch_idx, data in enumerate(data_loader): rnn.zero_grad() output.zero_grad() x_unsorted = data['x'].float() y_unsorted = data['y'].float() y_len_unsorted = data['len'] y_len_max = max(y_len_unsorted) x_unsorted = x_unsorted[:, 0:y_len_max, :] y_unsorted = y_unsorted[:, 0:y_len_max, :] # initialize lstm hidden state according to batch size rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) # output.hidden = output.init_hidden(batch_size=x_unsorted.size(0)*x_unsorted.size(1)) # sort input y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) y_len = y_len.numpy().tolist() x = torch.index_select(x_unsorted,0,sort_index) y = torch.index_select(y_unsorted,0,sort_index) # input, output for output rnn module # a smart use of pytorch builtin function: pack variable--b1_l1,b2_l1,...,b1_l2,b2_l2,... y_reshape = pack_padded_sequence(y,y_len,batch_first=True).data # reverse y_reshape, so that their lengths are sorted, add dimension idx = [i for i in range(y_reshape.size(0)-1, -1, -1)] idx = torch.LongTensor(idx) y_reshape = y_reshape.index_select(0, idx) y_reshape = y_reshape.view(y_reshape.size(0),y_reshape.size(1),1) output_x = torch.cat((torch.ones(y_reshape.size(0),1,1),y_reshape[:,0:-1,0:1]),dim=1) output_y = y_reshape # batch size for output module: sum(y_len) output_y_len = [] output_y_len_bin = np.bincount(np.array(y_len)) for i in range(len(output_y_len_bin)-1,0,-1): count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i output_y_len.extend([min(i,y.size(2))]*count_temp) # put them in output_y_len; max value should not exceed y.size(2) # pack into variable x = Variable(x).cuda() y = Variable(y).cuda() output_x = Variable(output_x).cuda() output_y = Variable(output_y).cuda() # print(output_y_len) # print('len',len(output_y_len)) # print('y',y.size()) # print('output_y',output_y.size()) # if using ground truth to train h = rnn(x, pack=True, input_len=y_len) h = pack_padded_sequence(h,y_len,batch_first=True).data # get packed hidden vector # reverse h idx = [i for i in range(h.size(0) - 1, -1, -1)] idx = Variable(torch.LongTensor(idx)).cuda() h = h.index_select(0, idx) hidden_null = Variable(torch.zeros(args.num_layers-1, h.size(0), h.size(1))).cuda() output.hidden = torch.cat((h.view(1,h.size(0),h.size(1)),hidden_null),dim=0) # num_layers, batch_size, hidden_size y_pred = output(output_x, pack=True, input_len=output_y_len) y_pred = F.sigmoid(y_pred) # clean y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True) y_pred = pad_packed_sequence(y_pred, batch_first=True)[0] output_y = pack_padded_sequence(output_y,output_y_len,batch_first=True) output_y = pad_packed_sequence(output_y,batch_first=True)[0] # use cross entropy loss loss = binary_cross_entropy_weight(y_pred, output_y) loss.backward() # update deterministic and lstm optimizer_output.step() optimizer_rnn.step() scheduler_output.step() scheduler_rnn.step() if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( epoch, args.epochs,loss.data, args.graph_type, args.num_layers, args.hidden_size_rnn)) # logging log_value('loss_'+args.fname, loss.data, epoch*args.batch_ratio+batch_idx) feature_dim = y.size(1)*y.size(2) loss_sum += loss.data*feature_dim return loss_sum/(batch_idx+1) def test_rnn_epoch(epoch, args, rnn, output, test_batch_size=16): rnn.hidden = rnn.init_hidden(test_batch_size) rnn.eval() output.eval() # generate graphs max_num_node = int(args.max_num_node) y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() for i in range(max_num_node): h = rnn(x_step) # output.hidden = h.permute(1,0,2) hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(2))).cuda() output.hidden = torch.cat((h.permute(1,0,2), hidden_null), dim=0) # num_layers, batch_size, hidden_size x_step = Variable(torch.zeros(test_batch_size,1,args.max_prev_node)).cuda() output_x_step = Variable(torch.ones(test_batch_size,1,1)).cuda() for j in range(min(args.max_prev_node,i+1)): output_y_pred_step = output(output_x_step) output_x_step = sample_sigmoid(output_y_pred_step, sample=True, sample_time=1) x_step[:,:,j:j+1] = output_x_step output.hidden = Variable(output.hidden.data).cuda() y_pred_long[:, i:i + 1, :] = x_step rnn.hidden = Variable(rnn.hidden.data).cuda() y_pred_long_data = y_pred_long.data.long() # save graphs as pickle G_pred_list = [] for i in range(test_batch_size): adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) G_pred = get_graph(adj_pred) # get a graph from zero-padded adj G_pred_list.append(G_pred) return G_pred_list def train_rnn_forward_epoch(epoch, args, rnn, output, data_loader): rnn.train() output.train() loss_sum = 0 for batch_idx, data in enumerate(data_loader): rnn.zero_grad() output.zero_grad() x_unsorted = data['x'].float() y_unsorted = data['y'].float() y_len_unsorted = data['len'] y_len_max = max(y_len_unsorted) x_unsorted = x_unsorted[:, 0:y_len_max, :] y_unsorted = y_unsorted[:, 0:y_len_max, :] # initialize lstm hidden state according to batch size rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) # output.hidden = output.init_hidden(batch_size=x_unsorted.size(0)*x_unsorted.size(1)) # sort input y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) y_len = y_len.numpy().tolist() x = torch.index_select(x_unsorted,0,sort_index) y = torch.index_select(y_unsorted,0,sort_index) # input, output for output rnn module # a smart use of pytorch builtin function: pack variable--b1_l1,b2_l1,...,b1_l2,b2_l2,... y_reshape = pack_padded_sequence(y,y_len,batch_first=True).data # reverse y_reshape, so that their lengths are sorted, add dimension idx = [i for i in range(y_reshape.size(0)-1, -1, -1)] idx = torch.LongTensor(idx) y_reshape = y_reshape.index_select(0, idx) y_reshape = y_reshape.view(y_reshape.size(0),y_reshape.size(1),1) output_x = torch.cat((torch.ones(y_reshape.size(0),1,1),y_reshape[:,0:-1,0:1]),dim=1) output_y = y_reshape # batch size for output module: sum(y_len) output_y_len = [] output_y_len_bin = np.bincount(np.array(y_len)) for i in range(len(output_y_len_bin)-1,0,-1): count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i output_y_len.extend([min(i,y.size(2))]*count_temp) # put them in output_y_len; max value should not exceed y.size(2) # pack into variable x = Variable(x).cuda() y = Variable(y).cuda() output_x = Variable(output_x).cuda() output_y = Variable(output_y).cuda() # print(output_y_len) # print('len',len(output_y_len)) # print('y',y.size()) # print('output_y',output_y.size()) # if using ground truth to train h = rnn(x, pack=True, input_len=y_len) h = pack_padded_sequence(h,y_len,batch_first=True).data # get packed hidden vector # reverse h idx = [i for i in range(h.size(0) - 1, -1, -1)] idx = Variable(torch.LongTensor(idx)).cuda() h = h.index_select(0, idx) hidden_null = Variable(torch.zeros(args.num_layers-1, h.size(0), h.size(1))).cuda() output.hidden = torch.cat((h.view(1,h.size(0),h.size(1)),hidden_null),dim=0) # num_layers, batch_size, hidden_size y_pred = output(output_x, pack=True, input_len=output_y_len) y_pred = F.sigmoid(y_pred) # clean y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True) y_pred = pad_packed_sequence(y_pred, batch_first=True)[0] output_y = pack_padded_sequence(output_y,output_y_len,batch_first=True) output_y = pad_packed_sequence(output_y,batch_first=True)[0] # use cross entropy loss loss = binary_cross_entropy_weight(y_pred, output_y) if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) # logging log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx) # print(y_pred.size()) feature_dim = y_pred.size(0)*y_pred.size(1) loss_sum += loss.data[0]*feature_dim/y.size(0) return loss_sum/(batch_idx+1) ########### train function for LSTM + VAE def train(args, dataset_train, rnn, output): # check if load existing model if args.load: fname = args.model_save_path + args.fname + 'lstm_' + str(args.load_epoch) + '.dat' rnn.load_state_dict(torch.load(fname)) fname = args.model_save_path + args.fname + 'output_' + str(args.load_epoch) + '.dat' output.load_state_dict(torch.load(fname)) args.lr = 0.00001 epoch = args.load_epoch print('model loaded!, lr: {}'.format(args.lr)) else: epoch = 1 # initialize optimizer optimizer_rnn = optim.Adam(list(rnn.parameters()), lr=args.lr) optimizer_output = optim.Adam(list(output.parameters()), lr=args.lr) scheduler_rnn = MultiStepLR(optimizer_rnn, milestones=args.milestones, gamma=args.lr_rate) scheduler_output = MultiStepLR(optimizer_output, milestones=args.milestones, gamma=args.lr_rate) # start main loop time_all = np.zeros(args.epochs) while epoch<=args.epochs: time_start = tm.time() # train if 'GraphRNN_VAE' in args.note: train_vae_epoch(epoch, args, rnn, output, dataset_train, optimizer_rnn, optimizer_output, scheduler_rnn, scheduler_output) elif 'GraphRNN_MLP' in args.note: train_mlp_epoch(epoch, args, rnn, output, dataset_train, optimizer_rnn, optimizer_output, scheduler_rnn, scheduler_output) elif 'GraphRNN_RNN' in args.note: train_rnn_epoch(epoch, args, rnn, output, dataset_train, optimizer_rnn, optimizer_output, scheduler_rnn, scheduler_output) time_end = tm.time() time_all[epoch - 1] = time_end - time_start # test if epoch % args.epochs_test == 0 and epoch>=args.epochs_test_start: for sample_time in range(1,4): G_pred = [] while len(G_pred)