### if clean tensorboard | ### if clean tensorboard | ||||
self.clean_tensorboard = False | self.clean_tensorboard = False | ||||
### Which CUDA GPU device is used for training | ### Which CUDA GPU device is used for training | ||||
self.cuda = 1 | |||||
self.cuda = 0 | |||||
### Which GraphRNN model variant is used. | ### Which GraphRNN model variant is used. | ||||
# The simple version of Graph RNN | # The simple version of Graph RNN | ||||
self.load = False # if load model, default lr is very low | self.load = False # if load model, default lr is very low | ||||
self.load_epoch = 3000 | |||||
self.load_epoch = 100 | |||||
self.save = True | self.save = True | ||||
import argparse | import argparse | ||||
import numpy as np | |||||
import logging | |||||
import os | import os | ||||
import re | |||||
from random import shuffle | |||||
from time import strftime, gmtime | |||||
import eval.stats | import eval.stats | ||||
import utils | import utils | ||||
# import main.Args | # import main.Args | ||||
from args import Args | |||||
from baselines.baseline_simple import * | from baselines.baseline_simple import * | ||||
class Args_evaluate(): | class Args_evaluate(): | ||||
def __init__(self): | def __init__(self): | ||||
# loop over the settings | # loop over the settings | ||||
# list of dataset to evaluate | # list of dataset to evaluate | ||||
# use a list of 1 element to evaluate a single dataset | # use a list of 1 element to evaluate a single dataset | ||||
self.dataset_name_all = ['caveman', 'grid', 'barabasi', 'citeseer', 'DD'] | |||||
# self.dataset_name_all = ['caveman', 'grid', 'barabasi', 'citeseer', 'DD'] | |||||
self.dataset_name_all = ['caveman_small'] | |||||
# self.dataset_name_all = ['citeseer_small','caveman_small'] | # self.dataset_name_all = ['citeseer_small','caveman_small'] | ||||
# self.dataset_name_all = ['barabasi_noise0','barabasi_noise2','barabasi_noise4','barabasi_noise6','barabasi_noise8','barabasi_noise10'] | # self.dataset_name_all = ['barabasi_noise0','barabasi_noise2','barabasi_noise4','barabasi_noise6','barabasi_noise8','barabasi_noise10'] | ||||
# self.dataset_name_all = ['caveman_small', 'ladder_small', 'grid_small', 'ladder_small', 'enzymes_small', 'barabasi_small','citeseer_small'] | # self.dataset_name_all = ['caveman_small', 'ladder_small', 'grid_small', 'ladder_small', 'enzymes_small', 'barabasi_small','citeseer_small'] | ||||
self.epoch_start=100 | |||||
self.epoch_end=3001 | |||||
self.epoch_step=100 | |||||
self.epoch_start = 10 | |||||
self.epoch_end = 101 | |||||
self.epoch_step = 10 | |||||
def find_nearest_idx(array,value): | |||||
idx = (np.abs(array-value)).argmin() | |||||
def find_nearest_idx(array, value): | |||||
idx = (np.abs(array - value)).argmin() | |||||
return idx | return idx | ||||
def extract_result_id_and_epoch(name, prefix, suffix): | def extract_result_id_and_epoch(name, prefix, suffix): | ||||
''' | ''' | ||||
Args: | Args: | ||||
if result_id not in pred_graphs_dict: | if result_id not in pred_graphs_dict: | ||||
pred_graphs_dict[result_id] = {} | pred_graphs_dict[result_id] = {} | ||||
pred_graphs_dict[result_id][epochs] = fname | pred_graphs_dict[result_id][epochs] = fname | ||||
for result_id in real_graphs_dict.keys(): | for result_id in real_graphs_dict.keys(): | ||||
for epochs in sorted(real_graphs_dict[result_id]): | for epochs in sorted(real_graphs_dict[result_id]): | ||||
real_g_list = utils.load_graph_list(real_graphs_dict[result_id][epochs]) | real_g_list = utils.load_graph_list(real_graphs_dict[result_id][epochs]) | ||||
shuffle(pred_g_list) | shuffle(pred_g_list) | ||||
perturbed_g_list = perturb(real_g_list, 0.05) | perturbed_g_list = perturb(real_g_list, 0.05) | ||||
#dist = eval.stats.degree_stats(real_g_list, pred_g_list) | |||||
# dist = eval.stats.degree_stats(real_g_list, pred_g_list) | |||||
dist = eval.stats.clustering_stats(real_g_list, pred_g_list) | dist = eval.stats.clustering_stats(real_g_list, pred_g_list) | ||||
print('dist between real and pred (', result_id, ') at epoch ', epochs, ': ', dist) | print('dist between real and pred (', result_id, ') at epoch ', epochs, ': ', dist) | ||||
#dist = eval.stats.degree_stats(real_g_list, perturbed_g_list) | |||||
# dist = eval.stats.degree_stats(real_g_list, perturbed_g_list) | |||||
dist = eval.stats.clustering_stats(real_g_list, perturbed_g_list) | dist = eval.stats.clustering_stats(real_g_list, perturbed_g_list) | ||||
print('dist between real and perturbed: ', dist) | print('dist between real and perturbed: ', dist) | ||||
mid = len(real_g_list) // 2 | mid = len(real_g_list) // 2 | ||||
#dist = eval.stats.degree_stats(real_g_list[:mid], real_g_list[mid:]) | |||||
# dist = eval.stats.degree_stats(real_g_list[:mid], real_g_list[mid:]) | |||||
dist = eval.stats.clustering_stats(real_g_list[:mid], real_g_list[mid:]) | dist = eval.stats.clustering_stats(real_g_list[:mid], real_g_list[mid:]) | ||||
print('dist among real: ', dist) | print('dist among real: ', dist) | ||||
return dist_degree, dist_clustering | return dist_degree, dist_clustering | ||||
def clean_graphs(graph_real, graph_pred): | def clean_graphs(graph_real, graph_pred): | ||||
''' Selecting graphs generated that have the similar sizes. | ''' Selecting graphs generated that have the similar sizes. | ||||
It is usually necessary for GraphRNN-S version, but not the full GraphRNN model. | It is usually necessary for GraphRNN-S version, but not the full GraphRNN model. | ||||
return graph_real, pred_graph_new | return graph_real, pred_graph_new | ||||
def load_ground_truth(dir_input, dataset_name, model_name='GraphRNN_RNN'): | def load_ground_truth(dir_input, dataset_name, model_name='GraphRNN_RNN'): | ||||
''' Read ground truth graphs. | ''' Read ground truth graphs. | ||||
''' | ''' | ||||
hidden = 128 | hidden = 128 | ||||
else: | else: | ||||
hidden = 64 | hidden = 64 | ||||
if model_name=='Internal' or model_name=='Noise' or model_name=='B-A' or model_name=='E-R': | |||||
if model_name == 'Internal' or model_name == 'Noise' or model_name == 'B-A' or model_name == 'E-R': | |||||
fname_test = dir_input + 'GraphRNN_MLP' + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | fname_test = dir_input + 'GraphRNN_MLP' + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | ||||
hidden) + '_test_' + str(0) + '.dat' | |||||
hidden) + '_test_' + str(0) + '.dat' | |||||
else: | else: | ||||
fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | ||||
hidden) + '_test_' + str(0) + '.dat' | |||||
hidden) + '_test_' + str(0) + '.dat' | |||||
try: | try: | ||||
graph_test = utils.load_graph_list(fname_test,is_real=True) | |||||
graph_test = utils.load_graph_list(fname_test, is_real=True) | |||||
except: | except: | ||||
print('Not found: ' + fname_test) | print('Not found: ' + fname_test) | ||||
logging.warning('Not found: ' + fname_test) | logging.warning('Not found: ' + fname_test) | ||||
return None | return None | ||||
return graph_test | return graph_test | ||||
def eval_single_list(graphs, dir_input, dataset_name): | def eval_single_list(graphs, dir_input, dataset_name): | ||||
''' Evaluate a list of graphs by comparing with graphs in directory dir_input. | ''' Evaluate a list of graphs by comparing with graphs in directory dir_input. | ||||
Args: | Args: | ||||
''' | ''' | ||||
graph_test = load_ground_truth(dir_input, dataset_name) | graph_test = load_ground_truth(dir_input, dataset_name) | ||||
graph_test_len = len(graph_test) | graph_test_len = len(graph_test) | ||||
graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set | |||||
graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set | |||||
mmd_degree = eval.stats.degree_stats(graph_test, graphs) | mmd_degree = eval.stats.degree_stats(graph_test, graphs) | ||||
mmd_clustering = eval.stats.clustering_stats(graph_test, graphs) | mmd_clustering = eval.stats.clustering_stats(graph_test, graphs) | ||||
try: | try: | ||||
print('clustering: ', mmd_clustering) | print('clustering: ', mmd_clustering) | ||||
print('orbits: ', mmd_4orbits) | print('orbits: ', mmd_4orbits) | ||||
def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is_clean=True, epoch_start=1000,epoch_end=3001,epoch_step=100): | |||||
def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is_clean=True, epoch_start=1000, | |||||
epoch_end=3001, epoch_step=100): | |||||
with open(fname_output, 'w+') as f: | with open(fname_output, 'w+') as f: | ||||
f.write('sample_time,epoch,degree_validate,clustering_validate,orbits4_validate,degree_test,clustering_test,orbits4_test\n') | |||||
f.write( | |||||
'sample_time,epoch,degree_validate,clustering_validate,orbits4_validate,degree_test,clustering_test,orbits4_test\n') | |||||
# TODO: Maybe refactor into a separate file/function that specifies THE naming convention | # TODO: Maybe refactor into a separate file/function that specifies THE naming convention | ||||
# across main and evaluate | # across main and evaluate | ||||
else: | else: | ||||
hidden = 64 | hidden = 64 | ||||
# read real graph | # read real graph | ||||
if model_name=='Internal' or model_name=='Noise' or model_name=='B-A' or model_name=='E-R': | |||||
if model_name == 'Internal' or model_name == 'Noise' or model_name == 'B-A' or model_name == 'E-R': | |||||
fname_test = dir_input + 'GraphRNN_MLP' + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | fname_test = dir_input + 'GraphRNN_MLP' + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | ||||
hidden) + '_test_' + str(0) + '.dat' | hidden) + '_test_' + str(0) + '.dat' | ||||
elif 'Baseline' in model_name: | elif 'Baseline' in model_name: | ||||
fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | ||||
hidden) + '_test_' + str(0) + '.dat' | hidden) + '_test_' + str(0) + '.dat' | ||||
try: | try: | ||||
graph_test = utils.load_graph_list(fname_test,is_real=True) | |||||
graph_test = utils.load_graph_list(fname_test, is_real=True) | |||||
except: | except: | ||||
print('Not found: ' + fname_test) | print('Not found: ' + fname_test) | ||||
logging.warning('Not found: ' + fname_test) | logging.warning('Not found: ' + fname_test) | ||||
return None | return None | ||||
graph_test_len = len(graph_test) | graph_test_len = len(graph_test) | ||||
graph_train = graph_test[0:int(0.8 * graph_test_len)] # train | |||||
graph_validate = graph_test[0:int(0.2 * graph_test_len)] # validate | |||||
graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set | |||||
graph_train = graph_test[0:int(0.8 * graph_test_len)] # train | |||||
graph_validate = graph_test[0:int(0.2 * graph_test_len)] # validate | |||||
graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set | |||||
graph_test_aver = 0 | graph_test_aver = 0 | ||||
for graph in graph_test: | for graph in graph_test: | ||||
graph_test_aver+=graph.number_of_nodes() | |||||
graph_test_aver += graph.number_of_nodes() | |||||
graph_test_aver /= len(graph_test) | graph_test_aver /= len(graph_test) | ||||
print('test average len',graph_test_aver) | |||||
print('test average len', graph_test_aver) | |||||
# get performance for proposed approaches | # get performance for proposed approaches | ||||
if 'GraphRNN' in model_name: | if 'GraphRNN' in model_name: | ||||
# read test graph | # read test graph | ||||
for epoch in range(epoch_start,epoch_end,epoch_step): | |||||
for sample_time in range(1,4): | |||||
for epoch in range(epoch_start, epoch_end, epoch_step): | |||||
for sample_time in range(1, 4): | |||||
# get filename | # get filename | ||||
fname_pred = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(hidden) + '_pred_' + str(epoch) + '_' + str(sample_time) + '.dat' | |||||
fname_pred = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | |||||
hidden) + '_pred_' + str(epoch) + '_' + str(sample_time) + '.dat' | |||||
# load graphs | # load graphs | ||||
try: | try: | ||||
graph_pred = utils.load_graph_list(fname_pred,is_real=False) # default False | |||||
graph_pred = utils.load_graph_list(fname_pred, is_real=False) # default False | |||||
except: | except: | ||||
print('Not found: '+ fname_pred) | |||||
logging.warning('Not found: '+ fname_pred) | |||||
print('Not found: ' + fname_pred) | |||||
logging.warning('Not found: ' + fname_pred) | |||||
continue | continue | ||||
# clean graphs | # clean graphs | ||||
if is_clean: | if is_clean: | ||||
except: | except: | ||||
mmd_4orbits_validate = -1 | mmd_4orbits_validate = -1 | ||||
# write results | # write results | ||||
f.write(str(sample_time)+','+ | |||||
str(epoch)+','+ | |||||
str(mmd_degree_validate)+','+ | |||||
str(mmd_clustering_validate)+','+ | |||||
str(mmd_4orbits_validate)+','+ | |||||
str(mmd_degree)+','+ | |||||
str(mmd_clustering)+','+ | |||||
str(mmd_4orbits)+'\n') | |||||
print('degree',mmd_degree,'clustering',mmd_clustering,'orbits',mmd_4orbits) | |||||
f.write(str(sample_time) + ',' + | |||||
str(epoch) + ',' + | |||||
str(mmd_degree_validate) + ',' + | |||||
str(mmd_clustering_validate) + ',' + | |||||
str(mmd_4orbits_validate) + ',' + | |||||
str(mmd_degree) + ',' + | |||||
str(mmd_clustering) + ',' + | |||||
str(mmd_4orbits) + '\n') | |||||
print('degree', mmd_degree, 'clustering', mmd_clustering, 'orbits', mmd_4orbits) | |||||
# get internal MMD (MMD between ground truth validation and test sets) | # get internal MMD (MMD between ground truth validation and test sets) | ||||
if model_name == 'Internal': | if model_name == 'Internal': | ||||
mmd_clustering_validate) + ',' + str(mmd_4orbits_validate) | mmd_clustering_validate) + ',' + str(mmd_4orbits_validate) | ||||
+ ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + '\n') | + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + '\n') | ||||
# get MMD between ground truth and its perturbed graphs | # get MMD between ground truth and its perturbed graphs | ||||
if model_name == 'Noise': | if model_name == 'Noise': | ||||
graph_validate_perturbed = perturb(graph_validate, 0.05) | graph_validate_perturbed = perturb(graph_validate, 0.05) | ||||
# get E-R MMD | # get E-R MMD | ||||
if model_name == 'E-R': | if model_name == 'E-R': | ||||
graph_pred = Graph_generator_baseline(graph_train,generator='Gnp') | |||||
graph_pred = Graph_generator_baseline(graph_train, generator='Gnp') | |||||
# clean graphs | # clean graphs | ||||
if is_clean: | if is_clean: | ||||
graph_test, graph_pred = clean_graphs(graph_test, graph_pred) | graph_test, graph_pred = clean_graphs(graph_test, graph_pred) | ||||
f.write(str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) | f.write(str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) | ||||
+ ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits_validate) + '\n') | + ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits_validate) + '\n') | ||||
# get B-A MMD | # get B-A MMD | ||||
if model_name == 'B-A': | if model_name == 'B-A': | ||||
graph_pred = Graph_generator_baseline(graph_train, generator='BA') | graph_pred = Graph_generator_baseline(graph_train, generator='BA') | ||||
+ ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits) + '\n') | + ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits) + '\n') | ||||
print('degree', mmd_degree, 'clustering', mmd_clustering, 'orbits', mmd_4orbits) | print('degree', mmd_degree, 'clustering', mmd_clustering, 'orbits', mmd_4orbits) | ||||
return True | return True | ||||
def evaluation(args_evaluate,dir_input, dir_output, model_name_all, dataset_name_all, args, overwrite = True): | |||||
def evaluation(args_evaluate, dir_input, dir_output, model_name_all, dataset_name_all, args, overwrite=True): | |||||
''' Evaluate the performance of a set of models on a set of datasets. | ''' Evaluate the performance of a set of models on a set of datasets. | ||||
''' | ''' | ||||
for model_name in model_name_all: | for model_name in model_name_all: | ||||
for dataset_name in dataset_name_all: | for dataset_name in dataset_name_all: | ||||
# check output exist | # check output exist | ||||
fname_output = dir_output+model_name+'_'+dataset_name+'.csv' | |||||
print('processing: '+dir_output + model_name + '_' + dataset_name + '.csv') | |||||
logging.info('processing: '+dir_output + model_name + '_' + dataset_name + '.csv') | |||||
if overwrite==False and os.path.isfile(fname_output): | |||||
print(dir_output+model_name+'_'+dataset_name+'.csv exists!') | |||||
logging.info(dir_output+model_name+'_'+dataset_name+'.csv exists!') | |||||
fname_output = dir_output + model_name + '_' + dataset_name + '.csv' | |||||
print('processing: ' + dir_output + model_name + '_' + dataset_name + '.csv') | |||||
logging.info('processing: ' + dir_output + model_name + '_' + dataset_name + '.csv') | |||||
if overwrite == False and os.path.isfile(fname_output): | |||||
print(dir_output + model_name + '_' + dataset_name + '.csv exists!') | |||||
logging.info(dir_output + model_name + '_' + dataset_name + '.csv exists!') | |||||
continue | continue | ||||
evaluation_epoch(dir_input,fname_output,model_name,dataset_name,args,is_clean=True, epoch_start=args_evaluate.epoch_start,epoch_end=args_evaluate.epoch_end,epoch_step=args_evaluate.epoch_step) | |||||
evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is_clean=True, | |||||
epoch_start=args_evaluate.epoch_start, epoch_end=args_evaluate.epoch_end, | |||||
epoch_step=args_evaluate.epoch_step) | |||||
def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | ||||
eval_every, epoch_range=None, out_file_prefix=None): | |||||
eval_every, epoch_range=None, out_file_prefix=None): | |||||
''' Evaluate list of predicted graphs compared to ground truth, stored in files. | ''' Evaluate list of predicted graphs compared to ground truth, stored in files. | ||||
Args: | Args: | ||||
baselines: dict mapping name of the baseline to list of generated graphs. | baselines: dict mapping name of the baseline to list of generated graphs. | ||||
if out_file_prefix is not None: | if out_file_prefix is not None: | ||||
out_files = { | out_files = { | ||||
'train': open(out_file_prefix + '_train.txt', 'w+'), | |||||
'compare': open(out_file_prefix + '_compare.txt', 'w+') | |||||
'train': open(out_file_prefix + '_train.txt', 'w+'), | |||||
'compare': open(out_file_prefix + '_compare.txt', 'w+') | |||||
} | } | ||||
out_files['train'].write('degree,clustering,orbits4\n') | out_files['train'].write('degree,clustering,orbits4\n') | ||||
line = 'metric,real,ours,perturbed' | line = 'metric,real,ours,perturbed' | ||||
for bl in baselines: | for bl in baselines: | ||||
line += ',' + bl | line += ',' + bl | ||||
out_files['compare'].write(line) | out_files['compare'].write(line) | ||||
results = { | results = { | ||||
'deg': { | |||||
'real': 0, | |||||
'ours': 100, # take min over all training epochs | |||||
'perturbed': 0, | |||||
'kron': 0}, | |||||
'clustering': { | |||||
'real': 0, | |||||
'ours': 100, | |||||
'perturbed': 0, | |||||
'kron': 0}, | |||||
'orbits4': { | |||||
'real': 0, | |||||
'ours': 100, | |||||
'perturbed': 0, | |||||
'kron': 0} | |||||
'deg': { | |||||
'real': 0, | |||||
'ours': 100, # take min over all training epochs | |||||
'perturbed': 0, | |||||
'kron': 0}, | |||||
'clustering': { | |||||
'real': 0, | |||||
'ours': 100, | |||||
'perturbed': 0, | |||||
'kron': 0}, | |||||
'orbits4': { | |||||
'real': 0, | |||||
'ours': 100, | |||||
'perturbed': 0, | |||||
'kron': 0} | |||||
} | } | ||||
num_evals = len(pred_graphs_filename) | num_evals = len(pred_graphs_filename) | ||||
if epoch_range is None: | if epoch_range is None: | ||||
epoch_range = [i * eval_every for i in range(num_evals)] | |||||
epoch_range = [i * eval_every for i in range(num_evals)] | |||||
for i in range(num_evals): | for i in range(num_evals): | ||||
real_g_list = utils.load_graph_list(real_graph_filename) | real_g_list = utils.load_graph_list(real_graph_filename) | ||||
#pred_g_list = utils.load_graph_list(pred_graphs_filename[i]) | |||||
# pred_g_list = utils.load_graph_list(pred_graphs_filename[i]) | |||||
# contains all predicted G | # contains all predicted G | ||||
pred_g_list_raw = utils.load_graph_list(pred_graphs_filename[i]) | pred_g_list_raw = utils.load_graph_list(pred_graphs_filename[i]) | ||||
if len(real_g_list)>200: | |||||
if len(real_g_list) > 200: | |||||
real_g_list = real_g_list[0:200] | real_g_list = real_g_list[0:200] | ||||
shuffle(real_g_list) | shuffle(real_g_list) | ||||
real_g_len_list = np.array([len(real_g_list[i]) for i in range(len(real_g_list))]) | real_g_len_list = np.array([len(real_g_list[i]) for i in range(len(real_g_list))]) | ||||
pred_g_len_list_raw = np.array([len(pred_g_list_raw[i]) for i in range(len(pred_g_list_raw))]) | pred_g_len_list_raw = np.array([len(pred_g_list_raw[i]) for i in range(len(pred_g_list_raw))]) | ||||
# get perturb real | # get perturb real | ||||
#perturbed_g_list_001 = perturb(real_g_list, 0.01) | |||||
# perturbed_g_list_001 = perturb(real_g_list, 0.01) | |||||
perturbed_g_list_005 = perturb(real_g_list, 0.05) | perturbed_g_list_005 = perturb(real_g_list, 0.05) | ||||
#perturbed_g_list_010 = perturb(real_g_list, 0.10) | |||||
# perturbed_g_list_010 = perturb(real_g_list, 0.10) | |||||
# select pred samples | # select pred samples | ||||
# The number of nodes are sampled from the similar distribution as the training set | # The number of nodes are sampled from the similar distribution as the training set | ||||
# ======================================== | # ======================================== | ||||
mid = len(real_g_list) // 2 | mid = len(real_g_list) // 2 | ||||
dist_degree, dist_clustering = compute_basic_stats(real_g_list[:mid], real_g_list[mid:]) | dist_degree, dist_clustering = compute_basic_stats(real_g_list[:mid], real_g_list[mid:]) | ||||
#dist_4cycle = eval.stats.motif_stats(real_g_list[:mid], real_g_list[mid:]) | |||||
# dist_4cycle = eval.stats.motif_stats(real_g_list[:mid], real_g_list[mid:]) | |||||
dist_4orbits = eval.stats.orbit_stats_all(real_g_list[:mid], real_g_list[mid:]) | dist_4orbits = eval.stats.orbit_stats_all(real_g_list[:mid], real_g_list[mid:]) | ||||
print('degree dist among real: ', dist_degree) | print('degree dist among real: ', dist_degree) | ||||
print('clustering dist among real: ', dist_clustering) | print('clustering dist among real: ', dist_clustering) | ||||
#print('4 cycle dist among real: ', dist_4cycle) | |||||
# print('4 cycle dist among real: ', dist_4cycle) | |||||
print('orbits dist among real: ', dist_4orbits) | print('orbits dist among real: ', dist_4orbits) | ||||
results['deg']['real'] += dist_degree | results['deg']['real'] += dist_degree | ||||
results['clustering']['real'] += dist_clustering | results['clustering']['real'] += dist_clustering | ||||
results['orbits4']['real'] += dist_4orbits | results['orbits4']['real'] += dist_4orbits | ||||
dist_degree, dist_clustering = compute_basic_stats(real_g_list, pred_g_list) | dist_degree, dist_clustering = compute_basic_stats(real_g_list, pred_g_list) | ||||
#dist_4cycle = eval.stats.motif_stats(real_g_list, pred_g_list) | |||||
# dist_4cycle = eval.stats.motif_stats(real_g_list, pred_g_list) | |||||
dist_4orbits = eval.stats.orbit_stats_all(real_g_list, pred_g_list) | dist_4orbits = eval.stats.orbit_stats_all(real_g_list, pred_g_list) | ||||
print('degree dist between real and pred at epoch ', epoch_range[i], ': ', dist_degree) | print('degree dist between real and pred at epoch ', epoch_range[i], ': ', dist_degree) | ||||
print('clustering dist between real and pred at epoch ', epoch_range[i], ': ', dist_clustering) | print('clustering dist between real and pred at epoch ', epoch_range[i], ': ', dist_clustering) | ||||
#print('4 cycle dist between real and pred at epoch: ', epoch_range[i], dist_4cycle) | |||||
# print('4 cycle dist between real and pred at epoch: ', epoch_range[i], dist_4cycle) | |||||
print('orbits dist between real and pred at epoch ', epoch_range[i], ': ', dist_4orbits) | print('orbits dist between real and pred at epoch ', epoch_range[i], ': ', dist_4orbits) | ||||
results['deg']['ours'] = min(dist_degree, results['deg']['ours']) | results['deg']['ours'] = min(dist_degree, results['deg']['ours']) | ||||
results['clustering']['ours'] = min(dist_clustering, results['clustering']['ours']) | results['clustering']['ours'] = min(dist_clustering, results['clustering']['ours']) | ||||
out_files['train'].write(str(dist_4orbits) + ',') | out_files['train'].write(str(dist_4orbits) + ',') | ||||
dist_degree, dist_clustering = compute_basic_stats(real_g_list, perturbed_g_list_005) | dist_degree, dist_clustering = compute_basic_stats(real_g_list, perturbed_g_list_005) | ||||
#dist_4cycle = eval.stats.motif_stats(real_g_list, perturbed_g_list_005) | |||||
# dist_4cycle = eval.stats.motif_stats(real_g_list, perturbed_g_list_005) | |||||
dist_4orbits = eval.stats.orbit_stats_all(real_g_list, perturbed_g_list_005) | dist_4orbits = eval.stats.orbit_stats_all(real_g_list, perturbed_g_list_005) | ||||
print('degree dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_degree) | print('degree dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_degree) | ||||
print('clustering dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_clustering) | print('clustering dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_clustering) | ||||
#print('4 cycle dist between real and perturbed at epoch: ', epoch_range[i], dist_4cycle) | |||||
# print('4 cycle dist between real and perturbed at epoch: ', epoch_range[i], dist_4cycle) | |||||
print('orbits dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_4orbits) | print('orbits dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_4orbits) | ||||
results['deg']['perturbed'] += dist_degree | results['deg']['perturbed'] += dist_degree | ||||
results['clustering']['perturbed'] += dist_clustering | results['clustering']['perturbed'] += dist_clustering | ||||
results['deg'][baseline] = dist_degree | results['deg'][baseline] = dist_degree | ||||
results['clustering'][baseline] = dist_clustering | results['clustering'][baseline] = dist_clustering | ||||
results['orbits4'][baseline] = dist_4orbits | results['orbits4'][baseline] = dist_4orbits | ||||
print('Kron: deg=', dist_degree, ', clustering=', dist_clustering, | |||||
', orbits4=', dist_4orbits) | |||||
print('Kron: deg=', dist_degree, ', clustering=', dist_clustering, | |||||
', orbits4=', dist_4orbits) | |||||
out_files['train'].write('\n') | out_files['train'].write('\n') | ||||
# Write results | # Write results | ||||
for metric, methods in results.items(): | for metric, methods in results.items(): | ||||
line = metric+','+ \ | |||||
str(methods['real'])+','+ \ | |||||
str(methods['ours'])+','+ \ | |||||
str(methods['perturbed']) | |||||
line = metric + ',' + \ | |||||
str(methods['real']) + ',' + \ | |||||
str(methods['ours']) + ',' + \ | |||||
str(methods['perturbed']) | |||||
for baseline in baselines: | for baseline in baselines: | ||||
line += ',' + str(methods[baseline]) | line += ',' + str(methods[baseline]) | ||||
line += '\n' | line += '\n' | ||||
def eval_performance(datadir, prefix=None, args=None, eval_every=200, out_file_prefix=None, | def eval_performance(datadir, prefix=None, args=None, eval_every=200, out_file_prefix=None, | ||||
sample_time = 2, baselines={}): | |||||
sample_time=2, baselines={}): | |||||
if args is None: | if args is None: | ||||
real_graphs_filename = [datadir + f for f in os.listdir(datadir) | real_graphs_filename = [datadir + f for f in os.listdir(datadir) | ||||
if re.match(prefix + '.*real.*\.dat', f)] | |||||
if re.match(prefix + '.*real.*\.dat', f)] | |||||
pred_graphs_filename = [datadir + f for f in os.listdir(datadir) | pred_graphs_filename = [datadir + f for f in os.listdir(datadir) | ||||
if re.match(prefix + '.*pred.*\.dat', f)] | |||||
if re.match(prefix + '.*pred.*\.dat', f)] | |||||
eval_list(real_graphs_filename, pred_graphs_filename, prefix, 200) | eval_list(real_graphs_filename, pred_graphs_filename, prefix, 200) | ||||
else: | else: | ||||
# str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(args.bptt_len) + '.dat' for epoch in range(0,50001,eval_every)] | # str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(args.bptt_len) + '.dat' for epoch in range(0,50001,eval_every)] | ||||
# pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | # pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | ||||
# str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(args.bptt_len) + '.dat' for epoch in range(0,50001,eval_every)] | # str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(args.bptt_len) + '.dat' for epoch in range(0,50001,eval_every)] | ||||
real_graph_filename = datadir+args.graph_save_path + args.fname_test + '0.dat' | |||||
real_graph_filename = datadir + args.graph_save_path + args.fname_test + '0.dat' | |||||
# for proposed model | # for proposed model | ||||
end_epoch = 3001 | end_epoch = 3001 | ||||
epoch_range = range(eval_every, end_epoch, eval_every) | epoch_range = range(eval_every, end_epoch, eval_every) | ||||
pred_graphs_filename = [datadir+args.graph_save_path + args.fname_pred+str(epoch)+'_'+str(sample_time)+'.dat' | |||||
for epoch in epoch_range] | |||||
pred_graphs_filename = [ | |||||
datadir + args.graph_save_path + args.fname_pred + str(epoch) + '_' + str(sample_time) + '.dat' | |||||
for epoch in epoch_range] | |||||
# for baseline model | # for baseline model | ||||
#pred_graphs_filename = [datadir+args.fname_baseline+'.dat'] | |||||
# pred_graphs_filename = [datadir+args.fname_baseline+'.dat'] | |||||
#real_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | |||||
# real_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | |||||
# str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str( | # str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str( | ||||
# args.bptt_len) + '_' + str(args.gumbel) + '.dat' for epoch in range(10000, 50001, eval_every)] | # args.bptt_len) + '_' + str(args.gumbel) + '.dat' for epoch in range(10000, 50001, eval_every)] | ||||
#pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | |||||
# pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | |||||
# str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str( | # str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str( | ||||
# args.bptt_len) + '_' + str(args.gumbel) + '.dat' for epoch in range(10000, 50001, eval_every)] | # args.bptt_len) + '_' + str(args.gumbel) + '.dat' for epoch in range(10000, 50001, eval_every)] | ||||
eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | ||||
epoch_range=epoch_range, | |||||
epoch_range=epoch_range, | |||||
eval_every=eval_every, | eval_every=eval_every, | ||||
out_file_prefix=out_file_prefix) | out_file_prefix=out_file_prefix) | ||||
def process_kron(kron_dir): | def process_kron(kron_dir): | ||||
txt_files = [] | txt_files = [] | ||||
for f in os.listdir(kron_dir): | for f in os.listdir(kron_dir): | ||||
G_list.append(utils.snap_txt_output_to_nx(os.path.join(kron_dir, filename))) | G_list.append(utils.snap_txt_output_to_nx(os.path.join(kron_dir, filename))) | ||||
return G_list | return G_list | ||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
args = Args() | args = Args() | ||||
feature_parser = parser.add_mutually_exclusive_group(required=False) | feature_parser = parser.add_mutually_exclusive_group(required=False) | ||||
feature_parser.add_argument('--export-real', dest='export', action='store_true') | feature_parser.add_argument('--export-real', dest='export', action='store_true') | ||||
feature_parser.add_argument('--no-export-real', dest='export', action='store_false') | feature_parser.add_argument('--no-export-real', dest='export', action='store_false') | ||||
feature_parser.add_argument('--kron-dir', dest='kron_dir', | |||||
help='Directory where graphs generated by kronecker method is stored.') | |||||
feature_parser.add_argument('--kron-dir', dest='kron_dir', | |||||
help='Directory where graphs generated by kronecker method is stored.') | |||||
parser.add_argument('--testfile', dest='test_file', | parser.add_argument('--testfile', dest='test_file', | ||||
help='The file that stores list of graphs to be evaluated. Only used when 1 list of ' | |||||
'graphs is to be evaluated.') | |||||
help='The file that stores list of graphs to be evaluated. Only used when 1 list of ' | |||||
'graphs is to be evaluated.') | |||||
parser.add_argument('--dir-prefix', dest='dir_prefix', | parser.add_argument('--dir-prefix', dest='dir_prefix', | ||||
help='The file that stores list of graphs to be evaluated. Can be used when evaluating multiple' | |||||
'models on multiple datasets.') | |||||
help='The file that stores list of graphs to be evaluated. Can be used when evaluating multiple' | |||||
'models on multiple datasets.') | |||||
parser.add_argument('--graph-type', dest='graph_type', | parser.add_argument('--graph-type', dest='graph_type', | ||||
help='Type of graphs / dataset.') | |||||
help='Type of graphs / dataset.') | |||||
parser.set_defaults(export=False, kron_dir='', test_file='', | parser.set_defaults(export=False, kron_dir='', test_file='', | ||||
dir_prefix='', | dir_prefix='', | ||||
graph_type=args.graph_type) | graph_type=args.graph_type) | ||||
# dir_prefix = "/dfs/scratch0/jiaxuany0/" | # dir_prefix = "/dfs/scratch0/jiaxuany0/" | ||||
dir_prefix = args.dir_input | dir_prefix = args.dir_input | ||||
time_now = strftime("%Y-%m-%d %H:%M:%S", gmtime()) | time_now = strftime("%Y-%m-%d %H:%M:%S", gmtime()) | ||||
if not os.path.isdir('logs/'): | if not os.path.isdir('logs/'): | ||||
os.makedirs('logs/') | os.makedirs('logs/') | ||||
if prog_args.graph_type == 'grid': | if prog_args.graph_type == 'grid': | ||||
graphs = [] | graphs = [] | ||||
for i in range(10,20): | |||||
for j in range(10,20): | |||||
graphs.append(nx.grid_2d_graph(i,j)) | |||||
for i in range(10, 20): | |||||
for j in range(10, 20): | |||||
graphs.append(nx.grid_2d_graph(i, j)) | |||||
utils.export_graphs_to_txt(graphs, output_prefix) | utils.export_graphs_to_txt(graphs, output_prefix) | ||||
elif prog_args.graph_type == 'caveman': | elif prog_args.graph_type == 'caveman': | ||||
graphs = [] | graphs = [] | ||||
for i in range(2, 3): | for i in range(2, 3): | ||||
for j in range(30, 81): | for j in range(30, 81): | ||||
for k in range(10): | for k in range(10): | ||||
graphs.append(caveman_special(i,j, p_edge=0.3)) | |||||
graphs.append(caveman_special(i, j, p_edge=0.3)) | |||||
utils.export_graphs_to_txt(graphs, output_prefix) | utils.export_graphs_to_txt(graphs, output_prefix) | ||||
elif prog_args.graph_type == 'citeseer': | elif prog_args.graph_type == 'citeseer': | ||||
graphs = utils.citeseer_ego() | graphs = utils.citeseer_ego() | ||||
elif not prog_args.test_file == '': | elif not prog_args.test_file == '': | ||||
# evaluate single .dat file containing list of test graphs (networkx format) | # evaluate single .dat file containing list of test graphs (networkx format) | ||||
graphs = utils.load_graph_list(prog_args.test_file) | graphs = utils.load_graph_list(prog_args.test_file) | ||||
eval_single_list(graphs, dir_input=dir_prefix+'graphs/', dataset_name='grid') | |||||
eval_single_list(graphs, dir_input=dir_prefix + 'graphs/', dataset_name='grid') | |||||
## if you don't try kronecker, only the following part is needed | ## if you don't try kronecker, only the following part is needed | ||||
else: | else: | ||||
if not os.path.isdir(dir_prefix+'eval_results'): | |||||
os.makedirs(dir_prefix+'eval_results') | |||||
evaluation(args_evaluate,dir_input=dir_prefix+"graphs/", dir_output=dir_prefix+"eval_results/", | |||||
model_name_all=args_evaluate.model_name_all,dataset_name_all=args_evaluate.dataset_name_all,args=args,overwrite=True) | |||||
if not os.path.isdir(dir_prefix + 'eval_results'): | |||||
os.makedirs(dir_prefix + 'eval_results') | |||||
evaluation(args_evaluate, dir_input=dir_prefix + "graphs/", dir_output=dir_prefix + "eval_results/", | |||||
model_name_all=args_evaluate.model_name_all, dataset_name_all=args_evaluate.dataset_name_all, | |||||
args=args, overwrite=True) |
import scipy.misc | import scipy.misc | ||||
import time as tm | import time as tm | ||||
from utils import * | |||||
from model import * | |||||
from tensorboard_logger import log_value | |||||
from data import * | from data import * | ||||
from args import Args | |||||
import create_graphs | |||||
def train_vae_epoch(epoch, args, rnn, output, data_loader, | def train_vae_epoch(epoch, args, rnn, output, data_loader, | ||||
rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) | rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) | ||||
# sort input | # sort input | ||||
y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) | |||||
y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True) | |||||
y_len = y_len.numpy().tolist() | y_len = y_len.numpy().tolist() | ||||
x = torch.index_select(x_unsorted,0,sort_index) | |||||
y = torch.index_select(y_unsorted,0,sort_index) | |||||
x = torch.index_select(x_unsorted, 0, sort_index) | |||||
y = torch.index_select(y_unsorted, 0, sort_index) | |||||
x = Variable(x).cuda() | x = Variable(x).cuda() | ||||
y = Variable(y).cuda() | y = Variable(y).cuda() | ||||
# if using ground truth to train | # if using ground truth to train | ||||
h = rnn(x, pack=True, input_len=y_len) | h = rnn(x, pack=True, input_len=y_len) | ||||
y_pred,z_mu,z_lsgms = output(h) | |||||
y_pred, z_mu, z_lsgms = output(h) | |||||
y_pred = F.sigmoid(y_pred) | y_pred = F.sigmoid(y_pred) | ||||
# clean | # clean | ||||
y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True) | y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True) | ||||
# use cross entropy loss | # use cross entropy loss | ||||
loss_bce = binary_cross_entropy_weight(y_pred, y) | loss_bce = binary_cross_entropy_weight(y_pred, y) | ||||
loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp()) | loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp()) | ||||
loss_kl /= y.size(0)*y.size(1)*sum(y_len) # normalize | |||||
loss_kl /= y.size(0) * y.size(1) * sum(y_len) # normalize | |||||
loss = loss_bce + loss_kl | loss = loss_bce + loss_kl | ||||
loss.backward() | loss.backward() | ||||
# update deterministic and lstm | # update deterministic and lstm | ||||
scheduler_output.step() | scheduler_output.step() | ||||
scheduler_rnn.step() | scheduler_rnn.step() | ||||
z_mu_mean = torch.mean(z_mu.data) | z_mu_mean = torch.mean(z_mu.data) | ||||
z_sgm_mean = torch.mean(z_lsgms.mul(0.5).exp_().data) | z_sgm_mean = torch.mean(z_lsgms.mul(0.5).exp_().data) | ||||
z_mu_min = torch.min(z_mu.data) | z_mu_min = torch.min(z_mu.data) | ||||
z_mu_max = torch.max(z_mu.data) | z_mu_max = torch.max(z_mu.data) | ||||
z_sgm_max = torch.max(z_lsgms.mul(0.5).exp_().data) | z_sgm_max = torch.max(z_lsgms.mul(0.5).exp_().data) | ||||
if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics | |||||
print('Epoch: {}/{}, train bce loss: {:.6f}, train kl loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | |||||
epoch, args.epochs,loss_bce.data[0], loss_kl.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||||
print('z_mu_mean', z_mu_mean, 'z_mu_min', z_mu_min, 'z_mu_max', z_mu_max, 'z_sgm_mean', z_sgm_mean, 'z_sgm_min', z_sgm_min, 'z_sgm_max', z_sgm_max) | |||||
if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics | |||||
print( | |||||
'Epoch: {}/{}, train bce loss: {:.6f}, train kl loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | |||||
epoch, args.epochs, loss_bce.data[0], loss_kl.data[0], args.graph_type, args.num_layers, | |||||
args.hidden_size_rnn)) | |||||
print('z_mu_mean', z_mu_mean, 'z_mu_min', z_mu_min, 'z_mu_max', z_mu_max, 'z_sgm_mean', z_sgm_mean, | |||||
'z_sgm_min', z_sgm_min, 'z_sgm_max', z_sgm_max) | |||||
# logging | # logging | ||||
log_value('bce_loss_'+args.fname, loss_bce.data[0], epoch*args.batch_ratio+batch_idx) | |||||
log_value('kl_loss_' +args.fname, loss_kl.data[0], epoch*args.batch_ratio + batch_idx) | |||||
log_value('z_mu_mean_'+args.fname, z_mu_mean, epoch*args.batch_ratio + batch_idx) | |||||
log_value('z_mu_min_'+args.fname, z_mu_min, epoch*args.batch_ratio + batch_idx) | |||||
log_value('z_mu_max_'+args.fname, z_mu_max, epoch*args.batch_ratio + batch_idx) | |||||
log_value('z_sgm_mean_'+args.fname, z_sgm_mean, epoch*args.batch_ratio + batch_idx) | |||||
log_value('z_sgm_min_'+args.fname, z_sgm_min, epoch*args.batch_ratio + batch_idx) | |||||
log_value('z_sgm_max_'+args.fname, z_sgm_max, epoch*args.batch_ratio + batch_idx) | |||||
log_value('bce_loss_' + args.fname, loss_bce.data[0], epoch * args.batch_ratio + batch_idx) | |||||
log_value('kl_loss_' + args.fname, loss_kl.data[0], epoch * args.batch_ratio + batch_idx) | |||||
log_value('z_mu_mean_' + args.fname, z_mu_mean, epoch * args.batch_ratio + batch_idx) | |||||
log_value('z_mu_min_' + args.fname, z_mu_min, epoch * args.batch_ratio + batch_idx) | |||||
log_value('z_mu_max_' + args.fname, z_mu_max, epoch * args.batch_ratio + batch_idx) | |||||
log_value('z_sgm_mean_' + args.fname, z_sgm_mean, epoch * args.batch_ratio + batch_idx) | |||||
log_value('z_sgm_min_' + args.fname, z_sgm_min, epoch * args.batch_ratio + batch_idx) | |||||
log_value('z_sgm_max_' + args.fname, z_sgm_max, epoch * args.batch_ratio + batch_idx) | |||||
loss_sum += loss.data[0] | loss_sum += loss.data[0] | ||||
return loss_sum/(batch_idx+1) | |||||
return loss_sum / (batch_idx + 1) | |||||
def test_vae_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False, sample_time = 1): | |||||
def test_vae_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False, sample_time=1): | |||||
rnn.hidden = rnn.init_hidden(test_batch_size) | rnn.hidden = rnn.init_hidden(test_batch_size) | ||||
rnn.eval() | rnn.eval() | ||||
output.eval() | output.eval() | ||||
# generate graphs | # generate graphs | ||||
max_num_node = int(args.max_num_node) | max_num_node = int(args.max_num_node) | ||||
y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||||
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||||
y_pred = Variable( | |||||
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||||
x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda() | |||||
for i in range(max_num_node): | for i in range(max_num_node): | ||||
h = rnn(x_step) | h = rnn(x_step) | ||||
y_pred_step, _, _ = output(h) | y_pred_step, _, _ = output(h) | ||||
G_pred_list = [] | G_pred_list = [] | ||||
for i in range(test_batch_size): | for i in range(test_batch_size): | ||||
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | ||||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||||
G_pred_list.append(G_pred) | G_pred_list.append(G_pred) | ||||
# save prediction histograms, plot histogram over each time step | # save prediction histograms, plot histogram over each time step | ||||
# fname_pred=args.figure_prediction_save_path+args.fname_pred+str(epoch)+'.jpg', | # fname_pred=args.figure_prediction_save_path+args.fname_pred+str(epoch)+'.jpg', | ||||
# max_num_node=max_num_node) | # max_num_node=max_num_node) | ||||
return G_pred_list | return G_pred_list | ||||
def test_vae_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1): | |||||
def test_vae_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False, sample_time=1): | |||||
rnn.eval() | rnn.eval() | ||||
output.eval() | output.eval() | ||||
G_pred_list = [] | G_pred_list = [] | ||||
rnn.hidden = rnn.init_hidden(test_batch_size) | rnn.hidden = rnn.init_hidden(test_batch_size) | ||||
# generate graphs | # generate graphs | ||||
max_num_node = int(args.max_num_node) | max_num_node = int(args.max_num_node) | ||||
y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||||
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||||
y_pred = Variable( | |||||
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||||
y_pred_long = Variable( | |||||
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||||
x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda() | |||||
for i in range(max_num_node): | for i in range(max_num_node): | ||||
print('finish node',i) | |||||
print('finish node', i) | |||||
h = rnn(x_step) | h = rnn(x_step) | ||||
y_pred_step, _, _ = output(h) | y_pred_step, _, _ = output(h) | ||||
y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) | y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) | ||||
x_step = sample_sigmoid_supervised(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time) | |||||
x_step = sample_sigmoid_supervised(y_pred_step, y[:, i:i + 1, :].cuda(), current=i, y_len=y_len, | |||||
sample_time=sample_time) | |||||
y_pred_long[:, i:i + 1, :] = x_step | y_pred_long[:, i:i + 1, :] = x_step | ||||
rnn.hidden = Variable(rnn.hidden.data).cuda() | rnn.hidden = Variable(rnn.hidden.data).cuda() | ||||
# save graphs as pickle | # save graphs as pickle | ||||
for i in range(test_batch_size): | for i in range(test_batch_size): | ||||
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | ||||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||||
G_pred_list.append(G_pred) | G_pred_list.append(G_pred) | ||||
return G_pred_list | return G_pred_list | ||||
def train_mlp_epoch(epoch, args, rnn, output, data_loader, | def train_mlp_epoch(epoch, args, rnn, output, data_loader, | ||||
optimizer_rnn, optimizer_output, | optimizer_rnn, optimizer_output, | ||||
scheduler_rnn, scheduler_output): | scheduler_rnn, scheduler_output): | ||||
rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) | rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) | ||||
# sort input | # sort input | ||||
y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) | |||||
y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True) | |||||
y_len = y_len.numpy().tolist() | y_len = y_len.numpy().tolist() | ||||
x = torch.index_select(x_unsorted,0,sort_index) | |||||
y = torch.index_select(y_unsorted,0,sort_index) | |||||
x = torch.index_select(x_unsorted, 0, sort_index) | |||||
y = torch.index_select(y_unsorted, 0, sort_index) | |||||
x = Variable(x).cuda() | x = Variable(x).cuda() | ||||
y = Variable(y).cuda() | y = Variable(y).cuda() | ||||
scheduler_output.step() | scheduler_output.step() | ||||
scheduler_rnn.step() | scheduler_rnn.step() | ||||
if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics | |||||
if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics | |||||
print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | ||||
epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||||
epoch, args.epochs, loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||||
# logging | # logging | ||||
log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx) | |||||
log_value('loss_' + args.fname, loss.data[0], epoch * args.batch_ratio + batch_idx) | |||||
loss_sum += loss.data[0] | loss_sum += loss.data[0] | ||||
return loss_sum/(batch_idx+1) | |||||
return loss_sum / (batch_idx + 1) | |||||
def test_mlp_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False,sample_time=1): | |||||
def test_mlp_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False, sample_time=1): | |||||
rnn.hidden = rnn.init_hidden(test_batch_size) | rnn.hidden = rnn.init_hidden(test_batch_size) | ||||
rnn.eval() | rnn.eval() | ||||
output.eval() | output.eval() | ||||
# generate graphs | # generate graphs | ||||
max_num_node = int(args.max_num_node) | max_num_node = int(args.max_num_node) | ||||
y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||||
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||||
y_pred = Variable( | |||||
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||||
x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda() | |||||
for i in range(max_num_node): | for i in range(max_num_node): | ||||
h = rnn(x_step) | h = rnn(x_step) | ||||
y_pred_step = output(h) | y_pred_step = output(h) | ||||
G_pred_list = [] | G_pred_list = [] | ||||
for i in range(test_batch_size): | for i in range(test_batch_size): | ||||
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | ||||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||||
G_pred_list.append(G_pred) | G_pred_list.append(G_pred) | ||||
# # save prediction histograms, plot histogram over each time step | # # save prediction histograms, plot histogram over each time step | ||||
# if save_histogram: | # if save_histogram: | ||||
# save_prediction_histogram(y_pred_data.cpu().numpy(), | # save_prediction_histogram(y_pred_data.cpu().numpy(), | ||||
return G_pred_list | return G_pred_list | ||||
def test_mlp_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1): | |||||
def test_mlp_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False, sample_time=1): | |||||
rnn.eval() | rnn.eval() | ||||
output.eval() | output.eval() | ||||
G_pred_list = [] | G_pred_list = [] | ||||
rnn.hidden = rnn.init_hidden(test_batch_size) | rnn.hidden = rnn.init_hidden(test_batch_size) | ||||
# generate graphs | # generate graphs | ||||
max_num_node = int(args.max_num_node) | max_num_node = int(args.max_num_node) | ||||
y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||||
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||||
y_pred = Variable( | |||||
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||||
y_pred_long = Variable( | |||||
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||||
x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda() | |||||
for i in range(max_num_node): | for i in range(max_num_node): | ||||
print('finish node',i) | |||||
print('finish node', i) | |||||
h = rnn(x_step) | h = rnn(x_step) | ||||
y_pred_step = output(h) | y_pred_step = output(h) | ||||
y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) | y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) | ||||
x_step = sample_sigmoid_supervised(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time) | |||||
x_step = sample_sigmoid_supervised(y_pred_step, y[:, i:i + 1, :].cuda(), current=i, y_len=y_len, | |||||
sample_time=sample_time) | |||||
y_pred_long[:, i:i + 1, :] = x_step | y_pred_long[:, i:i + 1, :] = x_step | ||||
rnn.hidden = Variable(rnn.hidden.data).cuda() | rnn.hidden = Variable(rnn.hidden.data).cuda() | ||||
# save graphs as pickle | # save graphs as pickle | ||||
for i in range(test_batch_size): | for i in range(test_batch_size): | ||||
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | ||||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||||
G_pred_list.append(G_pred) | G_pred_list.append(G_pred) | ||||
return G_pred_list | return G_pred_list | ||||
def test_mlp_partial_simple_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1): | |||||
def test_mlp_partial_simple_epoch(epoch, args, rnn, output, data_loader, save_histogram=False, sample_time=1): | |||||
rnn.eval() | rnn.eval() | ||||
output.eval() | output.eval() | ||||
G_pred_list = [] | G_pred_list = [] | ||||
rnn.hidden = rnn.init_hidden(test_batch_size) | rnn.hidden = rnn.init_hidden(test_batch_size) | ||||
# generate graphs | # generate graphs | ||||
max_num_node = int(args.max_num_node) | max_num_node = int(args.max_num_node) | ||||
y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||||
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||||
y_pred = Variable( | |||||
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||||
y_pred_long = Variable( | |||||
torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||||
x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda() | |||||
for i in range(max_num_node): | for i in range(max_num_node): | ||||
print('finish node',i) | |||||
print('finish node', i) | |||||
h = rnn(x_step) | h = rnn(x_step) | ||||
y_pred_step = output(h) | y_pred_step = output(h) | ||||
y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) | y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) | ||||
x_step = sample_sigmoid_supervised_simple(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time) | |||||
x_step = sample_sigmoid_supervised_simple(y_pred_step, y[:, i:i + 1, :].cuda(), current=i, y_len=y_len, | |||||
sample_time=sample_time) | |||||
y_pred_long[:, i:i + 1, :] = x_step | y_pred_long[:, i:i + 1, :] = x_step | ||||
rnn.hidden = Variable(rnn.hidden.data).cuda() | rnn.hidden = Variable(rnn.hidden.data).cuda() | ||||
# save graphs as pickle | # save graphs as pickle | ||||
for i in range(test_batch_size): | for i in range(test_batch_size): | ||||
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | ||||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||||
G_pred_list.append(G_pred) | G_pred_list.append(G_pred) | ||||
return G_pred_list | return G_pred_list | ||||
rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) | rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) | ||||
# sort input | # sort input | ||||
y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) | |||||
y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True) | |||||
y_len = y_len.numpy().tolist() | y_len = y_len.numpy().tolist() | ||||
x = torch.index_select(x_unsorted,0,sort_index) | |||||
y = torch.index_select(y_unsorted,0,sort_index) | |||||
x = torch.index_select(x_unsorted, 0, sort_index) | |||||
y = torch.index_select(y_unsorted, 0, sort_index) | |||||
x = Variable(x).cuda() | x = Variable(x).cuda() | ||||
y = Variable(y).cuda() | y = Variable(y).cuda() | ||||
loss = 0 | loss = 0 | ||||
for j in range(y.size(1)): | for j in range(y.size(1)): | ||||
# print('y_pred',y_pred[0,j,:],'y',y[0,j,:]) | # print('y_pred',y_pred[0,j,:],'y',y[0,j,:]) | ||||
end_idx = min(j+1,y.size(2)) | |||||
loss += binary_cross_entropy_weight(y_pred[:,j,0:end_idx], y[:,j,0:end_idx])*end_idx | |||||
end_idx = min(j + 1, y.size(2)) | |||||
loss += binary_cross_entropy_weight(y_pred[:, j, 0:end_idx], y[:, j, 0:end_idx]) * end_idx | |||||
if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics | |||||
if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics | |||||
print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | ||||
epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||||
epoch, args.epochs, loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||||
# logging | # logging | ||||
log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx) | |||||
log_value('loss_' + args.fname, loss.data[0], epoch * args.batch_ratio + batch_idx) | |||||
loss_sum += loss.data[0] | loss_sum += loss.data[0] | ||||
return loss_sum/(batch_idx+1) | |||||
return loss_sum / (batch_idx + 1) | |||||
## too complicated, deprecated | ## too complicated, deprecated | ||||
# output.hidden = output.init_hidden(batch_size=x_unsorted.size(0)*x_unsorted.size(1)) | # output.hidden = output.init_hidden(batch_size=x_unsorted.size(0)*x_unsorted.size(1)) | ||||
# sort input | # sort input | ||||
y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) | |||||
y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True) | |||||
y_len = y_len.numpy().tolist() | y_len = y_len.numpy().tolist() | ||||
x = torch.index_select(x_unsorted,0,sort_index) | |||||
y = torch.index_select(y_unsorted,0,sort_index) | |||||
x = torch.index_select(x_unsorted, 0, sort_index) | |||||
y = torch.index_select(y_unsorted, 0, sort_index) | |||||
# input, output for output rnn module | # input, output for output rnn module | ||||
# a smart use of pytorch builtin function: pack variable--b1_l1,b2_l1,...,b1_l2,b2_l2,... | # a smart use of pytorch builtin function: pack variable--b1_l1,b2_l1,...,b1_l2,b2_l2,... | ||||
y_reshape = pack_padded_sequence(y,y_len,batch_first=True).data | |||||
y_reshape = pack_padded_sequence(y, y_len, batch_first=True).data | |||||
# reverse y_reshape, so that their lengths are sorted, add dimension | # reverse y_reshape, so that their lengths are sorted, add dimension | ||||
idx = [i for i in range(y_reshape.size(0)-1, -1, -1)] | |||||
idx = [i for i in range(y_reshape.size(0) - 1, -1, -1)] | |||||
idx = torch.LongTensor(idx) | idx = torch.LongTensor(idx) | ||||
y_reshape = y_reshape.index_select(0, idx) | y_reshape = y_reshape.index_select(0, idx) | ||||
y_reshape = y_reshape.view(y_reshape.size(0),y_reshape.size(1),1) | |||||
y_reshape = y_reshape.view(y_reshape.size(0), y_reshape.size(1), 1) | |||||
output_x = torch.cat((torch.ones(y_reshape.size(0),1,1),y_reshape[:,0:-1,0:1]),dim=1) | |||||
output_x = torch.cat((torch.ones(y_reshape.size(0), 1, 1), y_reshape[:, 0:-1, 0:1]), dim=1) | |||||
output_y = y_reshape | output_y = y_reshape | ||||
# batch size for output module: sum(y_len) | # batch size for output module: sum(y_len) | ||||
output_y_len = [] | output_y_len = [] | ||||
output_y_len_bin = np.bincount(np.array(y_len)) | output_y_len_bin = np.bincount(np.array(y_len)) | ||||
for i in range(len(output_y_len_bin)-1,0,-1): | |||||
count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i | |||||
output_y_len.extend([min(i,y.size(2))]*count_temp) # put them in output_y_len; max value should not exceed y.size(2) | |||||
for i in range(len(output_y_len_bin) - 1, 0, -1): | |||||
count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i | |||||
output_y_len.extend( | |||||
[min(i, y.size(2))] * count_temp) # put them in output_y_len; max value should not exceed y.size(2) | |||||
# pack into variable | # pack into variable | ||||
x = Variable(x).cuda() | x = Variable(x).cuda() | ||||
y = Variable(y).cuda() | y = Variable(y).cuda() | ||||
# print('y',y.size()) | # print('y',y.size()) | ||||
# print('output_y',output_y.size()) | # print('output_y',output_y.size()) | ||||
# if using ground truth to train | # if using ground truth to train | ||||
h = rnn(x, pack=True, input_len=y_len) | h = rnn(x, pack=True, input_len=y_len) | ||||
h = pack_padded_sequence(h,y_len,batch_first=True).data # get packed hidden vector | |||||
h = pack_padded_sequence(h, y_len, batch_first=True).data # get packed hidden vector | |||||
# reverse h | # reverse h | ||||
idx = [i for i in range(h.size(0) - 1, -1, -1)] | idx = [i for i in range(h.size(0) - 1, -1, -1)] | ||||
idx = Variable(torch.LongTensor(idx)).cuda() | idx = Variable(torch.LongTensor(idx)).cuda() | ||||
h = h.index_select(0, idx) | h = h.index_select(0, idx) | ||||
hidden_null = Variable(torch.zeros(args.num_layers-1, h.size(0), h.size(1))).cuda() | |||||
output.hidden = torch.cat((h.view(1,h.size(0),h.size(1)),hidden_null),dim=0) # num_layers, batch_size, hidden_size | |||||
hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(1))).cuda() | |||||
output.hidden = torch.cat((h.view(1, h.size(0), h.size(1)), hidden_null), | |||||
dim=0) # num_layers, batch_size, hidden_size | |||||
y_pred = output(output_x, pack=True, input_len=output_y_len) | y_pred = output(output_x, pack=True, input_len=output_y_len) | ||||
y_pred = F.sigmoid(y_pred) | y_pred = F.sigmoid(y_pred) | ||||
# clean | # clean | ||||
y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True) | y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True) | ||||
y_pred = pad_packed_sequence(y_pred, batch_first=True)[0] | y_pred = pad_packed_sequence(y_pred, batch_first=True)[0] | ||||
output_y = pack_padded_sequence(output_y,output_y_len,batch_first=True) | |||||
output_y = pad_packed_sequence(output_y,batch_first=True)[0] | |||||
output_y = pack_padded_sequence(output_y, output_y_len, batch_first=True) | |||||
output_y = pad_packed_sequence(output_y, batch_first=True)[0] | |||||
# use cross entropy loss | # use cross entropy loss | ||||
loss = binary_cross_entropy_weight(y_pred, output_y) | loss = binary_cross_entropy_weight(y_pred, output_y) | ||||
loss.backward() | loss.backward() | ||||
scheduler_output.step() | scheduler_output.step() | ||||
scheduler_rnn.step() | scheduler_rnn.step() | ||||
if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics | |||||
if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics | |||||
print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | ||||
epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||||
epoch, args.epochs, loss.data, args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||||
# logging | # logging | ||||
log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx) | |||||
feature_dim = y.size(1)*y.size(2) | |||||
loss_sum += loss.data[0]*feature_dim | |||||
return loss_sum/(batch_idx+1) | |||||
log_value('loss_' + args.fname, loss.data, epoch * args.batch_ratio + batch_idx) | |||||
feature_dim = y.size(1) * y.size(2) | |||||
loss_sum += loss.data * feature_dim | |||||
return loss_sum / (batch_idx + 1) | |||||
def test_rnn_epoch(epoch, args, rnn, output, test_batch_size=16): | def test_rnn_epoch(epoch, args, rnn, output, test_batch_size=16): | ||||
# generate graphs | # generate graphs | ||||
max_num_node = int(args.max_num_node) | max_num_node = int(args.max_num_node) | ||||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||||
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||||
x_step = Variable(torch.ones(test_batch_size, 1, args.max_prev_node)).cuda() | |||||
for i in range(max_num_node): | for i in range(max_num_node): | ||||
h = rnn(x_step) | h = rnn(x_step) | ||||
# output.hidden = h.permute(1,0,2) | # output.hidden = h.permute(1,0,2) | ||||
hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(2))).cuda() | hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(2))).cuda() | ||||
output.hidden = torch.cat((h.permute(1,0,2), hidden_null), | |||||
output.hidden = torch.cat((h.permute(1, 0, 2), hidden_null), | |||||
dim=0) # num_layers, batch_size, hidden_size | dim=0) # num_layers, batch_size, hidden_size | ||||
x_step = Variable(torch.zeros(test_batch_size,1,args.max_prev_node)).cuda() | |||||
output_x_step = Variable(torch.ones(test_batch_size,1,1)).cuda() | |||||
for j in range(min(args.max_prev_node,i+1)): | |||||
x_step = Variable(torch.zeros(test_batch_size, 1, args.max_prev_node)).cuda() | |||||
output_x_step = Variable(torch.ones(test_batch_size, 1, 1)).cuda() | |||||
for j in range(min(args.max_prev_node, i + 1)): | |||||
output_y_pred_step = output(output_x_step) | output_y_pred_step = output(output_x_step) | ||||
output_x_step = sample_sigmoid(output_y_pred_step, sample=True, sample_time=1) | output_x_step = sample_sigmoid(output_y_pred_step, sample=True, sample_time=1) | ||||
x_step[:,:,j:j+1] = output_x_step | |||||
x_step[:, :, j:j + 1] = output_x_step | |||||
output.hidden = Variable(output.hidden.data).cuda() | output.hidden = Variable(output.hidden.data).cuda() | ||||
y_pred_long[:, i:i + 1, :] = x_step | y_pred_long[:, i:i + 1, :] = x_step | ||||
rnn.hidden = Variable(rnn.hidden.data).cuda() | rnn.hidden = Variable(rnn.hidden.data).cuda() | ||||
G_pred_list = [] | G_pred_list = [] | ||||
for i in range(test_batch_size): | for i in range(test_batch_size): | ||||
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | ||||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||||
G_pred_list.append(G_pred) | G_pred_list.append(G_pred) | ||||
return G_pred_list | return G_pred_list | ||||
def train_rnn_forward_epoch(epoch, args, rnn, output, data_loader): | def train_rnn_forward_epoch(epoch, args, rnn, output, data_loader): | ||||
rnn.train() | rnn.train() | ||||
output.train() | output.train() | ||||
# output.hidden = output.init_hidden(batch_size=x_unsorted.size(0)*x_unsorted.size(1)) | # output.hidden = output.init_hidden(batch_size=x_unsorted.size(0)*x_unsorted.size(1)) | ||||
# sort input | # sort input | ||||
y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) | |||||
y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True) | |||||
y_len = y_len.numpy().tolist() | y_len = y_len.numpy().tolist() | ||||
x = torch.index_select(x_unsorted,0,sort_index) | |||||
y = torch.index_select(y_unsorted,0,sort_index) | |||||
x = torch.index_select(x_unsorted, 0, sort_index) | |||||
y = torch.index_select(y_unsorted, 0, sort_index) | |||||
# input, output for output rnn module | # input, output for output rnn module | ||||
# a smart use of pytorch builtin function: pack variable--b1_l1,b2_l1,...,b1_l2,b2_l2,... | # a smart use of pytorch builtin function: pack variable--b1_l1,b2_l1,...,b1_l2,b2_l2,... | ||||
y_reshape = pack_padded_sequence(y,y_len,batch_first=True).data | |||||
y_reshape = pack_padded_sequence(y, y_len, batch_first=True).data | |||||
# reverse y_reshape, so that their lengths are sorted, add dimension | # reverse y_reshape, so that their lengths are sorted, add dimension | ||||
idx = [i for i in range(y_reshape.size(0)-1, -1, -1)] | |||||
idx = [i for i in range(y_reshape.size(0) - 1, -1, -1)] | |||||
idx = torch.LongTensor(idx) | idx = torch.LongTensor(idx) | ||||
y_reshape = y_reshape.index_select(0, idx) | y_reshape = y_reshape.index_select(0, idx) | ||||
y_reshape = y_reshape.view(y_reshape.size(0),y_reshape.size(1),1) | |||||
y_reshape = y_reshape.view(y_reshape.size(0), y_reshape.size(1), 1) | |||||
output_x = torch.cat((torch.ones(y_reshape.size(0),1,1),y_reshape[:,0:-1,0:1]),dim=1) | |||||
output_x = torch.cat((torch.ones(y_reshape.size(0), 1, 1), y_reshape[:, 0:-1, 0:1]), dim=1) | |||||
output_y = y_reshape | output_y = y_reshape | ||||
# batch size for output module: sum(y_len) | # batch size for output module: sum(y_len) | ||||
output_y_len = [] | output_y_len = [] | ||||
output_y_len_bin = np.bincount(np.array(y_len)) | output_y_len_bin = np.bincount(np.array(y_len)) | ||||
for i in range(len(output_y_len_bin)-1,0,-1): | |||||
count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i | |||||
output_y_len.extend([min(i,y.size(2))]*count_temp) # put them in output_y_len; max value should not exceed y.size(2) | |||||
for i in range(len(output_y_len_bin) - 1, 0, -1): | |||||
count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i | |||||
output_y_len.extend( | |||||
[min(i, y.size(2))] * count_temp) # put them in output_y_len; max value should not exceed y.size(2) | |||||
# pack into variable | # pack into variable | ||||
x = Variable(x).cuda() | x = Variable(x).cuda() | ||||
y = Variable(y).cuda() | y = Variable(y).cuda() | ||||
# print('y',y.size()) | # print('y',y.size()) | ||||
# print('output_y',output_y.size()) | # print('output_y',output_y.size()) | ||||
# if using ground truth to train | # if using ground truth to train | ||||
h = rnn(x, pack=True, input_len=y_len) | h = rnn(x, pack=True, input_len=y_len) | ||||
h = pack_padded_sequence(h,y_len,batch_first=True).data # get packed hidden vector | |||||
h = pack_padded_sequence(h, y_len, batch_first=True).data # get packed hidden vector | |||||
# reverse h | # reverse h | ||||
idx = [i for i in range(h.size(0) - 1, -1, -1)] | idx = [i for i in range(h.size(0) - 1, -1, -1)] | ||||
idx = Variable(torch.LongTensor(idx)).cuda() | idx = Variable(torch.LongTensor(idx)).cuda() | ||||
h = h.index_select(0, idx) | h = h.index_select(0, idx) | ||||
hidden_null = Variable(torch.zeros(args.num_layers-1, h.size(0), h.size(1))).cuda() | |||||
output.hidden = torch.cat((h.view(1,h.size(0),h.size(1)),hidden_null),dim=0) # num_layers, batch_size, hidden_size | |||||
hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(1))).cuda() | |||||
output.hidden = torch.cat((h.view(1, h.size(0), h.size(1)), hidden_null), | |||||
dim=0) # num_layers, batch_size, hidden_size | |||||
y_pred = output(output_x, pack=True, input_len=output_y_len) | y_pred = output(output_x, pack=True, input_len=output_y_len) | ||||
y_pred = F.sigmoid(y_pred) | y_pred = F.sigmoid(y_pred) | ||||
# clean | # clean | ||||
y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True) | y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True) | ||||
y_pred = pad_packed_sequence(y_pred, batch_first=True)[0] | y_pred = pad_packed_sequence(y_pred, batch_first=True)[0] | ||||
output_y = pack_padded_sequence(output_y,output_y_len,batch_first=True) | |||||
output_y = pad_packed_sequence(output_y,batch_first=True)[0] | |||||
output_y = pack_padded_sequence(output_y, output_y_len, batch_first=True) | |||||
output_y = pad_packed_sequence(output_y, batch_first=True)[0] | |||||
# use cross entropy loss | # use cross entropy loss | ||||
loss = binary_cross_entropy_weight(y_pred, output_y) | loss = binary_cross_entropy_weight(y_pred, output_y) | ||||
if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics | |||||
if epoch % args.epochs_log == 0 and batch_idx == 0: # only output first batch's statistics | |||||
print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | ||||
epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||||
epoch, args.epochs, loss.data, args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||||
# logging | # logging | ||||
log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx) | |||||
log_value('loss_' + args.fname, loss.data, epoch * args.batch_ratio + batch_idx) | |||||
# print(y_pred.size()) | # print(y_pred.size()) | ||||
feature_dim = y_pred.size(0)*y_pred.size(1) | |||||
loss_sum += loss.data[0]*feature_dim/y.size(0) | |||||
return loss_sum/(batch_idx+1) | |||||
feature_dim = y_pred.size(0) * y_pred.size(1) | |||||
loss_sum += loss.data * feature_dim / y.size(0) | |||||
return loss_sum / (batch_idx + 1) | |||||
########### train function for LSTM + VAE | ########### train function for LSTM + VAE | ||||
# start main loop | # start main loop | ||||
time_all = np.zeros(args.epochs) | time_all = np.zeros(args.epochs) | ||||
while epoch<=args.epochs: | |||||
while epoch <= args.epochs: | |||||
time_start = tm.time() | time_start = tm.time() | ||||
# train | # train | ||||
if 'GraphRNN_VAE' in args.note: | if 'GraphRNN_VAE' in args.note: | ||||
time_end = tm.time() | time_end = tm.time() | ||||
time_all[epoch - 1] = time_end - time_start | time_all[epoch - 1] = time_end - time_start | ||||
# test | # test | ||||
if epoch % args.epochs_test == 0 and epoch>=args.epochs_test_start: | |||||
for sample_time in range(1,4): | |||||
if epoch % args.epochs_test == 0 and epoch >= args.epochs_test_start: | |||||
for sample_time in range(1, 4): | |||||
G_pred = [] | G_pred = [] | ||||
while len(G_pred)<args.test_total_size: | |||||
while len(G_pred) < args.test_total_size: | |||||
if 'GraphRNN_VAE' in args.note: | if 'GraphRNN_VAE' in args.note: | ||||
G_pred_step = test_vae_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size,sample_time=sample_time) | |||||
G_pred_step = test_vae_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size, | |||||
sample_time=sample_time) | |||||
elif 'GraphRNN_MLP' in args.note: | elif 'GraphRNN_MLP' in args.note: | ||||
G_pred_step = test_mlp_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size,sample_time=sample_time) | |||||
G_pred_step = test_mlp_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size, | |||||
sample_time=sample_time) | |||||
elif 'GraphRNN_RNN' in args.note: | elif 'GraphRNN_RNN' in args.note: | ||||
G_pred_step = test_rnn_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size) | G_pred_step = test_rnn_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size) | ||||
G_pred.extend(G_pred_step) | G_pred.extend(G_pred_step) | ||||
# save graphs | # save graphs | ||||
fname = args.graph_save_path + args.fname_pred + str(epoch) +'_'+str(sample_time) + '.dat' | |||||
fname = args.graph_save_path + args.fname_pred + str(epoch) + '_' + str(sample_time) + '.dat' | |||||
save_graph_list(G_pred, fname) | save_graph_list(G_pred, fname) | ||||
if 'GraphRNN_RNN' in args.note: | if 'GraphRNN_RNN' in args.note: | ||||
break | break | ||||
print('test done, graphs saved') | print('test done, graphs saved') | ||||
# save model checkpoint | # save model checkpoint | ||||
if args.save: | if args.save: | ||||
if epoch % args.epochs_save == 0: | if epoch % args.epochs_save == 0: | ||||
fname = args.model_save_path + args.fname + 'output_' + str(epoch) + '.dat' | fname = args.model_save_path + args.fname + 'output_' + str(epoch) + '.dat' | ||||
torch.save(output.state_dict(), fname) | torch.save(output.state_dict(), fname) | ||||
epoch += 1 | epoch += 1 | ||||
np.save(args.timing_save_path+args.fname,time_all) | |||||
np.save(args.timing_save_path + args.fname, time_all) | |||||
########### for graph completion task | ########### for graph completion task | ||||
epoch = args.load_epoch | epoch = args.load_epoch | ||||
print('model loaded!, epoch: {}'.format(args.load_epoch)) | print('model loaded!, epoch: {}'.format(args.load_epoch)) | ||||
for sample_time in range(1,4): | |||||
for sample_time in range(1, 4): | |||||
if 'GraphRNN_MLP' in args.note: | if 'GraphRNN_MLP' in args.note: | ||||
G_pred = test_mlp_partial_simple_epoch(epoch, args, rnn, output, dataset_test,sample_time=sample_time) | |||||
G_pred = test_mlp_partial_simple_epoch(epoch, args, rnn, output, dataset_test, sample_time=sample_time) | |||||
if 'GraphRNN_VAE' in args.note: | if 'GraphRNN_VAE' in args.note: | ||||
G_pred = test_vae_partial_epoch(epoch, args, rnn, output, dataset_test,sample_time=sample_time) | |||||
G_pred = test_vae_partial_epoch(epoch, args, rnn, output, dataset_test, sample_time=sample_time) | |||||
# save graphs | # save graphs | ||||
fname = args.graph_save_path + args.fname_pred + str(epoch) +'_'+str(sample_time) + 'graph_completion.dat' | |||||
fname = args.graph_save_path + args.fname_pred + str(epoch) + '_' + str(sample_time) + 'graph_completion.dat' | |||||
save_graph_list(G_pred, fname) | save_graph_list(G_pred, fname) | ||||
print('graph completion done, graphs saved') | print('graph completion done, graphs saved') | ||||
########### for NLL evaluation | ########### for NLL evaluation | ||||
def train_nll(args, dataset_train, dataset_test, rnn, output,graph_validate_len,graph_test_len, max_iter = 1000): | |||||
def train_nll(args, dataset_train, dataset_test, rnn, output, graph_validate_len, graph_test_len, max_iter=1000): | |||||
fname = args.model_save_path + args.fname + 'lstm_' + str(args.load_epoch) + '.dat' | fname = args.model_save_path + args.fname + 'lstm_' + str(args.load_epoch) + '.dat' | ||||
rnn.load_state_dict(torch.load(fname)) | rnn.load_state_dict(torch.load(fname)) | ||||
fname = args.model_save_path + args.fname + 'output_' + str(args.load_epoch) + '.dat' | fname = args.model_save_path + args.fname + 'output_' + str(args.load_epoch) + '.dat' | ||||
print('model loaded!, epoch: {}'.format(args.load_epoch)) | print('model loaded!, epoch: {}'.format(args.load_epoch)) | ||||
fname_output = args.nll_save_path + args.note + '_' + args.graph_type + '.csv' | fname_output = args.nll_save_path + args.note + '_' + args.graph_type + '.csv' | ||||
with open(fname_output, 'w+') as f: | with open(fname_output, 'w+') as f: | ||||
f.write(str(graph_validate_len)+','+str(graph_test_len)+'\n') | |||||
f.write(str(graph_validate_len) + ',' + str(graph_test_len) + '\n') | |||||
f.write('train,test\n') | f.write('train,test\n') | ||||
for iter in range(max_iter): | for iter in range(max_iter): | ||||
if 'GraphRNN_MLP' in args.note: | if 'GraphRNN_MLP' in args.note: | ||||
if 'GraphRNN_RNN' in args.note: | if 'GraphRNN_RNN' in args.note: | ||||
nll_train = train_rnn_forward_epoch(epoch, args, rnn, output, dataset_train) | nll_train = train_rnn_forward_epoch(epoch, args, rnn, output, dataset_train) | ||||
nll_test = train_rnn_forward_epoch(epoch, args, rnn, output, dataset_test) | nll_test = train_rnn_forward_epoch(epoch, args, rnn, output, dataset_test) | ||||
print('train',nll_train,'test',nll_test) | |||||
f.write(str(nll_train)+','+str(nll_test)+'\n') | |||||
print('train', nll_train, 'test', nll_test) | |||||
f.write(str(nll_train) + ',' + str(nll_test) + '\n') | |||||
print('NLL evaluation done') | print('NLL evaluation done') |
import data | import data | ||||
def citeseer_ego(): | def citeseer_ego(): | ||||
_, _, G = data.Graph_load(dataset='citeseer') | _, _, G = data.Graph_load(dataset='citeseer') | ||||
G = max(nx.connected_component_subgraphs(G), key=len) | G = max(nx.connected_component_subgraphs(G), key=len) | ||||
graphs.append(G_ego) | graphs.append(G_ego) | ||||
return graphs | return graphs | ||||
def caveman_special(c=2,k=20,p_path=0.1,p_edge=0.3): | |||||
def caveman_special(c=2, k=20, p_path=0.1, p_edge=0.3): | |||||
p = p_path | p = p_path | ||||
path_count = max(int(np.ceil(p * k)),1) | |||||
path_count = max(int(np.ceil(p * k)), 1) | |||||
G = nx.caveman_graph(c, k) | G = nx.caveman_graph(c, k) | ||||
# remove 50% edges | # remove 50% edges | ||||
p = 1-p_edge | |||||
p = 1 - p_edge | |||||
for (u, v) in list(G.edges()): | for (u, v) in list(G.edges()): | ||||
if np.random.rand() < p and ((u < k and v < k) or (u >= k and v >= k)): | if np.random.rand() < p and ((u < k and v < k) or (u >= k and v >= k)): | ||||
G.remove_edge(u, v) | G.remove_edge(u, v) | ||||
G = max(nx.connected_component_subgraphs(G), key=len) | G = max(nx.connected_component_subgraphs(G), key=len) | ||||
return G | return G | ||||
def n_community(c_sizes, p_inter=0.01): | def n_community(c_sizes, p_inter=0.01): | ||||
graphs = [nx.gnp_random_graph(c_sizes[i], 0.7, seed=i) for i in range(len(c_sizes))] | graphs = [nx.gnp_random_graph(c_sizes[i], 0.7, seed=i) for i in range(len(c_sizes))] | ||||
G = nx.disjoint_union_all(graphs) | G = nx.disjoint_union_all(graphs) | ||||
for i in range(len(communities)): | for i in range(len(communities)): | ||||
subG1 = communities[i] | subG1 = communities[i] | ||||
nodes1 = list(subG1.nodes()) | nodes1 = list(subG1.nodes()) | ||||
for j in range(i+1, len(communities)): | |||||
for j in range(i + 1, len(communities)): | |||||
subG2 = communities[j] | subG2 = communities[j] | ||||
nodes2 = list(subG2.nodes()) | nodes2 = list(subG2.nodes()) | ||||
has_inter_edge = False | has_inter_edge = False | ||||
has_inter_edge = True | has_inter_edge = True | ||||
if not has_inter_edge: | if not has_inter_edge: | ||||
G.add_edge(nodes1[0], nodes2[0]) | G.add_edge(nodes1[0], nodes2[0]) | ||||
#print('connected comp: ', len(list(nx.connected_component_subgraphs(G)))) | |||||
# print('connected comp: ', len(list(nx.connected_component_subgraphs(G)))) | |||||
return G | return G | ||||
def perturb(graph_list, p_del, p_add=None): | def perturb(graph_list, p_del, p_add=None): | ||||
''' Perturb the list of graphs by adding/removing edges. | ''' Perturb the list of graphs by adding/removing edges. | ||||
Args: | Args: | ||||
if p_add is None: | if p_add is None: | ||||
num_nodes = G.number_of_nodes() | num_nodes = G.number_of_nodes() | ||||
p_add_est = np.sum(trials) / (num_nodes * (num_nodes - 1) / 2 - | p_add_est = np.sum(trials) / (num_nodes * (num_nodes - 1) / 2 - | ||||
G.number_of_edges()) | |||||
G.number_of_edges()) | |||||
else: | else: | ||||
p_add_est = p_add | p_add_est = p_add | ||||
u = nodes[i] | u = nodes[i] | ||||
trials = np.random.binomial(1, p_add_est, size=G.number_of_nodes()) | trials = np.random.binomial(1, p_add_est, size=G.number_of_nodes()) | ||||
j = 0 | j = 0 | ||||
for j in range(i+1, len(nodes)): | |||||
for j in range(i + 1, len(nodes)): | |||||
v = nodes[j] | v = nodes[j] | ||||
if trials[j] == 1: | if trials[j] == 1: | ||||
tmp += 1 | tmp += 1 | ||||
return perturbed_graph_list | return perturbed_graph_list | ||||
def perturb_new(graph_list, p): | def perturb_new(graph_list, p): | ||||
''' Perturb the list of graphs by adding/removing edges. | ''' Perturb the list of graphs by adding/removing edges. | ||||
Args: | Args: | ||||
G = G_original.copy() | G = G_original.copy() | ||||
edge_remove_count = 0 | edge_remove_count = 0 | ||||
for (u, v) in list(G.edges()): | for (u, v) in list(G.edges()): | ||||
if np.random.rand()<p: | |||||
if np.random.rand() < p: | |||||
G.remove_edge(u, v) | G.remove_edge(u, v) | ||||
edge_remove_count += 1 | edge_remove_count += 1 | ||||
# randomly add the edges back | # randomly add the edges back | ||||
while True: | while True: | ||||
u = np.random.randint(0, G.number_of_nodes()) | u = np.random.randint(0, G.number_of_nodes()) | ||||
v = np.random.randint(0, G.number_of_nodes()) | v = np.random.randint(0, G.number_of_nodes()) | ||||
if (not G.has_edge(u,v)) and (u!=v): | |||||
if (not G.has_edge(u, v)) and (u != v): | |||||
break | break | ||||
G.add_edge(u, v) | G.add_edge(u, v) | ||||
perturbed_graph_list.append(G) | perturbed_graph_list.append(G) | ||||
return perturbed_graph_list | return perturbed_graph_list | ||||
def imsave(fname, arr, vmin=None, vmax=None, cmap=None, format=None, origin=None): | def imsave(fname, arr, vmin=None, vmax=None, cmap=None, format=None, origin=None): | ||||
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas | ||||
from matplotlib.figure import Figure | from matplotlib.figure import Figure | ||||
# draw a single graph G | # draw a single graph G | ||||
def draw_graph(G, prefix = 'test'): | |||||
def draw_graph(G, prefix='test'): | |||||
parts = community.best_partition(G) | parts = community.best_partition(G) | ||||
values = [parts.get(node) for node in G.nodes()] | values = [parts.get(node) for node in G.nodes()] | ||||
colors = [] | colors = [] | ||||
plt.axis("off") | plt.axis("off") | ||||
pos = nx.spring_layout(G) | pos = nx.spring_layout(G) | ||||
nx.draw_networkx(G, with_labels=True, node_size=35, node_color=colors,pos=pos) | |||||
nx.draw_networkx(G, with_labels=True, node_size=35, node_color=colors, pos=pos) | |||||
# plt.switch_backend('agg') | # plt.switch_backend('agg') | ||||
# options = { | # options = { | ||||
# plt.figure() | # plt.figure() | ||||
# plt.subplot() | # plt.subplot() | ||||
# nx.draw_networkx(G, **options) | # nx.draw_networkx(G, **options) | ||||
plt.savefig('figures/graph_view_'+prefix+'.png', dpi=200) | |||||
plt.savefig('figures/graph_view_' + prefix + '.png', dpi=200) | |||||
plt.close() | plt.close() | ||||
plt.switch_backend('agg') | plt.switch_backend('agg') | ||||
G_deg = nx.degree_histogram(G) | G_deg = nx.degree_histogram(G) | ||||
G_deg = np.array(G_deg) | G_deg = np.array(G_deg) | ||||
# plt.plot(range(len(G_deg)), G_deg, 'r', linewidth = 2) | # plt.plot(range(len(G_deg)), G_deg, 'r', linewidth = 2) | ||||
plt.loglog(np.arange(len(G_deg))[G_deg>0], G_deg[G_deg>0], 'r', linewidth=2) | |||||
plt.loglog(np.arange(len(G_deg))[G_deg > 0], G_deg[G_deg > 0], 'r', linewidth=2) | |||||
plt.savefig('figures/degree_view_' + prefix + '.png', dpi=200) | plt.savefig('figures/degree_view_' + prefix + '.png', dpi=200) | ||||
plt.close() | plt.close() | ||||
# draw a list of graphs [G] | # draw a list of graphs [G] | ||||
def draw_graph_list(G_list, row, col, fname = 'figures/test', layout='spring', is_single=False,k=1,node_size=55,alpha=1,width=1.3): | |||||
def draw_graph_list(G_list, row, col, fname='figures/test', layout='spring', is_single=False, k=1, node_size=55, | |||||
alpha=1, width=1.3): | |||||
# # draw graph view | # # draw graph view | ||||
# from pylab import rcParams | # from pylab import rcParams | ||||
# rcParams['figure.figsize'] = 12,3 | # rcParams['figure.figsize'] = 12,3 | ||||
plt.switch_backend('agg') | plt.switch_backend('agg') | ||||
for i,G in enumerate(G_list): | |||||
plt.subplot(row,col,i+1) | |||||
for i, G in enumerate(G_list): | |||||
plt.subplot(row, col, i + 1) | |||||
plt.subplots_adjust(left=0, bottom=0, right=1, top=1, | plt.subplots_adjust(left=0, bottom=0, right=1, top=1, | ||||
wspace=0, hspace=0) | |||||
wspace=0, hspace=0) | |||||
# if i%2==0: | # if i%2==0: | ||||
# plt.title('real nodes: '+str(G.number_of_nodes()), fontsize = 4) | # plt.title('real nodes: '+str(G.number_of_nodes()), fontsize = 4) | ||||
# else: | # else: | ||||
# if values[i] == 6: | # if values[i] == 6: | ||||
# colors.append('black') | # colors.append('black') | ||||
plt.axis("off") | plt.axis("off") | ||||
if layout=='spring': | |||||
pos = nx.spring_layout(G,k=k/np.sqrt(G.number_of_nodes()),iterations=100) | |||||
if layout == 'spring': | |||||
pos = nx.spring_layout(G, k=k / np.sqrt(G.number_of_nodes()), iterations=100) | |||||
# pos = nx.spring_layout(G) | # pos = nx.spring_layout(G) | ||||
elif layout=='spectral': | |||||
elif layout == 'spectral': | |||||
pos = nx.spectral_layout(G) | pos = nx.spectral_layout(G) | ||||
# # nx.draw_networkx(G, with_labels=True, node_size=2, width=0.15, font_size = 1.5, node_color=colors,pos=pos) | # # nx.draw_networkx(G, with_labels=True, node_size=2, width=0.15, font_size = 1.5, node_color=colors,pos=pos) | ||||
# nx.draw_networkx(G, with_labels=False, node_size=1.5, width=0.2, font_size = 1.5, linewidths=0.2, node_color = 'k',pos=pos,alpha=0.2) | # nx.draw_networkx(G, with_labels=False, node_size=1.5, width=0.2, font_size = 1.5, linewidths=0.2, node_color = 'k',pos=pos,alpha=0.2) | ||||
if is_single: | if is_single: | ||||
# node_size default 60, edge_width default 1.5 | # node_size default 60, edge_width default 1.5 | ||||
nx.draw_networkx_nodes(G, pos, node_size=node_size, node_color='#336699', alpha=1, linewidths=0, font_size=0) | |||||
nx.draw_networkx_nodes(G, pos, node_size=node_size, node_color='#336699', alpha=1, linewidths=0, | |||||
font_size=0) | |||||
nx.draw_networkx_edges(G, pos, alpha=alpha, width=width) | nx.draw_networkx_edges(G, pos, alpha=alpha, width=width) | ||||
else: | else: | ||||
nx.draw_networkx_nodes(G, pos, node_size=1.5, node_color='#336699',alpha=1, linewidths=0.2, font_size = 1.5) | |||||
nx.draw_networkx_edges(G, pos, alpha=0.3,width=0.2) | |||||
nx.draw_networkx_nodes(G, pos, node_size=1.5, node_color='#336699', alpha=1, linewidths=0.2, font_size=1.5) | |||||
nx.draw_networkx_edges(G, pos, alpha=0.3, width=0.2) | |||||
# plt.axis('off') | # plt.axis('off') | ||||
# plt.title('Complete Graph of Odd-degree Nodes') | # plt.title('Complete Graph of Odd-degree Nodes') | ||||
# plt.show() | # plt.show() | ||||
plt.tight_layout() | plt.tight_layout() | ||||
plt.savefig(fname+'.png', dpi=600) | |||||
plt.savefig(fname + '.png', dpi=600) | |||||
plt.close() | plt.close() | ||||
# # draw degree distribution | # # draw degree distribution | ||||
# plt.savefig(fname+'_community.png', dpi=600) | # plt.savefig(fname+'_community.png', dpi=600) | ||||
# plt.close() | # plt.close() | ||||
# plt.switch_backend('agg') | # plt.switch_backend('agg') | ||||
# G_deg = nx.degree_histogram(G) | # G_deg = nx.degree_histogram(G) | ||||
# G_deg = np.array(G_deg) | # G_deg = np.array(G_deg) | ||||
# plt.close() | # plt.close() | ||||
# directly get graph statistics from adj, obsoleted | # directly get graph statistics from adj, obsoleted | ||||
def decode_graph(adj, prefix): | def decode_graph(adj, prefix): | ||||
adj = np.asmatrix(adj) | adj = np.asmatrix(adj) | ||||
G = nx.from_numpy_matrix(adj) | G = nx.from_numpy_matrix(adj) | ||||
return G | return G | ||||
# save a list of graphs | # save a list of graphs | ||||
def save_graph_list(G_list, fname): | def save_graph_list(G_list, fname): | ||||
with open(fname, "wb") as f: | with open(fname, "wb") as f: | ||||
# pick the first connected component | # pick the first connected component | ||||
def pick_connected_component(G): | def pick_connected_component(G): | ||||
node_list = nx.node_connected_component(G,0) | |||||
node_list = nx.node_connected_component(G, 0) | |||||
return G.subgraph(node_list) | return G.subgraph(node_list) | ||||
def pick_connected_component_new(G): | def pick_connected_component_new(G): | ||||
adj_list = G.adjacency_list() | adj_list = G.adjacency_list() | ||||
for id,adj in enumerate(adj_list): | |||||
for id, adj in enumerate(adj_list): | |||||
id_min = min(adj) | id_min = min(adj) | ||||
if id<id_min and id>=1: | |||||
# if id<id_min and id>=4: | |||||
if id < id_min and id >= 1: | |||||
# if id<id_min and id>=4: | |||||
break | break | ||||
node_list = list(range(id)) # only include node prior than node "id" | |||||
node_list = list(range(id)) # only include node prior than node "id" | |||||
G = G.subgraph(node_list) | G = G.subgraph(node_list) | ||||
G = max(nx.connected_component_subgraphs(G), key=len) | G = max(nx.connected_component_subgraphs(G), key=len) | ||||
return G | return G | ||||
# load a list of graphs | # load a list of graphs | ||||
def load_graph_list(fname,is_real=True): | |||||
def load_graph_list(fname, is_real=True): | |||||
with open(fname, "rb") as f: | with open(fname, "rb") as f: | ||||
graph_list = pickle.load(f) | graph_list = pickle.load(f) | ||||
for i in range(len(graph_list)): | for i in range(len(graph_list)): | ||||
edges_with_selfloops = graph_list[i].selfloop_edges() | edges_with_selfloops = graph_list[i].selfloop_edges() | ||||
if len(edges_with_selfloops)>0: | |||||
if len(edges_with_selfloops) > 0: | |||||
graph_list[i].remove_edges_from(edges_with_selfloops) | graph_list[i].remove_edges_from(edges_with_selfloops) | ||||
if is_real: | if is_real: | ||||
graph_list[i] = max(nx.connected_component_subgraphs(graph_list[i]), key=len) | graph_list[i] = max(nx.connected_component_subgraphs(graph_list[i]), key=len) | ||||
f.write(str(idx_u) + '\t' + str(idx_v) + '\n') | f.write(str(idx_u) + '\t' + str(idx_v) + '\n') | ||||
i += 1 | i += 1 | ||||
def snap_txt_output_to_nx(in_fname): | def snap_txt_output_to_nx(in_fname): | ||||
G = nx.Graph() | G = nx.Graph() | ||||
with open(in_fname, 'r') as f: | with open(in_fname, 'r') as f: | ||||
G.add_edge(int(u), int(v)) | G.add_edge(int(u), int(v)) | ||||
return G | return G | ||||
def test_perturbed(): | def test_perturbed(): | ||||
graphs = [] | graphs = [] | ||||
for i in range(100,101): | |||||
for j in range(4,5): | |||||
for i in range(100, 101): | |||||
for j in range(4, 5): | |||||
for k in range(500): | for k in range(500): | ||||
graphs.append(nx.barabasi_albert_graph(i,j)) | |||||
graphs.append(nx.barabasi_albert_graph(i, j)) | |||||
g_perturbed = perturb(graphs, 0.9) | g_perturbed = perturb(graphs, 0.9) | ||||
print([g.number_of_edges() for g in graphs]) | print([g.number_of_edges() for g in graphs]) | ||||
print([g.number_of_edges() for g in g_perturbed]) | print([g.number_of_edges() for g in g_perturbed]) | ||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
#test_perturbed() | |||||
#graphs = load_graph_list('graphs/' + 'GraphRNN_RNN_community4_4_128_train_0.dat') | |||||
#graphs = load_graph_list('graphs/' + 'GraphRNN_RNN_community4_4_128_pred_2500_1.dat') | |||||
graphs = load_graph_list('eval_results/mmsb/' + 'community41.dat') | |||||
for i in range(0, 160, 16): | |||||
draw_graph_list(graphs[i:i+16], 4, 4, fname='figures/community4_' + str(i)) | |||||
# test_perturbed() | |||||
graphs = load_graph_list('graphs/' + 'GraphRNN_RNN_grid_4_128_train_0.dat') | |||||
# graphs = load_graph_list('graphs/' + 'GraphRNN_RNN_community4_4_128_pred_2500_1.dat') | |||||
# graphs = load_graph_list('eval_results/mmsb/' + 'community41.dat') | |||||
for i in range(0, 160, 16): | |||||
draw_graph_list(graphs[i:i + 16], 4, 4, fname='figures/community4_' + str(i)) |