Browse Source

refactor code to make it more readable

master
Ali Amiri 5 years ago
parent
commit
85ae9e5daa
5 changed files with 504 additions and 555 deletions
  1. 2
    2
      args.py
  2. 163
    214
      data.py
  3. 133
    135
      evaluate.py
  4. 154
    154
      train.py
  5. 52
    50
      utils.py

+ 2
- 2
args.py View File

@@ -5,7 +5,7 @@ class Args():
### if clean tensorboard
self.clean_tensorboard = False
### Which CUDA GPU device is used for training
self.cuda = 1
self.cuda = 0

### Which GraphRNN model variant is used.
# The simple version of Graph RNN
@@ -88,7 +88,7 @@ class Args():


self.load = False # if load model, default lr is very low
self.load_epoch = 3000
self.load_epoch = 100
self.save = True



+ 163
- 214
data.py
File diff suppressed because it is too large
View File


+ 133
- 135
evaluate.py View File

@@ -1,14 +1,15 @@
import argparse
import numpy as np
import logging
import os
import re
from random import shuffle
from time import strftime, gmtime

import eval.stats
import utils
# import main.Args
from args import Args
from baselines.baseline_simple import *


class Args_evaluate():
def __init__(self):
# loop over the settings
@@ -19,19 +20,22 @@ class Args_evaluate():

# list of dataset to evaluate
# use a list of 1 element to evaluate a single dataset
self.dataset_name_all = ['caveman', 'grid', 'barabasi', 'citeseer', 'DD']
# self.dataset_name_all = ['caveman', 'grid', 'barabasi', 'citeseer', 'DD']
self.dataset_name_all = ['caveman_small']
# self.dataset_name_all = ['citeseer_small','caveman_small']
# self.dataset_name_all = ['barabasi_noise0','barabasi_noise2','barabasi_noise4','barabasi_noise6','barabasi_noise8','barabasi_noise10']
# self.dataset_name_all = ['caveman_small', 'ladder_small', 'grid_small', 'ladder_small', 'enzymes_small', 'barabasi_small','citeseer_small']

self.epoch_start=100
self.epoch_end=3001
self.epoch_step=100
self.epoch_start = 10
self.epoch_end = 101
self.epoch_step = 10


def find_nearest_idx(array,value):
idx = (np.abs(array-value)).argmin()
def find_nearest_idx(array, value):
idx = (np.abs(array - value)).argmin()
return idx


def extract_result_id_and_epoch(name, prefix, suffix):
'''
Args:
@@ -68,7 +72,7 @@ def eval_list(real_graphs_filename, pred_graphs_filename, prefix, eval_every):
if result_id not in pred_graphs_dict:
pred_graphs_dict[result_id] = {}
pred_graphs_dict[result_id][epochs] = fname
for result_id in real_graphs_dict.keys():
for epochs in sorted(real_graphs_dict[result_id]):
real_g_list = utils.load_graph_list(real_graphs_dict[result_id][epochs])
@@ -77,16 +81,16 @@ def eval_list(real_graphs_filename, pred_graphs_filename, prefix, eval_every):
shuffle(pred_g_list)
perturbed_g_list = perturb(real_g_list, 0.05)

#dist = eval.stats.degree_stats(real_g_list, pred_g_list)
# dist = eval.stats.degree_stats(real_g_list, pred_g_list)
dist = eval.stats.clustering_stats(real_g_list, pred_g_list)
print('dist between real and pred (', result_id, ') at epoch ', epochs, ': ', dist)
#dist = eval.stats.degree_stats(real_g_list, perturbed_g_list)
# dist = eval.stats.degree_stats(real_g_list, perturbed_g_list)
dist = eval.stats.clustering_stats(real_g_list, perturbed_g_list)
print('dist between real and perturbed: ', dist)

mid = len(real_g_list) // 2
#dist = eval.stats.degree_stats(real_g_list[:mid], real_g_list[mid:])
# dist = eval.stats.degree_stats(real_g_list[:mid], real_g_list[mid:])
dist = eval.stats.clustering_stats(real_g_list[:mid], real_g_list[mid:])
print('dist among real: ', dist)

@@ -97,7 +101,6 @@ def compute_basic_stats(real_g_list, target_g_list):
return dist_degree, dist_clustering



def clean_graphs(graph_real, graph_pred):
''' Selecting graphs generated that have the similar sizes.
It is usually necessary for GraphRNN-S version, but not the full GraphRNN model.
@@ -120,6 +123,7 @@ def clean_graphs(graph_real, graph_pred):

return graph_real, pred_graph_new


def load_ground_truth(dir_input, dataset_name, model_name='GraphRNN_RNN'):
''' Read ground truth graphs.
'''
@@ -127,20 +131,21 @@ def load_ground_truth(dir_input, dataset_name, model_name='GraphRNN_RNN'):
hidden = 128
else:
hidden = 64
if model_name=='Internal' or model_name=='Noise' or model_name=='B-A' or model_name=='E-R':
if model_name == 'Internal' or model_name == 'Noise' or model_name == 'B-A' or model_name == 'E-R':
fname_test = dir_input + 'GraphRNN_MLP' + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(
hidden) + '_test_' + str(0) + '.dat'
hidden) + '_test_' + str(0) + '.dat'
else:
fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(
hidden) + '_test_' + str(0) + '.dat'
hidden) + '_test_' + str(0) + '.dat'
try:
graph_test = utils.load_graph_list(fname_test,is_real=True)
graph_test = utils.load_graph_list(fname_test, is_real=True)
except:
print('Not found: ' + fname_test)
logging.warning('Not found: ' + fname_test)
return None
return graph_test


def eval_single_list(graphs, dir_input, dataset_name):
''' Evaluate a list of graphs by comparing with graphs in directory dir_input.
Args:
@@ -149,7 +154,7 @@ def eval_single_list(graphs, dir_input, dataset_name):
'''
graph_test = load_ground_truth(dir_input, dataset_name)
graph_test_len = len(graph_test)
graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set
graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set
mmd_degree = eval.stats.degree_stats(graph_test, graphs)
mmd_clustering = eval.stats.clustering_stats(graph_test, graphs)
try:
@@ -160,9 +165,12 @@ def eval_single_list(graphs, dir_input, dataset_name):
print('clustering: ', mmd_clustering)
print('orbits: ', mmd_4orbits)

def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is_clean=True, epoch_start=1000,epoch_end=3001,epoch_step=100):

def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is_clean=True, epoch_start=1000,
epoch_end=3001, epoch_step=100):
with open(fname_output, 'w+') as f:
f.write('sample_time,epoch,degree_validate,clustering_validate,orbits4_validate,degree_test,clustering_test,orbits4_test\n')
f.write(
'sample_time,epoch,degree_validate,clustering_validate,orbits4_validate,degree_test,clustering_test,orbits4_test\n')

# TODO: Maybe refactor into a separate file/function that specifies THE naming convention
# across main and evaluate
@@ -171,7 +179,7 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is
else:
hidden = 64
# read real graph
if model_name=='Internal' or model_name=='Noise' or model_name=='B-A' or model_name=='E-R':
if model_name == 'Internal' or model_name == 'Noise' or model_name == 'B-A' or model_name == 'E-R':
fname_test = dir_input + 'GraphRNN_MLP' + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(
hidden) + '_test_' + str(0) + '.dat'
elif 'Baseline' in model_name:
@@ -180,37 +188,37 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is
fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(
hidden) + '_test_' + str(0) + '.dat'
try:
graph_test = utils.load_graph_list(fname_test,is_real=True)
graph_test = utils.load_graph_list(fname_test, is_real=True)
except:
print('Not found: ' + fname_test)
logging.warning('Not found: ' + fname_test)
return None

graph_test_len = len(graph_test)
graph_train = graph_test[0:int(0.8 * graph_test_len)] # train
graph_validate = graph_test[0:int(0.2 * graph_test_len)] # validate
graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set
graph_train = graph_test[0:int(0.8 * graph_test_len)] # train
graph_validate = graph_test[0:int(0.2 * graph_test_len)] # validate
graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set

graph_test_aver = 0
for graph in graph_test:
graph_test_aver+=graph.number_of_nodes()
graph_test_aver += graph.number_of_nodes()
graph_test_aver /= len(graph_test)
print('test average len',graph_test_aver)

print('test average len', graph_test_aver)

# get performance for proposed approaches
if 'GraphRNN' in model_name:
# read test graph
for epoch in range(epoch_start,epoch_end,epoch_step):
for sample_time in range(1,4):
for epoch in range(epoch_start, epoch_end, epoch_step):
for sample_time in range(1, 4):
# get filename
fname_pred = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(hidden) + '_pred_' + str(epoch) + '_' + str(sample_time) + '.dat'
fname_pred = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(
hidden) + '_pred_' + str(epoch) + '_' + str(sample_time) + '.dat'
# load graphs
try:
graph_pred = utils.load_graph_list(fname_pred,is_real=False) # default False
graph_pred = utils.load_graph_list(fname_pred, is_real=False) # default False
except:
print('Not found: '+ fname_pred)
logging.warning('Not found: '+ fname_pred)
print('Not found: ' + fname_pred)
logging.warning('Not found: ' + fname_pred)
continue
# clean graphs
if is_clean:
@@ -243,15 +251,15 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is
except:
mmd_4orbits_validate = -1
# write results
f.write(str(sample_time)+','+
str(epoch)+','+
str(mmd_degree_validate)+','+
str(mmd_clustering_validate)+','+
str(mmd_4orbits_validate)+','+
str(mmd_degree)+','+
str(mmd_clustering)+','+
str(mmd_4orbits)+'\n')
print('degree',mmd_degree,'clustering',mmd_clustering,'orbits',mmd_4orbits)
f.write(str(sample_time) + ',' +
str(epoch) + ',' +
str(mmd_degree_validate) + ',' +
str(mmd_clustering_validate) + ',' +
str(mmd_4orbits_validate) + ',' +
str(mmd_degree) + ',' +
str(mmd_clustering) + ',' +
str(mmd_4orbits) + '\n')
print('degree', mmd_degree, 'clustering', mmd_clustering, 'orbits', mmd_4orbits)

# get internal MMD (MMD between ground truth validation and test sets)
if model_name == 'Internal':
@@ -265,7 +273,6 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is
mmd_clustering_validate) + ',' + str(mmd_4orbits_validate)
+ ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + '\n')


# get MMD between ground truth and its perturbed graphs
if model_name == 'Noise':
graph_validate_perturbed = perturb(graph_validate, 0.05)
@@ -281,7 +288,7 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is

# get E-R MMD
if model_name == 'E-R':
graph_pred = Graph_generator_baseline(graph_train,generator='Gnp')
graph_pred = Graph_generator_baseline(graph_train, generator='Gnp')
# clean graphs
if is_clean:
graph_test, graph_pred = clean_graphs(graph_test, graph_pred)
@@ -296,7 +303,6 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is
f.write(str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1)
+ ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits_validate) + '\n')


# get B-A MMD
if model_name == 'B-A':
graph_pred = Graph_generator_baseline(graph_train, generator='BA')
@@ -364,33 +370,29 @@ def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is
+ ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits) + '\n')
print('degree', mmd_degree, 'clustering', mmd_clustering, 'orbits', mmd_4orbits)



return True

def evaluation(args_evaluate,dir_input, dir_output, model_name_all, dataset_name_all, args, overwrite = True):

def evaluation(args_evaluate, dir_input, dir_output, model_name_all, dataset_name_all, args, overwrite=True):
''' Evaluate the performance of a set of models on a set of datasets.
'''
for model_name in model_name_all:
for dataset_name in dataset_name_all:
# check output exist
fname_output = dir_output+model_name+'_'+dataset_name+'.csv'
print('processing: '+dir_output + model_name + '_' + dataset_name + '.csv')
logging.info('processing: '+dir_output + model_name + '_' + dataset_name + '.csv')
if overwrite==False and os.path.isfile(fname_output):
print(dir_output+model_name+'_'+dataset_name+'.csv exists!')
logging.info(dir_output+model_name+'_'+dataset_name+'.csv exists!')
fname_output = dir_output + model_name + '_' + dataset_name + '.csv'
print('processing: ' + dir_output + model_name + '_' + dataset_name + '.csv')
logging.info('processing: ' + dir_output + model_name + '_' + dataset_name + '.csv')
if overwrite == False and os.path.isfile(fname_output):
print(dir_output + model_name + '_' + dataset_name + '.csv exists!')
logging.info(dir_output + model_name + '_' + dataset_name + '.csv exists!')
continue
evaluation_epoch(dir_input,fname_output,model_name,dataset_name,args,is_clean=True, epoch_start=args_evaluate.epoch_start,epoch_end=args_evaluate.epoch_end,epoch_step=args_evaluate.epoch_step)





evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is_clean=True,
epoch_start=args_evaluate.epoch_start, epoch_end=args_evaluate.epoch_end,
epoch_step=args_evaluate.epoch_step)


def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines,
eval_every, epoch_range=None, out_file_prefix=None):
eval_every, epoch_range=None, out_file_prefix=None):
''' Evaluate list of predicted graphs compared to ground truth, stored in files.
Args:
baselines: dict mapping name of the baseline to list of generated graphs.
@@ -398,12 +400,12 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines,

if out_file_prefix is not None:
out_files = {
'train': open(out_file_prefix + '_train.txt', 'w+'),
'compare': open(out_file_prefix + '_compare.txt', 'w+')
'train': open(out_file_prefix + '_train.txt', 'w+'),
'compare': open(out_file_prefix + '_compare.txt', 'w+')
}

out_files['train'].write('degree,clustering,orbits4\n')
line = 'metric,real,ours,perturbed'
for bl in baselines:
line += ',' + bl
@@ -411,34 +413,33 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines,
out_files['compare'].write(line)

results = {
'deg': {
'real': 0,
'ours': 100, # take min over all training epochs
'perturbed': 0,
'kron': 0},
'clustering': {
'real': 0,
'ours': 100,
'perturbed': 0,
'kron': 0},
'orbits4': {
'real': 0,
'ours': 100,
'perturbed': 0,
'kron': 0}
'deg': {
'real': 0,
'ours': 100, # take min over all training epochs
'perturbed': 0,
'kron': 0},
'clustering': {
'real': 0,
'ours': 100,
'perturbed': 0,
'kron': 0},
'orbits4': {
'real': 0,
'ours': 100,
'perturbed': 0,
'kron': 0}
}


num_evals = len(pred_graphs_filename)
if epoch_range is None:
epoch_range = [i * eval_every for i in range(num_evals)]
epoch_range = [i * eval_every for i in range(num_evals)]
for i in range(num_evals):
real_g_list = utils.load_graph_list(real_graph_filename)
#pred_g_list = utils.load_graph_list(pred_graphs_filename[i])
# pred_g_list = utils.load_graph_list(pred_graphs_filename[i])

# contains all predicted G
pred_g_list_raw = utils.load_graph_list(pred_graphs_filename[i])
if len(real_g_list)>200:
if len(real_g_list) > 200:
real_g_list = real_g_list[0:200]

shuffle(real_g_list)
@@ -448,10 +449,9 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines,
real_g_len_list = np.array([len(real_g_list[i]) for i in range(len(real_g_list))])
pred_g_len_list_raw = np.array([len(pred_g_list_raw[i]) for i in range(len(pred_g_list_raw))])
# get perturb real
#perturbed_g_list_001 = perturb(real_g_list, 0.01)
# perturbed_g_list_001 = perturb(real_g_list, 0.01)
perturbed_g_list_005 = perturb(real_g_list, 0.05)
#perturbed_g_list_010 = perturb(real_g_list, 0.10)

# perturbed_g_list_010 = perturb(real_g_list, 0.10)

# select pred samples
# The number of nodes are sampled from the similar distribution as the training set
@@ -482,22 +482,22 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines,
# ========================================
mid = len(real_g_list) // 2
dist_degree, dist_clustering = compute_basic_stats(real_g_list[:mid], real_g_list[mid:])
#dist_4cycle = eval.stats.motif_stats(real_g_list[:mid], real_g_list[mid:])
# dist_4cycle = eval.stats.motif_stats(real_g_list[:mid], real_g_list[mid:])
dist_4orbits = eval.stats.orbit_stats_all(real_g_list[:mid], real_g_list[mid:])
print('degree dist among real: ', dist_degree)
print('clustering dist among real: ', dist_clustering)
#print('4 cycle dist among real: ', dist_4cycle)
# print('4 cycle dist among real: ', dist_4cycle)
print('orbits dist among real: ', dist_4orbits)
results['deg']['real'] += dist_degree
results['clustering']['real'] += dist_clustering
results['orbits4']['real'] += dist_4orbits

dist_degree, dist_clustering = compute_basic_stats(real_g_list, pred_g_list)
#dist_4cycle = eval.stats.motif_stats(real_g_list, pred_g_list)
# dist_4cycle = eval.stats.motif_stats(real_g_list, pred_g_list)
dist_4orbits = eval.stats.orbit_stats_all(real_g_list, pred_g_list)
print('degree dist between real and pred at epoch ', epoch_range[i], ': ', dist_degree)
print('clustering dist between real and pred at epoch ', epoch_range[i], ': ', dist_clustering)
#print('4 cycle dist between real and pred at epoch: ', epoch_range[i], dist_4cycle)
# print('4 cycle dist between real and pred at epoch: ', epoch_range[i], dist_4cycle)
print('orbits dist between real and pred at epoch ', epoch_range[i], ': ', dist_4orbits)
results['deg']['ours'] = min(dist_degree, results['deg']['ours'])
results['clustering']['ours'] = min(dist_clustering, results['clustering']['ours'])
@@ -509,11 +509,11 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines,
out_files['train'].write(str(dist_4orbits) + ',')

dist_degree, dist_clustering = compute_basic_stats(real_g_list, perturbed_g_list_005)
#dist_4cycle = eval.stats.motif_stats(real_g_list, perturbed_g_list_005)
# dist_4cycle = eval.stats.motif_stats(real_g_list, perturbed_g_list_005)
dist_4orbits = eval.stats.orbit_stats_all(real_g_list, perturbed_g_list_005)
print('degree dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_degree)
print('clustering dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_clustering)
#print('4 cycle dist between real and perturbed at epoch: ', epoch_range[i], dist_4cycle)
# print('4 cycle dist between real and perturbed at epoch: ', epoch_range[i], dist_4cycle)
print('orbits dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_4orbits)
results['deg']['perturbed'] += dist_degree
results['clustering']['perturbed'] += dist_clustering
@@ -527,8 +527,8 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines,
results['deg'][baseline] = dist_degree
results['clustering'][baseline] = dist_clustering
results['orbits4'][baseline] = dist_4orbits
print('Kron: deg=', dist_degree, ', clustering=', dist_clustering,
', orbits4=', dist_4orbits)
print('Kron: deg=', dist_degree, ', clustering=', dist_clustering,
', orbits4=', dist_4orbits)

out_files['train'].write('\n')

@@ -538,10 +538,10 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines,

# Write results
for metric, methods in results.items():
line = metric+','+ \
str(methods['real'])+','+ \
str(methods['ours'])+','+ \
str(methods['perturbed'])
line = metric + ',' + \
str(methods['real']) + ',' + \
str(methods['ours']) + ',' + \
str(methods['perturbed'])
for baseline in baselines:
line += ',' + str(methods[baseline])
line += '\n'
@@ -553,12 +553,12 @@ def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines,


def eval_performance(datadir, prefix=None, args=None, eval_every=200, out_file_prefix=None,
sample_time = 2, baselines={}):
sample_time=2, baselines={}):
if args is None:
real_graphs_filename = [datadir + f for f in os.listdir(datadir)
if re.match(prefix + '.*real.*\.dat', f)]
if re.match(prefix + '.*real.*\.dat', f)]
pred_graphs_filename = [datadir + f for f in os.listdir(datadir)
if re.match(prefix + '.*pred.*\.dat', f)]
if re.match(prefix + '.*pred.*\.dat', f)]
eval_list(real_graphs_filename, pred_graphs_filename, prefix, 200)

else:
@@ -567,28 +567,30 @@ def eval_performance(datadir, prefix=None, args=None, eval_every=200, out_file_p
# str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(args.bptt_len) + '.dat' for epoch in range(0,50001,eval_every)]
# pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
# str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(args.bptt_len) + '.dat' for epoch in range(0,50001,eval_every)]
real_graph_filename = datadir+args.graph_save_path + args.fname_test + '0.dat'
real_graph_filename = datadir + args.graph_save_path + args.fname_test + '0.dat'
# for proposed model
end_epoch = 3001
epoch_range = range(eval_every, end_epoch, eval_every)
pred_graphs_filename = [datadir+args.graph_save_path + args.fname_pred+str(epoch)+'_'+str(sample_time)+'.dat'
for epoch in epoch_range]
pred_graphs_filename = [
datadir + args.graph_save_path + args.fname_pred + str(epoch) + '_' + str(sample_time) + '.dat'
for epoch in epoch_range]
# for baseline model
#pred_graphs_filename = [datadir+args.fname_baseline+'.dat']
# pred_graphs_filename = [datadir+args.fname_baseline+'.dat']

#real_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
# real_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
# str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(
# args.bptt_len) + '_' + str(args.gumbel) + '.dat' for epoch in range(10000, 50001, eval_every)]
#pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
# pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
# str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(
# args.bptt_len) + '_' + str(args.gumbel) + '.dat' for epoch in range(10000, 50001, eval_every)]

eval_list_fname(real_graph_filename, pred_graphs_filename, baselines,
epoch_range=epoch_range,
epoch_range=epoch_range,
eval_every=eval_every,
out_file_prefix=out_file_prefix)


def process_kron(kron_dir):
txt_files = []
for f in os.listdir(kron_dir):
@@ -602,7 +604,7 @@ def process_kron(kron_dir):
G_list.append(utils.snap_txt_output_to_nx(os.path.join(kron_dir, filename)))

return G_list

if __name__ == '__main__':
args = Args()
@@ -612,18 +614,18 @@ if __name__ == '__main__':
feature_parser = parser.add_mutually_exclusive_group(required=False)
feature_parser.add_argument('--export-real', dest='export', action='store_true')
feature_parser.add_argument('--no-export-real', dest='export', action='store_false')
feature_parser.add_argument('--kron-dir', dest='kron_dir',
help='Directory where graphs generated by kronecker method is stored.')
feature_parser.add_argument('--kron-dir', dest='kron_dir',
help='Directory where graphs generated by kronecker method is stored.')

parser.add_argument('--testfile', dest='test_file',
help='The file that stores list of graphs to be evaluated. Only used when 1 list of '
'graphs is to be evaluated.')
help='The file that stores list of graphs to be evaluated. Only used when 1 list of '
'graphs is to be evaluated.')
parser.add_argument('--dir-prefix', dest='dir_prefix',
help='The file that stores list of graphs to be evaluated. Can be used when evaluating multiple'
'models on multiple datasets.')
help='The file that stores list of graphs to be evaluated. Can be used when evaluating multiple'
'models on multiple datasets.')
parser.add_argument('--graph-type', dest='graph_type',
help='Type of graphs / dataset.')
help='Type of graphs / dataset.')
parser.set_defaults(export=False, kron_dir='', test_file='',
dir_prefix='',
graph_type=args.graph_type)
@@ -633,7 +635,6 @@ if __name__ == '__main__':
# dir_prefix = "/dfs/scratch0/jiaxuany0/"
dir_prefix = args.dir_input


time_now = strftime("%Y-%m-%d %H:%M:%S", gmtime())
if not os.path.isdir('logs/'):
os.makedirs('logs/')
@@ -652,16 +653,16 @@ if __name__ == '__main__':

if prog_args.graph_type == 'grid':
graphs = []
for i in range(10,20):
for j in range(10,20):
graphs.append(nx.grid_2d_graph(i,j))
for i in range(10, 20):
for j in range(10, 20):
graphs.append(nx.grid_2d_graph(i, j))
utils.export_graphs_to_txt(graphs, output_prefix)
elif prog_args.graph_type == 'caveman':
graphs = []
for i in range(2, 3):
for j in range(30, 81):
for k in range(10):
graphs.append(caveman_special(i,j, p_edge=0.3))
graphs.append(caveman_special(i, j, p_edge=0.3))
utils.export_graphs_to_txt(graphs, output_prefix)
elif prog_args.graph_type == 'citeseer':
graphs = utils.citeseer_ego()
@@ -679,14 +680,11 @@ if __name__ == '__main__':
elif not prog_args.test_file == '':
# evaluate single .dat file containing list of test graphs (networkx format)
graphs = utils.load_graph_list(prog_args.test_file)
eval_single_list(graphs, dir_input=dir_prefix+'graphs/', dataset_name='grid')
eval_single_list(graphs, dir_input=dir_prefix + 'graphs/', dataset_name='grid')
## if you don't try kronecker, only the following part is needed
else:
if not os.path.isdir(dir_prefix+'eval_results'):
os.makedirs(dir_prefix+'eval_results')
evaluation(args_evaluate,dir_input=dir_prefix+"graphs/", dir_output=dir_prefix+"eval_results/",
model_name_all=args_evaluate.model_name_all,dataset_name_all=args_evaluate.dataset_name_all,args=args,overwrite=True)




if not os.path.isdir(dir_prefix + 'eval_results'):
os.makedirs(dir_prefix + 'eval_results')
evaluation(args_evaluate, dir_input=dir_prefix + "graphs/", dir_output=dir_prefix + "eval_results/",
model_name_all=args_evaluate.model_name_all, dataset_name_all=args_evaluate.dataset_name_all,
args=args, overwrite=True)

+ 154
- 154
train.py View File

@@ -21,11 +21,9 @@ from tensorboard_logger import configure, log_value
import scipy.misc
import time as tm

from utils import *
from model import *
from tensorboard_logger import log_value
from data import *
from args import Args
import create_graphs


def train_vae_epoch(epoch, args, rnn, output, data_loader,
@@ -47,16 +45,16 @@ def train_vae_epoch(epoch, args, rnn, output, data_loader,
rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))

# sort input
y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True)
y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True)
y_len = y_len.numpy().tolist()
x = torch.index_select(x_unsorted,0,sort_index)
y = torch.index_select(y_unsorted,0,sort_index)
x = torch.index_select(x_unsorted, 0, sort_index)
y = torch.index_select(y_unsorted, 0, sort_index)
x = Variable(x).cuda()
y = Variable(y).cuda()

# if using ground truth to train
h = rnn(x, pack=True, input_len=y_len)
y_pred,z_mu,z_lsgms = output(h)
y_pred, z_mu, z_lsgms = output(h)
y_pred = F.sigmoid(y_pred)
# clean
y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True)
@@ -68,7 +66,7 @@ def train_vae_epoch(epoch, args, rnn, output, data_loader,
# use cross entropy loss
loss_bce = binary_cross_entropy_weight(y_pred, y)
loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp())
loss_kl /= y.size(0)*y.size(1)*sum(y_len) # normalize
loss_kl /= y.size(0) * y.size(1) * sum(y_len) # normalize
loss = loss_bce + loss_kl
loss.backward()
# update deterministic and lstm
@@ -77,7 +75,6 @@ def train_vae_epoch(epoch, args, rnn, output, data_loader,
scheduler_output.step()
scheduler_rnn.step()


z_mu_mean = torch.mean(z_mu.data)
z_sgm_mean = torch.mean(z_lsgms.mul(0.5).exp_().data)
z_mu_min = torch.min(z_mu.data)
@@ -85,35 +82,39 @@ def train_vae_epoch(epoch, args, rnn, output, data_loader,
z_mu_max = torch.max(z_mu.data)
z_sgm_max = torch.max(z_lsgms.mul(0.5).exp_().data)


if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics
print('Epoch: {}/{}, train bce loss: {:.6f}, train kl loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
epoch, args.epochs,loss_bce.data[0], loss_kl.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn))
print('z_mu_mean', z_mu_mean, 'z_mu_min', z_mu_min, 'z_mu_max', z_mu_max, 'z_sgm_mean', z_sgm_mean, 'z_sgm_min', z_sgm_min, 'z_sgm_max', z_sgm_max)
if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics
print(
'Epoch: {}/{}, train bce loss: {:.6f}, train kl loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
epoch, args.epochs, loss_bce.data[0], loss_kl.data[0], args.graph_type, args.num_layers,
args.hidden_size_rnn))
print('z_mu_mean', z_mu_mean, 'z_mu_min', z_mu_min, 'z_mu_max', z_mu_max, 'z_sgm_mean', z_sgm_mean,
'z_sgm_min', z_sgm_min, 'z_sgm_max', z_sgm_max)

# logging
log_value('bce_loss_'+args.fname, loss_bce.data[0], epoch*args.batch_ratio+batch_idx)
log_value('kl_loss_' +args.fname, loss_kl.data[0], epoch*args.batch_ratio + batch_idx)
log_value('z_mu_mean_'+args.fname, z_mu_mean, epoch*args.batch_ratio + batch_idx)
log_value('z_mu_min_'+args.fname, z_mu_min, epoch*args.batch_ratio + batch_idx)
log_value('z_mu_max_'+args.fname, z_mu_max, epoch*args.batch_ratio + batch_idx)
log_value('z_sgm_mean_'+args.fname, z_sgm_mean, epoch*args.batch_ratio + batch_idx)
log_value('z_sgm_min_'+args.fname, z_sgm_min, epoch*args.batch_ratio + batch_idx)
log_value('z_sgm_max_'+args.fname, z_sgm_max, epoch*args.batch_ratio + batch_idx)
log_value('bce_loss_' + args.fname, loss_bce.data[0], epoch * args.batch_ratio + batch_idx)
log_value('kl_loss_' + args.fname, loss_kl.data[0], epoch * args.batch_ratio + batch_idx)
log_value('z_mu_mean_' + args.fname, z_mu_mean, epoch * args.batch_ratio + batch_idx)
log_value('z_mu_min_' + args.fname, z_mu_min, epoch * args.batch_ratio + batch_idx)
log_value('z_mu_max_' + args.fname, z_mu_max, epoch * args.batch_ratio + batch_idx)
log_value('z_sgm_mean_' + args.fname, z_sgm_mean, epoch * args.batch_ratio + batch_idx)
log_value('z_sgm_min_' + args.fname, z_sgm_min, epoch * args.batch_ratio + batch_idx)
log_value('z_sgm_max_' + args.fname, z_sgm_max, epoch * args.batch_ratio + batch_idx)

loss_sum += loss.data[0]
return loss_sum/(batch_idx+1)
return loss_sum / (batch_idx + 1)

def test_vae_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False, sample_time = 1):

def test_vae_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False, sample_time=1):
rnn.hidden = rnn.init_hidden(test_batch_size)
rnn.eval()
output.eval()

# generate graphs
max_num_node = int(args.max_num_node)
y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda()
y_pred = Variable(
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda()
for i in range(max_num_node):
h = rnn(x_step)
y_pred_step, _, _ = output(h)
@@ -128,7 +129,7 @@ def test_vae_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=
G_pred_list = []
for i in range(test_batch_size):
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
G_pred_list.append(G_pred)

# save prediction histograms, plot histogram over each time step
@@ -137,11 +138,10 @@ def test_vae_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=
# fname_pred=args.figure_prediction_save_path+args.fname_pred+str(epoch)+'.jpg',
# max_num_node=max_num_node)


return G_pred_list


def test_vae_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1):
def test_vae_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False, sample_time=1):
rnn.eval()
output.eval()
G_pred_list = []
@@ -153,15 +153,18 @@ def test_vae_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram
rnn.hidden = rnn.init_hidden(test_batch_size)
# generate graphs
max_num_node = int(args.max_num_node)
y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda()
y_pred = Variable(
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
y_pred_long = Variable(
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda()
for i in range(max_num_node):
print('finish node',i)
print('finish node', i)
h = rnn(x_step)
y_pred_step, _, _ = output(h)
y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
x_step = sample_sigmoid_supervised(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time)
x_step = sample_sigmoid_supervised(y_pred_step, y[:, i:i + 1, :].cuda(), current=i, y_len=y_len,
sample_time=sample_time)

y_pred_long[:, i:i + 1, :] = x_step
rnn.hidden = Variable(rnn.hidden.data).cuda()
@@ -171,12 +174,11 @@ def test_vae_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram
# save graphs as pickle
for i in range(test_batch_size):
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
G_pred_list.append(G_pred)
return G_pred_list



def train_mlp_epoch(epoch, args, rnn, output, data_loader,
optimizer_rnn, optimizer_output,
scheduler_rnn, scheduler_output):
@@ -196,10 +198,10 @@ def train_mlp_epoch(epoch, args, rnn, output, data_loader,
rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))

# sort input
y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True)
y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True)
y_len = y_len.numpy().tolist()
x = torch.index_select(x_unsorted,0,sort_index)
y = torch.index_select(y_unsorted,0,sort_index)
x = torch.index_select(x_unsorted, 0, sort_index)
y = torch.index_select(y_unsorted, 0, sort_index)
x = Variable(x).cuda()
y = Variable(y).cuda()

@@ -218,28 +220,28 @@ def train_mlp_epoch(epoch, args, rnn, output, data_loader,
scheduler_output.step()
scheduler_rnn.step()


if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics
if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics
print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn))
epoch, args.epochs, loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn))

# logging
log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx)
log_value('loss_' + args.fname, loss.data[0], epoch * args.batch_ratio + batch_idx)

loss_sum += loss.data[0]
return loss_sum/(batch_idx+1)
return loss_sum / (batch_idx + 1)


def test_mlp_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False,sample_time=1):
def test_mlp_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False, sample_time=1):
rnn.hidden = rnn.init_hidden(test_batch_size)
rnn.eval()
output.eval()

# generate graphs
max_num_node = int(args.max_num_node)
y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda()
y_pred = Variable(
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda()
for i in range(max_num_node):
h = rnn(x_step)
y_pred_step = output(h)
@@ -254,10 +256,9 @@ def test_mlp_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=
G_pred_list = []
for i in range(test_batch_size):
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
G_pred_list.append(G_pred)


# # save prediction histograms, plot histogram over each time step
# if save_histogram:
# save_prediction_histogram(y_pred_data.cpu().numpy(),
@@ -266,8 +267,7 @@ def test_mlp_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=
return G_pred_list



def test_mlp_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1):
def test_mlp_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False, sample_time=1):
rnn.eval()
output.eval()
G_pred_list = []
@@ -279,15 +279,18 @@ def test_mlp_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram
rnn.hidden = rnn.init_hidden(test_batch_size)
# generate graphs
max_num_node = int(args.max_num_node)
y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda()
y_pred = Variable(
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
y_pred_long = Variable(
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda()
for i in range(max_num_node):
print('finish node',i)
print('finish node', i)
h = rnn(x_step)
y_pred_step = output(h)
y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
x_step = sample_sigmoid_supervised(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time)
x_step = sample_sigmoid_supervised(y_pred_step, y[:, i:i + 1, :].cuda(), current=i, y_len=y_len,
sample_time=sample_time)

y_pred_long[:, i:i + 1, :] = x_step
rnn.hidden = Variable(rnn.hidden.data).cuda()
@@ -297,12 +300,12 @@ def test_mlp_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram
# save graphs as pickle
for i in range(test_batch_size):
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
G_pred_list.append(G_pred)
return G_pred_list


def test_mlp_partial_simple_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1):
def test_mlp_partial_simple_epoch(epoch, args, rnn, output, data_loader, save_histogram=False, sample_time=1):
rnn.eval()
output.eval()
G_pred_list = []
@@ -314,15 +317,18 @@ def test_mlp_partial_simple_epoch(epoch, args, rnn, output, data_loader, save_hi
rnn.hidden = rnn.init_hidden(test_batch_size)
# generate graphs
max_num_node = int(args.max_num_node)
y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda()
y_pred = Variable(
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score
y_pred_long = Variable(
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda()
for i in range(max_num_node):
print('finish node',i)
print('finish node', i)
h = rnn(x_step)
y_pred_step = output(h)
y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step)
x_step = sample_sigmoid_supervised_simple(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time)
x_step = sample_sigmoid_supervised_simple(y_pred_step, y[:, i:i + 1, :].cuda(), current=i, y_len=y_len,
sample_time=sample_time)

y_pred_long[:, i:i + 1, :] = x_step
rnn.hidden = Variable(rnn.hidden.data).cuda()
@@ -332,7 +338,7 @@ def test_mlp_partial_simple_epoch(epoch, args, rnn, output, data_loader, save_hi
# save graphs as pickle
for i in range(test_batch_size):
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
G_pred_list.append(G_pred)
return G_pred_list

@@ -354,10 +360,10 @@ def train_mlp_forward_epoch(epoch, args, rnn, output, data_loader):
rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0))

# sort input
y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True)
y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True)
y_len = y_len.numpy().tolist()
x = torch.index_select(x_unsorted,0,sort_index)
y = torch.index_select(y_unsorted,0,sort_index)
x = torch.index_select(x_unsorted, 0, sort_index)
y = torch.index_select(y_unsorted, 0, sort_index)
x = Variable(x).cuda()
y = Variable(y).cuda()

@@ -372,22 +378,18 @@ def train_mlp_forward_epoch(epoch, args, rnn, output, data_loader):
loss = 0
for j in range(y.size(1)):
# print('y_pred',y_pred[0,j,:],'y',y[0,j,:])
end_idx = min(j+1,y.size(2))
loss += binary_cross_entropy_weight(y_pred[:,j,0:end_idx], y[:,j,0:end_idx])*end_idx
end_idx = min(j + 1, y.size(2))
loss += binary_cross_entropy_weight(y_pred[:, j, 0:end_idx], y[:, j, 0:end_idx]) * end_idx


if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics
if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics
print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn))
epoch, args.epochs, loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn))

# logging
log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx)
log_value('loss_' + args.fname, loss.data[0], epoch * args.batch_ratio + batch_idx)

loss_sum += loss.data[0]
return loss_sum/(batch_idx+1)



return loss_sum / (batch_idx + 1)


## too complicated, deprecated
@@ -449,28 +451,29 @@ def train_rnn_epoch(epoch, args, rnn, output, data_loader,
# output.hidden = output.init_hidden(batch_size=x_unsorted.size(0)*x_unsorted.size(1))

# sort input
y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True)
y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True)
y_len = y_len.numpy().tolist()
x = torch.index_select(x_unsorted,0,sort_index)
y = torch.index_select(y_unsorted,0,sort_index)
x = torch.index_select(x_unsorted, 0, sort_index)
y = torch.index_select(y_unsorted, 0, sort_index)

# input, output for output rnn module
# a smart use of pytorch builtin function: pack variable--b1_l1,b2_l1,...,b1_l2,b2_l2,...
y_reshape = pack_padded_sequence(y,y_len,batch_first=True).data
y_reshape = pack_padded_sequence(y, y_len, batch_first=True).data
# reverse y_reshape, so that their lengths are sorted, add dimension
idx = [i for i in range(y_reshape.size(0)-1, -1, -1)]
idx = [i for i in range(y_reshape.size(0) - 1, -1, -1)]
idx = torch.LongTensor(idx)
y_reshape = y_reshape.index_select(0, idx)
y_reshape = y_reshape.view(y_reshape.size(0),y_reshape.size(1),1)
y_reshape = y_reshape.view(y_reshape.size(0), y_reshape.size(1), 1)

output_x = torch.cat((torch.ones(y_reshape.size(0),1,1),y_reshape[:,0:-1,0:1]),dim=1)
output_x = torch.cat((torch.ones(y_reshape.size(0), 1, 1), y_reshape[:, 0:-1, 0:1]), dim=1)
output_y = y_reshape
# batch size for output module: sum(y_len)
output_y_len = []
output_y_len_bin = np.bincount(np.array(y_len))
for i in range(len(output_y_len_bin)-1,0,-1):
count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i
output_y_len.extend([min(i,y.size(2))]*count_temp) # put them in output_y_len; max value should not exceed y.size(2)
for i in range(len(output_y_len_bin) - 1, 0, -1):
count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i
output_y_len.extend(
[min(i, y.size(2))] * count_temp) # put them in output_y_len; max value should not exceed y.size(2)
# pack into variable
x = Variable(x).cuda()
y = Variable(y).cuda()
@@ -481,23 +484,23 @@ def train_rnn_epoch(epoch, args, rnn, output, data_loader,
# print('y',y.size())
# print('output_y',output_y.size())


# if using ground truth to train
h = rnn(x, pack=True, input_len=y_len)
h = pack_padded_sequence(h,y_len,batch_first=True).data # get packed hidden vector
h = pack_padded_sequence(h, y_len, batch_first=True).data # get packed hidden vector
# reverse h
idx = [i for i in range(h.size(0) - 1, -1, -1)]
idx = Variable(torch.LongTensor(idx)).cuda()
h = h.index_select(0, idx)
hidden_null = Variable(torch.zeros(args.num_layers-1, h.size(0), h.size(1))).cuda()
output.hidden = torch.cat((h.view(1,h.size(0),h.size(1)),hidden_null),dim=0) # num_layers, batch_size, hidden_size
hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(1))).cuda()
output.hidden = torch.cat((h.view(1, h.size(0), h.size(1)), hidden_null),
dim=0) # num_layers, batch_size, hidden_size
y_pred = output(output_x, pack=True, input_len=output_y_len)
y_pred = F.sigmoid(y_pred)
# clean
y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True)
y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
output_y = pack_padded_sequence(output_y,output_y_len,batch_first=True)
output_y = pad_packed_sequence(output_y,batch_first=True)[0]
output_y = pack_padded_sequence(output_y, output_y_len, batch_first=True)
output_y = pad_packed_sequence(output_y, batch_first=True)[0]
# use cross entropy loss
loss = binary_cross_entropy_weight(y_pred, output_y)
loss.backward()
@@ -507,17 +510,15 @@ def train_rnn_epoch(epoch, args, rnn, output, data_loader,
scheduler_output.step()
scheduler_rnn.step()


if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics
if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics
print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn))
epoch, args.epochs, loss.data, args.graph_type, args.num_layers, args.hidden_size_rnn))

# logging
log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx)
feature_dim = y.size(1)*y.size(2)
loss_sum += loss.data[0]*feature_dim
return loss_sum/(batch_idx+1)

log_value('loss_' + args.fname, loss.data, epoch * args.batch_ratio + batch_idx)
feature_dim = y.size(1) * y.size(2)
loss_sum += loss.data * feature_dim
return loss_sum / (batch_idx + 1)


def test_rnn_epoch(epoch, args, rnn, output, test_batch_size=16):
@@ -527,20 +528,20 @@ def test_rnn_epoch(epoch, args, rnn, output, test_batch_size=16):

# generate graphs
max_num_node = int(args.max_num_node)
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda()
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction
x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda()
for i in range(max_num_node):
h = rnn(x_step)
# output.hidden = h.permute(1,0,2)
hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(2))).cuda()
output.hidden = torch.cat((h.permute(1,0,2), hidden_null),
output.hidden = torch.cat((h.permute(1, 0, 2), hidden_null),
dim=0) # num_layers, batch_size, hidden_size
x_step = Variable(torch.zeros(test_batch_size,1,args.max_prev_node)).cuda()
output_x_step = Variable(torch.ones(test_batch_size,1,1)).cuda()
for j in range(min(args.max_prev_node,i+1)):
x_step = Variable(torch.zeros(test_batch_size, 1, args.max_prev_node)).cuda()
output_x_step = Variable(torch.ones(test_batch_size, 1, 1)).cuda()
for j in range(min(args.max_prev_node, i + 1)):
output_y_pred_step = output(output_x_step)
output_x_step = sample_sigmoid(output_y_pred_step, sample=True, sample_time=1)
x_step[:,:,j:j+1] = output_x_step
x_step[:, :, j:j + 1] = output_x_step
output.hidden = Variable(output.hidden.data).cuda()
y_pred_long[:, i:i + 1, :] = x_step
rnn.hidden = Variable(rnn.hidden.data).cuda()
@@ -550,14 +551,12 @@ def test_rnn_epoch(epoch, args, rnn, output, test_batch_size=16):
G_pred_list = []
for i in range(test_batch_size):
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
G_pred_list.append(G_pred)

return G_pred_list




def train_rnn_forward_epoch(epoch, args, rnn, output, data_loader):
rnn.train()
output.train()
@@ -576,28 +575,29 @@ def train_rnn_forward_epoch(epoch, args, rnn, output, data_loader):
# output.hidden = output.init_hidden(batch_size=x_unsorted.size(0)*x_unsorted.size(1))

# sort input
y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True)
y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True)
y_len = y_len.numpy().tolist()
x = torch.index_select(x_unsorted,0,sort_index)
y = torch.index_select(y_unsorted,0,sort_index)
x = torch.index_select(x_unsorted, 0, sort_index)
y = torch.index_select(y_unsorted, 0, sort_index)

# input, output for output rnn module
# a smart use of pytorch builtin function: pack variable--b1_l1,b2_l1,...,b1_l2,b2_l2,...
y_reshape = pack_padded_sequence(y,y_len,batch_first=True).data
y_reshape = pack_padded_sequence(y, y_len, batch_first=True).data
# reverse y_reshape, so that their lengths are sorted, add dimension
idx = [i for i in range(y_reshape.size(0)-1, -1, -1)]
idx = [i for i in range(y_reshape.size(0) - 1, -1, -1)]
idx = torch.LongTensor(idx)
y_reshape = y_reshape.index_select(0, idx)
y_reshape = y_reshape.view(y_reshape.size(0),y_reshape.size(1),1)
y_reshape = y_reshape.view(y_reshape.size(0), y_reshape.size(1), 1)

output_x = torch.cat((torch.ones(y_reshape.size(0),1,1),y_reshape[:,0:-1,0:1]),dim=1)
output_x = torch.cat((torch.ones(y_reshape.size(0), 1, 1), y_reshape[:, 0:-1, 0:1]), dim=1)
output_y = y_reshape
# batch size for output module: sum(y_len)
output_y_len = []
output_y_len_bin = np.bincount(np.array(y_len))
for i in range(len(output_y_len_bin)-1,0,-1):
count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i
output_y_len.extend([min(i,y.size(2))]*count_temp) # put them in output_y_len; max value should not exceed y.size(2)
for i in range(len(output_y_len_bin) - 1, 0, -1):
count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i
output_y_len.extend(
[min(i, y.size(2))] * count_temp) # put them in output_y_len; max value should not exceed y.size(2)
# pack into variable
x = Variable(x).cuda()
y = Variable(y).cuda()
@@ -608,37 +608,36 @@ def train_rnn_forward_epoch(epoch, args, rnn, output, data_loader):
# print('y',y.size())
# print('output_y',output_y.size())


# if using ground truth to train
h = rnn(x, pack=True, input_len=y_len)
h = pack_padded_sequence(h,y_len,batch_first=True).data # get packed hidden vector
h = pack_padded_sequence(h, y_len, batch_first=True).data # get packed hidden vector
# reverse h
idx = [i for i in range(h.size(0) - 1, -1, -1)]
idx = Variable(torch.LongTensor(idx)).cuda()
h = h.index_select(0, idx)
hidden_null = Variable(torch.zeros(args.num_layers-1, h.size(0), h.size(1))).cuda()
output.hidden = torch.cat((h.view(1,h.size(0),h.size(1)),hidden_null),dim=0) # num_layers, batch_size, hidden_size
hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(1))).cuda()
output.hidden = torch.cat((h.view(1, h.size(0), h.size(1)), hidden_null),
dim=0) # num_layers, batch_size, hidden_size
y_pred = output(output_x, pack=True, input_len=output_y_len)
y_pred = F.sigmoid(y_pred)
# clean
y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True)
y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
output_y = pack_padded_sequence(output_y,output_y_len,batch_first=True)
output_y = pad_packed_sequence(output_y,batch_first=True)[0]
output_y = pack_padded_sequence(output_y, output_y_len, batch_first=True)
output_y = pad_packed_sequence(output_y, batch_first=True)[0]
# use cross entropy loss
loss = binary_cross_entropy_weight(y_pred, output_y)


if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics
if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics
print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn))
epoch, args.epochs, loss.data, args.graph_type, args.num_layers, args.hidden_size_rnn))

# logging
log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx)
log_value('loss_' + args.fname, loss.data, epoch * args.batch_ratio + batch_idx)
# print(y_pred.size())
feature_dim = y_pred.size(0)*y_pred.size(1)
loss_sum += loss.data[0]*feature_dim/y.size(0)
return loss_sum/(batch_idx+1)
feature_dim = y_pred.size(0) * y_pred.size(1)
loss_sum += loss.data * feature_dim / y.size(0)
return loss_sum / (batch_idx + 1)


########### train function for LSTM + VAE
@@ -665,7 +664,7 @@ def train(args, dataset_train, rnn, output):

# start main loop
time_all = np.zeros(args.epochs)
while epoch<=args.epochs:
while epoch <= args.epochs:
time_start = tm.time()
# train
if 'GraphRNN_VAE' in args.note:
@@ -683,25 +682,26 @@ def train(args, dataset_train, rnn, output):
time_end = tm.time()
time_all[epoch - 1] = time_end - time_start
# test
if epoch % args.epochs_test == 0 and epoch>=args.epochs_test_start:
for sample_time in range(1,4):
if epoch % args.epochs_test == 0 and epoch >= args.epochs_test_start:
for sample_time in range(1, 4):
G_pred = []
while len(G_pred)<args.test_total_size:
while len(G_pred) < args.test_total_size:
if 'GraphRNN_VAE' in args.note:
G_pred_step = test_vae_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size,sample_time=sample_time)
G_pred_step = test_vae_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size,
sample_time=sample_time)
elif 'GraphRNN_MLP' in args.note:
G_pred_step = test_mlp_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size,sample_time=sample_time)
G_pred_step = test_mlp_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size,
sample_time=sample_time)
elif 'GraphRNN_RNN' in args.note:
G_pred_step = test_rnn_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size)
G_pred.extend(G_pred_step)
# save graphs
fname = args.graph_save_path + args.fname_pred + str(epoch) +'_'+str(sample_time) + '.dat'
fname = args.graph_save_path + args.fname_pred + str(epoch) + '_' + str(sample_time) + '.dat'
save_graph_list(G_pred, fname)
if 'GraphRNN_RNN' in args.note:
break
print('test done, graphs saved')


# save model checkpoint
if args.save:
if epoch % args.epochs_save == 0:
@@ -710,7 +710,7 @@ def train(args, dataset_train, rnn, output):
fname = args.model_save_path + args.fname + 'output_' + str(epoch) + '.dat'
torch.save(output.state_dict(), fname)
epoch += 1
np.save(args.timing_save_path+args.fname,time_all)
np.save(args.timing_save_path + args.fname, time_all)


########### for graph completion task
@@ -723,19 +723,19 @@ def train_graph_completion(args, dataset_test, rnn, output):
epoch = args.load_epoch
print('model loaded!, epoch: {}'.format(args.load_epoch))

for sample_time in range(1,4):
for sample_time in range(1, 4):
if 'GraphRNN_MLP' in args.note:
G_pred = test_mlp_partial_simple_epoch(epoch, args, rnn, output, dataset_test,sample_time=sample_time)
G_pred = test_mlp_partial_simple_epoch(epoch, args, rnn, output, dataset_test, sample_time=sample_time)
if 'GraphRNN_VAE' in args.note:
G_pred = test_vae_partial_epoch(epoch, args, rnn, output, dataset_test,sample_time=sample_time)
G_pred = test_vae_partial_epoch(epoch, args, rnn, output, dataset_test, sample_time=sample_time)
# save graphs
fname = args.graph_save_path + args.fname_pred + str(epoch) +'_'+str(sample_time) + 'graph_completion.dat'
fname = args.graph_save_path + args.fname_pred + str(epoch) + '_' + str(sample_time) + 'graph_completion.dat'
save_graph_list(G_pred, fname)
print('graph completion done, graphs saved')


########### for NLL evaluation
def train_nll(args, dataset_train, dataset_test, rnn, output,graph_validate_len,graph_test_len, max_iter = 1000):
def train_nll(args, dataset_train, dataset_test, rnn, output, graph_validate_len, graph_test_len, max_iter=1000):
fname = args.model_save_path + args.fname + 'lstm_' + str(args.load_epoch) + '.dat'
rnn.load_state_dict(torch.load(fname))
fname = args.model_save_path + args.fname + 'output_' + str(args.load_epoch) + '.dat'
@@ -745,7 +745,7 @@ def train_nll(args, dataset_train, dataset_test, rnn, output,graph_validate_len,
print('model loaded!, epoch: {}'.format(args.load_epoch))
fname_output = args.nll_save_path + args.note + '_' + args.graph_type + '.csv'
with open(fname_output, 'w+') as f:
f.write(str(graph_validate_len)+','+str(graph_test_len)+'\n')
f.write(str(graph_validate_len) + ',' + str(graph_test_len) + '\n')
f.write('train,test\n')
for iter in range(max_iter):
if 'GraphRNN_MLP' in args.note:
@@ -754,7 +754,7 @@ def train_nll(args, dataset_train, dataset_test, rnn, output,graph_validate_len,
if 'GraphRNN_RNN' in args.note:
nll_train = train_rnn_forward_epoch(epoch, args, rnn, output, dataset_train)
nll_test = train_rnn_forward_epoch(epoch, args, rnn, output, dataset_test)
print('train',nll_train,'test',nll_test)
f.write(str(nll_train)+','+str(nll_test)+'\n')
print('train', nll_train, 'test', nll_test)
f.write(str(nll_train) + ',' + str(nll_test) + '\n')

print('NLL evaluation done')

+ 52
- 50
utils.py View File

@@ -16,6 +16,7 @@ import re

import data


def citeseer_ego():
_, _, G = data.Graph_load(dataset='citeseer')
G = max(nx.connected_component_subgraphs(G), key=len)
@@ -27,12 +28,13 @@ def citeseer_ego():
graphs.append(G_ego)
return graphs

def caveman_special(c=2,k=20,p_path=0.1,p_edge=0.3):

def caveman_special(c=2, k=20, p_path=0.1, p_edge=0.3):
p = p_path
path_count = max(int(np.ceil(p * k)),1)
path_count = max(int(np.ceil(p * k)), 1)
G = nx.caveman_graph(c, k)
# remove 50% edges
p = 1-p_edge
p = 1 - p_edge
for (u, v) in list(G.edges()):
if np.random.rand() < p and ((u < k and v < k) or (u >= k and v >= k)):
G.remove_edge(u, v)
@@ -44,6 +46,7 @@ def caveman_special(c=2,k=20,p_path=0.1,p_edge=0.3):
G = max(nx.connected_component_subgraphs(G), key=len)
return G


def n_community(c_sizes, p_inter=0.01):
graphs = [nx.gnp_random_graph(c_sizes[i], 0.7, seed=i) for i in range(len(c_sizes))]
G = nx.disjoint_union_all(graphs)
@@ -51,7 +54,7 @@ def n_community(c_sizes, p_inter=0.01):
for i in range(len(communities)):
subG1 = communities[i]
nodes1 = list(subG1.nodes())
for j in range(i+1, len(communities)):
for j in range(i + 1, len(communities)):
subG2 = communities[j]
nodes2 = list(subG2.nodes())
has_inter_edge = False
@@ -62,9 +65,10 @@ def n_community(c_sizes, p_inter=0.01):
has_inter_edge = True
if not has_inter_edge:
G.add_edge(nodes1[0], nodes2[0])
#print('connected comp: ', len(list(nx.connected_component_subgraphs(G))))
# print('connected comp: ', len(list(nx.connected_component_subgraphs(G))))
return G


def perturb(graph_list, p_del, p_add=None):
''' Perturb the list of graphs by adding/removing edges.
Args:
@@ -87,7 +91,7 @@ def perturb(graph_list, p_del, p_add=None):
if p_add is None:
num_nodes = G.number_of_nodes()
p_add_est = np.sum(trials) / (num_nodes * (num_nodes - 1) / 2 -
G.number_of_edges())
G.number_of_edges())
else:
p_add_est = p_add

@@ -97,7 +101,7 @@ def perturb(graph_list, p_del, p_add=None):
u = nodes[i]
trials = np.random.binomial(1, p_add_est, size=G.number_of_nodes())
j = 0
for j in range(i+1, len(nodes)):
for j in range(i + 1, len(nodes)):
v = nodes[j]
if trials[j] == 1:
tmp += 1
@@ -108,7 +112,6 @@ def perturb(graph_list, p_del, p_add=None):
return perturbed_graph_list



def perturb_new(graph_list, p):
''' Perturb the list of graphs by adding/removing edges.
Args:
@@ -123,7 +126,7 @@ def perturb_new(graph_list, p):
G = G_original.copy()
edge_remove_count = 0
for (u, v) in list(G.edges()):
if np.random.rand()<p:
if np.random.rand() < p:
G.remove_edge(u, v)
edge_remove_count += 1
# randomly add the edges back
@@ -131,16 +134,13 @@ def perturb_new(graph_list, p):
while True:
u = np.random.randint(0, G.number_of_nodes())
v = np.random.randint(0, G.number_of_nodes())
if (not G.has_edge(u,v)) and (u!=v):
if (not G.has_edge(u, v)) and (u != v):
break
G.add_edge(u, v)
perturbed_graph_list.append(G)
return perturbed_graph_list





def imsave(fname, arr, vmin=None, vmax=None, cmap=None, format=None, origin=None):
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
@@ -162,7 +162,7 @@ def save_prediction_histogram(y_pred_data, fname_pred, max_num_node, bin_n=20):


# draw a single graph G
def draw_graph(G, prefix = 'test'):
def draw_graph(G, prefix='test'):
parts = community.best_partition(G)
values = [parts.get(node) for node in G.nodes()]
colors = []
@@ -187,8 +187,7 @@ def draw_graph(G, prefix = 'test'):
plt.axis("off")

pos = nx.spring_layout(G)
nx.draw_networkx(G, with_labels=True, node_size=35, node_color=colors,pos=pos)

nx.draw_networkx(G, with_labels=True, node_size=35, node_color=colors, pos=pos)

# plt.switch_backend('agg')
# options = {
@@ -199,14 +198,14 @@ def draw_graph(G, prefix = 'test'):
# plt.figure()
# plt.subplot()
# nx.draw_networkx(G, **options)
plt.savefig('figures/graph_view_'+prefix+'.png', dpi=200)
plt.savefig('figures/graph_view_' + prefix + '.png', dpi=200)
plt.close()

plt.switch_backend('agg')
G_deg = nx.degree_histogram(G)
G_deg = np.array(G_deg)
# plt.plot(range(len(G_deg)), G_deg, 'r', linewidth = 2)
plt.loglog(np.arange(len(G_deg))[G_deg>0], G_deg[G_deg>0], 'r', linewidth=2)
plt.loglog(np.arange(len(G_deg))[G_deg > 0], G_deg[G_deg > 0], 'r', linewidth=2)
plt.savefig('figures/degree_view_' + prefix + '.png', dpi=200)
plt.close()

@@ -225,15 +224,16 @@ def draw_graph(G, prefix = 'test'):


# draw a list of graphs [G]
def draw_graph_list(G_list, row, col, fname = 'figures/test', layout='spring', is_single=False,k=1,node_size=55,alpha=1,width=1.3):
def draw_graph_list(G_list, row, col, fname='figures/test', layout='spring', is_single=False, k=1, node_size=55,
alpha=1, width=1.3):
# # draw graph view
# from pylab import rcParams
# rcParams['figure.figsize'] = 12,3
plt.switch_backend('agg')
for i,G in enumerate(G_list):
plt.subplot(row,col,i+1)
for i, G in enumerate(G_list):
plt.subplot(row, col, i + 1)
plt.subplots_adjust(left=0, bottom=0, right=1, top=1,
wspace=0, hspace=0)
wspace=0, hspace=0)
# if i%2==0:
# plt.title('real nodes: '+str(G.number_of_nodes()), fontsize = 4)
# else:
@@ -260,28 +260,29 @@ def draw_graph_list(G_list, row, col, fname = 'figures/test', layout='spring', i
# if values[i] == 6:
# colors.append('black')
plt.axis("off")
if layout=='spring':
pos = nx.spring_layout(G,k=k/np.sqrt(G.number_of_nodes()),iterations=100)
if layout == 'spring':
pos = nx.spring_layout(G, k=k / np.sqrt(G.number_of_nodes()), iterations=100)
# pos = nx.spring_layout(G)

elif layout=='spectral':
elif layout == 'spectral':
pos = nx.spectral_layout(G)
# # nx.draw_networkx(G, with_labels=True, node_size=2, width=0.15, font_size = 1.5, node_color=colors,pos=pos)
# nx.draw_networkx(G, with_labels=False, node_size=1.5, width=0.2, font_size = 1.5, linewidths=0.2, node_color = 'k',pos=pos,alpha=0.2)

if is_single:
# node_size default 60, edge_width default 1.5
nx.draw_networkx_nodes(G, pos, node_size=node_size, node_color='#336699', alpha=1, linewidths=0, font_size=0)
nx.draw_networkx_nodes(G, pos, node_size=node_size, node_color='#336699', alpha=1, linewidths=0,
font_size=0)
nx.draw_networkx_edges(G, pos, alpha=alpha, width=width)
else:
nx.draw_networkx_nodes(G, pos, node_size=1.5, node_color='#336699',alpha=1, linewidths=0.2, font_size = 1.5)
nx.draw_networkx_edges(G, pos, alpha=0.3,width=0.2)
nx.draw_networkx_nodes(G, pos, node_size=1.5, node_color='#336699', alpha=1, linewidths=0.2, font_size=1.5)
nx.draw_networkx_edges(G, pos, alpha=0.3, width=0.2)

# plt.axis('off')
# plt.title('Complete Graph of Odd-degree Nodes')
# plt.show()
plt.tight_layout()
plt.savefig(fname+'.png', dpi=600)
plt.savefig(fname + '.png', dpi=600)
plt.close()

# # draw degree distribution
@@ -376,8 +377,6 @@ def draw_graph_list(G_list, row, col, fname = 'figures/test', layout='spring', i
# plt.savefig(fname+'_community.png', dpi=600)
# plt.close()



# plt.switch_backend('agg')
# G_deg = nx.degree_histogram(G)
# G_deg = np.array(G_deg)
@@ -395,7 +394,6 @@ def draw_graph_list(G_list, row, col, fname = 'figures/test', layout='spring', i
# plt.close()



# directly get graph statistics from adj, obsoleted
def decode_graph(adj, prefix):
adj = np.asmatrix(adj)
@@ -433,6 +431,7 @@ def get_graph(adj):
G = nx.from_numpy_matrix(adj)
return G


# save a list of graphs
def save_graph_list(G_list, fname):
with open(fname, "wb") as f:
@@ -441,28 +440,30 @@ def save_graph_list(G_list, fname):

# pick the first connected component
def pick_connected_component(G):
node_list = nx.node_connected_component(G,0)
node_list = nx.node_connected_component(G, 0)
return G.subgraph(node_list)


def pick_connected_component_new(G):
adj_list = G.adjacency_list()
for id,adj in enumerate(adj_list):
for id, adj in enumerate(adj_list):
id_min = min(adj)
if id<id_min and id>=1:
# if id<id_min and id>=4:
if id < id_min and id >= 1:
# if id<id_min and id>=4:
break
node_list = list(range(id)) # only include node prior than node "id"
node_list = list(range(id)) # only include node prior than node "id"
G = G.subgraph(node_list)
G = max(nx.connected_component_subgraphs(G), key=len)
return G


# load a list of graphs
def load_graph_list(fname,is_real=True):
def load_graph_list(fname, is_real=True):
with open(fname, "rb") as f:
graph_list = pickle.load(f)
for i in range(len(graph_list)):
edges_with_selfloops = graph_list[i].selfloop_edges()
if len(edges_with_selfloops)>0:
if len(edges_with_selfloops) > 0:
graph_list[i].remove_edges_from(edges_with_selfloops)
if is_real:
graph_list[i] = max(nx.connected_component_subgraphs(graph_list[i]), key=len)
@@ -482,6 +483,7 @@ def export_graphs_to_txt(g_list, output_filename_prefix):
f.write(str(idx_u) + '\t' + str(idx_v) + '\n')
i += 1


def snap_txt_output_to_nx(in_fname):
G = nx.Graph()
with open(in_fname, 'r') as f:
@@ -496,23 +498,23 @@ def snap_txt_output_to_nx(in_fname):
G.add_edge(int(u), int(v))
return G


def test_perturbed():
graphs = []
for i in range(100,101):
for j in range(4,5):
for i in range(100, 101):
for j in range(4, 5):
for k in range(500):
graphs.append(nx.barabasi_albert_graph(i,j))
graphs.append(nx.barabasi_albert_graph(i, j))
g_perturbed = perturb(graphs, 0.9)
print([g.number_of_edges() for g in graphs])
print([g.number_of_edges() for g in g_perturbed])


if __name__ == '__main__':
#test_perturbed()
#graphs = load_graph_list('graphs/' + 'GraphRNN_RNN_community4_4_128_train_0.dat')
#graphs = load_graph_list('graphs/' + 'GraphRNN_RNN_community4_4_128_pred_2500_1.dat')
graphs = load_graph_list('eval_results/mmsb/' + 'community41.dat')
for i in range(0, 160, 16):
draw_graph_list(graphs[i:i+16], 4, 4, fname='figures/community4_' + str(i))
# test_perturbed()
graphs = load_graph_list('graphs/' + 'GraphRNN_RNN_grid_4_128_train_0.dat')
# graphs = load_graph_list('graphs/' + 'GraphRNN_RNN_community4_4_128_pred_2500_1.dat')
# graphs = load_graph_list('eval_results/mmsb/' + 'community41.dat')

for i in range(0, 160, 16):
draw_graph_list(graphs[i:i + 16], 4, 4, fname='figures/community4_' + str(i))

Loading…
Cancel
Save