123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675 |
- import torch
- import torch.nn.functional as F
- from torch.autograd import Variable
- import sys
- sys.path.append('../')
- from train import *
- from args import Args
- from torch import optim
- from torch.optim.lr_scheduler import MultiStepLR
- from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
- from torch.nn.init import *
- import matplotlib.pyplot as plt
- import networkx as nx
- from tensorboard_logger import configure, log_value
- from baselines.graphvae.graphvae_model import GraphVAE
- import itertools
-
- device = torch.device("cuda:0")
-
-
- def plot_grad_flow(named_parameters):
- ave_grads = []
- layers = []
- for n, p in named_parameters:
- if(p.requires_grad) and ("bias" not in n):
- layers.append(n)
- print(p.grad)
- print("**********************************")
- print(p.grad.abs())
- print("**********************************")
- print(p.grad.abs().mean())
- print("**********************************")
- ave_grads.append(p.grad.abs().mean())
- plt.plot(ave_grads, alpha=0.3, color="b")
- plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k" )
- plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
- plt.xlim(xmin=0, xmax=len(ave_grads))
- plt.xlabel("Layers")
- plt.ylabel("average gradient")
- plt.title("Gradient flow")
- plt.grid(True)
- plt.show()
-
-
- def graph_show(G, title):
- pos = nx.spring_layout(G, scale=2)
- nx.draw(G, pos, font_size=8)
- fig = plt.gcf()
- fig.canvas.set_window_title(title)
- plt.show()
-
-
- def decode_adj(adj_output):
- '''
- recover to adj from adj_output
- note: here adj_output have shape (n-1)*m
- '''
- max_prev_node = adj_output.shape[1]
- adj = np.zeros((adj_output.shape[0], adj_output.shape[0]))
- for i in range(adj_output.shape[0]):
- input_start = max(0, i - max_prev_node + 1)
- input_end = i + 1
- output_start = max_prev_node + max(0, i - max_prev_node + 1) - (i + 1)
- output_end = max_prev_node
- adj[i, input_start:input_end] = adj_output[i,::-1][output_start:output_end] # reverse order
- adj_full = np.zeros((adj_output.shape[0]+1, adj_output.shape[0]+1))
- n = adj_full.shape[0]
- adj_full[1:n, 0:n-1] = np.tril(adj, 0)
- adj_full = adj_full + adj_full.T
-
- return adj_full
-
-
- def tensor_to_graph_converter(tensor, title):
- adj_matrix = decode_adj(tensor.numpy())
- G = nx.from_numpy_matrix(adj_matrix)
- graph_show(G, title)
- return
-
-
- def get_indices(number_of_missing_nodes, row_size):
- # print("$$$$$$$$")
- # print(row_size)
- row_index = []
- column_index = []
- start_index = 0
- for i in range(number_of_missing_nodes):
- for j in range(row_size):
- row_index.append(i)
- column_index.append(j + start_index)
- start_index+=1
- return row_index, column_index
-
-
- def GraphRNN_loss2(true_adj_batch, predicted_adj_batch, y_pred, y, batch_size, number_of_missing_nodes):
- row_list = list(range(number_of_missing_nodes))
- permutation_indices = list(itertools.permutations(row_list))
- loss_sum = 0
- for data_index in range(batch_size):
- row_size = len(true_adj_batch[data_index])-number_of_missing_nodes
- row_index, column_index = get_indices(number_of_missing_nodes, row_size)
- # print("true_adj_batch: ")
- # print(true_adj_batch[data_index])
- # print("y: ")
- # print(y[data_index])
- # print("predicted_adj_batch: ")
- # print(predicted_adj_batch[data_index])
- # print("y_pred:")
- # print(y_pred[data_index]+0)
- # print("selected_y_pred:")
- # print(y_pred[data_index][row_index, column_index])
- # print("selected_y:")
- # print(y[data_index][row_index, column_index])
- loss_list = []
- for row_order in permutation_indices:
- permuted_y_pred = y_pred[data_index][list(row_order)]
- loss_list.append(F.binary_cross_entropy(permuted_y_pred[row_index, column_index],
- y[data_index][row_index, column_index]))
- loss_sum+=min(loss_list)
- loss_sum/=batch_size
- return loss_sum
-
-
- def GraphRNN_loss(true_adj_batch, predicted_adj_batch, batch_size, number_of_missing_nodes):
- # loss_sum = 0
- for data_index in range(batch_size):
- loss_sum= F.binary_cross_entropy(predicted_adj_batch[data_index],
- true_adj_batch[data_index])
- break
- # loss_sum/=batch_size
- # loss_sum.requires_grad = True
- return loss_sum
-
-
- def matrix_completion_loss(true_adj_batch, predicted_adj_batch, batch_size, number_of_missing_nodes):
- # graph_show(nx.from_numpy_matrix(true_adj[0]), "true_adj")
- # graph_show(nx.from_numpy_matrix(predicted_adj[0]), "predicted_adj")
- loss_sum=F.binary_cross_entropy(torch.from_numpy(predicted_adj_batch[0]),
- torch.from_numpy(true_adj_batch[0]))
-
- # loss_sum = 0
- # for data_index in range(batch_size):
- # # print("########## data_len: ")
- # # print(data_len[data_index].item())
- # # print("^^^^ data_index: " + str(data_index))
- # adj_matrix = true_adj_batch[data_index]
- # predicted_adj_matrix = predicted_adj_batch[data_index]
- # number_of_nodes = np.shape(adj_matrix)[1]
- # # loss_sum+=F.binary_cross_entropy(torch.from_numpy(predicted_adj_matrix),
- # # torch.from_numpy(adj_matrix))
- # sub_matrix = adj_matrix[number_of_nodes - number_of_missing_nodes:][:,
- # :number_of_nodes - number_of_missing_nodes]
- # predicted_sub_matrix = predicted_adj_matrix[number_of_nodes - number_of_missing_nodes:][:,
- # :number_of_nodes - number_of_missing_nodes]
- # min_permutation_loss = float("inf")
- # row_list = list(range(number_of_missing_nodes))
- # permutation_indices = list(itertools.permutations(row_list))
- # number_of_rows = np.shape(predicted_sub_matrix)[0]
- # number_of_columns = np.shape(predicted_sub_matrix)[1]
- # matrix_size = number_of_rows*number_of_columns
- # # print(str(number_of_rows) + " ### " + str(number_of_columns) + " @@@ " + str(matrix_size))
- #
- # for row_order in permutation_indices:
- # permuted_sub_matrix = sub_matrix[list(row_order)]
- # # print("***** permuted_sub_matrix ")
- # # print(permuted_sub_matrix)
- # # print("***** predicted_sub_matrix ")
- # # print(predicted_sub_matrix)
- # adj_recon_loss = F.binary_cross_entropy(torch.from_numpy(predicted_sub_matrix),
- # torch.from_numpy(permuted_sub_matrix))
- # # print(adj_recon_loss)
- # min_permutation_loss = np.minimum(min_permutation_loss, adj_recon_loss.item())
- # # print(adj_recon_loss)
- # loss_sum += min_permutation_loss/matrix_size
- # # print("^^^^^ ")
- # # print(loss_sum)
- # loss_sum = torch.autograd.Variable(torch.from_numpy(np.array(loss_sum)))
- # # print(loss_sum)
- # loss_sum.requires_grad = True
- # print(loss_sum)
- # print(loss_sum/batch_size)
- return loss_sum/batch_size
-
-
- def graph_matching_loss(true_adj_batch, predicted_adj_batch, data_len, batch_size):
- # graph_show(nx.from_numpy_matrix(true_adj[0]), "true_adj")
- # graph_show(nx.from_numpy_matrix(predicted_adj[0]), "predicted_adj")
- loss_sum = 0
- for data_index in range(batch_size):
- features = torch.zeros(data_len[data_index])
- number_of_nodes = data_len[data_index].item()
- # print("########## data_len: ")
- # print(data_len[data_index].item())
- if data_len[data_index].item()>14:
- continue
- graphvae = GraphVAE(1, 64, 256, data_len[data_index].item())
- # print("^^^^ data_index: " + str(data_index))
- current_true_adj = true_adj_batch[data_index]
- current_predicted_adj = predicted_adj_batch[data_index]
- current_features = features
- S = graphvae.edge_similarity_matrix(current_true_adj, current_predicted_adj, current_features, current_features,
- graphvae.deg_feature_similarity)
- init_corr = 1 / number_of_nodes
- init_assignment = torch.ones(number_of_nodes, number_of_nodes) * init_corr
- assignment = graphvae.mpm(init_assignment, S)
- row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy())
- permuted_adj = graphvae.permute_adj(torch.from_numpy(current_true_adj).float(), row_ind, col_ind)
- # print("***** permuted_adj ")
- # print(permuted_adj)
- # print("***** current_predicted_adj ")
- # print(torch.from_numpy(current_predicted_adj).float())
- adj_recon_loss = graphvae.adj_recon_loss(permuted_adj, torch.from_numpy(current_predicted_adj).float())
- # print(adj_recon_loss)
- loss_sum += adj_recon_loss.data
- loss_sum.requires_grad = True
- # print(loss_sum/batch_size)
- # print("^^^^^ ")
- # print(loss_sum)
- return loss_sum/batch_size
-
-
- def make_batch_of_matrices(initial_y, initial_y_len, y_train, y_pred, batch_size):
- true_adj_batch = []
- predicted_adj_batch = []
- for data_index in range(batch_size):
- a = initial_y[data_index][:initial_y_len[data_index] - 1].cpu()
- b = y_train[data_index]
- true_adj = torch.cat((a, b), dim=0)
- true_adj = decode_adj(true_adj.numpy())
- true_adj_tensor = torch.from_numpy(true_adj + np.identity(len(true_adj[0]))).float().to(device)
- true_adj_batch.append(true_adj_tensor)
- b = y_pred[data_index].data
- # print("bbbbbbbbbbbbbbbbbbbbbbb :")
- # print(b)
- # print("b.data :")
- # print(b.data)
- # print("**********************************************************************************")
- predicted_adj = torch.cat((a, b.cpu()), dim=0)
- predicted_adj = decode_adj(predicted_adj.numpy())
- predicted_adj_tensor = torch.from_numpy(predicted_adj + np.identity(len(predicted_adj[0]))).float().to(device)
- predicted_adj_tensor.requires_grad = True
- # predicted_adj_tensor = predicted_adj_tensor + 1
- # print("*********** predicted_adj_tensor *******")
- # print(predicted_adj_tensor+0)
- predicted_adj_batch.append(predicted_adj_tensor+0)
- return true_adj_batch, predicted_adj_batch
-
-
- def train_mlp_epoch_for_completion(epoch, args, rnn, output, rnn_for_initialization,
- output_for_initialization, data_loader,
- optimizer_rnn, optimizer_output,
- scheduler_rnn, scheduler_output, number_of_missing_nodes):
- rnn.train()
- output.train()
- loss_sum = 0
- for batch_idx, data in enumerate(data_loader):
- rnn.zero_grad()
- output.zero_grad()
- x_unsorted = data['x'].float()
- y_unsorted = data['y'].float()
- data_len = y_len_unsorted = data['len']
-
- # data_index = 0
- # b = y_unsorted[data_index][:y_len_unsorted[data_index] - 1].cpu()
- # tensor_to_graph_converter(b, "Complete Graph")
- # ****
- x_unsorted, y_unsorted, y_len_unsorted, hidden, initial_y, initial_y_len = \
- data_initializer_for_training(x_unsorted, y_unsorted, y_len_unsorted,
- args, rnn_for_initialization, output_for_initialization,
- number_of_missing_nodes, False, 0.96)
- # ____________________________________________________________________________________________
- # x_unsorted1, y_unsorted1, y_len_unsorted1, hidden, initial_y1, initial_y_len1 = \
- # data_initializer_for_training(x_unsorted, y_unsorted, y_len_unsorted,
- # args, rnn_for_initialization, output_for_initialization,
- # number_of_missing_nodes, False, 0.96)
- # ____________________________________________________________________________________________
- # a = initial_y[data_index][:initial_y_len[data_index]-1].cpu()
- # tensor_to_graph_converter(a, "Incomplete Graph")
-
- # b = y_unsorted[data_index][:-1]
- # recons = torch.cat((a, b), dim=0)
- # tensor_to_graph_converter(recons, "Merged Complete Graph")
- # ****
- y_len_max = max(y_len_unsorted)
- x_unsorted = x_unsorted[:, 0:y_len_max, :]
- y_unsorted = y_unsorted[:, 0:y_len_max, :]
- # initialize lstm hidden state according to batch size
- # rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))
- # ****
- rnn.hidden = hidden
- # ****
-
- # sort input
- y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True)
- y_len = y_len.numpy().tolist()
- x = torch.index_select(x_unsorted,0,sort_index)
- y = torch.index_select(y_unsorted,0,sort_index)
- x = Variable(x).cuda()
- y = Variable(y).cuda()
-
- h = rnn(x, pack=True, input_len=y_len)
- y_pred = output(h)
- y_pred = F.sigmoid(y_pred)
-
- # clean
- y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True)
- y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
- # use cross entropy loss
- # print("y_unsorted: ")
- # print(y_unsorted[0])
- # print("y: ")
- # print(y[0].cpu())
- true_adj_batch, predicted_adj_batch = make_batch_of_matrices(initial_y, initial_y_len, y.cpu(),y_pred, len(y_pred))
- # graph_show(nx.from_numpy_matrix(true_adj_batch[data_index]), "Complete2")
- # import time
- # start = time.time()
- # print("@@@@@@@@@@@@@@@@@@@@@@@@@")
- # print(args.loss)
- if args.loss == "graph_matching_loss":
- print(graph_matching_loss)
- # loss = graph_matching_loss(true_adj_batch, predicted_adj_batch, data_len, len(data_len))
- elif args.loss == "matrix_completion_loss":
- print(matrix_completion_loss)
- # loss = matrix_completion_loss(true_adj_batch, predicted_adj_batch, len(data_len),
- # args.number_of_missing_nodes)
- elif args.loss =="GraphRNN_loss":
- # print("GraphRNN_loss")
- # print(y)
- # print(y_pred)
- # loss = F.binary_cross_entropy(torch.from_numpy(predicted_adj_batch[0]), torch.from_numpy(true_adj_batch[0]))
- # loss.requires_grad = True
- # print("true_adj_batch: ")
- # print(true_adj_batch[0])
- # print("y: ")
- # print(y[0])
- # print("predicted_adj_batch: ")
- # print(predicted_adj_batch[0])
- # print("y_pred:")
- # print(y_pred[0]+0)
- # print("**************************************************************")
- # loss = GraphRNN_loss(true_adj_batch, predicted_adj_batch, len(data_len), args.number_of_missing_nodes);
- # loss = binary_cross_entropy_weight(predicted_adj_batch[0], true_adj_batch[0])
- # print("**** GraphRNN_loss: ")
- # print(loss)
- # row_index = [0, 2]
- # col_index = [1, 3]
- # print("**************************")
- # print(y_pred[0][row_index, col_index])
- # print(np.shape(y_pred[0])[1])
- loss = GraphRNN_loss2(true_adj_batch, predicted_adj_batch, y_pred, y, len(y_pred), args.number_of_missing_nodes)
- # loss = binary_cross_entropy_weight(y_pred[0][row_index, col_index], y[0][row_index, col_index])
- # print("**** main_loss: ")
- # print(loss)
- # print("*** epoch : " + str(epoch) + "*** loss: " + str(loss.data))
- # break
- # end = time.time()
- # print(end - start)
- # print("y_pred")
- # print(y_pred[0])
- # print("y")
- # print(y[0])
- # print("***************************************************")
- # loss = binary_cross_entropy_weight(y_pred, y)
- loss.backward()
- # update deterministic and lstm
- optimizer_output.step()
- optimizer_rnn.step()
- scheduler_output.step()
- scheduler_rnn.step()
- # for p in rnn.parameters():
- # print(p)
- # print(p.grad)
- # if p.grad is not None:
- # print("________________________________________________________")
- # print(p.grad.data)
- # plot_grad_flow(rnn.named_parameters())
-
-
- if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics
- print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
- epoch, args.epochs,loss.data, args.graph_type, args.num_layers, args.hidden_size_rnn))
-
- log_value(args.loss + args.fname + args.training_mode + "_ "
- "number_of_missing_nodes_" +
- str(number_of_missing_nodes), loss.data, epoch * args.batch_ratio + batch_idx)
- loss_sum += loss.data
- return loss_sum/(batch_idx+1)
-
-
- def train_for_completion(args, dataset_train, rnn, output, rnn_for_initialization,output_for_initialization):
- # check if load existing model
- if args.load:
- fname = args.model_save_path + args.fname + 'lstm_' + args.graph_completion_string + str(args.load_epoch) + '.dat'
- rnn.load_state_dict(torch.load(fname))
- fname = args.model_save_path + args.fname + 'output_' + args.graph_completion_string + str(args.load_epoch) + '.dat'
- output.load_state_dict(torch.load(fname))
-
- args.lr = 0.00001
- epoch = args.load_epoch
- print('model loaded!, lr: {}'.format(args.lr))
- else:
- epoch = 1
-
- # initialize optimizer
- optimizer_rnn = optim.Adam(list(rnn.parameters()), lr=args.lr)
- optimizer_output = optim.Adam(list(output.parameters()), lr=args.lr)
-
- scheduler_rnn = MultiStepLR(optimizer_rnn, milestones=args.milestones, gamma=args.lr_rate)
- scheduler_output = MultiStepLR(optimizer_output, milestones=args.milestones, gamma=args.lr_rate)
-
- # start main loop
- time_all = np.zeros(args.epochs)
- while epoch<=args.epochs:
- time_start = tm.time()
- # train
- if 'GraphRNN_VAE' in args.note:
- train_vae_epoch(epoch, args, rnn, output, dataset_train,
- optimizer_rnn, optimizer_output,
- scheduler_rnn, scheduler_output)
- elif 'GraphRNN_MLP' in args.note:
- train_mlp_epoch_for_completion(epoch, args, rnn, output, rnn_for_initialization,
- output_for_initialization,dataset_train,
- optimizer_rnn, optimizer_output,
- scheduler_rnn, scheduler_output, args.number_of_missing_nodes)
- elif 'GraphRNN_RNN' in args.note:
- train_rnn_epoch(epoch, args, rnn, output, dataset_train,
- optimizer_rnn, optimizer_output,
- scheduler_rnn, scheduler_output)
- time_end = tm.time()
- time_all[epoch - 1] = time_end - time_start
- # save model checkpoint
- if args.save:
- if epoch % args.epochs_save == 0:
- fname = args.model_save_path + args.fname + 'lstm_' + args.graph_completion_string + str(epoch) + '.dat'
- torch.save(rnn.state_dict(), fname)
- fname = args.model_save_path + args.fname + 'output_' + args.graph_completion_string + str(epoch) + '.dat'
- torch.save(output.state_dict(), fname)
- epoch += 1
- np.save(args.timing_save_path+args.fname,time_all)
-
-
- def hidden_initializer(x, rnn, output):
- test_batch_size = len(x)
- max_num_node = len(x[0])
- rnn.hidden = rnn.init_hidden(test_batch_size)
- rnn.eval()
- output.eval()
- for i in range(max_num_node):
- h = rnn((x[:, i:i+1, :]).cuda())
- rnn.hidden = Variable(rnn.hidden.data).cuda()
- return rnn.hidden
-
-
- def data_initializer_for_training(x_unsorted, y_unsorted, y_len_unsorted, args, rnn, output,
- number_of_missing_nodes, ratio_mode, incompleteness_ratio):
- # print(number_of_missing_nodes)
- batch_size = len(x_unsorted)
- max_prev_node = len(x_unsorted[0][0])
- max_number_of_nodes = int(max(y_len_unsorted))
- if ratio_mode:
- max_incomplete_num_node = int(max_number_of_nodes*incompleteness_ratio)
- else:
- max_incomplete_num_node = max_number_of_nodes-number_of_missing_nodes
- initial_x = Variable(torch.zeros(batch_size, max_incomplete_num_node, max_prev_node)).cuda()
- initial_y = Variable(torch.zeros(batch_size, max_incomplete_num_node, max_prev_node)).cuda()
- initial_y_len = Variable(torch.zeros(len(y_len_unsorted))).type(torch.LongTensor)
- for i in range(len(y_len_unsorted)):
- if ratio_mode:
- incomplete_graph_size = int(int(y_len_unsorted.data[i]) * incompleteness_ratio)
- else:
- incomplete_graph_size = int(y_len_unsorted.data[i]) - number_of_missing_nodes
- initial_y_len[i] = incomplete_graph_size
- initial_x[i] = torch.cat((x_unsorted[i,:incomplete_graph_size],
- torch.zeros([max_incomplete_num_node - incomplete_graph_size, max_prev_node])), dim=0)
- initial_y[i] = torch.cat((y_unsorted[i,:incomplete_graph_size-1],
- torch.zeros([max_incomplete_num_node - incomplete_graph_size+1, max_prev_node])), dim=0)
- train_y_len = y_len_unsorted - initial_y_len
- max_train_num_node = int(max(train_y_len)) + 1
- temp_size_reduction = 1
- train_x = Variable(torch.ones(batch_size, max_train_num_node-temp_size_reduction, max_prev_node))
- train_y = Variable(torch.ones(batch_size, max_train_num_node-temp_size_reduction, max_prev_node))
-
- for i in range(len(train_y_len)):
- if ratio_mode:
- incomplete_graph_size = int(int(y_len_unsorted.data[i]) * incompleteness_ratio)
- else:
- incomplete_graph_size = int(y_len_unsorted.data[i]) - number_of_missing_nodes
- train_graph_size = int(train_y_len.data[i])
- train_x[i, 1-temp_size_reduction:,:] = torch.cat((x_unsorted[i,incomplete_graph_size:incomplete_graph_size + train_graph_size]
- , torch.zeros([max_train_num_node - train_graph_size-1, max_prev_node])), dim=0)
- train_y[i, :, :] = torch.cat((x_unsorted[i, incomplete_graph_size+temp_size_reduction:incomplete_graph_size + train_graph_size]
- , torch.zeros([max_train_num_node - train_graph_size, max_prev_node])), dim=0)
-
- hidden = hidden_initializer(initial_x, rnn, output)
-
- # just for testing
-
- # print(y_len_unsorted)
- # print(initial_y_len)
- # print(train_y_len)
- # print("Complete Data : 00000000000000000000000000000000000000000000000000")
- # print('x:')
- # print(x_unsorted[0].int())
- # print('y:')
- # print(y_unsorted[0].int())
- # print("Data for Initialization : 11111111111111111111111111111111111111111111111111")
- # print('x:')
- # print(initial_x[0].int())
- # print('y:')
- # print(initial_y[0].int())
- # print('Data for Training : 22222222222222222222222222222222222222222222222222')
- # print('x:')
- # print(train_x[0].int())
- # print('y:')
- # print(train_y[0].int())
- # print('33333333333333333333333333333333333333333333333333')
-
- # print(hidden.size())
- return train_x, train_y, train_y_len, hidden, initial_y, initial_y_len
-
-
- def save_graph(graph, name):
- adj_pred = decode_adj(graph.cpu().numpy())
- G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
- G = np.asarray(nx.to_numpy_matrix(G_pred))
- np.savetxt(name + '.txt', G, fmt='%d')
-
-
- def data_to_graph_converter(data):
- G_list = []
- for i in range(len(data)):
- x = data[i].numpy()
- x = x.astype(int)
- adj_pred = decode_adj(x)
- G = get_graph(adj_pred) # get a graph from zero-padded adj
- G_list.append(G)
- return G_list
-
-
- if __name__ == '__main__':
- # All necessary arguments are defined in args.py
- args = Args()
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
- print('CUDA', args.cuda)
- print('File name prefix', args.fname)
- # check if necessary directories exist
- if not os.path.isdir(args.model_save_path):
- os.makedirs(args.model_save_path)
- if not os.path.isdir(args.graph_save_path):
- os.makedirs(args.graph_save_path)
- if not os.path.isdir(args.figure_save_path):
- os.makedirs(args.figure_save_path)
- if not os.path.isdir(args.timing_save_path):
- os.makedirs(args.timing_save_path)
- if not os.path.isdir(args.figure_prediction_save_path):
- os.makedirs(args.figure_prediction_save_path)
- if not os.path.isdir(args.nll_save_path):
- os.makedirs(args.nll_save_path)
-
- time = strftime("%Y-%m-%d %H:%M:%S", gmtime())
- # logging.basicConfig(filename='logs/train' + time + '.log', level=logging.DEBUG)
- if args.clean_tensorboard:
- if os.path.isdir("tensorboard"):
- shutil.rmtree("tensorboard")
- configure("tensorboard/run" + time, flush_secs=5)
-
- graphs = create_graphs.create(args)
-
- random.seed(123)
- shuffle(graphs)
- graphs_len = len(graphs)
- graphs_test = graphs[int(0.8 * graphs_len):]
- graphs_train = graphs[0:int(0.8 * graphs_len)]
- graphs_validate = graphs[0:int(0.2 * graphs_len)]
-
- graph_validate_len = 0
- for graph in graphs_validate:
- graph_validate_len += graph.number_of_nodes()
- graph_validate_len /= len(graphs_validate)
- print('graph_validate_len', graph_validate_len)
-
- graph_test_len = 0
- for graph in graphs_test:
- graph_test_len += graph.number_of_nodes()
- graph_test_len /= len(graphs_test)
- print('graph_test_len', graph_test_len)
-
-
-
- args.max_num_node = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
- max_num_edge = max([graphs[i].number_of_edges() for i in range(len(graphs))])
- min_num_edge = min([graphs[i].number_of_edges() for i in range(len(graphs))])
-
- # args.max_num_node = 2000
- # show graphs statistics
- print('total graph num: {}, training set: {}'.format(len(graphs),len(graphs_train)))
- print('max number node: {}'.format(args.max_num_node))
- print('max/min number edge: {}; {}'.format(max_num_edge,min_num_edge))
- print('max previous node: {}'.format(args.max_prev_node))
-
- # save ground truth graphs
- ## To get train and test set, after loading you need to manually slice
- save_graph_list(graphs, args.graph_save_path + args.fname_train + '0.dat')
- save_graph_list(graphs, args.graph_save_path + args.fname_test + '0.dat')
- print('train and test graphs saved at: ', args.graph_save_path + args.fname_test + '0.dat')
-
- if 'nobfs' in args.note:
- print('nobfs')
- args.max_prev_node = args.max_num_node - 1
- dataset = Graph_sequence_sampler_pytorch_nobfs_for_completion(graphs_train,
- max_num_node=args.max_num_node)
- else:
- dataset = Graph_sequence_sampler_pytorch(graphs_train,max_prev_node=args.max_prev_node,max_num_node=args.max_num_node)
- sample_strategy = torch.utils.data.sampler.WeightedRandomSampler([1.0 / len(dataset) for i in range(len(dataset))],
- num_samples=args.batch_size * args.batch_ratio, replacement=True)
- dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers,
- sampler=sample_strategy)
-
- rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
- hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
- has_output=False).cuda()
- output = MLP_plain(h_size=args.hidden_size_rnn, embedding_size=args.embedding_size_output,
- y_size=args.max_prev_node).cuda()
-
- rnn_for_initialization = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
- hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
- has_output=False).cuda()
- output_for_initialization = MLP_plain(h_size=args.hidden_size_rnn, embedding_size=args.embedding_size_output,
- y_size=args.max_prev_node).cuda()
- #
- fname = args.model_save_path + args.initial_fname + 'lstm_' + str(args.load_epoch) + '.dat'
- rnn_for_initialization.load_state_dict(torch.load(fname))
- rnn.load_state_dict(torch.load(fname))
- fname = args.model_save_path + args.initial_fname + 'output_' + str(args.load_epoch) + '.dat'
- output_for_initialization.load_state_dict(torch.load(fname))
- output.load_state_dict(torch.load(fname))
-
- # weight initialization
-
- # rnn_for_initialization_dictionary = dict(rnn_for_initialization.named_parameters())
- # for name, param in rnn.named_parameters():
- # # print(name)
- # print(1)
- # print(param)
- # param.data = rnn_for_initialization_dictionary[name]
- # print(2)
- # print(param)
- # output_for_initialization_dictionary = dict(output_for_initialization.named_parameters())
- # for name, param in output.named_parameters():
- # # print(name)
- # print(1)
- # print(param)
- # param.data = output_for_initialization_dictionary[name]
- # print(1)
- # print(param)
-
- # args.lr = 0.00001
- # epoch = args.load_epoch
- print('model loaded!, lr: {}'.format(args.lr))
-
- # train_for_completion(args, dataset_loader, rnn, output, rnn,
- # output)
- train_for_completion(args, dataset_loader, rnn, output, rnn_for_initialization,
- output_for_initialization)
-
- # for batch_idx, data in enumerate(dataset_loader):
- # if batch_idx==0:
- # rnn.zero_grad()
- # output.zero_grad()
- # x_unsorted = data['x'].float()
- # y_unsorted = data['y'].float()
- # y_len_unsorted = data['len']
- # train_x, train_y, train_y_len, hidden = \
- # data_initializer_for_training(x_unsorted, y_unsorted, y_len_unsorted,
- # args, rnn, output)
|