|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690 |
- import argparse
- import logging
- import os
- from time import strftime, gmtime
-
- import eval.stats
- import utils
- # import main.Args
- from args import Args
- from baselines.baseline_simple import *
-
-
- class Args_evaluate():
- def __init__(self):
- # loop over the settings
- # self.model_name_all = ['GraphRNN_MLP','GraphRNN_RNN','Internal','Noise']
- # self.model_name_all = ['E-R', 'B-A']
- self.model_name_all = ['GraphRNN_RNN']
- # self.model_name_all = ['Baseline_DGMG']
-
- # list of dataset to evaluate
- # use a list of 1 element to evaluate a single dataset
- # self.dataset_name_all = ['caveman', 'grid', 'barabasi', 'citeseer', 'DD']
- self.dataset_name_all = ['caveman_small']
- # self.dataset_name_all = ['citeseer_small','caveman_small']
- # self.dataset_name_all = ['barabasi_noise0','barabasi_noise2','barabasi_noise4','barabasi_noise6','barabasi_noise8','barabasi_noise10']
- # self.dataset_name_all = ['caveman_small', 'ladder_small', 'grid_small', 'ladder_small', 'enzymes_small', 'barabasi_small','citeseer_small']
-
- self.epoch_start = 10
- self.epoch_end = 101
- self.epoch_step = 10
-
-
- def find_nearest_idx(array, value):
- idx = (np.abs(array - value)).argmin()
- return idx
-
-
- def extract_result_id_and_epoch(name, prefix, suffix):
- '''
- Args:
- eval_every: the number of epochs between consecutive evaluations
- suffix: real_ or pred_
- Returns:
- A tuple of (id, epoch number) extracted from the filename string
- '''
- pos = name.find(suffix) + len(suffix)
- end_pos = name.find('.dat')
- result_id = name[pos:end_pos]
-
- pos = name.find(prefix) + len(prefix)
- end_pos = name.find('_', pos)
- epochs = int(name[pos:end_pos])
- return result_id, epochs
-
-
- def eval_list(real_graphs_filename, pred_graphs_filename, prefix, eval_every):
- real_graphs_dict = {}
- pred_graphs_dict = {}
-
- for fname in real_graphs_filename:
- result_id, epochs = extract_result_id_and_epoch(fname, prefix, 'real_')
- if not epochs % eval_every == 0:
- continue
- if result_id not in real_graphs_dict:
- real_graphs_dict[result_id] = {}
- real_graphs_dict[result_id][epochs] = fname
- for fname in pred_graphs_filename:
- result_id, epochs = extract_result_id_and_epoch(fname, prefix, 'pred_')
- if not epochs % eval_every == 0:
- continue
- if result_id not in pred_graphs_dict:
- pred_graphs_dict[result_id] = {}
- pred_graphs_dict[result_id][epochs] = fname
-
- for result_id in real_graphs_dict.keys():
- for epochs in sorted(real_graphs_dict[result_id]):
- real_g_list = utils.load_graph_list(real_graphs_dict[result_id][epochs])
- pred_g_list = utils.load_graph_list(pred_graphs_dict[result_id][epochs])
- shuffle(real_g_list)
- shuffle(pred_g_list)
- perturbed_g_list = perturb(real_g_list, 0.05)
-
- # dist = eval.stats.degree_stats(real_g_list, pred_g_list)
- dist = eval.stats.clustering_stats(real_g_list, pred_g_list)
- print('dist between real and pred (', result_id, ') at epoch ', epochs, ': ', dist)
-
- # dist = eval.stats.degree_stats(real_g_list, perturbed_g_list)
- dist = eval.stats.clustering_stats(real_g_list, perturbed_g_list)
- print('dist between real and perturbed: ', dist)
-
- mid = len(real_g_list) // 2
- # dist = eval.stats.degree_stats(real_g_list[:mid], real_g_list[mid:])
- dist = eval.stats.clustering_stats(real_g_list[:mid], real_g_list[mid:])
- print('dist among real: ', dist)
-
-
- def compute_basic_stats(real_g_list, target_g_list):
- dist_degree = eval.stats.degree_stats(real_g_list, target_g_list)
- dist_clustering = eval.stats.clustering_stats(real_g_list, target_g_list)
- return dist_degree, dist_clustering
-
-
- def clean_graphs(graph_real, graph_pred):
- ''' Selecting graphs generated that have the similar sizes.
- It is usually necessary for GraphRNN-S version, but not the full GraphRNN model.
- '''
- shuffle(graph_real)
- shuffle(graph_pred)
-
- # get length
- real_graph_len = np.array([len(graph_real[i]) for i in range(len(graph_real))])
- pred_graph_len = np.array([len(graph_pred[i]) for i in range(len(graph_pred))])
-
- # select pred samples
- # The number of nodes are sampled from the similar distribution as the training set
- pred_graph_new = []
- pred_graph_len_new = []
- for value in real_graph_len:
- pred_idx = find_nearest_idx(pred_graph_len, value)
- pred_graph_new.append(graph_pred[pred_idx])
- pred_graph_len_new.append(pred_graph_len[pred_idx])
-
- return graph_real, pred_graph_new
-
-
- def load_ground_truth(dir_input, dataset_name, model_name='GraphRNN_RNN'):
- ''' Read ground truth graphs.
- '''
- if not 'small' in dataset_name:
- hidden = 128
- else:
- hidden = 64
- if model_name == 'Internal' or model_name == 'Noise' or model_name == 'B-A' or model_name == 'E-R':
- fname_test = dir_input + 'GraphRNN_MLP' + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(
- hidden) + '_test_' + str(0) + '.dat'
- else:
- fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(
- hidden) + '_test_' + str(0) + '.dat'
- try:
- graph_test = utils.load_graph_list(fname_test, is_real=True)
- except:
- print('Not found: ' + fname_test)
- logging.warning('Not found: ' + fname_test)
- return None
- return graph_test
-
-
- def eval_single_list(graphs, dir_input, dataset_name):
- ''' Evaluate a list of graphs by comparing with graphs in directory dir_input.
- Args:
- dir_input: directory where ground truth graph list is stored
- dataset_name: name of the dataset (ground truth)
- '''
- graph_test = load_ground_truth(dir_input, dataset_name)
- graph_test_len = len(graph_test)
- graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set
- mmd_degree = eval.stats.degree_stats(graph_test, graphs)
- mmd_clustering = eval.stats.clustering_stats(graph_test, graphs)
- try:
- mmd_4orbits = eval.stats.orbit_stats_all(graph_test, graphs)
- except:
- mmd_4orbits = -1
- print('deg: ', mmd_degree)
- print('clustering: ', mmd_clustering)
- print('orbits: ', mmd_4orbits)
-
-
- def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is_clean=True, epoch_start=1000,
- epoch_end=3001, epoch_step=100):
- with open(fname_output, 'w+') as f:
- f.write(
- 'sample_time,epoch,degree_validate,clustering_validate,orbits4_validate,degree_test,clustering_test,orbits4_test\n')
-
- # TODO: Maybe refactor into a separate file/function that specifies THE naming convention
- # across main and evaluate
- if not 'small' in dataset_name:
- hidden = 128
- else:
- hidden = 64
- # read real graph
- if model_name == 'Internal' or model_name == 'Noise' or model_name == 'B-A' or model_name == 'E-R':
- fname_test = dir_input + 'GraphRNN_MLP' + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(
- hidden) + '_test_' + str(0) + '.dat'
- elif 'Baseline' in model_name:
- fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(64) + '_test_' + str(0) + '.dat'
- else:
- fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(
- hidden) + '_test_' + str(0) + '.dat'
- try:
- graph_test = utils.load_graph_list(fname_test, is_real=True)
- except:
- print('Not found: ' + fname_test)
- logging.warning('Not found: ' + fname_test)
- return None
-
- graph_test_len = len(graph_test)
- graph_train = graph_test[0:int(0.8 * graph_test_len)] # train
- graph_validate = graph_test[0:int(0.2 * graph_test_len)] # validate
- graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set
-
- graph_test_aver = 0
- for graph in graph_test:
- graph_test_aver += graph.number_of_nodes()
- graph_test_aver /= len(graph_test)
- print('test average len', graph_test_aver)
-
- # get performance for proposed approaches
- if 'GraphRNN' in model_name:
- # read test graph
- for epoch in range(epoch_start, epoch_end, epoch_step):
- for sample_time in range(1, 4):
- # get filename
- fname_pred = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(
- hidden) + '_pred_' + str(epoch) + '_' + str(sample_time) + '.dat'
- # load graphs
- try:
- graph_pred = utils.load_graph_list(fname_pred, is_real=False) # default False
- except:
- print('Not found: ' + fname_pred)
- logging.warning('Not found: ' + fname_pred)
- continue
- # clean graphs
- if is_clean:
- graph_test, graph_pred = clean_graphs(graph_test, graph_pred)
- else:
- shuffle(graph_pred)
- graph_pred = graph_pred[0:len(graph_test)]
- print('len graph_test', len(graph_test))
- print('len graph_validate', len(graph_validate))
- print('len graph_pred', len(graph_pred))
-
- graph_pred_aver = 0
- for graph in graph_pred:
- graph_pred_aver += graph.number_of_nodes()
- graph_pred_aver /= len(graph_pred)
- print('pred average len', graph_pred_aver)
-
- # evaluate MMD test
- mmd_degree = eval.stats.degree_stats(graph_test, graph_pred)
- mmd_clustering = eval.stats.clustering_stats(graph_test, graph_pred)
- try:
- mmd_4orbits = eval.stats.orbit_stats_all(graph_test, graph_pred)
- except:
- mmd_4orbits = -1
- # evaluate MMD validate
- mmd_degree_validate = eval.stats.degree_stats(graph_validate, graph_pred)
- mmd_clustering_validate = eval.stats.clustering_stats(graph_validate, graph_pred)
- try:
- mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_validate, graph_pred)
- except:
- mmd_4orbits_validate = -1
- # write results
- f.write(str(sample_time) + ',' +
- str(epoch) + ',' +
- str(mmd_degree_validate) + ',' +
- str(mmd_clustering_validate) + ',' +
- str(mmd_4orbits_validate) + ',' +
- str(mmd_degree) + ',' +
- str(mmd_clustering) + ',' +
- str(mmd_4orbits) + '\n')
- print('degree', mmd_degree, 'clustering', mmd_clustering, 'orbits', mmd_4orbits)
-
- # get internal MMD (MMD between ground truth validation and test sets)
- if model_name == 'Internal':
- mmd_degree_validate = eval.stats.degree_stats(graph_test, graph_validate)
- mmd_clustering_validate = eval.stats.clustering_stats(graph_test, graph_validate)
- try:
- mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_test, graph_validate)
- except:
- mmd_4orbits_validate = -1
- f.write(str(-1) + ',' + str(-1) + ',' + str(mmd_degree_validate) + ',' + str(
- mmd_clustering_validate) + ',' + str(mmd_4orbits_validate)
- + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + '\n')
-
- # get MMD between ground truth and its perturbed graphs
- if model_name == 'Noise':
- graph_validate_perturbed = perturb(graph_validate, 0.05)
- mmd_degree_validate = eval.stats.degree_stats(graph_test, graph_validate_perturbed)
- mmd_clustering_validate = eval.stats.clustering_stats(graph_test, graph_validate_perturbed)
- try:
- mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_test, graph_validate_perturbed)
- except:
- mmd_4orbits_validate = -1
- f.write(str(-1) + ',' + str(-1) + ',' + str(mmd_degree_validate) + ',' + str(
- mmd_clustering_validate) + ',' + str(mmd_4orbits_validate)
- + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + '\n')
-
- # get E-R MMD
- if model_name == 'E-R':
- graph_pred = Graph_generator_baseline(graph_train, generator='Gnp')
- # clean graphs
- if is_clean:
- graph_test, graph_pred = clean_graphs(graph_test, graph_pred)
- print('len graph_test', len(graph_test))
- print('len graph_pred', len(graph_pred))
- mmd_degree = eval.stats.degree_stats(graph_test, graph_pred)
- mmd_clustering = eval.stats.clustering_stats(graph_test, graph_pred)
- try:
- mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_test, graph_pred)
- except:
- mmd_4orbits_validate = -1
- f.write(str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1)
- + ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits_validate) + '\n')
-
- # get B-A MMD
- if model_name == 'B-A':
- graph_pred = Graph_generator_baseline(graph_train, generator='BA')
- # clean graphs
- if is_clean:
- graph_test, graph_pred = clean_graphs(graph_test, graph_pred)
- print('len graph_test', len(graph_test))
- print('len graph_pred', len(graph_pred))
- mmd_degree = eval.stats.degree_stats(graph_test, graph_pred)
- mmd_clustering = eval.stats.clustering_stats(graph_test, graph_pred)
- try:
- mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_test, graph_pred)
- except:
- mmd_4orbits_validate = -1
- f.write(str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1)
- + ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits_validate) + '\n')
-
- # get performance for baseline approaches
- if 'Baseline' in model_name:
- # read test graph
- for epoch in range(epoch_start, epoch_end, epoch_step):
- # get filename
- fname_pred = dir_input + model_name + '_' + dataset_name + '_' + str(
- 64) + '_pred_' + str(epoch) + '.dat'
- # load graphs
- try:
- graph_pred = utils.load_graph_list(fname_pred, is_real=True) # default False
- except:
- print('Not found: ' + fname_pred)
- logging.warning('Not found: ' + fname_pred)
- continue
- # clean graphs
- if is_clean:
- graph_test, graph_pred = clean_graphs(graph_test, graph_pred)
- else:
- shuffle(graph_pred)
- graph_pred = graph_pred[0:len(graph_test)]
- print('len graph_test', len(graph_test))
- print('len graph_validate', len(graph_validate))
- print('len graph_pred', len(graph_pred))
-
- graph_pred_aver = 0
- for graph in graph_pred:
- graph_pred_aver += graph.number_of_nodes()
- graph_pred_aver /= len(graph_pred)
- print('pred average len', graph_pred_aver)
-
- # evaluate MMD test
- mmd_degree = eval.stats.degree_stats(graph_test, graph_pred)
- mmd_clustering = eval.stats.clustering_stats(graph_test, graph_pred)
- try:
- mmd_4orbits = eval.stats.orbit_stats_all(graph_test, graph_pred)
- except:
- mmd_4orbits = -1
- # evaluate MMD validate
- mmd_degree_validate = eval.stats.degree_stats(graph_validate, graph_pred)
- mmd_clustering_validate = eval.stats.clustering_stats(graph_validate, graph_pred)
- try:
- mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_validate, graph_pred)
- except:
- mmd_4orbits_validate = -1
- # write results
- f.write(str(-1) + ',' + str(epoch) + ',' + str(mmd_degree_validate) + ',' + str(
- mmd_clustering_validate) + ',' + str(mmd_4orbits_validate)
- + ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits) + '\n')
- print('degree', mmd_degree, 'clustering', mmd_clustering, 'orbits', mmd_4orbits)
-
- return True
-
-
- def evaluation(args_evaluate, dir_input, dir_output, model_name_all, dataset_name_all, args, overwrite=True):
- ''' Evaluate the performance of a set of models on a set of datasets.
- '''
- for model_name in model_name_all:
- for dataset_name in dataset_name_all:
- # check output exist
- fname_output = dir_output + model_name + '_' + dataset_name + '.csv'
- print('processing: ' + dir_output + model_name + '_' + dataset_name + '.csv')
- logging.info('processing: ' + dir_output + model_name + '_' + dataset_name + '.csv')
- if overwrite == False and os.path.isfile(fname_output):
- print(dir_output + model_name + '_' + dataset_name + '.csv exists!')
- logging.info(dir_output + model_name + '_' + dataset_name + '.csv exists!')
- continue
- evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is_clean=True,
- epoch_start=args_evaluate.epoch_start, epoch_end=args_evaluate.epoch_end,
- epoch_step=args_evaluate.epoch_step)
-
-
- def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines,
- eval_every, epoch_range=None, out_file_prefix=None):
- ''' Evaluate list of predicted graphs compared to ground truth, stored in files.
- Args:
- baselines: dict mapping name of the baseline to list of generated graphs.
- '''
-
- if out_file_prefix is not None:
- out_files = {
- 'train': open(out_file_prefix + '_train.txt', 'w+'),
- 'compare': open(out_file_prefix + '_compare.txt', 'w+')
- }
-
- out_files['train'].write('degree,clustering,orbits4\n')
-
- line = 'metric,real,ours,perturbed'
- for bl in baselines:
- line += ',' + bl
- line += '\n'
- out_files['compare'].write(line)
-
- results = {
- 'deg': {
- 'real': 0,
- 'ours': 100, # take min over all training epochs
- 'perturbed': 0,
- 'kron': 0},
- 'clustering': {
- 'real': 0,
- 'ours': 100,
- 'perturbed': 0,
- 'kron': 0},
- 'orbits4': {
- 'real': 0,
- 'ours': 100,
- 'perturbed': 0,
- 'kron': 0}
- }
-
- num_evals = len(pred_graphs_filename)
- if epoch_range is None:
- epoch_range = [i * eval_every for i in range(num_evals)]
- for i in range(num_evals):
- real_g_list = utils.load_graph_list(real_graph_filename)
- # pred_g_list = utils.load_graph_list(pred_graphs_filename[i])
-
- # contains all predicted G
- pred_g_list_raw = utils.load_graph_list(pred_graphs_filename[i])
- if len(real_g_list) > 200:
- real_g_list = real_g_list[0:200]
-
- shuffle(real_g_list)
- shuffle(pred_g_list_raw)
-
- # get length
- real_g_len_list = np.array([len(real_g_list[i]) for i in range(len(real_g_list))])
- pred_g_len_list_raw = np.array([len(pred_g_list_raw[i]) for i in range(len(pred_g_list_raw))])
- # get perturb real
- # perturbed_g_list_001 = perturb(real_g_list, 0.01)
- perturbed_g_list_005 = perturb(real_g_list, 0.05)
- # perturbed_g_list_010 = perturb(real_g_list, 0.10)
-
- # select pred samples
- # The number of nodes are sampled from the similar distribution as the training set
- pred_g_list = []
- pred_g_len_list = []
- for value in real_g_len_list:
- pred_idx = find_nearest_idx(pred_g_len_list_raw, value)
- pred_g_list.append(pred_g_list_raw[pred_idx])
- pred_g_len_list.append(pred_g_len_list_raw[pred_idx])
- # delete
- pred_g_len_list_raw = np.delete(pred_g_len_list_raw, pred_idx)
- del pred_g_list_raw[pred_idx]
- if len(pred_g_list) == len(real_g_list):
- break
- # pred_g_len_list = np.array(pred_g_len_list)
- print('################## epoch {} ##################'.format(epoch_range[i]))
-
- # info about graph size
- print('real average nodes',
- sum([real_g_list[i].number_of_nodes() for i in range(len(real_g_list))]) / len(real_g_list))
- print('pred average nodes',
- sum([pred_g_list[i].number_of_nodes() for i in range(len(pred_g_list))]) / len(pred_g_list))
- print('num of real graphs', len(real_g_list))
- print('num of pred graphs', len(pred_g_list))
-
- # ========================================
- # Evaluation
- # ========================================
- mid = len(real_g_list) // 2
- dist_degree, dist_clustering = compute_basic_stats(real_g_list[:mid], real_g_list[mid:])
- # dist_4cycle = eval.stats.motif_stats(real_g_list[:mid], real_g_list[mid:])
- dist_4orbits = eval.stats.orbit_stats_all(real_g_list[:mid], real_g_list[mid:])
- print('degree dist among real: ', dist_degree)
- print('clustering dist among real: ', dist_clustering)
- # print('4 cycle dist among real: ', dist_4cycle)
- print('orbits dist among real: ', dist_4orbits)
- results['deg']['real'] += dist_degree
- results['clustering']['real'] += dist_clustering
- results['orbits4']['real'] += dist_4orbits
-
- dist_degree, dist_clustering = compute_basic_stats(real_g_list, pred_g_list)
- # dist_4cycle = eval.stats.motif_stats(real_g_list, pred_g_list)
- dist_4orbits = eval.stats.orbit_stats_all(real_g_list, pred_g_list)
- print('degree dist between real and pred at epoch ', epoch_range[i], ': ', dist_degree)
- print('clustering dist between real and pred at epoch ', epoch_range[i], ': ', dist_clustering)
- # print('4 cycle dist between real and pred at epoch: ', epoch_range[i], dist_4cycle)
- print('orbits dist between real and pred at epoch ', epoch_range[i], ': ', dist_4orbits)
- results['deg']['ours'] = min(dist_degree, results['deg']['ours'])
- results['clustering']['ours'] = min(dist_clustering, results['clustering']['ours'])
- results['orbits4']['ours'] = min(dist_4orbits, results['orbits4']['ours'])
-
- # performance at training time
- out_files['train'].write(str(dist_degree) + ',')
- out_files['train'].write(str(dist_clustering) + ',')
- out_files['train'].write(str(dist_4orbits) + ',')
-
- dist_degree, dist_clustering = compute_basic_stats(real_g_list, perturbed_g_list_005)
- # dist_4cycle = eval.stats.motif_stats(real_g_list, perturbed_g_list_005)
- dist_4orbits = eval.stats.orbit_stats_all(real_g_list, perturbed_g_list_005)
- print('degree dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_degree)
- print('clustering dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_clustering)
- # print('4 cycle dist between real and perturbed at epoch: ', epoch_range[i], dist_4cycle)
- print('orbits dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_4orbits)
- results['deg']['perturbed'] += dist_degree
- results['clustering']['perturbed'] += dist_clustering
- results['orbits4']['perturbed'] += dist_4orbits
-
- if i == 0:
- # Baselines
- for baseline in baselines:
- dist_degree, dist_clustering = compute_basic_stats(real_g_list, baselines[baseline])
- dist_4orbits = eval.stats.orbit_stats_all(real_g_list, baselines[baseline])
- results['deg'][baseline] = dist_degree
- results['clustering'][baseline] = dist_clustering
- results['orbits4'][baseline] = dist_4orbits
- print('Kron: deg=', dist_degree, ', clustering=', dist_clustering,
- ', orbits4=', dist_4orbits)
-
- out_files['train'].write('\n')
-
- for metric, methods in results.items():
- methods['real'] /= num_evals
- methods['perturbed'] /= num_evals
-
- # Write results
- for metric, methods in results.items():
- line = metric + ',' + \
- str(methods['real']) + ',' + \
- str(methods['ours']) + ',' + \
- str(methods['perturbed'])
- for baseline in baselines:
- line += ',' + str(methods[baseline])
- line += '\n'
-
- out_files['compare'].write(line)
-
- for _, out_f in out_files.items():
- out_f.close()
-
-
- def eval_performance(datadir, prefix=None, args=None, eval_every=200, out_file_prefix=None,
- sample_time=2, baselines={}):
- if args is None:
- real_graphs_filename = [datadir + f for f in os.listdir(datadir)
- if re.match(prefix + '.*real.*\.dat', f)]
- pred_graphs_filename = [datadir + f for f in os.listdir(datadir)
- if re.match(prefix + '.*pred.*\.dat', f)]
- eval_list(real_graphs_filename, pred_graphs_filename, prefix, 200)
-
- else:
- # # for vanilla graphrnn
- # real_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
- # str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(args.bptt_len) + '.dat' for epoch in range(0,50001,eval_every)]
- # pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
- # str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(args.bptt_len) + '.dat' for epoch in range(0,50001,eval_every)]
-
- real_graph_filename = datadir + args.graph_save_path + args.fname_test + '0.dat'
- # for proposed model
- end_epoch = 3001
- epoch_range = range(eval_every, end_epoch, eval_every)
- pred_graphs_filename = [
- datadir + args.graph_save_path + args.fname_pred + str(epoch) + '_' + str(sample_time) + '.dat'
- for epoch in epoch_range]
- # for baseline model
- # pred_graphs_filename = [datadir+args.fname_baseline+'.dat']
-
- # real_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
- # str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(
- # args.bptt_len) + '_' + str(args.gumbel) + '.dat' for epoch in range(10000, 50001, eval_every)]
- # pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
- # str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(
- # args.bptt_len) + '_' + str(args.gumbel) + '.dat' for epoch in range(10000, 50001, eval_every)]
-
- eval_list_fname(real_graph_filename, pred_graphs_filename, baselines,
- epoch_range=epoch_range,
- eval_every=eval_every,
- out_file_prefix=out_file_prefix)
-
-
- def process_kron(kron_dir):
- txt_files = []
- for f in os.listdir(kron_dir):
- filename = os.fsdecode(f)
- if filename.endswith('.txt'):
- txt_files.append(filename)
- elif filename.endswith('.dat'):
- return utils.load_graph_list(os.path.join(kron_dir, filename))
- G_list = []
- for filename in txt_files:
- G_list.append(utils.snap_txt_output_to_nx(os.path.join(kron_dir, filename)))
-
- return G_list
-
-
- if __name__ == '__main__':
- args = Args()
- args_evaluate = Args_evaluate()
-
- parser = argparse.ArgumentParser(description='Evaluation arguments.')
- feature_parser = parser.add_mutually_exclusive_group(required=False)
- feature_parser.add_argument('--export-real', dest='export', action='store_true')
- feature_parser.add_argument('--no-export-real', dest='export', action='store_false')
- feature_parser.add_argument('--kron-dir', dest='kron_dir',
- help='Directory where graphs generated by kronecker method is stored.')
-
- parser.add_argument('--testfile', dest='test_file',
- help='The file that stores list of graphs to be evaluated. Only used when 1 list of '
- 'graphs is to be evaluated.')
- parser.add_argument('--dir-prefix', dest='dir_prefix',
- help='The file that stores list of graphs to be evaluated. Can be used when evaluating multiple'
- 'models on multiple datasets.')
- parser.add_argument('--graph-type', dest='graph_type',
- help='Type of graphs / dataset.')
-
- parser.set_defaults(export=False, kron_dir='', test_file='',
- dir_prefix='',
- graph_type=args.graph_type)
- prog_args = parser.parse_args()
-
- # dir_prefix = prog_args.dir_prefix
- # dir_prefix = "/dfs/scratch0/jiaxuany0/"
- dir_prefix = args.dir_input
-
- time_now = strftime("%Y-%m-%d %H:%M:%S", gmtime())
- if not os.path.isdir('logs/'):
- os.makedirs('logs/')
- logging.basicConfig(filename='logs/evaluate' + time_now + '.log', level=logging.INFO)
-
- if prog_args.export:
- if not os.path.isdir('eval_results'):
- os.makedirs('eval_results')
- if not os.path.isdir('eval_results/ground_truth'):
- os.makedirs('eval_results/ground_truth')
- out_dir = os.path.join('eval_results/ground_truth', prog_args.graph_type)
- if not os.path.isdir(out_dir):
- os.makedirs(out_dir)
- output_prefix = os.path.join(out_dir, prog_args.graph_type)
- print('Export ground truth to prefix: ', output_prefix)
-
- if prog_args.graph_type == 'grid':
- graphs = []
- for i in range(10, 20):
- for j in range(10, 20):
- graphs.append(nx.grid_2d_graph(i, j))
- utils.export_graphs_to_txt(graphs, output_prefix)
- elif prog_args.graph_type == 'caveman':
- graphs = []
- for i in range(2, 3):
- for j in range(30, 81):
- for k in range(10):
- graphs.append(caveman_special(i, j, p_edge=0.3))
- utils.export_graphs_to_txt(graphs, output_prefix)
- elif prog_args.graph_type == 'citeseer':
- graphs = utils.citeseer_ego()
- utils.export_graphs_to_txt(graphs, output_prefix)
- else:
- # load from directory
- input_path = dir_prefix + real_graph_filename
- g_list = utils.load_graph_list(input_path)
- utils.export_graphs_to_txt(g_list, output_prefix)
- elif not prog_args.kron_dir == '':
- kron_g_list = process_kron(prog_args.kron_dir)
- fname = os.path.join(prog_args.kron_dir, prog_args.graph_type + '.dat')
- print([g.number_of_nodes() for g in kron_g_list])
- utils.save_graph_list(kron_g_list, fname)
- elif not prog_args.test_file == '':
- # evaluate single .dat file containing list of test graphs (networkx format)
- graphs = utils.load_graph_list(prog_args.test_file)
- eval_single_list(graphs, dir_input=dir_prefix + 'graphs/', dataset_name='grid')
- ## if you don't try kronecker, only the following part is needed
- else:
- if not os.path.isdir(dir_prefix + 'eval_results'):
- os.makedirs(dir_prefix + 'eval_results')
- evaluation(args_evaluate, dir_input=dir_prefix + "graphs/", dir_output=dir_prefix + "eval_results/",
- model_name_all=args_evaluate.model_name_all, dataset_name_all=args_evaluate.dataset_name_all,
- args=args, overwrite=True)
|