| @@ -5,7 +5,7 @@ class Args(): | |||
| ### if clean tensorboard | |||
| self.clean_tensorboard = False | |||
| ### Which CUDA GPU device is used for training | |||
| self.cuda = 1 | |||
| self.cuda = 0 | |||
| ### Which GraphRNN model variant is used. | |||
| # The simple version of Graph RNN | |||
| @@ -88,7 +88,7 @@ class Args(): | |||
| self.load = False # if load model, default lr is very low | |||
| self.load_epoch = 3000 | |||
| self.load_epoch = 100 | |||
| self.save = True | |||
| @@ -1,14 +1,15 @@ | |||
| import argparse | |||
| import numpy as np | |||
| import logging | |||
| import os | |||
| import re | |||
| from random import shuffle | |||
| from time import strftime, gmtime | |||
| import eval.stats | |||
| import utils | |||
| # import main.Args | |||
| from args import Args | |||
| from baselines.baseline_simple import * | |||
| class Args_evaluate(): | |||
| def __init__(self): | |||
| # loop over the settings | |||
| @@ -19,19 +20,22 @@ class Args_evaluate(): | |||
| # list of dataset to evaluate | |||
| # use a list of 1 element to evaluate a single dataset | |||
| self.dataset_name_all = ['caveman', 'grid', 'barabasi', 'citeseer', 'DD'] | |||
| # self.dataset_name_all = ['caveman', 'grid', 'barabasi', 'citeseer', 'DD'] | |||
| self.dataset_name_all = ['caveman_small'] | |||
| # self.dataset_name_all = ['citeseer_small','caveman_small'] | |||
| # self.dataset_name_all = ['barabasi_noise0','barabasi_noise2','barabasi_noise4','barabasi_noise6','barabasi_noise8','barabasi_noise10'] | |||
| # self.dataset_name_all = ['caveman_small', 'ladder_small', 'grid_small', 'ladder_small', 'enzymes_small', 'barabasi_small','citeseer_small'] | |||
| self.epoch_start=100 | |||
| self.epoch_end=3001 | |||
| self.epoch_step=100 | |||
| self.epoch_start = 10 | |||
| self.epoch_end = 101 | |||
| self.epoch_step = 10 | |||
| def find_nearest_idx(array,value): | |||
| idx = (np.abs(array-value)).argmin() | |||
| def find_nearest_idx(array, value): | |||
| idx = (np.abs(array - value)).argmin() | |||
| return idx | |||
| def extract_result_id_and_epoch(name, prefix, suffix): | |||
| ''' | |||
| Args: | |||
| @@ -68,7 +72,7 @@ def eval_list(real_graphs_filename, pred_graphs_filename, prefix, eval_every): | |||
| if result_id not in pred_graphs_dict: | |||
| pred_graphs_dict[result_id] = {} | |||
| pred_graphs_dict[result_id][epochs] = fname | |||
| for result_id in real_graphs_dict.keys(): | |||
| for epochs in sorted(real_graphs_dict[result_id]): | |||
| real_g_list = utils.load_graph_list(real_graphs_dict[result_id][epochs]) | |||
| @@ -77,16 +81,16 @@ def eval_list(real_graphs_filename, pred_graphs_filename, prefix, eval_every): | |||
| shuffle(pred_g_list) | |||
| perturbed_g_list = perturb(real_g_list, 0.05) | |||
| #dist = eval.stats.degree_stats(real_g_list, pred_g_list) | |||
| # dist = eval.stats.degree_stats(real_g_list, pred_g_list) | |||
| dist = eval.stats.clustering_stats(real_g_list, pred_g_list) | |||
| print('dist between real and pred (', result_id, ') at epoch ', epochs, ': ', dist) | |||
| #dist = eval.stats.degree_stats(real_g_list, perturbed_g_list) | |||
| # dist = eval.stats.degree_stats(real_g_list, perturbed_g_list) | |||
| dist = eval.stats.clustering_stats(real_g_list, perturbed_g_list) | |||
| print('dist between real and perturbed: ', dist) | |||
| mid = len(real_g_list) // 2 | |||
| #dist = eval.stats.degree_stats(real_g_list[:mid], real_g_list[mid:]) | |||
| # dist = eval.stats.degree_stats(real_g_list[:mid], real_g_list[mid:]) | |||
| dist = eval.stats.clustering_stats(real_g_list[:mid], real_g_list[mid:]) | |||
| print('dist among real: ', dist) | |||
| @@ -97,7 +101,6 @@ def compute_basic_stats(real_g_list, target_g_list): | |||
| return dist_degree, dist_clustering | |||
| def clean_graphs(graph_real, graph_pred): | |||
| ''' Selecting graphs generated that have the similar sizes. | |||
| It is usually necessary for GraphRNN-S version, but not the full GraphRNN model. | |||
| @@ -120,6 +123,7 @@ def clean_graphs(graph_real, graph_pred): | |||
| return graph_real, pred_graph_new | |||
| def load_ground_truth(dir_input, dataset_name, model_name='GraphRNN_RNN'): | |||
| ''' Read ground truth graphs. | |||
| ''' | |||
| @@ -127,20 +131,21 @@ def load_ground_truth(dir_input, dataset_name, model_name='GraphRNN_RNN'): | |||
| hidden = 128 | |||
| else: | |||
| hidden = 64 | |||
| if model_name=='Internal' or model_name=='Noise' or model_name=='B-A' or model_name=='E-R': | |||
| if model_name == 'Internal' or model_name == 'Noise' or model_name == 'B-A' or model_name == 'E-R': | |||
| fname_test = dir_input + 'GraphRNN_MLP' + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | |||
| hidden) + '_test_' + str(0) + '.dat' | |||
| hidden) + '_test_' + str(0) + '.dat' | |||
| else: | |||
| fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | |||
| hidden) + '_test_' + str(0) + '.dat' | |||
| hidden) + '_test_' + str(0) + '.dat' | |||
| try: | |||
| graph_test = utils.load_graph_list(fname_test,is_real=True) | |||
| graph_test = utils.load_graph_list(fname_test, is_real=True) | |||
| except: | |||
| print('Not found: ' + fname_test) | |||
| logging.warning('Not found: ' + fname_test) | |||
| return None | |||
| return graph_test | |||
| def eval_single_list(graphs, dir_input, dataset_name): | |||
| ''' Evaluate a list of graphs by comparing with graphs in directory dir_input. | |||
| Args: | |||
| @@ -149,7 +154,7 @@ def eval_single_list(graphs, dir_input, dataset_name): | |||
| ''' | |||
| graph_test = load_ground_truth(dir_input, dataset_name) | |||
| graph_test_len = len(graph_test) | |||
| graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set | |||
| graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set | |||
| mmd_degree = eval.stats.degree_stats(graph_test, graphs) | |||
| mmd_clustering = eval.stats.clustering_stats(graph_test, graphs) | |||
| try: | |||
| @@ -160,9 +165,12 @@ def eval_single_list(graphs, dir_input, dataset_name): | |||
| print('clustering: ', mmd_clustering) | |||
| print('orbits: ', mmd_4orbits) | |||
| def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is_clean=True, epoch_start=1000,epoch_end=3001,epoch_step=100): | |||
| def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is_clean=True, epoch_start=1000, | |||
| epoch_end=3001, epoch_step=100): | |||
| with open(fname_output, 'w+') as f: | |||
| f.write('sample_time,epoch,degree_validate,clustering_validate,orbits4_validate,degree_test,clustering_test,orbits4_test\n') | |||
| f.write( | |||
| 'sample_time,epoch,degree_validate,clustering_validate,orbits4_validate,degree_test,clustering_test,orbits4_test\n') | |||
| # TODO: Maybe refactor into a separate file/function that specifies THE naming convention | |||
| # across main and evaluate | |||
| @@ -171,7 +179,7 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is | |||
| else: | |||
| hidden = 64 | |||
| # read real graph | |||
| if model_name=='Internal' or model_name=='Noise' or model_name=='B-A' or model_name=='E-R': | |||
| if model_name == 'Internal' or model_name == 'Noise' or model_name == 'B-A' or model_name == 'E-R': | |||
| fname_test = dir_input + 'GraphRNN_MLP' + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | |||
| hidden) + '_test_' + str(0) + '.dat' | |||
| elif 'Baseline' in model_name: | |||
| @@ -180,37 +188,37 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is | |||
| fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | |||
| hidden) + '_test_' + str(0) + '.dat' | |||
| try: | |||
| graph_test = utils.load_graph_list(fname_test,is_real=True) | |||
| graph_test = utils.load_graph_list(fname_test, is_real=True) | |||
| except: | |||
| print('Not found: ' + fname_test) | |||
| logging.warning('Not found: ' + fname_test) | |||
| return None | |||
| graph_test_len = len(graph_test) | |||
| graph_train = graph_test[0:int(0.8 * graph_test_len)] # train | |||
| graph_validate = graph_test[0:int(0.2 * graph_test_len)] # validate | |||
| graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set | |||
| graph_train = graph_test[0:int(0.8 * graph_test_len)] # train | |||
| graph_validate = graph_test[0:int(0.2 * graph_test_len)] # validate | |||
| graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set | |||
| graph_test_aver = 0 | |||
| for graph in graph_test: | |||
| graph_test_aver+=graph.number_of_nodes() | |||
| graph_test_aver += graph.number_of_nodes() | |||
| graph_test_aver /= len(graph_test) | |||
| print('test average len',graph_test_aver) | |||
| print('test average len', graph_test_aver) | |||
| # get performance for proposed approaches | |||
| if 'GraphRNN' in model_name: | |||
| # read test graph | |||
| for epoch in range(epoch_start,epoch_end,epoch_step): | |||
| for sample_time in range(1,4): | |||
| for epoch in range(epoch_start, epoch_end, epoch_step): | |||
| for sample_time in range(1, 4): | |||
| # get filename | |||
| fname_pred = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(hidden) + '_pred_' + str(epoch) + '_' + str(sample_time) + '.dat' | |||
| fname_pred = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | |||
| hidden) + '_pred_' + str(epoch) + '_' + str(sample_time) + '.dat' | |||
| # load graphs | |||
| try: | |||
| graph_pred = utils.load_graph_list(fname_pred,is_real=False) # default False | |||
| graph_pred = utils.load_graph_list(fname_pred, is_real=False) # default False | |||
| except: | |||
| print('Not found: '+ fname_pred) | |||
| logging.warning('Not found: '+ fname_pred) | |||
| print('Not found: ' + fname_pred) | |||
| logging.warning('Not found: ' + fname_pred) | |||
| continue | |||
| # clean graphs | |||
| if is_clean: | |||
| @@ -243,15 +251,15 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is | |||
| except: | |||
| mmd_4orbits_validate = -1 | |||
| # write results | |||
| f.write(str(sample_time)+','+ | |||
| str(epoch)+','+ | |||
| str(mmd_degree_validate)+','+ | |||
| str(mmd_clustering_validate)+','+ | |||
| str(mmd_4orbits_validate)+','+ | |||
| str(mmd_degree)+','+ | |||
| str(mmd_clustering)+','+ | |||
| str(mmd_4orbits)+'\n') | |||
| print('degree',mmd_degree,'clustering',mmd_clustering,'orbits',mmd_4orbits) | |||
| f.write(str(sample_time) + ',' + | |||
| str(epoch) + ',' + | |||
| str(mmd_degree_validate) + ',' + | |||
| str(mmd_clustering_validate) + ',' + | |||
| str(mmd_4orbits_validate) + ',' + | |||
| str(mmd_degree) + ',' + | |||
| str(mmd_clustering) + ',' + | |||
| str(mmd_4orbits) + '\n') | |||
| print('degree', mmd_degree, 'clustering', mmd_clustering, 'orbits', mmd_4orbits) | |||
| # get internal MMD (MMD between ground truth validation and test sets) | |||
| if model_name == 'Internal': | |||
| @@ -265,7 +273,6 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is | |||
| mmd_clustering_validate) + ',' + str(mmd_4orbits_validate) | |||
| + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + '\n') | |||
| # get MMD between ground truth and its perturbed graphs | |||
| if model_name == 'Noise': | |||
| graph_validate_perturbed = perturb(graph_validate, 0.05) | |||
| @@ -281,7 +288,7 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is | |||
| # get E-R MMD | |||
| if model_name == 'E-R': | |||
| graph_pred = Graph_generator_baseline(graph_train,generator='Gnp') | |||
| graph_pred = Graph_generator_baseline(graph_train, generator='Gnp') | |||
| # clean graphs | |||
| if is_clean: | |||
| graph_test, graph_pred = clean_graphs(graph_test, graph_pred) | |||
| @@ -296,7 +303,6 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is | |||
| f.write(str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) | |||
| + ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits_validate) + '\n') | |||
| # get B-A MMD | |||
| if model_name == 'B-A': | |||
| graph_pred = Graph_generator_baseline(graph_train, generator='BA') | |||
| @@ -364,33 +370,29 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is | |||
| + ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits) + '\n') | |||
| print('degree', mmd_degree, 'clustering', mmd_clustering, 'orbits', mmd_4orbits) | |||
| return True | |||
| def evaluation(args_evaluate,dir_input, dir_output, model_name_all, dataset_name_all, args, overwrite = True): | |||
| def evaluation(args_evaluate, dir_input, dir_output, model_name_all, dataset_name_all, args, overwrite=True): | |||
| ''' Evaluate the performance of a set of models on a set of datasets. | |||
| ''' | |||
| for model_name in model_name_all: | |||
| for dataset_name in dataset_name_all: | |||
| # check output exist | |||
| fname_output = dir_output+model_name+'_'+dataset_name+'.csv' | |||
| print('processing: '+dir_output + model_name + '_' + dataset_name + '.csv') | |||
| logging.info('processing: '+dir_output + model_name + '_' + dataset_name + '.csv') | |||
| if overwrite==False and os.path.isfile(fname_output): | |||
| print(dir_output+model_name+'_'+dataset_name+'.csv exists!') | |||
| logging.info(dir_output+model_name+'_'+dataset_name+'.csv exists!') | |||
| fname_output = dir_output + model_name + '_' + dataset_name + '.csv' | |||
| print('processing: ' + dir_output + model_name + '_' + dataset_name + '.csv') | |||
| logging.info('processing: ' + dir_output + model_name + '_' + dataset_name + '.csv') | |||
| if overwrite == False and os.path.isfile(fname_output): | |||
| print(dir_output + model_name + '_' + dataset_name + '.csv exists!') | |||
| logging.info(dir_output + model_name + '_' + dataset_name + '.csv exists!') | |||
| continue | |||
| evaluation_epoch(dir_input,fname_output,model_name,dataset_name,args,is_clean=True, epoch_start=args_evaluate.epoch_start,epoch_end=args_evaluate.epoch_end,epoch_step=args_evaluate.epoch_step) | |||
| evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is_clean=True, | |||
| epoch_start=args_evaluate.epoch_start, epoch_end=args_evaluate.epoch_end, | |||
| epoch_step=args_evaluate.epoch_step) | |||
| def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||
| eval_every, epoch_range=None, out_file_prefix=None): | |||
| eval_every, epoch_range=None, out_file_prefix=None): | |||
| ''' Evaluate list of predicted graphs compared to ground truth, stored in files. | |||
| Args: | |||
| baselines: dict mapping name of the baseline to list of generated graphs. | |||
| @@ -398,12 +400,12 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||
| if out_file_prefix is not None: | |||
| out_files = { | |||
| 'train': open(out_file_prefix + '_train.txt', 'w+'), | |||
| 'compare': open(out_file_prefix + '_compare.txt', 'w+') | |||
| 'train': open(out_file_prefix + '_train.txt', 'w+'), | |||
| 'compare': open(out_file_prefix + '_compare.txt', 'w+') | |||
| } | |||
| out_files['train'].write('degree,clustering,orbits4\n') | |||
| line = 'metric,real,ours,perturbed' | |||
| for bl in baselines: | |||
| line += ',' + bl | |||
| @@ -411,34 +413,33 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||
| out_files['compare'].write(line) | |||
| results = { | |||
| 'deg': { | |||
| 'real': 0, | |||
| 'ours': 100, # take min over all training epochs | |||
| 'perturbed': 0, | |||
| 'kron': 0}, | |||
| 'clustering': { | |||
| 'real': 0, | |||
| 'ours': 100, | |||
| 'perturbed': 0, | |||
| 'kron': 0}, | |||
| 'orbits4': { | |||
| 'real': 0, | |||
| 'ours': 100, | |||
| 'perturbed': 0, | |||
| 'kron': 0} | |||
| 'deg': { | |||
| 'real': 0, | |||
| 'ours': 100, # take min over all training epochs | |||
| 'perturbed': 0, | |||
| 'kron': 0}, | |||
| 'clustering': { | |||
| 'real': 0, | |||
| 'ours': 100, | |||
| 'perturbed': 0, | |||
| 'kron': 0}, | |||
| 'orbits4': { | |||
| 'real': 0, | |||
| 'ours': 100, | |||
| 'perturbed': 0, | |||
| 'kron': 0} | |||
| } | |||
| num_evals = len(pred_graphs_filename) | |||
| if epoch_range is None: | |||
| epoch_range = [i * eval_every for i in range(num_evals)] | |||
| epoch_range = [i * eval_every for i in range(num_evals)] | |||
| for i in range(num_evals): | |||
| real_g_list = utils.load_graph_list(real_graph_filename) | |||
| #pred_g_list = utils.load_graph_list(pred_graphs_filename[i]) | |||
| # pred_g_list = utils.load_graph_list(pred_graphs_filename[i]) | |||
| # contains all predicted G | |||
| pred_g_list_raw = utils.load_graph_list(pred_graphs_filename[i]) | |||
| if len(real_g_list)>200: | |||
| if len(real_g_list) > 200: | |||
| real_g_list = real_g_list[0:200] | |||
| shuffle(real_g_list) | |||
| @@ -448,10 +449,9 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||
| real_g_len_list = np.array([len(real_g_list[i]) for i in range(len(real_g_list))]) | |||
| pred_g_len_list_raw = np.array([len(pred_g_list_raw[i]) for i in range(len(pred_g_list_raw))]) | |||
| # get perturb real | |||
| #perturbed_g_list_001 = perturb(real_g_list, 0.01) | |||
| # perturbed_g_list_001 = perturb(real_g_list, 0.01) | |||
| perturbed_g_list_005 = perturb(real_g_list, 0.05) | |||
| #perturbed_g_list_010 = perturb(real_g_list, 0.10) | |||
| # perturbed_g_list_010 = perturb(real_g_list, 0.10) | |||
| # select pred samples | |||
| # The number of nodes are sampled from the similar distribution as the training set | |||
| @@ -482,22 +482,22 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||
| # ======================================== | |||
| mid = len(real_g_list) // 2 | |||
| dist_degree, dist_clustering = compute_basic_stats(real_g_list[:mid], real_g_list[mid:]) | |||
| #dist_4cycle = eval.stats.motif_stats(real_g_list[:mid], real_g_list[mid:]) | |||
| # dist_4cycle = eval.stats.motif_stats(real_g_list[:mid], real_g_list[mid:]) | |||
| dist_4orbits = eval.stats.orbit_stats_all(real_g_list[:mid], real_g_list[mid:]) | |||
| print('degree dist among real: ', dist_degree) | |||
| print('clustering dist among real: ', dist_clustering) | |||
| #print('4 cycle dist among real: ', dist_4cycle) | |||
| # print('4 cycle dist among real: ', dist_4cycle) | |||
| print('orbits dist among real: ', dist_4orbits) | |||
| results['deg']['real'] += dist_degree | |||
| results['clustering']['real'] += dist_clustering | |||
| results['orbits4']['real'] += dist_4orbits | |||
| dist_degree, dist_clustering = compute_basic_stats(real_g_list, pred_g_list) | |||
| #dist_4cycle = eval.stats.motif_stats(real_g_list, pred_g_list) | |||
| # dist_4cycle = eval.stats.motif_stats(real_g_list, pred_g_list) | |||
| dist_4orbits = eval.stats.orbit_stats_all(real_g_list, pred_g_list) | |||
| print('degree dist between real and pred at epoch ', epoch_range[i], ': ', dist_degree) | |||
| print('clustering dist between real and pred at epoch ', epoch_range[i], ': ', dist_clustering) | |||
| #print('4 cycle dist between real and pred at epoch: ', epoch_range[i], dist_4cycle) | |||
| # print('4 cycle dist between real and pred at epoch: ', epoch_range[i], dist_4cycle) | |||
| print('orbits dist between real and pred at epoch ', epoch_range[i], ': ', dist_4orbits) | |||
| results['deg']['ours'] = min(dist_degree, results['deg']['ours']) | |||
| results['clustering']['ours'] = min(dist_clustering, results['clustering']['ours']) | |||
| @@ -509,11 +509,11 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||
| out_files['train'].write(str(dist_4orbits) + ',') | |||
| dist_degree, dist_clustering = compute_basic_stats(real_g_list, perturbed_g_list_005) | |||
| #dist_4cycle = eval.stats.motif_stats(real_g_list, perturbed_g_list_005) | |||
| # dist_4cycle = eval.stats.motif_stats(real_g_list, perturbed_g_list_005) | |||
| dist_4orbits = eval.stats.orbit_stats_all(real_g_list, perturbed_g_list_005) | |||
| print('degree dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_degree) | |||
| print('clustering dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_clustering) | |||
| #print('4 cycle dist between real and perturbed at epoch: ', epoch_range[i], dist_4cycle) | |||
| # print('4 cycle dist between real and perturbed at epoch: ', epoch_range[i], dist_4cycle) | |||
| print('orbits dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_4orbits) | |||
| results['deg']['perturbed'] += dist_degree | |||
| results['clustering']['perturbed'] += dist_clustering | |||
| @@ -527,8 +527,8 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||
| results['deg'][baseline] = dist_degree | |||
| results['clustering'][baseline] = dist_clustering | |||
| results['orbits4'][baseline] = dist_4orbits | |||
| print('Kron: deg=', dist_degree, ', clustering=', dist_clustering, | |||
| ', orbits4=', dist_4orbits) | |||
| print('Kron: deg=', dist_degree, ', clustering=', dist_clustering, | |||
| ', orbits4=', dist_4orbits) | |||
| out_files['train'].write('\n') | |||
| @@ -538,10 +538,10 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||
| # Write results | |||
| for metric, methods in results.items(): | |||
| line = metric+','+ \ | |||
| str(methods['real'])+','+ \ | |||
| str(methods['ours'])+','+ \ | |||
| str(methods['perturbed']) | |||
| line = metric + ',' + \ | |||
| str(methods['real']) + ',' + \ | |||
| str(methods['ours']) + ',' + \ | |||
| str(methods['perturbed']) | |||
| for baseline in baselines: | |||
| line += ',' + str(methods[baseline]) | |||
| line += '\n' | |||
| @@ -553,12 +553,12 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||
| def eval_performance(datadir, prefix=None, args=None, eval_every=200, out_file_prefix=None, | |||
| sample_time = 2, baselines={}): | |||
| sample_time=2, baselines={}): | |||
| if args is None: | |||
| real_graphs_filename = [datadir + f for f in os.listdir(datadir) | |||
| if re.match(prefix + '.*real.*\.dat', f)] | |||
| if re.match(prefix + '.*real.*\.dat', f)] | |||
| pred_graphs_filename = [datadir + f for f in os.listdir(datadir) | |||
| if re.match(prefix + '.*pred.*\.dat', f)] | |||
| if re.match(prefix + '.*pred.*\.dat', f)] | |||
| eval_list(real_graphs_filename, pred_graphs_filename, prefix, 200) | |||
| else: | |||
| @@ -567,28 +567,30 @@ def eval_performance(datadir, prefix=None, args=None, eval_every=200, out_file_p | |||
| # str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(args.bptt_len) + '.dat' for epoch in range(0,50001,eval_every)] | |||
| # pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | |||
| # str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(args.bptt_len) + '.dat' for epoch in range(0,50001,eval_every)] | |||
| real_graph_filename = datadir+args.graph_save_path + args.fname_test + '0.dat' | |||
| real_graph_filename = datadir + args.graph_save_path + args.fname_test + '0.dat' | |||
| # for proposed model | |||
| end_epoch = 3001 | |||
| epoch_range = range(eval_every, end_epoch, eval_every) | |||
| pred_graphs_filename = [datadir+args.graph_save_path + args.fname_pred+str(epoch)+'_'+str(sample_time)+'.dat' | |||
| for epoch in epoch_range] | |||
| pred_graphs_filename = [ | |||
| datadir + args.graph_save_path + args.fname_pred + str(epoch) + '_' + str(sample_time) + '.dat' | |||
| for epoch in epoch_range] | |||
| # for baseline model | |||
| #pred_graphs_filename = [datadir+args.fname_baseline+'.dat'] | |||
| # pred_graphs_filename = [datadir+args.fname_baseline+'.dat'] | |||
| #real_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | |||
| # real_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | |||
| # str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str( | |||
| # args.bptt_len) + '_' + str(args.gumbel) + '.dat' for epoch in range(10000, 50001, eval_every)] | |||
| #pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | |||
| # pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | |||
| # str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str( | |||
| # args.bptt_len) + '_' + str(args.gumbel) + '.dat' for epoch in range(10000, 50001, eval_every)] | |||
| eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||
| epoch_range=epoch_range, | |||
| epoch_range=epoch_range, | |||
| eval_every=eval_every, | |||
| out_file_prefix=out_file_prefix) | |||
| def process_kron(kron_dir): | |||
| txt_files = [] | |||
| for f in os.listdir(kron_dir): | |||
| @@ -602,7 +604,7 @@ def process_kron(kron_dir): | |||
| G_list.append(utils.snap_txt_output_to_nx(os.path.join(kron_dir, filename))) | |||
| return G_list | |||
| if __name__ == '__main__': | |||
| args = Args() | |||
| @@ -612,18 +614,18 @@ if __name__ == '__main__': | |||
| feature_parser = parser.add_mutually_exclusive_group(required=False) | |||
| feature_parser.add_argument('--export-real', dest='export', action='store_true') | |||
| feature_parser.add_argument('--no-export-real', dest='export', action='store_false') | |||
| feature_parser.add_argument('--kron-dir', dest='kron_dir', | |||
| help='Directory where graphs generated by kronecker method is stored.') | |||
| feature_parser.add_argument('--kron-dir', dest='kron_dir', | |||
| help='Directory where graphs generated by kronecker method is stored.') | |||
| parser.add_argument('--testfile', dest='test_file', | |||
| help='The file that stores list of graphs to be evaluated. Only used when 1 list of ' | |||
| 'graphs is to be evaluated.') | |||
| help='The file that stores list of graphs to be evaluated. Only used when 1 list of ' | |||
| 'graphs is to be evaluated.') | |||
| parser.add_argument('--dir-prefix', dest='dir_prefix', | |||
| help='The file that stores list of graphs to be evaluated. Can be used when evaluating multiple' | |||
| 'models on multiple datasets.') | |||
| help='The file that stores list of graphs to be evaluated. Can be used when evaluating multiple' | |||
| 'models on multiple datasets.') | |||
| parser.add_argument('--graph-type', dest='graph_type', | |||
| help='Type of graphs / dataset.') | |||
| help='Type of graphs / dataset.') | |||
| parser.set_defaults(export=False, kron_dir='', test_file='', | |||
| dir_prefix='', | |||
| graph_type=args.graph_type) | |||
| @@ -633,7 +635,6 @@ if __name__ == '__main__': | |||
| # dir_prefix = "/dfs/scratch0/jiaxuany0/" | |||
| dir_prefix = args.dir_input | |||
| time_now = strftime("%Y-%m-%d %H:%M:%S", gmtime()) | |||
| if not os.path.isdir('logs/'): | |||
| os.makedirs('logs/') | |||
| @@ -652,16 +653,16 @@ if __name__ == '__main__': | |||
| if prog_args.graph_type == 'grid': | |||
| graphs = [] | |||
| for i in range(10,20): | |||
| for j in range(10,20): | |||
| graphs.append(nx.grid_2d_graph(i,j)) | |||
| for i in range(10, 20): | |||
| for j in range(10, 20): | |||
| graphs.append(nx.grid_2d_graph(i, j)) | |||
| utils.export_graphs_to_txt(graphs, output_prefix) | |||
| elif prog_args.graph_type == 'caveman': | |||
| graphs = [] | |||
| for i in range(2, 3): | |||
| for j in range(30, 81): | |||
| for k in range(10): | |||
| graphs.append(caveman_special(i,j, p_edge=0.3)) | |||
| graphs.append(caveman_special(i, j, p_edge=0.3)) | |||
| utils.export_graphs_to_txt(graphs, output_prefix) | |||
| elif prog_args.graph_type == 'citeseer': | |||
| graphs = utils.citeseer_ego() | |||
| @@ -679,14 +680,11 @@ if __name__ == '__main__': | |||
| elif not prog_args.test_file == '': | |||
| # evaluate single .dat file containing list of test graphs (networkx format) | |||
| graphs = utils.load_graph_list(prog_args.test_file) | |||
| eval_single_list(graphs, dir_input=dir_prefix+'graphs/', dataset_name='grid') | |||
| eval_single_list(graphs, dir_input=dir_prefix + 'graphs/', dataset_name='grid') | |||
| ## if you don't try kronecker, only the following part is needed | |||
| else: | |||
| if not os.path.isdir(dir_prefix+'eval_results'): | |||
| os.makedirs(dir_prefix+'eval_results') | |||
| evaluation(args_evaluate,dir_input=dir_prefix+"graphs/", dir_output=dir_prefix+"eval_results/", | |||
| model_name_all=args_evaluate.model_name_all,dataset_name_all=args_evaluate.dataset_name_all,args=args,overwrite=True) | |||
| if not os.path.isdir(dir_prefix + 'eval_results'): | |||
| os.makedirs(dir_prefix + 'eval_results') | |||
| evaluation(args_evaluate, dir_input=dir_prefix + "graphs/", dir_output=dir_prefix + "eval_results/", | |||
| model_name_all=args_evaluate.model_name_all, dataset_name_all=args_evaluate.dataset_name_all, | |||
| args=args, overwrite=True) | |||
| @@ -21,11 +21,9 @@ from tensorboard_logger import configure, log_value | |||
| import scipy.misc | |||
| import time as tm | |||
| from utils import * | |||
| from model import * | |||
| from tensorboard_logger import log_value | |||
| from data import * | |||
| from args import Args | |||
| import create_graphs | |||
| def train_vae_epoch(epoch, args, rnn, output, data_loader, | |||
| @@ -47,16 +45,16 @@ def train_vae_epoch(epoch, args, rnn, output, data_loader, | |||
| rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) | |||
| # sort input | |||
| y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) | |||
| y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True) | |||
| y_len = y_len.numpy().tolist() | |||
| x = torch.index_select(x_unsorted,0,sort_index) | |||
| y = torch.index_select(y_unsorted,0,sort_index) | |||
| x = torch.index_select(x_unsorted, 0, sort_index) | |||
| y = torch.index_select(y_unsorted, 0, sort_index) | |||
| x = Variable(x).cuda() | |||
| y = Variable(y).cuda() | |||
| # if using ground truth to train | |||
| h = rnn(x, pack=True, input_len=y_len) | |||
| y_pred,z_mu,z_lsgms = output(h) | |||
| y_pred, z_mu, z_lsgms = output(h) | |||
| y_pred = F.sigmoid(y_pred) | |||
| # clean | |||
| y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True) | |||
| @@ -68,7 +66,7 @@ def train_vae_epoch(epoch, args, rnn, output, data_loader, | |||
| # use cross entropy loss | |||
| loss_bce = binary_cross_entropy_weight(y_pred, y) | |||
| loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp()) | |||
| loss_kl /= y.size(0)*y.size(1)*sum(y_len) # normalize | |||
| loss_kl /= y.size(0) * y.size(1) * sum(y_len) # normalize | |||
| loss = loss_bce + loss_kl | |||
| loss.backward() | |||
| # update deterministic and lstm | |||
| @@ -77,7 +75,6 @@ def train_vae_epoch(epoch, args, rnn, output, data_loader, | |||
| scheduler_output.step() | |||
| scheduler_rnn.step() | |||
| z_mu_mean = torch.mean(z_mu.data) | |||
| z_sgm_mean = torch.mean(z_lsgms.mul(0.5).exp_().data) | |||
| z_mu_min = torch.min(z_mu.data) | |||
| @@ -85,35 +82,39 @@ def train_vae_epoch(epoch, args, rnn, output, data_loader, | |||
| z_mu_max = torch.max(z_mu.data) | |||
| z_sgm_max = torch.max(z_lsgms.mul(0.5).exp_().data) | |||
| if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics | |||
| print('Epoch: {}/{}, train bce loss: {:.6f}, train kl loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | |||
| epoch, args.epochs,loss_bce.data[0], loss_kl.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||
| print('z_mu_mean', z_mu_mean, 'z_mu_min', z_mu_min, 'z_mu_max', z_mu_max, 'z_sgm_mean', z_sgm_mean, 'z_sgm_min', z_sgm_min, 'z_sgm_max', z_sgm_max) | |||
| if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics | |||
| print( | |||
| 'Epoch: {}/{}, train bce loss: {:.6f}, train kl loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | |||
| epoch, args.epochs, loss_bce.data[0], loss_kl.data[0], args.graph_type, args.num_layers, | |||
| args.hidden_size_rnn)) | |||
| print('z_mu_mean', z_mu_mean, 'z_mu_min', z_mu_min, 'z_mu_max', z_mu_max, 'z_sgm_mean', z_sgm_mean, | |||
| 'z_sgm_min', z_sgm_min, 'z_sgm_max', z_sgm_max) | |||
| # logging | |||
| log_value('bce_loss_'+args.fname, loss_bce.data[0], epoch*args.batch_ratio+batch_idx) | |||
| log_value('kl_loss_' +args.fname, loss_kl.data[0], epoch*args.batch_ratio + batch_idx) | |||
| log_value('z_mu_mean_'+args.fname, z_mu_mean, epoch*args.batch_ratio + batch_idx) | |||
| log_value('z_mu_min_'+args.fname, z_mu_min, epoch*args.batch_ratio + batch_idx) | |||
| log_value('z_mu_max_'+args.fname, z_mu_max, epoch*args.batch_ratio + batch_idx) | |||
| log_value('z_sgm_mean_'+args.fname, z_sgm_mean, epoch*args.batch_ratio + batch_idx) | |||
| log_value('z_sgm_min_'+args.fname, z_sgm_min, epoch*args.batch_ratio + batch_idx) | |||
| log_value('z_sgm_max_'+args.fname, z_sgm_max, epoch*args.batch_ratio + batch_idx) | |||
| log_value('bce_loss_' + args.fname, loss_bce.data[0], epoch * args.batch_ratio + batch_idx) | |||
| log_value('kl_loss_' + args.fname, loss_kl.data[0], epoch * args.batch_ratio + batch_idx) | |||
| log_value('z_mu_mean_' + args.fname, z_mu_mean, epoch * args.batch_ratio + batch_idx) | |||
| log_value('z_mu_min_' + args.fname, z_mu_min, epoch * args.batch_ratio + batch_idx) | |||
| log_value('z_mu_max_' + args.fname, z_mu_max, epoch * args.batch_ratio + batch_idx) | |||
| log_value('z_sgm_mean_' + args.fname, z_sgm_mean, epoch * args.batch_ratio + batch_idx) | |||
| log_value('z_sgm_min_' + args.fname, z_sgm_min, epoch * args.batch_ratio + batch_idx) | |||
| log_value('z_sgm_max_' + args.fname, z_sgm_max, epoch * args.batch_ratio + batch_idx) | |||
| loss_sum += loss.data[0] | |||
| return loss_sum/(batch_idx+1) | |||
| return loss_sum / (batch_idx + 1) | |||
| def test_vae_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False, sample_time = 1): | |||
| def test_vae_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False, sample_time=1): | |||
| rnn.hidden = rnn.init_hidden(test_batch_size) | |||
| rnn.eval() | |||
| output.eval() | |||
| # generate graphs | |||
| max_num_node = int(args.max_num_node) | |||
| y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||
| y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
| x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||
| y_pred = Variable( | |||
| torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||
| y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
| x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda() | |||
| for i in range(max_num_node): | |||
| h = rnn(x_step) | |||
| y_pred_step, _, _ = output(h) | |||
| @@ -128,7 +129,7 @@ def test_vae_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram= | |||
| G_pred_list = [] | |||
| for i in range(test_batch_size): | |||
| adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | |||
| G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
| G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
| G_pred_list.append(G_pred) | |||
| # save prediction histograms, plot histogram over each time step | |||
| @@ -137,11 +138,10 @@ def test_vae_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram= | |||
| # fname_pred=args.figure_prediction_save_path+args.fname_pred+str(epoch)+'.jpg', | |||
| # max_num_node=max_num_node) | |||
| return G_pred_list | |||
| def test_vae_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1): | |||
| def test_vae_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False, sample_time=1): | |||
| rnn.eval() | |||
| output.eval() | |||
| G_pred_list = [] | |||
| @@ -153,15 +153,18 @@ def test_vae_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram | |||
| rnn.hidden = rnn.init_hidden(test_batch_size) | |||
| # generate graphs | |||
| max_num_node = int(args.max_num_node) | |||
| y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||
| y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
| x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||
| y_pred = Variable( | |||
| torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||
| y_pred_long = Variable( | |||
| torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
| x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda() | |||
| for i in range(max_num_node): | |||
| print('finish node',i) | |||
| print('finish node', i) | |||
| h = rnn(x_step) | |||
| y_pred_step, _, _ = output(h) | |||
| y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) | |||
| x_step = sample_sigmoid_supervised(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time) | |||
| x_step = sample_sigmoid_supervised(y_pred_step, y[:, i:i + 1, :].cuda(), current=i, y_len=y_len, | |||
| sample_time=sample_time) | |||
| y_pred_long[:, i:i + 1, :] = x_step | |||
| rnn.hidden = Variable(rnn.hidden.data).cuda() | |||
| @@ -171,12 +174,11 @@ def test_vae_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram | |||
| # save graphs as pickle | |||
| for i in range(test_batch_size): | |||
| adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | |||
| G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
| G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
| G_pred_list.append(G_pred) | |||
| return G_pred_list | |||
| def train_mlp_epoch(epoch, args, rnn, output, data_loader, | |||
| optimizer_rnn, optimizer_output, | |||
| scheduler_rnn, scheduler_output): | |||
| @@ -196,10 +198,10 @@ def train_mlp_epoch(epoch, args, rnn, output, data_loader, | |||
| rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) | |||
| # sort input | |||
| y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) | |||
| y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True) | |||
| y_len = y_len.numpy().tolist() | |||
| x = torch.index_select(x_unsorted,0,sort_index) | |||
| y = torch.index_select(y_unsorted,0,sort_index) | |||
| x = torch.index_select(x_unsorted, 0, sort_index) | |||
| y = torch.index_select(y_unsorted, 0, sort_index) | |||
| x = Variable(x).cuda() | |||
| y = Variable(y).cuda() | |||
| @@ -218,28 +220,28 @@ def train_mlp_epoch(epoch, args, rnn, output, data_loader, | |||
| scheduler_output.step() | |||
| scheduler_rnn.step() | |||
| if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics | |||
| if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics | |||
| print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | |||
| epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||
| epoch, args.epochs, loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||
| # logging | |||
| log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx) | |||
| log_value('loss_' + args.fname, loss.data[0], epoch * args.batch_ratio + batch_idx) | |||
| loss_sum += loss.data[0] | |||
| return loss_sum/(batch_idx+1) | |||
| return loss_sum / (batch_idx + 1) | |||
| def test_mlp_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False,sample_time=1): | |||
| def test_mlp_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False, sample_time=1): | |||
| rnn.hidden = rnn.init_hidden(test_batch_size) | |||
| rnn.eval() | |||
| output.eval() | |||
| # generate graphs | |||
| max_num_node = int(args.max_num_node) | |||
| y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||
| y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
| x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||
| y_pred = Variable( | |||
| torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||
| y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
| x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda() | |||
| for i in range(max_num_node): | |||
| h = rnn(x_step) | |||
| y_pred_step = output(h) | |||
| @@ -254,10 +256,9 @@ def test_mlp_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram= | |||
| G_pred_list = [] | |||
| for i in range(test_batch_size): | |||
| adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | |||
| G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
| G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
| G_pred_list.append(G_pred) | |||
| # # save prediction histograms, plot histogram over each time step | |||
| # if save_histogram: | |||
| # save_prediction_histogram(y_pred_data.cpu().numpy(), | |||
| @@ -266,8 +267,7 @@ def test_mlp_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram= | |||
| return G_pred_list | |||
| def test_mlp_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1): | |||
| def test_mlp_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False, sample_time=1): | |||
| rnn.eval() | |||
| output.eval() | |||
| G_pred_list = [] | |||
| @@ -279,15 +279,18 @@ def test_mlp_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram | |||
| rnn.hidden = rnn.init_hidden(test_batch_size) | |||
| # generate graphs | |||
| max_num_node = int(args.max_num_node) | |||
| y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||
| y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
| x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||
| y_pred = Variable( | |||
| torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||
| y_pred_long = Variable( | |||
| torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
| x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda() | |||
| for i in range(max_num_node): | |||
| print('finish node',i) | |||
| print('finish node', i) | |||
| h = rnn(x_step) | |||
| y_pred_step = output(h) | |||
| y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) | |||
| x_step = sample_sigmoid_supervised(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time) | |||
| x_step = sample_sigmoid_supervised(y_pred_step, y[:, i:i + 1, :].cuda(), current=i, y_len=y_len, | |||
| sample_time=sample_time) | |||
| y_pred_long[:, i:i + 1, :] = x_step | |||
| rnn.hidden = Variable(rnn.hidden.data).cuda() | |||
| @@ -297,12 +300,12 @@ def test_mlp_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram | |||
| # save graphs as pickle | |||
| for i in range(test_batch_size): | |||
| adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | |||
| G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
| G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
| G_pred_list.append(G_pred) | |||
| return G_pred_list | |||
| def test_mlp_partial_simple_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1): | |||
| def test_mlp_partial_simple_epoch(epoch, args, rnn, output, data_loader, save_histogram=False, sample_time=1): | |||
| rnn.eval() | |||
| output.eval() | |||
| G_pred_list = [] | |||
| @@ -314,15 +317,18 @@ def test_mlp_partial_simple_epoch(epoch, args, rnn, output, data_loader, save_hi | |||
| rnn.hidden = rnn.init_hidden(test_batch_size) | |||
| # generate graphs | |||
| max_num_node = int(args.max_num_node) | |||
| y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||
| y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
| x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||
| y_pred = Variable( | |||
| torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||
| y_pred_long = Variable( | |||
| torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
| x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda() | |||
| for i in range(max_num_node): | |||
| print('finish node',i) | |||
| print('finish node', i) | |||
| h = rnn(x_step) | |||
| y_pred_step = output(h) | |||
| y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) | |||
| x_step = sample_sigmoid_supervised_simple(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time) | |||
| x_step = sample_sigmoid_supervised_simple(y_pred_step, y[:, i:i + 1, :].cuda(), current=i, y_len=y_len, | |||
| sample_time=sample_time) | |||
| y_pred_long[:, i:i + 1, :] = x_step | |||
| rnn.hidden = Variable(rnn.hidden.data).cuda() | |||
| @@ -332,7 +338,7 @@ def test_mlp_partial_simple_epoch(epoch, args, rnn, output, data_loader, save_hi | |||
| # save graphs as pickle | |||
| for i in range(test_batch_size): | |||
| adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | |||
| G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
| G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
| G_pred_list.append(G_pred) | |||
| return G_pred_list | |||
| @@ -354,10 +360,10 @@ def train_mlp_forward_epoch(epoch, args, rnn, output, data_loader): | |||
| rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) | |||
| # sort input | |||
| y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) | |||
| y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True) | |||
| y_len = y_len.numpy().tolist() | |||
| x = torch.index_select(x_unsorted,0,sort_index) | |||
| y = torch.index_select(y_unsorted,0,sort_index) | |||
| x = torch.index_select(x_unsorted, 0, sort_index) | |||
| y = torch.index_select(y_unsorted, 0, sort_index) | |||
| x = Variable(x).cuda() | |||
| y = Variable(y).cuda() | |||
| @@ -372,22 +378,18 @@ def train_mlp_forward_epoch(epoch, args, rnn, output, data_loader): | |||
| loss = 0 | |||
| for j in range(y.size(1)): | |||
| # print('y_pred',y_pred[0,j,:],'y',y[0,j,:]) | |||
| end_idx = min(j+1,y.size(2)) | |||
| loss += binary_cross_entropy_weight(y_pred[:,j,0:end_idx], y[:,j,0:end_idx])*end_idx | |||
| end_idx = min(j + 1, y.size(2)) | |||
| loss += binary_cross_entropy_weight(y_pred[:, j, 0:end_idx], y[:, j, 0:end_idx]) * end_idx | |||
| if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics | |||
| if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics | |||
| print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | |||
| epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||
| epoch, args.epochs, loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||
| # logging | |||
| log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx) | |||
| log_value('loss_' + args.fname, loss.data[0], epoch * args.batch_ratio + batch_idx) | |||
| loss_sum += loss.data[0] | |||
| return loss_sum/(batch_idx+1) | |||
| return loss_sum / (batch_idx + 1) | |||
| ## too complicated, deprecated | |||
| @@ -449,28 +451,29 @@ def train_rnn_epoch(epoch, args, rnn, output, data_loader, | |||
| # output.hidden = output.init_hidden(batch_size=x_unsorted.size(0)*x_unsorted.size(1)) | |||
| # sort input | |||
| y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) | |||
| y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True) | |||
| y_len = y_len.numpy().tolist() | |||
| x = torch.index_select(x_unsorted,0,sort_index) | |||
| y = torch.index_select(y_unsorted,0,sort_index) | |||
| x = torch.index_select(x_unsorted, 0, sort_index) | |||
| y = torch.index_select(y_unsorted, 0, sort_index) | |||
| # input, output for output rnn module | |||
| # a smart use of pytorch builtin function: pack variable--b1_l1,b2_l1,...,b1_l2,b2_l2,... | |||
| y_reshape = pack_padded_sequence(y,y_len,batch_first=True).data | |||
| y_reshape = pack_padded_sequence(y, y_len, batch_first=True).data | |||
| # reverse y_reshape, so that their lengths are sorted, add dimension | |||
| idx = [i for i in range(y_reshape.size(0)-1, -1, -1)] | |||
| idx = [i for i in range(y_reshape.size(0) - 1, -1, -1)] | |||
| idx = torch.LongTensor(idx) | |||
| y_reshape = y_reshape.index_select(0, idx) | |||
| y_reshape = y_reshape.view(y_reshape.size(0),y_reshape.size(1),1) | |||
| y_reshape = y_reshape.view(y_reshape.size(0), y_reshape.size(1), 1) | |||
| output_x = torch.cat((torch.ones(y_reshape.size(0),1,1),y_reshape[:,0:-1,0:1]),dim=1) | |||
| output_x = torch.cat((torch.ones(y_reshape.size(0), 1, 1), y_reshape[:, 0:-1, 0:1]), dim=1) | |||
| output_y = y_reshape | |||
| # batch size for output module: sum(y_len) | |||
| output_y_len = [] | |||
| output_y_len_bin = np.bincount(np.array(y_len)) | |||
| for i in range(len(output_y_len_bin)-1,0,-1): | |||
| count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i | |||
| output_y_len.extend([min(i,y.size(2))]*count_temp) # put them in output_y_len; max value should not exceed y.size(2) | |||
| for i in range(len(output_y_len_bin) - 1, 0, -1): | |||
| count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i | |||
| output_y_len.extend( | |||
| [min(i, y.size(2))] * count_temp) # put them in output_y_len; max value should not exceed y.size(2) | |||
| # pack into variable | |||
| x = Variable(x).cuda() | |||
| y = Variable(y).cuda() | |||
| @@ -481,23 +484,23 @@ def train_rnn_epoch(epoch, args, rnn, output, data_loader, | |||
| # print('y',y.size()) | |||
| # print('output_y',output_y.size()) | |||
| # if using ground truth to train | |||
| h = rnn(x, pack=True, input_len=y_len) | |||
| h = pack_padded_sequence(h,y_len,batch_first=True).data # get packed hidden vector | |||
| h = pack_padded_sequence(h, y_len, batch_first=True).data # get packed hidden vector | |||
| # reverse h | |||
| idx = [i for i in range(h.size(0) - 1, -1, -1)] | |||
| idx = Variable(torch.LongTensor(idx)).cuda() | |||
| h = h.index_select(0, idx) | |||
| hidden_null = Variable(torch.zeros(args.num_layers-1, h.size(0), h.size(1))).cuda() | |||
| output.hidden = torch.cat((h.view(1,h.size(0),h.size(1)),hidden_null),dim=0) # num_layers, batch_size, hidden_size | |||
| hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(1))).cuda() | |||
| output.hidden = torch.cat((h.view(1, h.size(0), h.size(1)), hidden_null), | |||
| dim=0) # num_layers, batch_size, hidden_size | |||
| y_pred = output(output_x, pack=True, input_len=output_y_len) | |||
| y_pred = F.sigmoid(y_pred) | |||
| # clean | |||
| y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True) | |||
| y_pred = pad_packed_sequence(y_pred, batch_first=True)[0] | |||
| output_y = pack_padded_sequence(output_y,output_y_len,batch_first=True) | |||
| output_y = pad_packed_sequence(output_y,batch_first=True)[0] | |||
| output_y = pack_padded_sequence(output_y, output_y_len, batch_first=True) | |||
| output_y = pad_packed_sequence(output_y, batch_first=True)[0] | |||
| # use cross entropy loss | |||
| loss = binary_cross_entropy_weight(y_pred, output_y) | |||
| loss.backward() | |||
| @@ -507,17 +510,15 @@ def train_rnn_epoch(epoch, args, rnn, output, data_loader, | |||
| scheduler_output.step() | |||
| scheduler_rnn.step() | |||
| if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics | |||
| if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics | |||
| print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | |||
| epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||
| epoch, args.epochs, loss.data, args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||
| # logging | |||
| log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx) | |||
| feature_dim = y.size(1)*y.size(2) | |||
| loss_sum += loss.data[0]*feature_dim | |||
| return loss_sum/(batch_idx+1) | |||
| log_value('loss_' + args.fname, loss.data, epoch * args.batch_ratio + batch_idx) | |||
| feature_dim = y.size(1) * y.size(2) | |||
| loss_sum += loss.data * feature_dim | |||
| return loss_sum / (batch_idx + 1) | |||
| def test_rnn_epoch(epoch, args, rnn, output, test_batch_size=16): | |||
| @@ -527,20 +528,20 @@ def test_rnn_epoch(epoch, args, rnn, output, test_batch_size=16): | |||
| # generate graphs | |||
| max_num_node = int(args.max_num_node) | |||
| y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
| x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||
| y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
| x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda() | |||
| for i in range(max_num_node): | |||
| h = rnn(x_step) | |||
| # output.hidden = h.permute(1,0,2) | |||
| hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(2))).cuda() | |||
| output.hidden = torch.cat((h.permute(1,0,2), hidden_null), | |||
| output.hidden = torch.cat((h.permute(1, 0, 2), hidden_null), | |||
| dim=0) # num_layers, batch_size, hidden_size | |||
| x_step = Variable(torch.zeros(test_batch_size,1,args.max_prev_node)).cuda() | |||
| output_x_step = Variable(torch.ones(test_batch_size,1,1)).cuda() | |||
| for j in range(min(args.max_prev_node,i+1)): | |||
| x_step = Variable(torch.zeros(test_batch_size, 1, args.max_prev_node)).cuda() | |||
| output_x_step = Variable(torch.ones(test_batch_size, 1, 1)).cuda() | |||
| for j in range(min(args.max_prev_node, i + 1)): | |||
| output_y_pred_step = output(output_x_step) | |||
| output_x_step = sample_sigmoid(output_y_pred_step, sample=True, sample_time=1) | |||
| x_step[:,:,j:j+1] = output_x_step | |||
| x_step[:, :, j:j + 1] = output_x_step | |||
| output.hidden = Variable(output.hidden.data).cuda() | |||
| y_pred_long[:, i:i + 1, :] = x_step | |||
| rnn.hidden = Variable(rnn.hidden.data).cuda() | |||
| @@ -550,14 +551,12 @@ def test_rnn_epoch(epoch, args, rnn, output, test_batch_size=16): | |||
| G_pred_list = [] | |||
| for i in range(test_batch_size): | |||
| adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | |||
| G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
| G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
| G_pred_list.append(G_pred) | |||
| return G_pred_list | |||
| def train_rnn_forward_epoch(epoch, args, rnn, output, data_loader): | |||
| rnn.train() | |||
| output.train() | |||
| @@ -576,28 +575,29 @@ def train_rnn_forward_epoch(epoch, args, rnn, output, data_loader): | |||
| # output.hidden = output.init_hidden(batch_size=x_unsorted.size(0)*x_unsorted.size(1)) | |||
| # sort input | |||
| y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) | |||
| y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True) | |||
| y_len = y_len.numpy().tolist() | |||
| x = torch.index_select(x_unsorted,0,sort_index) | |||
| y = torch.index_select(y_unsorted,0,sort_index) | |||
| x = torch.index_select(x_unsorted, 0, sort_index) | |||
| y = torch.index_select(y_unsorted, 0, sort_index) | |||
| # input, output for output rnn module | |||
| # a smart use of pytorch builtin function: pack variable--b1_l1,b2_l1,...,b1_l2,b2_l2,... | |||
| y_reshape = pack_padded_sequence(y,y_len,batch_first=True).data | |||
| y_reshape = pack_padded_sequence(y, y_len, batch_first=True).data | |||
| # reverse y_reshape, so that their lengths are sorted, add dimension | |||
| idx = [i for i in range(y_reshape.size(0)-1, -1, -1)] | |||
| idx = [i for i in range(y_reshape.size(0) - 1, -1, -1)] | |||
| idx = torch.LongTensor(idx) | |||
| y_reshape = y_reshape.index_select(0, idx) | |||
| y_reshape = y_reshape.view(y_reshape.size(0),y_reshape.size(1),1) | |||
| y_reshape = y_reshape.view(y_reshape.size(0), y_reshape.size(1), 1) | |||
| output_x = torch.cat((torch.ones(y_reshape.size(0),1,1),y_reshape[:,0:-1,0:1]),dim=1) | |||
| output_x = torch.cat((torch.ones(y_reshape.size(0), 1, 1), y_reshape[:, 0:-1, 0:1]), dim=1) | |||
| output_y = y_reshape | |||
| # batch size for output module: sum(y_len) | |||
| output_y_len = [] | |||
| output_y_len_bin = np.bincount(np.array(y_len)) | |||
| for i in range(len(output_y_len_bin)-1,0,-1): | |||
| count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i | |||
| output_y_len.extend([min(i,y.size(2))]*count_temp) # put them in output_y_len; max value should not exceed y.size(2) | |||
| for i in range(len(output_y_len_bin) - 1, 0, -1): | |||
| count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i | |||
| output_y_len.extend( | |||
| [min(i, y.size(2))] * count_temp) # put them in output_y_len; max value should not exceed y.size(2) | |||
| # pack into variable | |||
| x = Variable(x).cuda() | |||
| y = Variable(y).cuda() | |||
| @@ -608,37 +608,36 @@ def train_rnn_forward_epoch(epoch, args, rnn, output, data_loader): | |||
| # print('y',y.size()) | |||
| # print('output_y',output_y.size()) | |||
| # if using ground truth to train | |||
| h = rnn(x, pack=True, input_len=y_len) | |||
| h = pack_padded_sequence(h,y_len,batch_first=True).data # get packed hidden vector | |||
| h = pack_padded_sequence(h, y_len, batch_first=True).data # get packed hidden vector | |||
| # reverse h | |||
| idx = [i for i in range(h.size(0) - 1, -1, -1)] | |||
| idx = Variable(torch.LongTensor(idx)).cuda() | |||
| h = h.index_select(0, idx) | |||
| hidden_null = Variable(torch.zeros(args.num_layers-1, h.size(0), h.size(1))).cuda() | |||
| output.hidden = torch.cat((h.view(1,h.size(0),h.size(1)),hidden_null),dim=0) # num_layers, batch_size, hidden_size | |||
| hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(1))).cuda() | |||
| output.hidden = torch.cat((h.view(1, h.size(0), h.size(1)), hidden_null), | |||
| dim=0) # num_layers, batch_size, hidden_size | |||
| y_pred = output(output_x, pack=True, input_len=output_y_len) | |||
| y_pred = F.sigmoid(y_pred) | |||
| # clean | |||
| y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True) | |||
| y_pred = pad_packed_sequence(y_pred, batch_first=True)[0] | |||
| output_y = pack_padded_sequence(output_y,output_y_len,batch_first=True) | |||
| output_y = pad_packed_sequence(output_y,batch_first=True)[0] | |||
| output_y = pack_padded_sequence(output_y, output_y_len, batch_first=True) | |||
| output_y = pad_packed_sequence(output_y, batch_first=True)[0] | |||
| # use cross entropy loss | |||
| loss = binary_cross_entropy_weight(y_pred, output_y) | |||
| if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics | |||
| if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics | |||
| print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | |||
| epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||
| epoch, args.epochs, loss.data, args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||
| # logging | |||
| log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx) | |||
| log_value('loss_' + args.fname, loss.data, epoch * args.batch_ratio + batch_idx) | |||
| # print(y_pred.size()) | |||
| feature_dim = y_pred.size(0)*y_pred.size(1) | |||
| loss_sum += loss.data[0]*feature_dim/y.size(0) | |||
| return loss_sum/(batch_idx+1) | |||
| feature_dim = y_pred.size(0) * y_pred.size(1) | |||
| loss_sum += loss.data * feature_dim / y.size(0) | |||
| return loss_sum / (batch_idx + 1) | |||
| ########### train function for LSTM + VAE | |||
| @@ -665,7 +664,7 @@ def train(args, dataset_train, rnn, output): | |||
| # start main loop | |||
| time_all = np.zeros(args.epochs) | |||
| while epoch<=args.epochs: | |||
| while epoch <= args.epochs: | |||
| time_start = tm.time() | |||
| # train | |||
| if 'GraphRNN_VAE' in args.note: | |||
| @@ -683,25 +682,26 @@ def train(args, dataset_train, rnn, output): | |||
| time_end = tm.time() | |||
| time_all[epoch - 1] = time_end - time_start | |||
| # test | |||
| if epoch % args.epochs_test == 0 and epoch>=args.epochs_test_start: | |||
| for sample_time in range(1,4): | |||
| if epoch % args.epochs_test == 0 and epoch >= args.epochs_test_start: | |||
| for sample_time in range(1, 4): | |||
| G_pred = [] | |||
| while len(G_pred)<args.test_total_size: | |||
| while len(G_pred) < args.test_total_size: | |||
| if 'GraphRNN_VAE' in args.note: | |||
| G_pred_step = test_vae_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size,sample_time=sample_time) | |||
| G_pred_step = test_vae_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size, | |||
| sample_time=sample_time) | |||
| elif 'GraphRNN_MLP' in args.note: | |||
| G_pred_step = test_mlp_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size,sample_time=sample_time) | |||
| G_pred_step = test_mlp_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size, | |||
| sample_time=sample_time) | |||
| elif 'GraphRNN_RNN' in args.note: | |||
| G_pred_step = test_rnn_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size) | |||
| G_pred.extend(G_pred_step) | |||
| # save graphs | |||
| fname = args.graph_save_path + args.fname_pred + str(epoch) +'_'+str(sample_time) + '.dat' | |||
| fname = args.graph_save_path + args.fname_pred + str(epoch) + '_' + str(sample_time) + '.dat' | |||
| save_graph_list(G_pred, fname) | |||
| if 'GraphRNN_RNN' in args.note: | |||
| break | |||
| print('test done, graphs saved') | |||
| # save model checkpoint | |||
| if args.save: | |||
| if epoch % args.epochs_save == 0: | |||
| @@ -710,7 +710,7 @@ def train(args, dataset_train, rnn, output): | |||
| fname = args.model_save_path + args.fname + 'output_' + str(epoch) + '.dat' | |||
| torch.save(output.state_dict(), fname) | |||
| epoch += 1 | |||
| np.save(args.timing_save_path+args.fname,time_all) | |||
| np.save(args.timing_save_path + args.fname, time_all) | |||
| ########### for graph completion task | |||
| @@ -723,19 +723,19 @@ def train_graph_completion(args, dataset_test, rnn, output): | |||
| epoch = args.load_epoch | |||
| print('model loaded!, epoch: {}'.format(args.load_epoch)) | |||
| for sample_time in range(1,4): | |||
| for sample_time in range(1, 4): | |||
| if 'GraphRNN_MLP' in args.note: | |||
| G_pred = test_mlp_partial_simple_epoch(epoch, args, rnn, output, dataset_test,sample_time=sample_time) | |||
| G_pred = test_mlp_partial_simple_epoch(epoch, args, rnn, output, dataset_test, sample_time=sample_time) | |||
| if 'GraphRNN_VAE' in args.note: | |||
| G_pred = test_vae_partial_epoch(epoch, args, rnn, output, dataset_test,sample_time=sample_time) | |||
| G_pred = test_vae_partial_epoch(epoch, args, rnn, output, dataset_test, sample_time=sample_time) | |||
| # save graphs | |||
| fname = args.graph_save_path + args.fname_pred + str(epoch) +'_'+str(sample_time) + 'graph_completion.dat' | |||
| fname = args.graph_save_path + args.fname_pred + str(epoch) + '_' + str(sample_time) + 'graph_completion.dat' | |||
| save_graph_list(G_pred, fname) | |||
| print('graph completion done, graphs saved') | |||
| ########### for NLL evaluation | |||
| def train_nll(args, dataset_train, dataset_test, rnn, output,graph_validate_len,graph_test_len, max_iter = 1000): | |||
| def train_nll(args, dataset_train, dataset_test, rnn, output, graph_validate_len, graph_test_len, max_iter=1000): | |||
| fname = args.model_save_path + args.fname + 'lstm_' + str(args.load_epoch) + '.dat' | |||
| rnn.load_state_dict(torch.load(fname)) | |||
| fname = args.model_save_path + args.fname + 'output_' + str(args.load_epoch) + '.dat' | |||
| @@ -745,7 +745,7 @@ def train_nll(args, dataset_train, dataset_test, rnn, output,graph_validate_len, | |||
| print('model loaded!, epoch: {}'.format(args.load_epoch)) | |||
| fname_output = args.nll_save_path + args.note + '_' + args.graph_type + '.csv' | |||
| with open(fname_output, 'w+') as f: | |||
| f.write(str(graph_validate_len)+','+str(graph_test_len)+'\n') | |||
| f.write(str(graph_validate_len) + ',' + str(graph_test_len) + '\n') | |||
| f.write('train,test\n') | |||
| for iter in range(max_iter): | |||
| if 'GraphRNN_MLP' in args.note: | |||
| @@ -754,7 +754,7 @@ def train_nll(args, dataset_train, dataset_test, rnn, output,graph_validate_len, | |||
| if 'GraphRNN_RNN' in args.note: | |||
| nll_train = train_rnn_forward_epoch(epoch, args, rnn, output, dataset_train) | |||
| nll_test = train_rnn_forward_epoch(epoch, args, rnn, output, dataset_test) | |||
| print('train',nll_train,'test',nll_test) | |||
| f.write(str(nll_train)+','+str(nll_test)+'\n') | |||
| print('train', nll_train, 'test', nll_test) | |||
| f.write(str(nll_train) + ',' + str(nll_test) + '\n') | |||
| print('NLL evaluation done') | |||
| @@ -16,6 +16,7 @@ import re | |||
| import data | |||
| def citeseer_ego(): | |||
| _, _, G = data.Graph_load(dataset='citeseer') | |||
| G = max(nx.connected_component_subgraphs(G), key=len) | |||
| @@ -27,12 +28,13 @@ def citeseer_ego(): | |||
| graphs.append(G_ego) | |||
| return graphs | |||
| def caveman_special(c=2,k=20,p_path=0.1,p_edge=0.3): | |||
| def caveman_special(c=2, k=20, p_path=0.1, p_edge=0.3): | |||
| p = p_path | |||
| path_count = max(int(np.ceil(p * k)),1) | |||
| path_count = max(int(np.ceil(p * k)), 1) | |||
| G = nx.caveman_graph(c, k) | |||
| # remove 50% edges | |||
| p = 1-p_edge | |||
| p = 1 - p_edge | |||
| for (u, v) in list(G.edges()): | |||
| if np.random.rand() < p and ((u < k and v < k) or (u >= k and v >= k)): | |||
| G.remove_edge(u, v) | |||
| @@ -44,6 +46,7 @@ def caveman_special(c=2,k=20,p_path=0.1,p_edge=0.3): | |||
| G = max(nx.connected_component_subgraphs(G), key=len) | |||
| return G | |||
| def n_community(c_sizes, p_inter=0.01): | |||
| graphs = [nx.gnp_random_graph(c_sizes[i], 0.7, seed=i) for i in range(len(c_sizes))] | |||
| G = nx.disjoint_union_all(graphs) | |||
| @@ -51,7 +54,7 @@ def n_community(c_sizes, p_inter=0.01): | |||
| for i in range(len(communities)): | |||
| subG1 = communities[i] | |||
| nodes1 = list(subG1.nodes()) | |||
| for j in range(i+1, len(communities)): | |||
| for j in range(i + 1, len(communities)): | |||
| subG2 = communities[j] | |||
| nodes2 = list(subG2.nodes()) | |||
| has_inter_edge = False | |||
| @@ -62,9 +65,10 @@ def n_community(c_sizes, p_inter=0.01): | |||
| has_inter_edge = True | |||
| if not has_inter_edge: | |||
| G.add_edge(nodes1[0], nodes2[0]) | |||
| #print('connected comp: ', len(list(nx.connected_component_subgraphs(G)))) | |||
| # print('connected comp: ', len(list(nx.connected_component_subgraphs(G)))) | |||
| return G | |||
| def perturb(graph_list, p_del, p_add=None): | |||
| ''' Perturb the list of graphs by adding/removing edges. | |||
| Args: | |||
| @@ -87,7 +91,7 @@ def perturb(graph_list, p_del, p_add=None): | |||
| if p_add is None: | |||
| num_nodes = G.number_of_nodes() | |||
| p_add_est = np.sum(trials) / (num_nodes * (num_nodes - 1) / 2 - | |||
| G.number_of_edges()) | |||
| G.number_of_edges()) | |||
| else: | |||
| p_add_est = p_add | |||
| @@ -97,7 +101,7 @@ def perturb(graph_list, p_del, p_add=None): | |||
| u = nodes[i] | |||
| trials = np.random.binomial(1, p_add_est, size=G.number_of_nodes()) | |||
| j = 0 | |||
| for j in range(i+1, len(nodes)): | |||
| for j in range(i + 1, len(nodes)): | |||
| v = nodes[j] | |||
| if trials[j] == 1: | |||
| tmp += 1 | |||
| @@ -108,7 +112,6 @@ def perturb(graph_list, p_del, p_add=None): | |||
| return perturbed_graph_list | |||
| def perturb_new(graph_list, p): | |||
| ''' Perturb the list of graphs by adding/removing edges. | |||
| Args: | |||
| @@ -123,7 +126,7 @@ def perturb_new(graph_list, p): | |||
| G = G_original.copy() | |||
| edge_remove_count = 0 | |||
| for (u, v) in list(G.edges()): | |||
| if np.random.rand()<p: | |||
| if np.random.rand() < p: | |||
| G.remove_edge(u, v) | |||
| edge_remove_count += 1 | |||
| # randomly add the edges back | |||
| @@ -131,16 +134,13 @@ def perturb_new(graph_list, p): | |||
| while True: | |||
| u = np.random.randint(0, G.number_of_nodes()) | |||
| v = np.random.randint(0, G.number_of_nodes()) | |||
| if (not G.has_edge(u,v)) and (u!=v): | |||
| if (not G.has_edge(u, v)) and (u != v): | |||
| break | |||
| G.add_edge(u, v) | |||
| perturbed_graph_list.append(G) | |||
| return perturbed_graph_list | |||
| def imsave(fname, arr, vmin=None, vmax=None, cmap=None, format=None, origin=None): | |||
| from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas | |||
| from matplotlib.figure import Figure | |||
| @@ -162,7 +162,7 @@ def save_prediction_histogram(y_pred_data, fname_pred, max_num_node, bin_n=20): | |||
| # draw a single graph G | |||
| def draw_graph(G, prefix = 'test'): | |||
| def draw_graph(G, prefix='test'): | |||
| parts = community.best_partition(G) | |||
| values = [parts.get(node) for node in G.nodes()] | |||
| colors = [] | |||
| @@ -187,8 +187,7 @@ def draw_graph(G, prefix = 'test'): | |||
| plt.axis("off") | |||
| pos = nx.spring_layout(G) | |||
| nx.draw_networkx(G, with_labels=True, node_size=35, node_color=colors,pos=pos) | |||
| nx.draw_networkx(G, with_labels=True, node_size=35, node_color=colors, pos=pos) | |||
| # plt.switch_backend('agg') | |||
| # options = { | |||
| @@ -199,14 +198,14 @@ def draw_graph(G, prefix = 'test'): | |||
| # plt.figure() | |||
| # plt.subplot() | |||
| # nx.draw_networkx(G, **options) | |||
| plt.savefig('figures/graph_view_'+prefix+'.png', dpi=200) | |||
| plt.savefig('figures/graph_view_' + prefix + '.png', dpi=200) | |||
| plt.close() | |||
| plt.switch_backend('agg') | |||
| G_deg = nx.degree_histogram(G) | |||
| G_deg = np.array(G_deg) | |||
| # plt.plot(range(len(G_deg)), G_deg, 'r', linewidth = 2) | |||
| plt.loglog(np.arange(len(G_deg))[G_deg>0], G_deg[G_deg>0], 'r', linewidth=2) | |||
| plt.loglog(np.arange(len(G_deg))[G_deg > 0], G_deg[G_deg > 0], 'r', linewidth=2) | |||
| plt.savefig('figures/degree_view_' + prefix + '.png', dpi=200) | |||
| plt.close() | |||
| @@ -225,15 +224,16 @@ def draw_graph(G, prefix = 'test'): | |||
| # draw a list of graphs [G] | |||
| def draw_graph_list(G_list, row, col, fname = 'figures/test', layout='spring', is_single=False,k=1,node_size=55,alpha=1,width=1.3): | |||
| def draw_graph_list(G_list, row, col, fname='figures/test', layout='spring', is_single=False, k=1, node_size=55, | |||
| alpha=1, width=1.3): | |||
| # # draw graph view | |||
| # from pylab import rcParams | |||
| # rcParams['figure.figsize'] = 12,3 | |||
| plt.switch_backend('agg') | |||
| for i,G in enumerate(G_list): | |||
| plt.subplot(row,col,i+1) | |||
| for i, G in enumerate(G_list): | |||
| plt.subplot(row, col, i + 1) | |||
| plt.subplots_adjust(left=0, bottom=0, right=1, top=1, | |||
| wspace=0, hspace=0) | |||
| wspace=0, hspace=0) | |||
| # if i%2==0: | |||
| # plt.title('real nodes: '+str(G.number_of_nodes()), fontsize = 4) | |||
| # else: | |||
| @@ -260,28 +260,29 @@ def draw_graph_list(G_list, row, col, fname = 'figures/test', layout='spring', i | |||
| # if values[i] == 6: | |||
| # colors.append('black') | |||
| plt.axis("off") | |||
| if layout=='spring': | |||
| pos = nx.spring_layout(G,k=k/np.sqrt(G.number_of_nodes()),iterations=100) | |||
| if layout == 'spring': | |||
| pos = nx.spring_layout(G, k=k / np.sqrt(G.number_of_nodes()), iterations=100) | |||
| # pos = nx.spring_layout(G) | |||
| elif layout=='spectral': | |||
| elif layout == 'spectral': | |||
| pos = nx.spectral_layout(G) | |||
| # # nx.draw_networkx(G, with_labels=True, node_size=2, width=0.15, font_size = 1.5, node_color=colors,pos=pos) | |||
| # nx.draw_networkx(G, with_labels=False, node_size=1.5, width=0.2, font_size = 1.5, linewidths=0.2, node_color = 'k',pos=pos,alpha=0.2) | |||
| if is_single: | |||
| # node_size default 60, edge_width default 1.5 | |||
| nx.draw_networkx_nodes(G, pos, node_size=node_size, node_color='#336699', alpha=1, linewidths=0, font_size=0) | |||
| nx.draw_networkx_nodes(G, pos, node_size=node_size, node_color='#336699', alpha=1, linewidths=0, | |||
| font_size=0) | |||
| nx.draw_networkx_edges(G, pos, alpha=alpha, width=width) | |||
| else: | |||
| nx.draw_networkx_nodes(G, pos, node_size=1.5, node_color='#336699',alpha=1, linewidths=0.2, font_size = 1.5) | |||
| nx.draw_networkx_edges(G, pos, alpha=0.3,width=0.2) | |||
| nx.draw_networkx_nodes(G, pos, node_size=1.5, node_color='#336699', alpha=1, linewidths=0.2, font_size=1.5) | |||
| nx.draw_networkx_edges(G, pos, alpha=0.3, width=0.2) | |||
| # plt.axis('off') | |||
| # plt.title('Complete Graph of Odd-degree Nodes') | |||
| # plt.show() | |||
| plt.tight_layout() | |||
| plt.savefig(fname+'.png', dpi=600) | |||
| plt.savefig(fname + '.png', dpi=600) | |||
| plt.close() | |||
| # # draw degree distribution | |||
| @@ -376,8 +377,6 @@ def draw_graph_list(G_list, row, col, fname = 'figures/test', layout='spring', i | |||
| # plt.savefig(fname+'_community.png', dpi=600) | |||
| # plt.close() | |||
| # plt.switch_backend('agg') | |||
| # G_deg = nx.degree_histogram(G) | |||
| # G_deg = np.array(G_deg) | |||
| @@ -395,7 +394,6 @@ def draw_graph_list(G_list, row, col, fname = 'figures/test', layout='spring', i | |||
| # plt.close() | |||
| # directly get graph statistics from adj, obsoleted | |||
| def decode_graph(adj, prefix): | |||
| adj = np.asmatrix(adj) | |||
| @@ -433,6 +431,7 @@ def get_graph(adj): | |||
| G = nx.from_numpy_matrix(adj) | |||
| return G | |||
| # save a list of graphs | |||
| def save_graph_list(G_list, fname): | |||
| with open(fname, "wb") as f: | |||
| @@ -441,28 +440,30 @@ def save_graph_list(G_list, fname): | |||
| # pick the first connected component | |||
| def pick_connected_component(G): | |||
| node_list = nx.node_connected_component(G,0) | |||
| node_list = nx.node_connected_component(G, 0) | |||
| return G.subgraph(node_list) | |||
| def pick_connected_component_new(G): | |||
| adj_list = G.adjacency_list() | |||
| for id,adj in enumerate(adj_list): | |||
| for id, adj in enumerate(adj_list): | |||
| id_min = min(adj) | |||
| if id<id_min and id>=1: | |||
| # if id<id_min and id>=4: | |||
| if id < id_min and id >= 1: | |||
| # if id<id_min and id>=4: | |||
| break | |||
| node_list = list(range(id)) # only include node prior than node "id" | |||
| node_list = list(range(id)) # only include node prior than node "id" | |||
| G = G.subgraph(node_list) | |||
| G = max(nx.connected_component_subgraphs(G), key=len) | |||
| return G | |||
| # load a list of graphs | |||
| def load_graph_list(fname,is_real=True): | |||
| def load_graph_list(fname, is_real=True): | |||
| with open(fname, "rb") as f: | |||
| graph_list = pickle.load(f) | |||
| for i in range(len(graph_list)): | |||
| edges_with_selfloops = graph_list[i].selfloop_edges() | |||
| if len(edges_with_selfloops)>0: | |||
| if len(edges_with_selfloops) > 0: | |||
| graph_list[i].remove_edges_from(edges_with_selfloops) | |||
| if is_real: | |||
| graph_list[i] = max(nx.connected_component_subgraphs(graph_list[i]), key=len) | |||
| @@ -482,6 +483,7 @@ def export_graphs_to_txt(g_list, output_filename_prefix): | |||
| f.write(str(idx_u) + '\t' + str(idx_v) + '\n') | |||
| i += 1 | |||
| def snap_txt_output_to_nx(in_fname): | |||
| G = nx.Graph() | |||
| with open(in_fname, 'r') as f: | |||
| @@ -496,23 +498,23 @@ def snap_txt_output_to_nx(in_fname): | |||
| G.add_edge(int(u), int(v)) | |||
| return G | |||
| def test_perturbed(): | |||
| graphs = [] | |||
| for i in range(100,101): | |||
| for j in range(4,5): | |||
| for i in range(100, 101): | |||
| for j in range(4, 5): | |||
| for k in range(500): | |||
| graphs.append(nx.barabasi_albert_graph(i,j)) | |||
| graphs.append(nx.barabasi_albert_graph(i, j)) | |||
| g_perturbed = perturb(graphs, 0.9) | |||
| print([g.number_of_edges() for g in graphs]) | |||
| print([g.number_of_edges() for g in g_perturbed]) | |||
| if __name__ == '__main__': | |||
| #test_perturbed() | |||
| #graphs = load_graph_list('graphs/' + 'GraphRNN_RNN_community4_4_128_train_0.dat') | |||
| #graphs = load_graph_list('graphs/' + 'GraphRNN_RNN_community4_4_128_pred_2500_1.dat') | |||
| graphs = load_graph_list('eval_results/mmsb/' + 'community41.dat') | |||
| for i in range(0, 160, 16): | |||
| draw_graph_list(graphs[i:i+16], 4, 4, fname='figures/community4_' + str(i)) | |||
| # test_perturbed() | |||
| graphs = load_graph_list('graphs/' + 'GraphRNN_RNN_grid_4_128_train_0.dat') | |||
| # graphs = load_graph_list('graphs/' + 'GraphRNN_RNN_community4_4_128_pred_2500_1.dat') | |||
| # graphs = load_graph_list('eval_results/mmsb/' + 'community41.dat') | |||
| for i in range(0, 160, 16): | |||
| draw_graph_list(graphs[i:i + 16], 4, 4, fname='figures/community4_' + str(i)) | |||