@@ -5,7 +5,7 @@ class Args(): | |||
### if clean tensorboard | |||
self.clean_tensorboard = False | |||
### Which CUDA GPU device is used for training | |||
self.cuda = 1 | |||
self.cuda = 0 | |||
### Which GraphRNN model variant is used. | |||
# The simple version of Graph RNN | |||
@@ -88,7 +88,7 @@ class Args(): | |||
self.load = False # if load model, default lr is very low | |||
self.load_epoch = 3000 | |||
self.load_epoch = 100 | |||
self.save = True | |||
@@ -1,14 +1,15 @@ | |||
import argparse | |||
import numpy as np | |||
import logging | |||
import os | |||
import re | |||
from random import shuffle | |||
from time import strftime, gmtime | |||
import eval.stats | |||
import utils | |||
# import main.Args | |||
from args import Args | |||
from baselines.baseline_simple import * | |||
class Args_evaluate(): | |||
def __init__(self): | |||
# loop over the settings | |||
@@ -19,19 +20,22 @@ class Args_evaluate(): | |||
# list of dataset to evaluate | |||
# use a list of 1 element to evaluate a single dataset | |||
self.dataset_name_all = ['caveman', 'grid', 'barabasi', 'citeseer', 'DD'] | |||
# self.dataset_name_all = ['caveman', 'grid', 'barabasi', 'citeseer', 'DD'] | |||
self.dataset_name_all = ['caveman_small'] | |||
# self.dataset_name_all = ['citeseer_small','caveman_small'] | |||
# self.dataset_name_all = ['barabasi_noise0','barabasi_noise2','barabasi_noise4','barabasi_noise6','barabasi_noise8','barabasi_noise10'] | |||
# self.dataset_name_all = ['caveman_small', 'ladder_small', 'grid_small', 'ladder_small', 'enzymes_small', 'barabasi_small','citeseer_small'] | |||
self.epoch_start=100 | |||
self.epoch_end=3001 | |||
self.epoch_step=100 | |||
self.epoch_start = 10 | |||
self.epoch_end = 101 | |||
self.epoch_step = 10 | |||
def find_nearest_idx(array,value): | |||
idx = (np.abs(array-value)).argmin() | |||
def find_nearest_idx(array, value): | |||
idx = (np.abs(array - value)).argmin() | |||
return idx | |||
def extract_result_id_and_epoch(name, prefix, suffix): | |||
''' | |||
Args: | |||
@@ -68,7 +72,7 @@ def eval_list(real_graphs_filename, pred_graphs_filename, prefix, eval_every): | |||
if result_id not in pred_graphs_dict: | |||
pred_graphs_dict[result_id] = {} | |||
pred_graphs_dict[result_id][epochs] = fname | |||
for result_id in real_graphs_dict.keys(): | |||
for epochs in sorted(real_graphs_dict[result_id]): | |||
real_g_list = utils.load_graph_list(real_graphs_dict[result_id][epochs]) | |||
@@ -77,16 +81,16 @@ def eval_list(real_graphs_filename, pred_graphs_filename, prefix, eval_every): | |||
shuffle(pred_g_list) | |||
perturbed_g_list = perturb(real_g_list, 0.05) | |||
#dist = eval.stats.degree_stats(real_g_list, pred_g_list) | |||
# dist = eval.stats.degree_stats(real_g_list, pred_g_list) | |||
dist = eval.stats.clustering_stats(real_g_list, pred_g_list) | |||
print('dist between real and pred (', result_id, ') at epoch ', epochs, ': ', dist) | |||
#dist = eval.stats.degree_stats(real_g_list, perturbed_g_list) | |||
# dist = eval.stats.degree_stats(real_g_list, perturbed_g_list) | |||
dist = eval.stats.clustering_stats(real_g_list, perturbed_g_list) | |||
print('dist between real and perturbed: ', dist) | |||
mid = len(real_g_list) // 2 | |||
#dist = eval.stats.degree_stats(real_g_list[:mid], real_g_list[mid:]) | |||
# dist = eval.stats.degree_stats(real_g_list[:mid], real_g_list[mid:]) | |||
dist = eval.stats.clustering_stats(real_g_list[:mid], real_g_list[mid:]) | |||
print('dist among real: ', dist) | |||
@@ -97,7 +101,6 @@ def compute_basic_stats(real_g_list, target_g_list): | |||
return dist_degree, dist_clustering | |||
def clean_graphs(graph_real, graph_pred): | |||
''' Selecting graphs generated that have the similar sizes. | |||
It is usually necessary for GraphRNN-S version, but not the full GraphRNN model. | |||
@@ -120,6 +123,7 @@ def clean_graphs(graph_real, graph_pred): | |||
return graph_real, pred_graph_new | |||
def load_ground_truth(dir_input, dataset_name, model_name='GraphRNN_RNN'): | |||
''' Read ground truth graphs. | |||
''' | |||
@@ -127,20 +131,21 @@ def load_ground_truth(dir_input, dataset_name, model_name='GraphRNN_RNN'): | |||
hidden = 128 | |||
else: | |||
hidden = 64 | |||
if model_name=='Internal' or model_name=='Noise' or model_name=='B-A' or model_name=='E-R': | |||
if model_name == 'Internal' or model_name == 'Noise' or model_name == 'B-A' or model_name == 'E-R': | |||
fname_test = dir_input + 'GraphRNN_MLP' + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | |||
hidden) + '_test_' + str(0) + '.dat' | |||
hidden) + '_test_' + str(0) + '.dat' | |||
else: | |||
fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | |||
hidden) + '_test_' + str(0) + '.dat' | |||
hidden) + '_test_' + str(0) + '.dat' | |||
try: | |||
graph_test = utils.load_graph_list(fname_test,is_real=True) | |||
graph_test = utils.load_graph_list(fname_test, is_real=True) | |||
except: | |||
print('Not found: ' + fname_test) | |||
logging.warning('Not found: ' + fname_test) | |||
return None | |||
return graph_test | |||
def eval_single_list(graphs, dir_input, dataset_name): | |||
''' Evaluate a list of graphs by comparing with graphs in directory dir_input. | |||
Args: | |||
@@ -149,7 +154,7 @@ def eval_single_list(graphs, dir_input, dataset_name): | |||
''' | |||
graph_test = load_ground_truth(dir_input, dataset_name) | |||
graph_test_len = len(graph_test) | |||
graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set | |||
graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set | |||
mmd_degree = eval.stats.degree_stats(graph_test, graphs) | |||
mmd_clustering = eval.stats.clustering_stats(graph_test, graphs) | |||
try: | |||
@@ -160,9 +165,12 @@ def eval_single_list(graphs, dir_input, dataset_name): | |||
print('clustering: ', mmd_clustering) | |||
print('orbits: ', mmd_4orbits) | |||
def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is_clean=True, epoch_start=1000,epoch_end=3001,epoch_step=100): | |||
def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is_clean=True, epoch_start=1000, | |||
epoch_end=3001, epoch_step=100): | |||
with open(fname_output, 'w+') as f: | |||
f.write('sample_time,epoch,degree_validate,clustering_validate,orbits4_validate,degree_test,clustering_test,orbits4_test\n') | |||
f.write( | |||
'sample_time,epoch,degree_validate,clustering_validate,orbits4_validate,degree_test,clustering_test,orbits4_test\n') | |||
# TODO: Maybe refactor into a separate file/function that specifies THE naming convention | |||
# across main and evaluate | |||
@@ -171,7 +179,7 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is | |||
else: | |||
hidden = 64 | |||
# read real graph | |||
if model_name=='Internal' or model_name=='Noise' or model_name=='B-A' or model_name=='E-R': | |||
if model_name == 'Internal' or model_name == 'Noise' or model_name == 'B-A' or model_name == 'E-R': | |||
fname_test = dir_input + 'GraphRNN_MLP' + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | |||
hidden) + '_test_' + str(0) + '.dat' | |||
elif 'Baseline' in model_name: | |||
@@ -180,37 +188,37 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is | |||
fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | |||
hidden) + '_test_' + str(0) + '.dat' | |||
try: | |||
graph_test = utils.load_graph_list(fname_test,is_real=True) | |||
graph_test = utils.load_graph_list(fname_test, is_real=True) | |||
except: | |||
print('Not found: ' + fname_test) | |||
logging.warning('Not found: ' + fname_test) | |||
return None | |||
graph_test_len = len(graph_test) | |||
graph_train = graph_test[0:int(0.8 * graph_test_len)] # train | |||
graph_validate = graph_test[0:int(0.2 * graph_test_len)] # validate | |||
graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set | |||
graph_train = graph_test[0:int(0.8 * graph_test_len)] # train | |||
graph_validate = graph_test[0:int(0.2 * graph_test_len)] # validate | |||
graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set | |||
graph_test_aver = 0 | |||
for graph in graph_test: | |||
graph_test_aver+=graph.number_of_nodes() | |||
graph_test_aver += graph.number_of_nodes() | |||
graph_test_aver /= len(graph_test) | |||
print('test average len',graph_test_aver) | |||
print('test average len', graph_test_aver) | |||
# get performance for proposed approaches | |||
if 'GraphRNN' in model_name: | |||
# read test graph | |||
for epoch in range(epoch_start,epoch_end,epoch_step): | |||
for sample_time in range(1,4): | |||
for epoch in range(epoch_start, epoch_end, epoch_step): | |||
for sample_time in range(1, 4): | |||
# get filename | |||
fname_pred = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(hidden) + '_pred_' + str(epoch) + '_' + str(sample_time) + '.dat' | |||
fname_pred = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | |||
hidden) + '_pred_' + str(epoch) + '_' + str(sample_time) + '.dat' | |||
# load graphs | |||
try: | |||
graph_pred = utils.load_graph_list(fname_pred,is_real=False) # default False | |||
graph_pred = utils.load_graph_list(fname_pred, is_real=False) # default False | |||
except: | |||
print('Not found: '+ fname_pred) | |||
logging.warning('Not found: '+ fname_pred) | |||
print('Not found: ' + fname_pred) | |||
logging.warning('Not found: ' + fname_pred) | |||
continue | |||
# clean graphs | |||
if is_clean: | |||
@@ -243,15 +251,15 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is | |||
except: | |||
mmd_4orbits_validate = -1 | |||
# write results | |||
f.write(str(sample_time)+','+ | |||
str(epoch)+','+ | |||
str(mmd_degree_validate)+','+ | |||
str(mmd_clustering_validate)+','+ | |||
str(mmd_4orbits_validate)+','+ | |||
str(mmd_degree)+','+ | |||
str(mmd_clustering)+','+ | |||
str(mmd_4orbits)+'\n') | |||
print('degree',mmd_degree,'clustering',mmd_clustering,'orbits',mmd_4orbits) | |||
f.write(str(sample_time) + ',' + | |||
str(epoch) + ',' + | |||
str(mmd_degree_validate) + ',' + | |||
str(mmd_clustering_validate) + ',' + | |||
str(mmd_4orbits_validate) + ',' + | |||
str(mmd_degree) + ',' + | |||
str(mmd_clustering) + ',' + | |||
str(mmd_4orbits) + '\n') | |||
print('degree', mmd_degree, 'clustering', mmd_clustering, 'orbits', mmd_4orbits) | |||
# get internal MMD (MMD between ground truth validation and test sets) | |||
if model_name == 'Internal': | |||
@@ -265,7 +273,6 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is | |||
mmd_clustering_validate) + ',' + str(mmd_4orbits_validate) | |||
+ ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + '\n') | |||
# get MMD between ground truth and its perturbed graphs | |||
if model_name == 'Noise': | |||
graph_validate_perturbed = perturb(graph_validate, 0.05) | |||
@@ -281,7 +288,7 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is | |||
# get E-R MMD | |||
if model_name == 'E-R': | |||
graph_pred = Graph_generator_baseline(graph_train,generator='Gnp') | |||
graph_pred = Graph_generator_baseline(graph_train, generator='Gnp') | |||
# clean graphs | |||
if is_clean: | |||
graph_test, graph_pred = clean_graphs(graph_test, graph_pred) | |||
@@ -296,7 +303,6 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is | |||
f.write(str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) | |||
+ ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits_validate) + '\n') | |||
# get B-A MMD | |||
if model_name == 'B-A': | |||
graph_pred = Graph_generator_baseline(graph_train, generator='BA') | |||
@@ -364,33 +370,29 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is | |||
+ ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits) + '\n') | |||
print('degree', mmd_degree, 'clustering', mmd_clustering, 'orbits', mmd_4orbits) | |||
return True | |||
def evaluation(args_evaluate,dir_input, dir_output, model_name_all, dataset_name_all, args, overwrite = True): | |||
def evaluation(args_evaluate, dir_input, dir_output, model_name_all, dataset_name_all, args, overwrite=True): | |||
''' Evaluate the performance of a set of models on a set of datasets. | |||
''' | |||
for model_name in model_name_all: | |||
for dataset_name in dataset_name_all: | |||
# check output exist | |||
fname_output = dir_output+model_name+'_'+dataset_name+'.csv' | |||
print('processing: '+dir_output + model_name + '_' + dataset_name + '.csv') | |||
logging.info('processing: '+dir_output + model_name + '_' + dataset_name + '.csv') | |||
if overwrite==False and os.path.isfile(fname_output): | |||
print(dir_output+model_name+'_'+dataset_name+'.csv exists!') | |||
logging.info(dir_output+model_name+'_'+dataset_name+'.csv exists!') | |||
fname_output = dir_output + model_name + '_' + dataset_name + '.csv' | |||
print('processing: ' + dir_output + model_name + '_' + dataset_name + '.csv') | |||
logging.info('processing: ' + dir_output + model_name + '_' + dataset_name + '.csv') | |||
if overwrite == False and os.path.isfile(fname_output): | |||
print(dir_output + model_name + '_' + dataset_name + '.csv exists!') | |||
logging.info(dir_output + model_name + '_' + dataset_name + '.csv exists!') | |||
continue | |||
evaluation_epoch(dir_input,fname_output,model_name,dataset_name,args,is_clean=True, epoch_start=args_evaluate.epoch_start,epoch_end=args_evaluate.epoch_end,epoch_step=args_evaluate.epoch_step) | |||
evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is_clean=True, | |||
epoch_start=args_evaluate.epoch_start, epoch_end=args_evaluate.epoch_end, | |||
epoch_step=args_evaluate.epoch_step) | |||
def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||
eval_every, epoch_range=None, out_file_prefix=None): | |||
eval_every, epoch_range=None, out_file_prefix=None): | |||
''' Evaluate list of predicted graphs compared to ground truth, stored in files. | |||
Args: | |||
baselines: dict mapping name of the baseline to list of generated graphs. | |||
@@ -398,12 +400,12 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||
if out_file_prefix is not None: | |||
out_files = { | |||
'train': open(out_file_prefix + '_train.txt', 'w+'), | |||
'compare': open(out_file_prefix + '_compare.txt', 'w+') | |||
'train': open(out_file_prefix + '_train.txt', 'w+'), | |||
'compare': open(out_file_prefix + '_compare.txt', 'w+') | |||
} | |||
out_files['train'].write('degree,clustering,orbits4\n') | |||
line = 'metric,real,ours,perturbed' | |||
for bl in baselines: | |||
line += ',' + bl | |||
@@ -411,34 +413,33 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||
out_files['compare'].write(line) | |||
results = { | |||
'deg': { | |||
'real': 0, | |||
'ours': 100, # take min over all training epochs | |||
'perturbed': 0, | |||
'kron': 0}, | |||
'clustering': { | |||
'real': 0, | |||
'ours': 100, | |||
'perturbed': 0, | |||
'kron': 0}, | |||
'orbits4': { | |||
'real': 0, | |||
'ours': 100, | |||
'perturbed': 0, | |||
'kron': 0} | |||
'deg': { | |||
'real': 0, | |||
'ours': 100, # take min over all training epochs | |||
'perturbed': 0, | |||
'kron': 0}, | |||
'clustering': { | |||
'real': 0, | |||
'ours': 100, | |||
'perturbed': 0, | |||
'kron': 0}, | |||
'orbits4': { | |||
'real': 0, | |||
'ours': 100, | |||
'perturbed': 0, | |||
'kron': 0} | |||
} | |||
num_evals = len(pred_graphs_filename) | |||
if epoch_range is None: | |||
epoch_range = [i * eval_every for i in range(num_evals)] | |||
epoch_range = [i * eval_every for i in range(num_evals)] | |||
for i in range(num_evals): | |||
real_g_list = utils.load_graph_list(real_graph_filename) | |||
#pred_g_list = utils.load_graph_list(pred_graphs_filename[i]) | |||
# pred_g_list = utils.load_graph_list(pred_graphs_filename[i]) | |||
# contains all predicted G | |||
pred_g_list_raw = utils.load_graph_list(pred_graphs_filename[i]) | |||
if len(real_g_list)>200: | |||
if len(real_g_list) > 200: | |||
real_g_list = real_g_list[0:200] | |||
shuffle(real_g_list) | |||
@@ -448,10 +449,9 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||
real_g_len_list = np.array([len(real_g_list[i]) for i in range(len(real_g_list))]) | |||
pred_g_len_list_raw = np.array([len(pred_g_list_raw[i]) for i in range(len(pred_g_list_raw))]) | |||
# get perturb real | |||
#perturbed_g_list_001 = perturb(real_g_list, 0.01) | |||
# perturbed_g_list_001 = perturb(real_g_list, 0.01) | |||
perturbed_g_list_005 = perturb(real_g_list, 0.05) | |||
#perturbed_g_list_010 = perturb(real_g_list, 0.10) | |||
# perturbed_g_list_010 = perturb(real_g_list, 0.10) | |||
# select pred samples | |||
# The number of nodes are sampled from the similar distribution as the training set | |||
@@ -482,22 +482,22 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||
# ======================================== | |||
mid = len(real_g_list) // 2 | |||
dist_degree, dist_clustering = compute_basic_stats(real_g_list[:mid], real_g_list[mid:]) | |||
#dist_4cycle = eval.stats.motif_stats(real_g_list[:mid], real_g_list[mid:]) | |||
# dist_4cycle = eval.stats.motif_stats(real_g_list[:mid], real_g_list[mid:]) | |||
dist_4orbits = eval.stats.orbit_stats_all(real_g_list[:mid], real_g_list[mid:]) | |||
print('degree dist among real: ', dist_degree) | |||
print('clustering dist among real: ', dist_clustering) | |||
#print('4 cycle dist among real: ', dist_4cycle) | |||
# print('4 cycle dist among real: ', dist_4cycle) | |||
print('orbits dist among real: ', dist_4orbits) | |||
results['deg']['real'] += dist_degree | |||
results['clustering']['real'] += dist_clustering | |||
results['orbits4']['real'] += dist_4orbits | |||
dist_degree, dist_clustering = compute_basic_stats(real_g_list, pred_g_list) | |||
#dist_4cycle = eval.stats.motif_stats(real_g_list, pred_g_list) | |||
# dist_4cycle = eval.stats.motif_stats(real_g_list, pred_g_list) | |||
dist_4orbits = eval.stats.orbit_stats_all(real_g_list, pred_g_list) | |||
print('degree dist between real and pred at epoch ', epoch_range[i], ': ', dist_degree) | |||
print('clustering dist between real and pred at epoch ', epoch_range[i], ': ', dist_clustering) | |||
#print('4 cycle dist between real and pred at epoch: ', epoch_range[i], dist_4cycle) | |||
# print('4 cycle dist between real and pred at epoch: ', epoch_range[i], dist_4cycle) | |||
print('orbits dist between real and pred at epoch ', epoch_range[i], ': ', dist_4orbits) | |||
results['deg']['ours'] = min(dist_degree, results['deg']['ours']) | |||
results['clustering']['ours'] = min(dist_clustering, results['clustering']['ours']) | |||
@@ -509,11 +509,11 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||
out_files['train'].write(str(dist_4orbits) + ',') | |||
dist_degree, dist_clustering = compute_basic_stats(real_g_list, perturbed_g_list_005) | |||
#dist_4cycle = eval.stats.motif_stats(real_g_list, perturbed_g_list_005) | |||
# dist_4cycle = eval.stats.motif_stats(real_g_list, perturbed_g_list_005) | |||
dist_4orbits = eval.stats.orbit_stats_all(real_g_list, perturbed_g_list_005) | |||
print('degree dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_degree) | |||
print('clustering dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_clustering) | |||
#print('4 cycle dist between real and perturbed at epoch: ', epoch_range[i], dist_4cycle) | |||
# print('4 cycle dist between real and perturbed at epoch: ', epoch_range[i], dist_4cycle) | |||
print('orbits dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_4orbits) | |||
results['deg']['perturbed'] += dist_degree | |||
results['clustering']['perturbed'] += dist_clustering | |||
@@ -527,8 +527,8 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||
results['deg'][baseline] = dist_degree | |||
results['clustering'][baseline] = dist_clustering | |||
results['orbits4'][baseline] = dist_4orbits | |||
print('Kron: deg=', dist_degree, ', clustering=', dist_clustering, | |||
', orbits4=', dist_4orbits) | |||
print('Kron: deg=', dist_degree, ', clustering=', dist_clustering, | |||
', orbits4=', dist_4orbits) | |||
out_files['train'].write('\n') | |||
@@ -538,10 +538,10 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||
# Write results | |||
for metric, methods in results.items(): | |||
line = metric+','+ \ | |||
str(methods['real'])+','+ \ | |||
str(methods['ours'])+','+ \ | |||
str(methods['perturbed']) | |||
line = metric + ',' + \ | |||
str(methods['real']) + ',' + \ | |||
str(methods['ours']) + ',' + \ | |||
str(methods['perturbed']) | |||
for baseline in baselines: | |||
line += ',' + str(methods[baseline]) | |||
line += '\n' | |||
@@ -553,12 +553,12 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||
def eval_performance(datadir, prefix=None, args=None, eval_every=200, out_file_prefix=None, | |||
sample_time = 2, baselines={}): | |||
sample_time=2, baselines={}): | |||
if args is None: | |||
real_graphs_filename = [datadir + f for f in os.listdir(datadir) | |||
if re.match(prefix + '.*real.*\.dat', f)] | |||
if re.match(prefix + '.*real.*\.dat', f)] | |||
pred_graphs_filename = [datadir + f for f in os.listdir(datadir) | |||
if re.match(prefix + '.*pred.*\.dat', f)] | |||
if re.match(prefix + '.*pred.*\.dat', f)] | |||
eval_list(real_graphs_filename, pred_graphs_filename, prefix, 200) | |||
else: | |||
@@ -567,28 +567,30 @@ def eval_performance(datadir, prefix=None, args=None, eval_every=200, out_file_p | |||
# str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(args.bptt_len) + '.dat' for epoch in range(0,50001,eval_every)] | |||
# pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | |||
# str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(args.bptt_len) + '.dat' for epoch in range(0,50001,eval_every)] | |||
real_graph_filename = datadir+args.graph_save_path + args.fname_test + '0.dat' | |||
real_graph_filename = datadir + args.graph_save_path + args.fname_test + '0.dat' | |||
# for proposed model | |||
end_epoch = 3001 | |||
epoch_range = range(eval_every, end_epoch, eval_every) | |||
pred_graphs_filename = [datadir+args.graph_save_path + args.fname_pred+str(epoch)+'_'+str(sample_time)+'.dat' | |||
for epoch in epoch_range] | |||
pred_graphs_filename = [ | |||
datadir + args.graph_save_path + args.fname_pred + str(epoch) + '_' + str(sample_time) + '.dat' | |||
for epoch in epoch_range] | |||
# for baseline model | |||
#pred_graphs_filename = [datadir+args.fname_baseline+'.dat'] | |||
# pred_graphs_filename = [datadir+args.fname_baseline+'.dat'] | |||
#real_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | |||
# real_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | |||
# str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str( | |||
# args.bptt_len) + '_' + str(args.gumbel) + '.dat' for epoch in range(10000, 50001, eval_every)] | |||
#pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | |||
# pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | |||
# str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str( | |||
# args.bptt_len) + '_' + str(args.gumbel) + '.dat' for epoch in range(10000, 50001, eval_every)] | |||
eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||
epoch_range=epoch_range, | |||
epoch_range=epoch_range, | |||
eval_every=eval_every, | |||
out_file_prefix=out_file_prefix) | |||
def process_kron(kron_dir): | |||
txt_files = [] | |||
for f in os.listdir(kron_dir): | |||
@@ -602,7 +604,7 @@ def process_kron(kron_dir): | |||
G_list.append(utils.snap_txt_output_to_nx(os.path.join(kron_dir, filename))) | |||
return G_list | |||
if __name__ == '__main__': | |||
args = Args() | |||
@@ -612,18 +614,18 @@ if __name__ == '__main__': | |||
feature_parser = parser.add_mutually_exclusive_group(required=False) | |||
feature_parser.add_argument('--export-real', dest='export', action='store_true') | |||
feature_parser.add_argument('--no-export-real', dest='export', action='store_false') | |||
feature_parser.add_argument('--kron-dir', dest='kron_dir', | |||
help='Directory where graphs generated by kronecker method is stored.') | |||
feature_parser.add_argument('--kron-dir', dest='kron_dir', | |||
help='Directory where graphs generated by kronecker method is stored.') | |||
parser.add_argument('--testfile', dest='test_file', | |||
help='The file that stores list of graphs to be evaluated. Only used when 1 list of ' | |||
'graphs is to be evaluated.') | |||
help='The file that stores list of graphs to be evaluated. Only used when 1 list of ' | |||
'graphs is to be evaluated.') | |||
parser.add_argument('--dir-prefix', dest='dir_prefix', | |||
help='The file that stores list of graphs to be evaluated. Can be used when evaluating multiple' | |||
'models on multiple datasets.') | |||
help='The file that stores list of graphs to be evaluated. Can be used when evaluating multiple' | |||
'models on multiple datasets.') | |||
parser.add_argument('--graph-type', dest='graph_type', | |||
help='Type of graphs / dataset.') | |||
help='Type of graphs / dataset.') | |||
parser.set_defaults(export=False, kron_dir='', test_file='', | |||
dir_prefix='', | |||
graph_type=args.graph_type) | |||
@@ -633,7 +635,6 @@ if __name__ == '__main__': | |||
# dir_prefix = "/dfs/scratch0/jiaxuany0/" | |||
dir_prefix = args.dir_input | |||
time_now = strftime("%Y-%m-%d %H:%M:%S", gmtime()) | |||
if not os.path.isdir('logs/'): | |||
os.makedirs('logs/') | |||
@@ -652,16 +653,16 @@ if __name__ == '__main__': | |||
if prog_args.graph_type == 'grid': | |||
graphs = [] | |||
for i in range(10,20): | |||
for j in range(10,20): | |||
graphs.append(nx.grid_2d_graph(i,j)) | |||
for i in range(10, 20): | |||
for j in range(10, 20): | |||
graphs.append(nx.grid_2d_graph(i, j)) | |||
utils.export_graphs_to_txt(graphs, output_prefix) | |||
elif prog_args.graph_type == 'caveman': | |||
graphs = [] | |||
for i in range(2, 3): | |||
for j in range(30, 81): | |||
for k in range(10): | |||
graphs.append(caveman_special(i,j, p_edge=0.3)) | |||
graphs.append(caveman_special(i, j, p_edge=0.3)) | |||
utils.export_graphs_to_txt(graphs, output_prefix) | |||
elif prog_args.graph_type == 'citeseer': | |||
graphs = utils.citeseer_ego() | |||
@@ -679,14 +680,11 @@ if __name__ == '__main__': | |||
elif not prog_args.test_file == '': | |||
# evaluate single .dat file containing list of test graphs (networkx format) | |||
graphs = utils.load_graph_list(prog_args.test_file) | |||
eval_single_list(graphs, dir_input=dir_prefix+'graphs/', dataset_name='grid') | |||
eval_single_list(graphs, dir_input=dir_prefix + 'graphs/', dataset_name='grid') | |||
## if you don't try kronecker, only the following part is needed | |||
else: | |||
if not os.path.isdir(dir_prefix+'eval_results'): | |||
os.makedirs(dir_prefix+'eval_results') | |||
evaluation(args_evaluate,dir_input=dir_prefix+"graphs/", dir_output=dir_prefix+"eval_results/", | |||
model_name_all=args_evaluate.model_name_all,dataset_name_all=args_evaluate.dataset_name_all,args=args,overwrite=True) | |||
if not os.path.isdir(dir_prefix + 'eval_results'): | |||
os.makedirs(dir_prefix + 'eval_results') | |||
evaluation(args_evaluate, dir_input=dir_prefix + "graphs/", dir_output=dir_prefix + "eval_results/", | |||
model_name_all=args_evaluate.model_name_all, dataset_name_all=args_evaluate.dataset_name_all, | |||
args=args, overwrite=True) |
@@ -21,11 +21,9 @@ from tensorboard_logger import configure, log_value | |||
import scipy.misc | |||
import time as tm | |||
from utils import * | |||
from model import * | |||
from tensorboard_logger import log_value | |||
from data import * | |||
from args import Args | |||
import create_graphs | |||
def train_vae_epoch(epoch, args, rnn, output, data_loader, | |||
@@ -47,16 +45,16 @@ def train_vae_epoch(epoch, args, rnn, output, data_loader, | |||
rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) | |||
# sort input | |||
y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) | |||
y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True) | |||
y_len = y_len.numpy().tolist() | |||
x = torch.index_select(x_unsorted,0,sort_index) | |||
y = torch.index_select(y_unsorted,0,sort_index) | |||
x = torch.index_select(x_unsorted, 0, sort_index) | |||
y = torch.index_select(y_unsorted, 0, sort_index) | |||
x = Variable(x).cuda() | |||
y = Variable(y).cuda() | |||
# if using ground truth to train | |||
h = rnn(x, pack=True, input_len=y_len) | |||
y_pred,z_mu,z_lsgms = output(h) | |||
y_pred, z_mu, z_lsgms = output(h) | |||
y_pred = F.sigmoid(y_pred) | |||
# clean | |||
y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True) | |||
@@ -68,7 +66,7 @@ def train_vae_epoch(epoch, args, rnn, output, data_loader, | |||
# use cross entropy loss | |||
loss_bce = binary_cross_entropy_weight(y_pred, y) | |||
loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp()) | |||
loss_kl /= y.size(0)*y.size(1)*sum(y_len) # normalize | |||
loss_kl /= y.size(0) * y.size(1) * sum(y_len) # normalize | |||
loss = loss_bce + loss_kl | |||
loss.backward() | |||
# update deterministic and lstm | |||
@@ -77,7 +75,6 @@ def train_vae_epoch(epoch, args, rnn, output, data_loader, | |||
scheduler_output.step() | |||
scheduler_rnn.step() | |||
z_mu_mean = torch.mean(z_mu.data) | |||
z_sgm_mean = torch.mean(z_lsgms.mul(0.5).exp_().data) | |||
z_mu_min = torch.min(z_mu.data) | |||
@@ -85,35 +82,39 @@ def train_vae_epoch(epoch, args, rnn, output, data_loader, | |||
z_mu_max = torch.max(z_mu.data) | |||
z_sgm_max = torch.max(z_lsgms.mul(0.5).exp_().data) | |||
if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics | |||
print('Epoch: {}/{}, train bce loss: {:.6f}, train kl loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | |||
epoch, args.epochs,loss_bce.data[0], loss_kl.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||
print('z_mu_mean', z_mu_mean, 'z_mu_min', z_mu_min, 'z_mu_max', z_mu_max, 'z_sgm_mean', z_sgm_mean, 'z_sgm_min', z_sgm_min, 'z_sgm_max', z_sgm_max) | |||
if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics | |||
print( | |||
'Epoch: {}/{}, train bce loss: {:.6f}, train kl loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | |||
epoch, args.epochs, loss_bce.data[0], loss_kl.data[0], args.graph_type, args.num_layers, | |||
args.hidden_size_rnn)) | |||
print('z_mu_mean', z_mu_mean, 'z_mu_min', z_mu_min, 'z_mu_max', z_mu_max, 'z_sgm_mean', z_sgm_mean, | |||
'z_sgm_min', z_sgm_min, 'z_sgm_max', z_sgm_max) | |||
# logging | |||
log_value('bce_loss_'+args.fname, loss_bce.data[0], epoch*args.batch_ratio+batch_idx) | |||
log_value('kl_loss_' +args.fname, loss_kl.data[0], epoch*args.batch_ratio + batch_idx) | |||
log_value('z_mu_mean_'+args.fname, z_mu_mean, epoch*args.batch_ratio + batch_idx) | |||
log_value('z_mu_min_'+args.fname, z_mu_min, epoch*args.batch_ratio + batch_idx) | |||
log_value('z_mu_max_'+args.fname, z_mu_max, epoch*args.batch_ratio + batch_idx) | |||
log_value('z_sgm_mean_'+args.fname, z_sgm_mean, epoch*args.batch_ratio + batch_idx) | |||
log_value('z_sgm_min_'+args.fname, z_sgm_min, epoch*args.batch_ratio + batch_idx) | |||
log_value('z_sgm_max_'+args.fname, z_sgm_max, epoch*args.batch_ratio + batch_idx) | |||
log_value('bce_loss_' + args.fname, loss_bce.data[0], epoch * args.batch_ratio + batch_idx) | |||
log_value('kl_loss_' + args.fname, loss_kl.data[0], epoch * args.batch_ratio + batch_idx) | |||
log_value('z_mu_mean_' + args.fname, z_mu_mean, epoch * args.batch_ratio + batch_idx) | |||
log_value('z_mu_min_' + args.fname, z_mu_min, epoch * args.batch_ratio + batch_idx) | |||
log_value('z_mu_max_' + args.fname, z_mu_max, epoch * args.batch_ratio + batch_idx) | |||
log_value('z_sgm_mean_' + args.fname, z_sgm_mean, epoch * args.batch_ratio + batch_idx) | |||
log_value('z_sgm_min_' + args.fname, z_sgm_min, epoch * args.batch_ratio + batch_idx) | |||
log_value('z_sgm_max_' + args.fname, z_sgm_max, epoch * args.batch_ratio + batch_idx) | |||
loss_sum += loss.data[0] | |||
return loss_sum/(batch_idx+1) | |||
return loss_sum / (batch_idx + 1) | |||
def test_vae_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False, sample_time = 1): | |||
def test_vae_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False, sample_time=1): | |||
rnn.hidden = rnn.init_hidden(test_batch_size) | |||
rnn.eval() | |||
output.eval() | |||
# generate graphs | |||
max_num_node = int(args.max_num_node) | |||
y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||
y_pred = Variable( | |||
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda() | |||
for i in range(max_num_node): | |||
h = rnn(x_step) | |||
y_pred_step, _, _ = output(h) | |||
@@ -128,7 +129,7 @@ def test_vae_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram= | |||
G_pred_list = [] | |||
for i in range(test_batch_size): | |||
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | |||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
G_pred_list.append(G_pred) | |||
# save prediction histograms, plot histogram over each time step | |||
@@ -137,11 +138,10 @@ def test_vae_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram= | |||
# fname_pred=args.figure_prediction_save_path+args.fname_pred+str(epoch)+'.jpg', | |||
# max_num_node=max_num_node) | |||
return G_pred_list | |||
def test_vae_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1): | |||
def test_vae_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False, sample_time=1): | |||
rnn.eval() | |||
output.eval() | |||
G_pred_list = [] | |||
@@ -153,15 +153,18 @@ def test_vae_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram | |||
rnn.hidden = rnn.init_hidden(test_batch_size) | |||
# generate graphs | |||
max_num_node = int(args.max_num_node) | |||
y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||
y_pred = Variable( | |||
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||
y_pred_long = Variable( | |||
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda() | |||
for i in range(max_num_node): | |||
print('finish node',i) | |||
print('finish node', i) | |||
h = rnn(x_step) | |||
y_pred_step, _, _ = output(h) | |||
y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) | |||
x_step = sample_sigmoid_supervised(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time) | |||
x_step = sample_sigmoid_supervised(y_pred_step, y[:, i:i + 1, :].cuda(), current=i, y_len=y_len, | |||
sample_time=sample_time) | |||
y_pred_long[:, i:i + 1, :] = x_step | |||
rnn.hidden = Variable(rnn.hidden.data).cuda() | |||
@@ -171,12 +174,11 @@ def test_vae_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram | |||
# save graphs as pickle | |||
for i in range(test_batch_size): | |||
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | |||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
G_pred_list.append(G_pred) | |||
return G_pred_list | |||
def train_mlp_epoch(epoch, args, rnn, output, data_loader, | |||
optimizer_rnn, optimizer_output, | |||
scheduler_rnn, scheduler_output): | |||
@@ -196,10 +198,10 @@ def train_mlp_epoch(epoch, args, rnn, output, data_loader, | |||
rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) | |||
# sort input | |||
y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) | |||
y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True) | |||
y_len = y_len.numpy().tolist() | |||
x = torch.index_select(x_unsorted,0,sort_index) | |||
y = torch.index_select(y_unsorted,0,sort_index) | |||
x = torch.index_select(x_unsorted, 0, sort_index) | |||
y = torch.index_select(y_unsorted, 0, sort_index) | |||
x = Variable(x).cuda() | |||
y = Variable(y).cuda() | |||
@@ -218,28 +220,28 @@ def train_mlp_epoch(epoch, args, rnn, output, data_loader, | |||
scheduler_output.step() | |||
scheduler_rnn.step() | |||
if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics | |||
if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics | |||
print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | |||
epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||
epoch, args.epochs, loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||
# logging | |||
log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx) | |||
log_value('loss_' + args.fname, loss.data[0], epoch * args.batch_ratio + batch_idx) | |||
loss_sum += loss.data[0] | |||
return loss_sum/(batch_idx+1) | |||
return loss_sum / (batch_idx + 1) | |||
def test_mlp_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False,sample_time=1): | |||
def test_mlp_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False, sample_time=1): | |||
rnn.hidden = rnn.init_hidden(test_batch_size) | |||
rnn.eval() | |||
output.eval() | |||
# generate graphs | |||
max_num_node = int(args.max_num_node) | |||
y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||
y_pred = Variable( | |||
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda() | |||
for i in range(max_num_node): | |||
h = rnn(x_step) | |||
y_pred_step = output(h) | |||
@@ -254,10 +256,9 @@ def test_mlp_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram= | |||
G_pred_list = [] | |||
for i in range(test_batch_size): | |||
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | |||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
G_pred_list.append(G_pred) | |||
# # save prediction histograms, plot histogram over each time step | |||
# if save_histogram: | |||
# save_prediction_histogram(y_pred_data.cpu().numpy(), | |||
@@ -266,8 +267,7 @@ def test_mlp_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram= | |||
return G_pred_list | |||
def test_mlp_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1): | |||
def test_mlp_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False, sample_time=1): | |||
rnn.eval() | |||
output.eval() | |||
G_pred_list = [] | |||
@@ -279,15 +279,18 @@ def test_mlp_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram | |||
rnn.hidden = rnn.init_hidden(test_batch_size) | |||
# generate graphs | |||
max_num_node = int(args.max_num_node) | |||
y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||
y_pred = Variable( | |||
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||
y_pred_long = Variable( | |||
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda() | |||
for i in range(max_num_node): | |||
print('finish node',i) | |||
print('finish node', i) | |||
h = rnn(x_step) | |||
y_pred_step = output(h) | |||
y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) | |||
x_step = sample_sigmoid_supervised(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time) | |||
x_step = sample_sigmoid_supervised(y_pred_step, y[:, i:i + 1, :].cuda(), current=i, y_len=y_len, | |||
sample_time=sample_time) | |||
y_pred_long[:, i:i + 1, :] = x_step | |||
rnn.hidden = Variable(rnn.hidden.data).cuda() | |||
@@ -297,12 +300,12 @@ def test_mlp_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram | |||
# save graphs as pickle | |||
for i in range(test_batch_size): | |||
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | |||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
G_pred_list.append(G_pred) | |||
return G_pred_list | |||
def test_mlp_partial_simple_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1): | |||
def test_mlp_partial_simple_epoch(epoch, args, rnn, output, data_loader, save_histogram=False, sample_time=1): | |||
rnn.eval() | |||
output.eval() | |||
G_pred_list = [] | |||
@@ -314,15 +317,18 @@ def test_mlp_partial_simple_epoch(epoch, args, rnn, output, data_loader, save_hi | |||
rnn.hidden = rnn.init_hidden(test_batch_size) | |||
# generate graphs | |||
max_num_node = int(args.max_num_node) | |||
y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||
y_pred = Variable( | |||
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||
y_pred_long = Variable( | |||
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda() | |||
for i in range(max_num_node): | |||
print('finish node',i) | |||
print('finish node', i) | |||
h = rnn(x_step) | |||
y_pred_step = output(h) | |||
y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) | |||
x_step = sample_sigmoid_supervised_simple(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time) | |||
x_step = sample_sigmoid_supervised_simple(y_pred_step, y[:, i:i + 1, :].cuda(), current=i, y_len=y_len, | |||
sample_time=sample_time) | |||
y_pred_long[:, i:i + 1, :] = x_step | |||
rnn.hidden = Variable(rnn.hidden.data).cuda() | |||
@@ -332,7 +338,7 @@ def test_mlp_partial_simple_epoch(epoch, args, rnn, output, data_loader, save_hi | |||
# save graphs as pickle | |||
for i in range(test_batch_size): | |||
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | |||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
G_pred_list.append(G_pred) | |||
return G_pred_list | |||
@@ -354,10 +360,10 @@ def train_mlp_forward_epoch(epoch, args, rnn, output, data_loader): | |||
rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) | |||
# sort input | |||
y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) | |||
y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True) | |||
y_len = y_len.numpy().tolist() | |||
x = torch.index_select(x_unsorted,0,sort_index) | |||
y = torch.index_select(y_unsorted,0,sort_index) | |||
x = torch.index_select(x_unsorted, 0, sort_index) | |||
y = torch.index_select(y_unsorted, 0, sort_index) | |||
x = Variable(x).cuda() | |||
y = Variable(y).cuda() | |||
@@ -372,22 +378,18 @@ def train_mlp_forward_epoch(epoch, args, rnn, output, data_loader): | |||
loss = 0 | |||
for j in range(y.size(1)): | |||
# print('y_pred',y_pred[0,j,:],'y',y[0,j,:]) | |||
end_idx = min(j+1,y.size(2)) | |||
loss += binary_cross_entropy_weight(y_pred[:,j,0:end_idx], y[:,j,0:end_idx])*end_idx | |||
end_idx = min(j + 1, y.size(2)) | |||
loss += binary_cross_entropy_weight(y_pred[:, j, 0:end_idx], y[:, j, 0:end_idx]) * end_idx | |||
if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics | |||
if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics | |||
print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | |||
epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||
epoch, args.epochs, loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||
# logging | |||
log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx) | |||
log_value('loss_' + args.fname, loss.data[0], epoch * args.batch_ratio + batch_idx) | |||
loss_sum += loss.data[0] | |||
return loss_sum/(batch_idx+1) | |||
return loss_sum / (batch_idx + 1) | |||
## too complicated, deprecated | |||
@@ -449,28 +451,29 @@ def train_rnn_epoch(epoch, args, rnn, output, data_loader, | |||
# output.hidden = output.init_hidden(batch_size=x_unsorted.size(0)*x_unsorted.size(1)) | |||
# sort input | |||
y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) | |||
y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True) | |||
y_len = y_len.numpy().tolist() | |||
x = torch.index_select(x_unsorted,0,sort_index) | |||
y = torch.index_select(y_unsorted,0,sort_index) | |||
x = torch.index_select(x_unsorted, 0, sort_index) | |||
y = torch.index_select(y_unsorted, 0, sort_index) | |||
# input, output for output rnn module | |||
# a smart use of pytorch builtin function: pack variable--b1_l1,b2_l1,...,b1_l2,b2_l2,... | |||
y_reshape = pack_padded_sequence(y,y_len,batch_first=True).data | |||
y_reshape = pack_padded_sequence(y, y_len, batch_first=True).data | |||
# reverse y_reshape, so that their lengths are sorted, add dimension | |||
idx = [i for i in range(y_reshape.size(0)-1, -1, -1)] | |||
idx = [i for i in range(y_reshape.size(0) - 1, -1, -1)] | |||
idx = torch.LongTensor(idx) | |||
y_reshape = y_reshape.index_select(0, idx) | |||
y_reshape = y_reshape.view(y_reshape.size(0),y_reshape.size(1),1) | |||
y_reshape = y_reshape.view(y_reshape.size(0), y_reshape.size(1), 1) | |||
output_x = torch.cat((torch.ones(y_reshape.size(0),1,1),y_reshape[:,0:-1,0:1]),dim=1) | |||
output_x = torch.cat((torch.ones(y_reshape.size(0), 1, 1), y_reshape[:, 0:-1, 0:1]), dim=1) | |||
output_y = y_reshape | |||
# batch size for output module: sum(y_len) | |||
output_y_len = [] | |||
output_y_len_bin = np.bincount(np.array(y_len)) | |||
for i in range(len(output_y_len_bin)-1,0,-1): | |||
count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i | |||
output_y_len.extend([min(i,y.size(2))]*count_temp) # put them in output_y_len; max value should not exceed y.size(2) | |||
for i in range(len(output_y_len_bin) - 1, 0, -1): | |||
count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i | |||
output_y_len.extend( | |||
[min(i, y.size(2))] * count_temp) # put them in output_y_len; max value should not exceed y.size(2) | |||
# pack into variable | |||
x = Variable(x).cuda() | |||
y = Variable(y).cuda() | |||
@@ -481,23 +484,23 @@ def train_rnn_epoch(epoch, args, rnn, output, data_loader, | |||
# print('y',y.size()) | |||
# print('output_y',output_y.size()) | |||
# if using ground truth to train | |||
h = rnn(x, pack=True, input_len=y_len) | |||
h = pack_padded_sequence(h,y_len,batch_first=True).data # get packed hidden vector | |||
h = pack_padded_sequence(h, y_len, batch_first=True).data # get packed hidden vector | |||
# reverse h | |||
idx = [i for i in range(h.size(0) - 1, -1, -1)] | |||
idx = Variable(torch.LongTensor(idx)).cuda() | |||
h = h.index_select(0, idx) | |||
hidden_null = Variable(torch.zeros(args.num_layers-1, h.size(0), h.size(1))).cuda() | |||
output.hidden = torch.cat((h.view(1,h.size(0),h.size(1)),hidden_null),dim=0) # num_layers, batch_size, hidden_size | |||
hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(1))).cuda() | |||
output.hidden = torch.cat((h.view(1, h.size(0), h.size(1)), hidden_null), | |||
dim=0) # num_layers, batch_size, hidden_size | |||
y_pred = output(output_x, pack=True, input_len=output_y_len) | |||
y_pred = F.sigmoid(y_pred) | |||
# clean | |||
y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True) | |||
y_pred = pad_packed_sequence(y_pred, batch_first=True)[0] | |||
output_y = pack_padded_sequence(output_y,output_y_len,batch_first=True) | |||
output_y = pad_packed_sequence(output_y,batch_first=True)[0] | |||
output_y = pack_padded_sequence(output_y, output_y_len, batch_first=True) | |||
output_y = pad_packed_sequence(output_y, batch_first=True)[0] | |||
# use cross entropy loss | |||
loss = binary_cross_entropy_weight(y_pred, output_y) | |||
loss.backward() | |||
@@ -507,17 +510,15 @@ def train_rnn_epoch(epoch, args, rnn, output, data_loader, | |||
scheduler_output.step() | |||
scheduler_rnn.step() | |||
if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics | |||
if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics | |||
print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | |||
epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||
epoch, args.epochs, loss.data, args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||
# logging | |||
log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx) | |||
feature_dim = y.size(1)*y.size(2) | |||
loss_sum += loss.data[0]*feature_dim | |||
return loss_sum/(batch_idx+1) | |||
log_value('loss_' + args.fname, loss.data, epoch * args.batch_ratio + batch_idx) | |||
feature_dim = y.size(1) * y.size(2) | |||
loss_sum += loss.data * feature_dim | |||
return loss_sum / (batch_idx + 1) | |||
def test_rnn_epoch(epoch, args, rnn, output, test_batch_size=16): | |||
@@ -527,20 +528,20 @@ def test_rnn_epoch(epoch, args, rnn, output, test_batch_size=16): | |||
# generate graphs | |||
max_num_node = int(args.max_num_node) | |||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||
x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda() | |||
for i in range(max_num_node): | |||
h = rnn(x_step) | |||
# output.hidden = h.permute(1,0,2) | |||
hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(2))).cuda() | |||
output.hidden = torch.cat((h.permute(1,0,2), hidden_null), | |||
output.hidden = torch.cat((h.permute(1, 0, 2), hidden_null), | |||
dim=0) # num_layers, batch_size, hidden_size | |||
x_step = Variable(torch.zeros(test_batch_size,1,args.max_prev_node)).cuda() | |||
output_x_step = Variable(torch.ones(test_batch_size,1,1)).cuda() | |||
for j in range(min(args.max_prev_node,i+1)): | |||
x_step = Variable(torch.zeros(test_batch_size, 1, args.max_prev_node)).cuda() | |||
output_x_step = Variable(torch.ones(test_batch_size, 1, 1)).cuda() | |||
for j in range(min(args.max_prev_node, i + 1)): | |||
output_y_pred_step = output(output_x_step) | |||
output_x_step = sample_sigmoid(output_y_pred_step, sample=True, sample_time=1) | |||
x_step[:,:,j:j+1] = output_x_step | |||
x_step[:, :, j:j + 1] = output_x_step | |||
output.hidden = Variable(output.hidden.data).cuda() | |||
y_pred_long[:, i:i + 1, :] = x_step | |||
rnn.hidden = Variable(rnn.hidden.data).cuda() | |||
@@ -550,14 +551,12 @@ def test_rnn_epoch(epoch, args, rnn, output, test_batch_size=16): | |||
G_pred_list = [] | |||
for i in range(test_batch_size): | |||
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | |||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||
G_pred_list.append(G_pred) | |||
return G_pred_list | |||
def train_rnn_forward_epoch(epoch, args, rnn, output, data_loader): | |||
rnn.train() | |||
output.train() | |||
@@ -576,28 +575,29 @@ def train_rnn_forward_epoch(epoch, args, rnn, output, data_loader): | |||
# output.hidden = output.init_hidden(batch_size=x_unsorted.size(0)*x_unsorted.size(1)) | |||
# sort input | |||
y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) | |||
y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True) | |||
y_len = y_len.numpy().tolist() | |||
x = torch.index_select(x_unsorted,0,sort_index) | |||
y = torch.index_select(y_unsorted,0,sort_index) | |||
x = torch.index_select(x_unsorted, 0, sort_index) | |||
y = torch.index_select(y_unsorted, 0, sort_index) | |||
# input, output for output rnn module | |||
# a smart use of pytorch builtin function: pack variable--b1_l1,b2_l1,...,b1_l2,b2_l2,... | |||
y_reshape = pack_padded_sequence(y,y_len,batch_first=True).data | |||
y_reshape = pack_padded_sequence(y, y_len, batch_first=True).data | |||
# reverse y_reshape, so that their lengths are sorted, add dimension | |||
idx = [i for i in range(y_reshape.size(0)-1, -1, -1)] | |||
idx = [i for i in range(y_reshape.size(0) - 1, -1, -1)] | |||
idx = torch.LongTensor(idx) | |||
y_reshape = y_reshape.index_select(0, idx) | |||
y_reshape = y_reshape.view(y_reshape.size(0),y_reshape.size(1),1) | |||
y_reshape = y_reshape.view(y_reshape.size(0), y_reshape.size(1), 1) | |||
output_x = torch.cat((torch.ones(y_reshape.size(0),1,1),y_reshape[:,0:-1,0:1]),dim=1) | |||
output_x = torch.cat((torch.ones(y_reshape.size(0), 1, 1), y_reshape[:, 0:-1, 0:1]), dim=1) | |||
output_y = y_reshape | |||
# batch size for output module: sum(y_len) | |||
output_y_len = [] | |||
output_y_len_bin = np.bincount(np.array(y_len)) | |||
for i in range(len(output_y_len_bin)-1,0,-1): | |||
count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i | |||
output_y_len.extend([min(i,y.size(2))]*count_temp) # put them in output_y_len; max value should not exceed y.size(2) | |||
for i in range(len(output_y_len_bin) - 1, 0, -1): | |||
count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i | |||
output_y_len.extend( | |||
[min(i, y.size(2))] * count_temp) # put them in output_y_len; max value should not exceed y.size(2) | |||
# pack into variable | |||
x = Variable(x).cuda() | |||
y = Variable(y).cuda() | |||
@@ -608,37 +608,36 @@ def train_rnn_forward_epoch(epoch, args, rnn, output, data_loader): | |||
# print('y',y.size()) | |||
# print('output_y',output_y.size()) | |||
# if using ground truth to train | |||
h = rnn(x, pack=True, input_len=y_len) | |||
h = pack_padded_sequence(h,y_len,batch_first=True).data # get packed hidden vector | |||
h = pack_padded_sequence(h, y_len, batch_first=True).data # get packed hidden vector | |||
# reverse h | |||
idx = [i for i in range(h.size(0) - 1, -1, -1)] | |||
idx = Variable(torch.LongTensor(idx)).cuda() | |||
h = h.index_select(0, idx) | |||
hidden_null = Variable(torch.zeros(args.num_layers-1, h.size(0), h.size(1))).cuda() | |||
output.hidden = torch.cat((h.view(1,h.size(0),h.size(1)),hidden_null),dim=0) # num_layers, batch_size, hidden_size | |||
hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(1))).cuda() | |||
output.hidden = torch.cat((h.view(1, h.size(0), h.size(1)), hidden_null), | |||
dim=0) # num_layers, batch_size, hidden_size | |||
y_pred = output(output_x, pack=True, input_len=output_y_len) | |||
y_pred = F.sigmoid(y_pred) | |||
# clean | |||
y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True) | |||
y_pred = pad_packed_sequence(y_pred, batch_first=True)[0] | |||
output_y = pack_padded_sequence(output_y,output_y_len,batch_first=True) | |||
output_y = pad_packed_sequence(output_y,batch_first=True)[0] | |||
output_y = pack_padded_sequence(output_y, output_y_len, batch_first=True) | |||
output_y = pad_packed_sequence(output_y, batch_first=True)[0] | |||
# use cross entropy loss | |||
loss = binary_cross_entropy_weight(y_pred, output_y) | |||
if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics | |||
if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics | |||
print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | |||
epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||
epoch, args.epochs, loss.data, args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||
# logging | |||
log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx) | |||
log_value('loss_' + args.fname, loss.data, epoch * args.batch_ratio + batch_idx) | |||
# print(y_pred.size()) | |||
feature_dim = y_pred.size(0)*y_pred.size(1) | |||
loss_sum += loss.data[0]*feature_dim/y.size(0) | |||
return loss_sum/(batch_idx+1) | |||
feature_dim = y_pred.size(0) * y_pred.size(1) | |||
loss_sum += loss.data * feature_dim / y.size(0) | |||
return loss_sum / (batch_idx + 1) | |||
########### train function for LSTM + VAE | |||
@@ -665,7 +664,7 @@ def train(args, dataset_train, rnn, output): | |||
# start main loop | |||
time_all = np.zeros(args.epochs) | |||
while epoch<=args.epochs: | |||
while epoch <= args.epochs: | |||
time_start = tm.time() | |||
# train | |||
if 'GraphRNN_VAE' in args.note: | |||
@@ -683,25 +682,26 @@ def train(args, dataset_train, rnn, output): | |||
time_end = tm.time() | |||
time_all[epoch - 1] = time_end - time_start | |||
# test | |||
if epoch % args.epochs_test == 0 and epoch>=args.epochs_test_start: | |||
for sample_time in range(1,4): | |||
if epoch % args.epochs_test == 0 and epoch >= args.epochs_test_start: | |||
for sample_time in range(1, 4): | |||
G_pred = [] | |||
while len(G_pred)<args.test_total_size: | |||
while len(G_pred) < args.test_total_size: | |||
if 'GraphRNN_VAE' in args.note: | |||
G_pred_step = test_vae_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size,sample_time=sample_time) | |||
G_pred_step = test_vae_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size, | |||
sample_time=sample_time) | |||
elif 'GraphRNN_MLP' in args.note: | |||
G_pred_step = test_mlp_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size,sample_time=sample_time) | |||
G_pred_step = test_mlp_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size, | |||
sample_time=sample_time) | |||
elif 'GraphRNN_RNN' in args.note: | |||
G_pred_step = test_rnn_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size) | |||
G_pred.extend(G_pred_step) | |||
# save graphs | |||
fname = args.graph_save_path + args.fname_pred + str(epoch) +'_'+str(sample_time) + '.dat' | |||
fname = args.graph_save_path + args.fname_pred + str(epoch) + '_' + str(sample_time) + '.dat' | |||
save_graph_list(G_pred, fname) | |||
if 'GraphRNN_RNN' in args.note: | |||
break | |||
print('test done, graphs saved') | |||
# save model checkpoint | |||
if args.save: | |||
if epoch % args.epochs_save == 0: | |||
@@ -710,7 +710,7 @@ def train(args, dataset_train, rnn, output): | |||
fname = args.model_save_path + args.fname + 'output_' + str(epoch) + '.dat' | |||
torch.save(output.state_dict(), fname) | |||
epoch += 1 | |||
np.save(args.timing_save_path+args.fname,time_all) | |||
np.save(args.timing_save_path + args.fname, time_all) | |||
########### for graph completion task | |||
@@ -723,19 +723,19 @@ def train_graph_completion(args, dataset_test, rnn, output): | |||
epoch = args.load_epoch | |||
print('model loaded!, epoch: {}'.format(args.load_epoch)) | |||
for sample_time in range(1,4): | |||
for sample_time in range(1, 4): | |||
if 'GraphRNN_MLP' in args.note: | |||
G_pred = test_mlp_partial_simple_epoch(epoch, args, rnn, output, dataset_test,sample_time=sample_time) | |||
G_pred = test_mlp_partial_simple_epoch(epoch, args, rnn, output, dataset_test, sample_time=sample_time) | |||
if 'GraphRNN_VAE' in args.note: | |||
G_pred = test_vae_partial_epoch(epoch, args, rnn, output, dataset_test,sample_time=sample_time) | |||
G_pred = test_vae_partial_epoch(epoch, args, rnn, output, dataset_test, sample_time=sample_time) | |||
# save graphs | |||
fname = args.graph_save_path + args.fname_pred + str(epoch) +'_'+str(sample_time) + 'graph_completion.dat' | |||
fname = args.graph_save_path + args.fname_pred + str(epoch) + '_' + str(sample_time) + 'graph_completion.dat' | |||
save_graph_list(G_pred, fname) | |||
print('graph completion done, graphs saved') | |||
########### for NLL evaluation | |||
def train_nll(args, dataset_train, dataset_test, rnn, output,graph_validate_len,graph_test_len, max_iter = 1000): | |||
def train_nll(args, dataset_train, dataset_test, rnn, output, graph_validate_len, graph_test_len, max_iter=1000): | |||
fname = args.model_save_path + args.fname + 'lstm_' + str(args.load_epoch) + '.dat' | |||
rnn.load_state_dict(torch.load(fname)) | |||
fname = args.model_save_path + args.fname + 'output_' + str(args.load_epoch) + '.dat' | |||
@@ -745,7 +745,7 @@ def train_nll(args, dataset_train, dataset_test, rnn, output,graph_validate_len, | |||
print('model loaded!, epoch: {}'.format(args.load_epoch)) | |||
fname_output = args.nll_save_path + args.note + '_' + args.graph_type + '.csv' | |||
with open(fname_output, 'w+') as f: | |||
f.write(str(graph_validate_len)+','+str(graph_test_len)+'\n') | |||
f.write(str(graph_validate_len) + ',' + str(graph_test_len) + '\n') | |||
f.write('train,test\n') | |||
for iter in range(max_iter): | |||
if 'GraphRNN_MLP' in args.note: | |||
@@ -754,7 +754,7 @@ def train_nll(args, dataset_train, dataset_test, rnn, output,graph_validate_len, | |||
if 'GraphRNN_RNN' in args.note: | |||
nll_train = train_rnn_forward_epoch(epoch, args, rnn, output, dataset_train) | |||
nll_test = train_rnn_forward_epoch(epoch, args, rnn, output, dataset_test) | |||
print('train',nll_train,'test',nll_test) | |||
f.write(str(nll_train)+','+str(nll_test)+'\n') | |||
print('train', nll_train, 'test', nll_test) | |||
f.write(str(nll_train) + ',' + str(nll_test) + '\n') | |||
print('NLL evaluation done') |
@@ -16,6 +16,7 @@ import re | |||
import data | |||
def citeseer_ego(): | |||
_, _, G = data.Graph_load(dataset='citeseer') | |||
G = max(nx.connected_component_subgraphs(G), key=len) | |||
@@ -27,12 +28,13 @@ def citeseer_ego(): | |||
graphs.append(G_ego) | |||
return graphs | |||
def caveman_special(c=2,k=20,p_path=0.1,p_edge=0.3): | |||
def caveman_special(c=2, k=20, p_path=0.1, p_edge=0.3): | |||
p = p_path | |||
path_count = max(int(np.ceil(p * k)),1) | |||
path_count = max(int(np.ceil(p * k)), 1) | |||
G = nx.caveman_graph(c, k) | |||
# remove 50% edges | |||
p = 1-p_edge | |||
p = 1 - p_edge | |||
for (u, v) in list(G.edges()): | |||
if np.random.rand() < p and ((u < k and v < k) or (u >= k and v >= k)): | |||
G.remove_edge(u, v) | |||
@@ -44,6 +46,7 @@ def caveman_special(c=2,k=20,p_path=0.1,p_edge=0.3): | |||
G = max(nx.connected_component_subgraphs(G), key=len) | |||
return G | |||
def n_community(c_sizes, p_inter=0.01): | |||
graphs = [nx.gnp_random_graph(c_sizes[i], 0.7, seed=i) for i in range(len(c_sizes))] | |||
G = nx.disjoint_union_all(graphs) | |||
@@ -51,7 +54,7 @@ def n_community(c_sizes, p_inter=0.01): | |||
for i in range(len(communities)): | |||
subG1 = communities[i] | |||
nodes1 = list(subG1.nodes()) | |||
for j in range(i+1, len(communities)): | |||
for j in range(i + 1, len(communities)): | |||
subG2 = communities[j] | |||
nodes2 = list(subG2.nodes()) | |||
has_inter_edge = False | |||
@@ -62,9 +65,10 @@ def n_community(c_sizes, p_inter=0.01): | |||
has_inter_edge = True | |||
if not has_inter_edge: | |||
G.add_edge(nodes1[0], nodes2[0]) | |||
#print('connected comp: ', len(list(nx.connected_component_subgraphs(G)))) | |||
# print('connected comp: ', len(list(nx.connected_component_subgraphs(G)))) | |||
return G | |||
def perturb(graph_list, p_del, p_add=None): | |||
''' Perturb the list of graphs by adding/removing edges. | |||
Args: | |||
@@ -87,7 +91,7 @@ def perturb(graph_list, p_del, p_add=None): | |||
if p_add is None: | |||
num_nodes = G.number_of_nodes() | |||
p_add_est = np.sum(trials) / (num_nodes * (num_nodes - 1) / 2 - | |||
G.number_of_edges()) | |||
G.number_of_edges()) | |||
else: | |||
p_add_est = p_add | |||
@@ -97,7 +101,7 @@ def perturb(graph_list, p_del, p_add=None): | |||
u = nodes[i] | |||
trials = np.random.binomial(1, p_add_est, size=G.number_of_nodes()) | |||
j = 0 | |||
for j in range(i+1, len(nodes)): | |||
for j in range(i + 1, len(nodes)): | |||
v = nodes[j] | |||
if trials[j] == 1: | |||
tmp += 1 | |||
@@ -108,7 +112,6 @@ def perturb(graph_list, p_del, p_add=None): | |||
return perturbed_graph_list | |||
def perturb_new(graph_list, p): | |||
''' Perturb the list of graphs by adding/removing edges. | |||
Args: | |||
@@ -123,7 +126,7 @@ def perturb_new(graph_list, p): | |||
G = G_original.copy() | |||
edge_remove_count = 0 | |||
for (u, v) in list(G.edges()): | |||
if np.random.rand()<p: | |||
if np.random.rand() < p: | |||
G.remove_edge(u, v) | |||
edge_remove_count += 1 | |||
# randomly add the edges back | |||
@@ -131,16 +134,13 @@ def perturb_new(graph_list, p): | |||
while True: | |||
u = np.random.randint(0, G.number_of_nodes()) | |||
v = np.random.randint(0, G.number_of_nodes()) | |||
if (not G.has_edge(u,v)) and (u!=v): | |||
if (not G.has_edge(u, v)) and (u != v): | |||
break | |||
G.add_edge(u, v) | |||
perturbed_graph_list.append(G) | |||
return perturbed_graph_list | |||
def imsave(fname, arr, vmin=None, vmax=None, cmap=None, format=None, origin=None): | |||
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas | |||
from matplotlib.figure import Figure | |||
@@ -162,7 +162,7 @@ def save_prediction_histogram(y_pred_data, fname_pred, max_num_node, bin_n=20): | |||
# draw a single graph G | |||
def draw_graph(G, prefix = 'test'): | |||
def draw_graph(G, prefix='test'): | |||
parts = community.best_partition(G) | |||
values = [parts.get(node) for node in G.nodes()] | |||
colors = [] | |||
@@ -187,8 +187,7 @@ def draw_graph(G, prefix = 'test'): | |||
plt.axis("off") | |||
pos = nx.spring_layout(G) | |||
nx.draw_networkx(G, with_labels=True, node_size=35, node_color=colors,pos=pos) | |||
nx.draw_networkx(G, with_labels=True, node_size=35, node_color=colors, pos=pos) | |||
# plt.switch_backend('agg') | |||
# options = { | |||
@@ -199,14 +198,14 @@ def draw_graph(G, prefix = 'test'): | |||
# plt.figure() | |||
# plt.subplot() | |||
# nx.draw_networkx(G, **options) | |||
plt.savefig('figures/graph_view_'+prefix+'.png', dpi=200) | |||
plt.savefig('figures/graph_view_' + prefix + '.png', dpi=200) | |||
plt.close() | |||
plt.switch_backend('agg') | |||
G_deg = nx.degree_histogram(G) | |||
G_deg = np.array(G_deg) | |||
# plt.plot(range(len(G_deg)), G_deg, 'r', linewidth = 2) | |||
plt.loglog(np.arange(len(G_deg))[G_deg>0], G_deg[G_deg>0], 'r', linewidth=2) | |||
plt.loglog(np.arange(len(G_deg))[G_deg > 0], G_deg[G_deg > 0], 'r', linewidth=2) | |||
plt.savefig('figures/degree_view_' + prefix + '.png', dpi=200) | |||
plt.close() | |||
@@ -225,15 +224,16 @@ def draw_graph(G, prefix = 'test'): | |||
# draw a list of graphs [G] | |||
def draw_graph_list(G_list, row, col, fname = 'figures/test', layout='spring', is_single=False,k=1,node_size=55,alpha=1,width=1.3): | |||
def draw_graph_list(G_list, row, col, fname='figures/test', layout='spring', is_single=False, k=1, node_size=55, | |||
alpha=1, width=1.3): | |||
# # draw graph view | |||
# from pylab import rcParams | |||
# rcParams['figure.figsize'] = 12,3 | |||
plt.switch_backend('agg') | |||
for i,G in enumerate(G_list): | |||
plt.subplot(row,col,i+1) | |||
for i, G in enumerate(G_list): | |||
plt.subplot(row, col, i + 1) | |||
plt.subplots_adjust(left=0, bottom=0, right=1, top=1, | |||
wspace=0, hspace=0) | |||
wspace=0, hspace=0) | |||
# if i%2==0: | |||
# plt.title('real nodes: '+str(G.number_of_nodes()), fontsize = 4) | |||
# else: | |||
@@ -260,28 +260,29 @@ def draw_graph_list(G_list, row, col, fname = 'figures/test', layout='spring', i | |||
# if values[i] == 6: | |||
# colors.append('black') | |||
plt.axis("off") | |||
if layout=='spring': | |||
pos = nx.spring_layout(G,k=k/np.sqrt(G.number_of_nodes()),iterations=100) | |||
if layout == 'spring': | |||
pos = nx.spring_layout(G, k=k / np.sqrt(G.number_of_nodes()), iterations=100) | |||
# pos = nx.spring_layout(G) | |||
elif layout=='spectral': | |||
elif layout == 'spectral': | |||
pos = nx.spectral_layout(G) | |||
# # nx.draw_networkx(G, with_labels=True, node_size=2, width=0.15, font_size = 1.5, node_color=colors,pos=pos) | |||
# nx.draw_networkx(G, with_labels=False, node_size=1.5, width=0.2, font_size = 1.5, linewidths=0.2, node_color = 'k',pos=pos,alpha=0.2) | |||
if is_single: | |||
# node_size default 60, edge_width default 1.5 | |||
nx.draw_networkx_nodes(G, pos, node_size=node_size, node_color='#336699', alpha=1, linewidths=0, font_size=0) | |||
nx.draw_networkx_nodes(G, pos, node_size=node_size, node_color='#336699', alpha=1, linewidths=0, | |||
font_size=0) | |||
nx.draw_networkx_edges(G, pos, alpha=alpha, width=width) | |||
else: | |||
nx.draw_networkx_nodes(G, pos, node_size=1.5, node_color='#336699',alpha=1, linewidths=0.2, font_size = 1.5) | |||
nx.draw_networkx_edges(G, pos, alpha=0.3,width=0.2) | |||
nx.draw_networkx_nodes(G, pos, node_size=1.5, node_color='#336699', alpha=1, linewidths=0.2, font_size=1.5) | |||
nx.draw_networkx_edges(G, pos, alpha=0.3, width=0.2) | |||
# plt.axis('off') | |||
# plt.title('Complete Graph of Odd-degree Nodes') | |||
# plt.show() | |||
plt.tight_layout() | |||
plt.savefig(fname+'.png', dpi=600) | |||
plt.savefig(fname + '.png', dpi=600) | |||
plt.close() | |||
# # draw degree distribution | |||
@@ -376,8 +377,6 @@ def draw_graph_list(G_list, row, col, fname = 'figures/test', layout='spring', i | |||
# plt.savefig(fname+'_community.png', dpi=600) | |||
# plt.close() | |||
# plt.switch_backend('agg') | |||
# G_deg = nx.degree_histogram(G) | |||
# G_deg = np.array(G_deg) | |||
@@ -395,7 +394,6 @@ def draw_graph_list(G_list, row, col, fname = 'figures/test', layout='spring', i | |||
# plt.close() | |||
# directly get graph statistics from adj, obsoleted | |||
def decode_graph(adj, prefix): | |||
adj = np.asmatrix(adj) | |||
@@ -433,6 +431,7 @@ def get_graph(adj): | |||
G = nx.from_numpy_matrix(adj) | |||
return G | |||
# save a list of graphs | |||
def save_graph_list(G_list, fname): | |||
with open(fname, "wb") as f: | |||
@@ -441,28 +440,30 @@ def save_graph_list(G_list, fname): | |||
# pick the first connected component | |||
def pick_connected_component(G): | |||
node_list = nx.node_connected_component(G,0) | |||
node_list = nx.node_connected_component(G, 0) | |||
return G.subgraph(node_list) | |||
def pick_connected_component_new(G): | |||
adj_list = G.adjacency_list() | |||
for id,adj in enumerate(adj_list): | |||
for id, adj in enumerate(adj_list): | |||
id_min = min(adj) | |||
if id<id_min and id>=1: | |||
# if id<id_min and id>=4: | |||
if id < id_min and id >= 1: | |||
# if id<id_min and id>=4: | |||
break | |||
node_list = list(range(id)) # only include node prior than node "id" | |||
node_list = list(range(id)) # only include node prior than node "id" | |||
G = G.subgraph(node_list) | |||
G = max(nx.connected_component_subgraphs(G), key=len) | |||
return G | |||
# load a list of graphs | |||
def load_graph_list(fname,is_real=True): | |||
def load_graph_list(fname, is_real=True): | |||
with open(fname, "rb") as f: | |||
graph_list = pickle.load(f) | |||
for i in range(len(graph_list)): | |||
edges_with_selfloops = graph_list[i].selfloop_edges() | |||
if len(edges_with_selfloops)>0: | |||
if len(edges_with_selfloops) > 0: | |||
graph_list[i].remove_edges_from(edges_with_selfloops) | |||
if is_real: | |||
graph_list[i] = max(nx.connected_component_subgraphs(graph_list[i]), key=len) | |||
@@ -482,6 +483,7 @@ def export_graphs_to_txt(g_list, output_filename_prefix): | |||
f.write(str(idx_u) + '\t' + str(idx_v) + '\n') | |||
i += 1 | |||
def snap_txt_output_to_nx(in_fname): | |||
G = nx.Graph() | |||
with open(in_fname, 'r') as f: | |||
@@ -496,23 +498,23 @@ def snap_txt_output_to_nx(in_fname): | |||
G.add_edge(int(u), int(v)) | |||
return G | |||
def test_perturbed(): | |||
graphs = [] | |||
for i in range(100,101): | |||
for j in range(4,5): | |||
for i in range(100, 101): | |||
for j in range(4, 5): | |||
for k in range(500): | |||
graphs.append(nx.barabasi_albert_graph(i,j)) | |||
graphs.append(nx.barabasi_albert_graph(i, j)) | |||
g_perturbed = perturb(graphs, 0.9) | |||
print([g.number_of_edges() for g in graphs]) | |||
print([g.number_of_edges() for g in g_perturbed]) | |||
if __name__ == '__main__': | |||
#test_perturbed() | |||
#graphs = load_graph_list('graphs/' + 'GraphRNN_RNN_community4_4_128_train_0.dat') | |||
#graphs = load_graph_list('graphs/' + 'GraphRNN_RNN_community4_4_128_pred_2500_1.dat') | |||
graphs = load_graph_list('eval_results/mmsb/' + 'community41.dat') | |||
for i in range(0, 160, 16): | |||
draw_graph_list(graphs[i:i+16], 4, 4, fname='figures/community4_' + str(i)) | |||
# test_perturbed() | |||
graphs = load_graph_list('graphs/' + 'GraphRNN_RNN_grid_4_128_train_0.dat') | |||
# graphs = load_graph_list('graphs/' + 'GraphRNN_RNN_community4_4_128_pred_2500_1.dat') | |||
# graphs = load_graph_list('eval_results/mmsb/' + 'community41.dat') | |||
for i in range(0, 160, 16): | |||
draw_graph_list(graphs[i:i + 16], 4, 4, fname='figures/community4_' + str(i)) |