import torch import torch.nn.functional as F from torch.autograd import Variable import sys sys.path.append('../') from train import * from args import Args from torch import optim from torch.optim.lr_scheduler import MultiStepLR from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence from torch.nn.init import * import matplotlib.pyplot as plt import networkx as nx from tensorboard_logger import configure, log_value from baselines.graphvae.graphvae_model import GraphVAE import itertools device = torch.device("cuda:0") def plot_grad_flow(named_parameters): ave_grads = [] layers = [] for n, p in named_parameters: if(p.requires_grad) and ("bias" not in n): layers.append(n) print(p.grad) print("**********************************") print(p.grad.abs()) print("**********************************") print(p.grad.abs().mean()) print("**********************************") ave_grads.append(p.grad.abs().mean()) plt.plot(ave_grads, alpha=0.3, color="b") plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k" ) plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical") plt.xlim(xmin=0, xmax=len(ave_grads)) plt.xlabel("Layers") plt.ylabel("average gradient") plt.title("Gradient flow") plt.grid(True) plt.show() def graph_show(G, title): pos = nx.spring_layout(G, scale=2) nx.draw(G, pos, font_size=8) fig = plt.gcf() fig.canvas.set_window_title(title) plt.show() def decode_adj(adj_output): ''' recover to adj from adj_output note: here adj_output have shape (n-1)*m ''' max_prev_node = adj_output.shape[1] adj = np.zeros((adj_output.shape[0], adj_output.shape[0])) for i in range(adj_output.shape[0]): input_start = max(0, i - max_prev_node + 1) input_end = i + 1 output_start = max_prev_node + max(0, i - max_prev_node + 1) - (i + 1) output_end = max_prev_node adj[i, input_start:input_end] = adj_output[i,::-1][output_start:output_end] # reverse order adj_full = np.zeros((adj_output.shape[0]+1, adj_output.shape[0]+1)) n = adj_full.shape[0] adj_full[1:n, 0:n-1] = np.tril(adj, 0) adj_full = adj_full + adj_full.T return adj_full def tensor_to_graph_converter(tensor, title): adj_matrix = decode_adj(tensor.numpy()) G = nx.from_numpy_matrix(adj_matrix) graph_show(G, title) return def get_indices(number_of_missing_nodes, row_size): # print("$$$$$$$$") # print(row_size) row_index = [] column_index = [] start_index = 0 for i in range(number_of_missing_nodes): for j in range(row_size): row_index.append(i) column_index.append(j + start_index) start_index+=1 return row_index, column_index def GraphRNN_loss2(true_adj_batch, predicted_adj_batch, y_pred, y, batch_size, number_of_missing_nodes): row_list = list(range(number_of_missing_nodes)) permutation_indices = list(itertools.permutations(row_list)) loss_sum = 0 for data_index in range(batch_size): row_size = len(true_adj_batch[data_index])-number_of_missing_nodes row_index, column_index = get_indices(number_of_missing_nodes, row_size) # print("true_adj_batch: ") # print(true_adj_batch[data_index]) # print("y: ") # print(y[data_index]) # print("predicted_adj_batch: ") # print(predicted_adj_batch[data_index]) # print("y_pred:") # print(y_pred[data_index]+0) # print("selected_y_pred:") # print(y_pred[data_index][row_index, column_index]) # print("selected_y:") # print(y[data_index][row_index, column_index]) loss_list = [] for row_order in permutation_indices: permuted_y_pred = y_pred[data_index][list(row_order)] loss_list.append(F.binary_cross_entropy(permuted_y_pred[row_index, column_index], y[data_index][row_index, column_index])) loss_sum+=min(loss_list) loss_sum/=batch_size return loss_sum def GraphRNN_loss(true_adj_batch, predicted_adj_batch, batch_size, number_of_missing_nodes): # loss_sum = 0 for data_index in range(batch_size): loss_sum= F.binary_cross_entropy(predicted_adj_batch[data_index], true_adj_batch[data_index]) break # loss_sum/=batch_size # loss_sum.requires_grad = True return loss_sum def matrix_completion_loss(true_adj_batch, predicted_adj_batch, batch_size, number_of_missing_nodes): # graph_show(nx.from_numpy_matrix(true_adj[0]), "true_adj") # graph_show(nx.from_numpy_matrix(predicted_adj[0]), "predicted_adj") loss_sum=F.binary_cross_entropy(torch.from_numpy(predicted_adj_batch[0]), torch.from_numpy(true_adj_batch[0])) # loss_sum = 0 # for data_index in range(batch_size): # # print("########## data_len: ") # # print(data_len[data_index].item()) # # print("^^^^ data_index: " + str(data_index)) # adj_matrix = true_adj_batch[data_index] # predicted_adj_matrix = predicted_adj_batch[data_index] # number_of_nodes = np.shape(adj_matrix)[1] # # loss_sum+=F.binary_cross_entropy(torch.from_numpy(predicted_adj_matrix), # # torch.from_numpy(adj_matrix)) # sub_matrix = adj_matrix[number_of_nodes - number_of_missing_nodes:][:, # :number_of_nodes - number_of_missing_nodes] # predicted_sub_matrix = predicted_adj_matrix[number_of_nodes - number_of_missing_nodes:][:, # :number_of_nodes - number_of_missing_nodes] # min_permutation_loss = float("inf") # row_list = list(range(number_of_missing_nodes)) # permutation_indices = list(itertools.permutations(row_list)) # number_of_rows = np.shape(predicted_sub_matrix)[0] # number_of_columns = np.shape(predicted_sub_matrix)[1] # matrix_size = number_of_rows*number_of_columns # # print(str(number_of_rows) + " ### " + str(number_of_columns) + " @@@ " + str(matrix_size)) # # for row_order in permutation_indices: # permuted_sub_matrix = sub_matrix[list(row_order)] # # print("***** permuted_sub_matrix ") # # print(permuted_sub_matrix) # # print("***** predicted_sub_matrix ") # # print(predicted_sub_matrix) # adj_recon_loss = F.binary_cross_entropy(torch.from_numpy(predicted_sub_matrix), # torch.from_numpy(permuted_sub_matrix)) # # print(adj_recon_loss) # min_permutation_loss = np.minimum(min_permutation_loss, adj_recon_loss.item()) # # print(adj_recon_loss) # loss_sum += min_permutation_loss/matrix_size # # print("^^^^^ ") # # print(loss_sum) # loss_sum = torch.autograd.Variable(torch.from_numpy(np.array(loss_sum))) # # print(loss_sum) # loss_sum.requires_grad = True # print(loss_sum) # print(loss_sum/batch_size) return loss_sum/batch_size def graph_matching_loss(true_adj_batch, predicted_adj_batch, data_len, batch_size): # graph_show(nx.from_numpy_matrix(true_adj[0]), "true_adj") # graph_show(nx.from_numpy_matrix(predicted_adj[0]), "predicted_adj") loss_sum = 0 for data_index in range(batch_size): features = torch.zeros(data_len[data_index]) number_of_nodes = data_len[data_index].item() # print("########## data_len: ") # print(data_len[data_index].item()) if data_len[data_index].item()>14: continue graphvae = GraphVAE(1, 64, 256, data_len[data_index].item()) # print("^^^^ data_index: " + str(data_index)) current_true_adj = true_adj_batch[data_index] current_predicted_adj = predicted_adj_batch[data_index] current_features = features S = graphvae.edge_similarity_matrix(current_true_adj, current_predicted_adj, current_features, current_features, graphvae.deg_feature_similarity) init_corr = 1 / number_of_nodes init_assignment = torch.ones(number_of_nodes, number_of_nodes) * init_corr assignment = graphvae.mpm(init_assignment, S) row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy()) permuted_adj = graphvae.permute_adj(torch.from_numpy(current_true_adj).float(), row_ind, col_ind) # print("***** permuted_adj ") # print(permuted_adj) # print("***** current_predicted_adj ") # print(torch.from_numpy(current_predicted_adj).float()) adj_recon_loss = graphvae.adj_recon_loss(permuted_adj, torch.from_numpy(current_predicted_adj).float()) # print(adj_recon_loss) loss_sum += adj_recon_loss.data loss_sum.requires_grad = True # print(loss_sum/batch_size) # print("^^^^^ ") # print(loss_sum) return loss_sum/batch_size def make_batch_of_matrices(initial_y, initial_y_len, y_train, y_pred, batch_size): true_adj_batch = [] predicted_adj_batch = [] for data_index in range(batch_size): a = initial_y[data_index][:initial_y_len[data_index] - 1].cpu() b = y_train[data_index] true_adj = torch.cat((a, b), dim=0) true_adj = decode_adj(true_adj.numpy()) true_adj_tensor = torch.from_numpy(true_adj + np.identity(len(true_adj[0]))).float().to(device) true_adj_batch.append(true_adj_tensor) b = y_pred[data_index].data # print("bbbbbbbbbbbbbbbbbbbbbbb :") # print(b) # print("b.data :") # print(b.data) # print("**********************************************************************************") predicted_adj = torch.cat((a, b.cpu()), dim=0) predicted_adj = decode_adj(predicted_adj.numpy()) predicted_adj_tensor = torch.from_numpy(predicted_adj + np.identity(len(predicted_adj[0]))).float().to(device) predicted_adj_tensor.requires_grad = True # predicted_adj_tensor = predicted_adj_tensor + 1 # print("*********** predicted_adj_tensor *******") # print(predicted_adj_tensor+0) predicted_adj_batch.append(predicted_adj_tensor+0) return true_adj_batch, predicted_adj_batch def train_mlp_epoch_for_completion(epoch, args, rnn, output, rnn_for_initialization, output_for_initialization, data_loader, optimizer_rnn, optimizer_output, scheduler_rnn, scheduler_output, number_of_missing_nodes): rnn.train() output.train() loss_sum = 0 for batch_idx, data in enumerate(data_loader): rnn.zero_grad() output.zero_grad() x_unsorted = data['x'].float() y_unsorted = data['y'].float() data_len = y_len_unsorted = data['len'] # data_index = 0 # b = y_unsorted[data_index][:y_len_unsorted[data_index] - 1].cpu() # tensor_to_graph_converter(b, "Complete Graph") # **** x_unsorted, y_unsorted, y_len_unsorted, hidden, initial_y, initial_y_len = \ data_initializer_for_training(x_unsorted, y_unsorted, y_len_unsorted, args, rnn_for_initialization, output_for_initialization, number_of_missing_nodes, False, 0.96) # ____________________________________________________________________________________________ # x_unsorted1, y_unsorted1, y_len_unsorted1, hidden, initial_y1, initial_y_len1 = \ # data_initializer_for_training(x_unsorted, y_unsorted, y_len_unsorted, # args, rnn_for_initialization, output_for_initialization, # number_of_missing_nodes, False, 0.96) # ____________________________________________________________________________________________ # a = initial_y[data_index][:initial_y_len[data_index]-1].cpu() # tensor_to_graph_converter(a, "Incomplete Graph") # b = y_unsorted[data_index][:-1] # recons = torch.cat((a, b), dim=0) # tensor_to_graph_converter(recons, "Merged Complete Graph") # **** y_len_max = max(y_len_unsorted) x_unsorted = x_unsorted[:, 0:y_len_max, :] y_unsorted = y_unsorted[:, 0:y_len_max, :] # initialize lstm hidden state according to batch size # rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) # **** rnn.hidden = hidden # **** # sort input y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) y_len = y_len.numpy().tolist() x = torch.index_select(x_unsorted,0,sort_index) y = torch.index_select(y_unsorted,0,sort_index) x = Variable(x).cuda() y = Variable(y).cuda() h = rnn(x, pack=True, input_len=y_len) y_pred = output(h) y_pred = F.sigmoid(y_pred) # clean y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True) y_pred = pad_packed_sequence(y_pred, batch_first=True)[0] # use cross entropy loss # print("y_unsorted: ") # print(y_unsorted[0]) # print("y: ") # print(y[0].cpu()) true_adj_batch, predicted_adj_batch = make_batch_of_matrices(initial_y, initial_y_len, y.cpu(),y_pred, len(y_pred)) # graph_show(nx.from_numpy_matrix(true_adj_batch[data_index]), "Complete2") # import time # start = time.time() # print("@@@@@@@@@@@@@@@@@@@@@@@@@") # print(args.loss) if args.loss == "graph_matching_loss": print(graph_matching_loss) # loss = graph_matching_loss(true_adj_batch, predicted_adj_batch, data_len, len(data_len)) elif args.loss == "matrix_completion_loss": print(matrix_completion_loss) # loss = matrix_completion_loss(true_adj_batch, predicted_adj_batch, len(data_len), # args.number_of_missing_nodes) elif args.loss =="GraphRNN_loss": # print("GraphRNN_loss") # print(y) # print(y_pred) # loss = F.binary_cross_entropy(torch.from_numpy(predicted_adj_batch[0]), torch.from_numpy(true_adj_batch[0])) # loss.requires_grad = True # print("true_adj_batch: ") # print(true_adj_batch[0]) # print("y: ") # print(y[0]) # print("predicted_adj_batch: ") # print(predicted_adj_batch[0]) # print("y_pred:") # print(y_pred[0]+0) # print("**************************************************************") # loss = GraphRNN_loss(true_adj_batch, predicted_adj_batch, len(data_len), args.number_of_missing_nodes); # loss = binary_cross_entropy_weight(predicted_adj_batch[0], true_adj_batch[0]) # print("**** GraphRNN_loss: ") # print(loss) # row_index = [0, 2] # col_index = [1, 3] # print("**************************") # print(y_pred[0][row_index, col_index]) # print(np.shape(y_pred[0])[1]) loss = GraphRNN_loss2(true_adj_batch, predicted_adj_batch, y_pred, y, len(y_pred), args.number_of_missing_nodes) # loss = binary_cross_entropy_weight(y_pred[0][row_index, col_index], y[0][row_index, col_index]) # print("**** main_loss: ") # print(loss) # print("*** epoch : " + str(epoch) + "*** loss: " + str(loss.data)) # break # end = time.time() # print(end - start) # print("y_pred") # print(y_pred[0]) # print("y") # print(y[0]) # print("***************************************************") # loss = binary_cross_entropy_weight(y_pred, y) loss.backward() # update deterministic and lstm optimizer_output.step() optimizer_rnn.step() scheduler_output.step() scheduler_rnn.step() # for p in rnn.parameters(): # print(p) # print(p.grad) # if p.grad is not None: # print("________________________________________________________") # print(p.grad.data) # plot_grad_flow(rnn.named_parameters()) if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( epoch, args.epochs,loss.data, args.graph_type, args.num_layers, args.hidden_size_rnn)) log_value(args.loss + args.fname + args.training_mode + "_ " "number_of_missing_nodes_" + str(number_of_missing_nodes), loss.data, epoch * args.batch_ratio + batch_idx) loss_sum += loss.data return loss_sum/(batch_idx+1) def train_for_completion(args, dataset_train, rnn, output, rnn_for_initialization,output_for_initialization): # check if load existing model if args.load: fname = args.model_save_path + args.fname + 'lstm_' + args.graph_completion_string + str(args.load_epoch) + '.dat' rnn.load_state_dict(torch.load(fname)) fname = args.model_save_path + args.fname + 'output_' + args.graph_completion_string + str(args.load_epoch) + '.dat' output.load_state_dict(torch.load(fname)) args.lr = 0.00001 epoch = args.load_epoch print('model loaded!, lr: {}'.format(args.lr)) else: epoch = 1 # initialize optimizer optimizer_rnn = optim.Adam(list(rnn.parameters()), lr=args.lr) optimizer_output = optim.Adam(list(output.parameters()), lr=args.lr) scheduler_rnn = MultiStepLR(optimizer_rnn, milestones=args.milestones, gamma=args.lr_rate) scheduler_output = MultiStepLR(optimizer_output, milestones=args.milestones, gamma=args.lr_rate) # start main loop time_all = np.zeros(args.epochs) while epoch<=args.epochs: time_start = tm.time() # train if 'GraphRNN_VAE' in args.note: train_vae_epoch(epoch, args, rnn, output, dataset_train, optimizer_rnn, optimizer_output, scheduler_rnn, scheduler_output) elif 'GraphRNN_MLP' in args.note: train_mlp_epoch_for_completion(epoch, args, rnn, output, rnn_for_initialization, output_for_initialization,dataset_train, optimizer_rnn, optimizer_output, scheduler_rnn, scheduler_output, args.number_of_missing_nodes) elif 'GraphRNN_RNN' in args.note: train_rnn_epoch(epoch, args, rnn, output, dataset_train, optimizer_rnn, optimizer_output, scheduler_rnn, scheduler_output) time_end = tm.time() time_all[epoch - 1] = time_end - time_start # save model checkpoint if args.save: if epoch % args.epochs_save == 0: fname = args.model_save_path + args.fname + 'lstm_' + args.graph_completion_string + str(epoch) + '.dat' torch.save(rnn.state_dict(), fname) fname = args.model_save_path + args.fname + 'output_' + args.graph_completion_string + str(epoch) + '.dat' torch.save(output.state_dict(), fname) epoch += 1 np.save(args.timing_save_path+args.fname,time_all) def hidden_initializer(x, rnn, output): test_batch_size = len(x) max_num_node = len(x[0]) rnn.hidden = rnn.init_hidden(test_batch_size) rnn.eval() output.eval() for i in range(max_num_node): h = rnn((x[:, i:i+1, :]).cuda()) rnn.hidden = Variable(rnn.hidden.data).cuda() return rnn.hidden def data_initializer_for_training(x_unsorted, y_unsorted, y_len_unsorted, args, rnn, output, number_of_missing_nodes, ratio_mode, incompleteness_ratio): # print(number_of_missing_nodes) batch_size = len(x_unsorted) max_prev_node = len(x_unsorted[0][0]) max_number_of_nodes = int(max(y_len_unsorted)) if ratio_mode: max_incomplete_num_node = int(max_number_of_nodes*incompleteness_ratio) else: max_incomplete_num_node = max_number_of_nodes-number_of_missing_nodes initial_x = Variable(torch.zeros(batch_size, max_incomplete_num_node, max_prev_node)).cuda() initial_y = Variable(torch.zeros(batch_size, max_incomplete_num_node, max_prev_node)).cuda() initial_y_len = Variable(torch.zeros(len(y_len_unsorted))).type(torch.LongTensor) for i in range(len(y_len_unsorted)): if ratio_mode: incomplete_graph_size = int(int(y_len_unsorted.data[i]) * incompleteness_ratio) else: incomplete_graph_size = int(y_len_unsorted.data[i]) - number_of_missing_nodes initial_y_len[i] = incomplete_graph_size initial_x[i] = torch.cat((x_unsorted[i,:incomplete_graph_size], torch.zeros([max_incomplete_num_node - incomplete_graph_size, max_prev_node])), dim=0) initial_y[i] = torch.cat((y_unsorted[i,:incomplete_graph_size-1], torch.zeros([max_incomplete_num_node - incomplete_graph_size+1, max_prev_node])), dim=0) train_y_len = y_len_unsorted - initial_y_len max_train_num_node = int(max(train_y_len)) + 1 temp_size_reduction = 1 train_x = Variable(torch.ones(batch_size, max_train_num_node-temp_size_reduction, max_prev_node)) train_y = Variable(torch.ones(batch_size, max_train_num_node-temp_size_reduction, max_prev_node)) for i in range(len(train_y_len)): if ratio_mode: incomplete_graph_size = int(int(y_len_unsorted.data[i]) * incompleteness_ratio) else: incomplete_graph_size = int(y_len_unsorted.data[i]) - number_of_missing_nodes train_graph_size = int(train_y_len.data[i]) train_x[i, 1-temp_size_reduction:,:] = torch.cat((x_unsorted[i,incomplete_graph_size:incomplete_graph_size + train_graph_size] , torch.zeros([max_train_num_node - train_graph_size-1, max_prev_node])), dim=0) train_y[i, :, :] = torch.cat((x_unsorted[i, incomplete_graph_size+temp_size_reduction:incomplete_graph_size + train_graph_size] , torch.zeros([max_train_num_node - train_graph_size, max_prev_node])), dim=0) hidden = hidden_initializer(initial_x, rnn, output) # just for testing # print(y_len_unsorted) # print(initial_y_len) # print(train_y_len) # print("Complete Data : 00000000000000000000000000000000000000000000000000") # print('x:') # print(x_unsorted[0].int()) # print('y:') # print(y_unsorted[0].int()) # print("Data for Initialization : 11111111111111111111111111111111111111111111111111") # print('x:') # print(initial_x[0].int()) # print('y:') # print(initial_y[0].int()) # print('Data for Training : 22222222222222222222222222222222222222222222222222') # print('x:') # print(train_x[0].int()) # print('y:') # print(train_y[0].int()) # print('33333333333333333333333333333333333333333333333333') # print(hidden.size()) return train_x, train_y, train_y_len, hidden, initial_y, initial_y_len def save_graph(graph, name): adj_pred = decode_adj(graph.cpu().numpy()) G_pred = get_graph(adj_pred) # get a graph from zero-padded adj G = np.asarray(nx.to_numpy_matrix(G_pred)) np.savetxt(name + '.txt', G, fmt='%d') def data_to_graph_converter(data): G_list = [] for i in range(len(data)): x = data[i].numpy() x = x.astype(int) adj_pred = decode_adj(x) G = get_graph(adj_pred) # get a graph from zero-padded adj G_list.append(G) return G_list if __name__ == '__main__': # All necessary arguments are defined in args.py args = Args() os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda) print('CUDA', args.cuda) print('File name prefix', args.fname) # check if necessary directories exist if not os.path.isdir(args.model_save_path): os.makedirs(args.model_save_path) if not os.path.isdir(args.graph_save_path): os.makedirs(args.graph_save_path) if not os.path.isdir(args.figure_save_path): os.makedirs(args.figure_save_path) if not os.path.isdir(args.timing_save_path): os.makedirs(args.timing_save_path) if not os.path.isdir(args.figure_prediction_save_path): os.makedirs(args.figure_prediction_save_path) if not os.path.isdir(args.nll_save_path): os.makedirs(args.nll_save_path) time = strftime("%Y-%m-%d %H:%M:%S", gmtime()) # logging.basicConfig(filename='logs/train' + time + '.log', level=logging.DEBUG) if args.clean_tensorboard: if os.path.isdir("tensorboard"): shutil.rmtree("tensorboard") configure("tensorboard/run" + time, flush_secs=5) graphs = create_graphs.create(args) random.seed(123) shuffle(graphs) graphs_len = len(graphs) graphs_test = graphs[int(0.8 * graphs_len):] graphs_train = graphs[0:int(0.8 * graphs_len)] graphs_validate = graphs[0:int(0.2 * graphs_len)] graph_validate_len = 0 for graph in graphs_validate: graph_validate_len += graph.number_of_nodes() graph_validate_len /= len(graphs_validate) print('graph_validate_len', graph_validate_len) graph_test_len = 0 for graph in graphs_test: graph_test_len += graph.number_of_nodes() graph_test_len /= len(graphs_test) print('graph_test_len', graph_test_len) args.max_num_node = max([graphs[i].number_of_nodes() for i in range(len(graphs))]) max_num_edge = max([graphs[i].number_of_edges() for i in range(len(graphs))]) min_num_edge = min([graphs[i].number_of_edges() for i in range(len(graphs))]) # args.max_num_node = 2000 # show graphs statistics print('total graph num: {}, training set: {}'.format(len(graphs),len(graphs_train))) print('max number node: {}'.format(args.max_num_node)) print('max/min number edge: {}; {}'.format(max_num_edge,min_num_edge)) print('max previous node: {}'.format(args.max_prev_node)) # save ground truth graphs ## To get train and test set, after loading you need to manually slice save_graph_list(graphs, args.graph_save_path + args.fname_train + '0.dat') save_graph_list(graphs, args.graph_save_path + args.fname_test + '0.dat') print('train and test graphs saved at: ', args.graph_save_path + args.fname_test + '0.dat') if 'nobfs' in args.note: print('nobfs') args.max_prev_node = args.max_num_node - 1 dataset = Graph_sequence_sampler_pytorch_nobfs_for_completion(graphs_train, max_num_node=args.max_num_node) else: dataset = Graph_sequence_sampler_pytorch(graphs_train,max_prev_node=args.max_prev_node,max_num_node=args.max_num_node) sample_strategy = torch.utils.data.sampler.WeightedRandomSampler([1.0 / len(dataset) for i in range(len(dataset))], num_samples=args.batch_size * args.batch_ratio, replacement=True) dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, sampler=sample_strategy) rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn, hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True, has_output=False).cuda() output = MLP_plain(h_size=args.hidden_size_rnn, embedding_size=args.embedding_size_output, y_size=args.max_prev_node).cuda() rnn_for_initialization = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn, hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True, has_output=False).cuda() output_for_initialization = MLP_plain(h_size=args.hidden_size_rnn, embedding_size=args.embedding_size_output, y_size=args.max_prev_node).cuda() # fname = args.model_save_path + args.initial_fname + 'lstm_' + str(args.load_epoch) + '.dat' rnn_for_initialization.load_state_dict(torch.load(fname)) rnn.load_state_dict(torch.load(fname)) fname = args.model_save_path + args.initial_fname + 'output_' + str(args.load_epoch) + '.dat' output_for_initialization.load_state_dict(torch.load(fname)) output.load_state_dict(torch.load(fname)) # weight initialization # rnn_for_initialization_dictionary = dict(rnn_for_initialization.named_parameters()) # for name, param in rnn.named_parameters(): # # print(name) # print(1) # print(param) # param.data = rnn_for_initialization_dictionary[name] # print(2) # print(param) # output_for_initialization_dictionary = dict(output_for_initialization.named_parameters()) # for name, param in output.named_parameters(): # # print(name) # print(1) # print(param) # param.data = output_for_initialization_dictionary[name] # print(1) # print(param) # args.lr = 0.00001 # epoch = args.load_epoch print('model loaded!, lr: {}'.format(args.lr)) # train_for_completion(args, dataset_loader, rnn, output, rnn, # output) train_for_completion(args, dataset_loader, rnn, output, rnn_for_initialization, output_for_initialization) # for batch_idx, data in enumerate(dataset_loader): # if batch_idx==0: # rnn.zero_grad() # output.zero_grad() # x_unsorted = data['x'].float() # y_unsorted = data['y'].float() # y_len_unsorted = data['len'] # train_x, train_y, train_y_len, hidden = \ # data_initializer_for_training(x_unsorted, y_unsorted, y_len_unsorted, # args, rnn, output)