You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluate.py 32KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692
  1. import argparse
  2. import numpy as np
  3. import os
  4. import re
  5. from random import shuffle
  6. import eval.stats
  7. import utils
  8. # import main.Args
  9. from baselines.baseline_simple import *
  10. class Args_evaluate():
  11. def __init__(self):
  12. # loop over the settings
  13. # self.model_name_all = ['GraphRNN_MLP','GraphRNN_RNN','Internal','Noise']
  14. # self.model_name_all = ['E-R', 'B-A']
  15. self.model_name_all = ['GraphRNN_RNN']
  16. # self.model_name_all = ['Baseline_DGMG']
  17. # list of dataset to evaluate
  18. # use a list of 1 element to evaluate a single dataset
  19. self.dataset_name_all = ['caveman', 'grid', 'barabasi', 'citeseer', 'DD']
  20. # self.dataset_name_all = ['citeseer_small','caveman_small']
  21. # self.dataset_name_all = ['barabasi_noise0','barabasi_noise2','barabasi_noise4','barabasi_noise6','barabasi_noise8','barabasi_noise10']
  22. # self.dataset_name_all = ['caveman_small', 'ladder_small', 'grid_small', 'ladder_small', 'enzymes_small', 'barabasi_small','citeseer_small']
  23. self.epoch_start=100
  24. self.epoch_end=3001
  25. self.epoch_step=100
  26. def find_nearest_idx(array,value):
  27. idx = (np.abs(array-value)).argmin()
  28. return idx
  29. def extract_result_id_and_epoch(name, prefix, suffix):
  30. '''
  31. Args:
  32. eval_every: the number of epochs between consecutive evaluations
  33. suffix: real_ or pred_
  34. Returns:
  35. A tuple of (id, epoch number) extracted from the filename string
  36. '''
  37. pos = name.find(suffix) + len(suffix)
  38. end_pos = name.find('.dat')
  39. result_id = name[pos:end_pos]
  40. pos = name.find(prefix) + len(prefix)
  41. end_pos = name.find('_', pos)
  42. epochs = int(name[pos:end_pos])
  43. return result_id, epochs
  44. def eval_list(real_graphs_filename, pred_graphs_filename, prefix, eval_every):
  45. real_graphs_dict = {}
  46. pred_graphs_dict = {}
  47. for fname in real_graphs_filename:
  48. result_id, epochs = extract_result_id_and_epoch(fname, prefix, 'real_')
  49. if not epochs % eval_every == 0:
  50. continue
  51. if result_id not in real_graphs_dict:
  52. real_graphs_dict[result_id] = {}
  53. real_graphs_dict[result_id][epochs] = fname
  54. for fname in pred_graphs_filename:
  55. result_id, epochs = extract_result_id_and_epoch(fname, prefix, 'pred_')
  56. if not epochs % eval_every == 0:
  57. continue
  58. if result_id not in pred_graphs_dict:
  59. pred_graphs_dict[result_id] = {}
  60. pred_graphs_dict[result_id][epochs] = fname
  61. for result_id in real_graphs_dict.keys():
  62. for epochs in sorted(real_graphs_dict[result_id]):
  63. real_g_list = utils.load_graph_list(real_graphs_dict[result_id][epochs])
  64. pred_g_list = utils.load_graph_list(pred_graphs_dict[result_id][epochs])
  65. shuffle(real_g_list)
  66. shuffle(pred_g_list)
  67. perturbed_g_list = perturb(real_g_list, 0.05)
  68. #dist = eval.stats.degree_stats(real_g_list, pred_g_list)
  69. dist = eval.stats.clustering_stats(real_g_list, pred_g_list)
  70. print('dist between real and pred (', result_id, ') at epoch ', epochs, ': ', dist)
  71. #dist = eval.stats.degree_stats(real_g_list, perturbed_g_list)
  72. dist = eval.stats.clustering_stats(real_g_list, perturbed_g_list)
  73. print('dist between real and perturbed: ', dist)
  74. mid = len(real_g_list) // 2
  75. #dist = eval.stats.degree_stats(real_g_list[:mid], real_g_list[mid:])
  76. dist = eval.stats.clustering_stats(real_g_list[:mid], real_g_list[mid:])
  77. print('dist among real: ', dist)
  78. def compute_basic_stats(real_g_list, target_g_list):
  79. dist_degree = eval.stats.degree_stats(real_g_list, target_g_list)
  80. dist_clustering = eval.stats.clustering_stats(real_g_list, target_g_list)
  81. return dist_degree, dist_clustering
  82. def clean_graphs(graph_real, graph_pred):
  83. ''' Selecting graphs generated that have the similar sizes.
  84. It is usually necessary for GraphRNN-S version, but not the full GraphRNN model.
  85. '''
  86. shuffle(graph_real)
  87. shuffle(graph_pred)
  88. # get length
  89. real_graph_len = np.array([len(graph_real[i]) for i in range(len(graph_real))])
  90. pred_graph_len = np.array([len(graph_pred[i]) for i in range(len(graph_pred))])
  91. # select pred samples
  92. # The number of nodes are sampled from the similar distribution as the training set
  93. pred_graph_new = []
  94. pred_graph_len_new = []
  95. for value in real_graph_len:
  96. pred_idx = find_nearest_idx(pred_graph_len, value)
  97. pred_graph_new.append(graph_pred[pred_idx])
  98. pred_graph_len_new.append(pred_graph_len[pred_idx])
  99. return graph_real, pred_graph_new
  100. def load_ground_truth(dir_input, dataset_name, model_name='GraphRNN_RNN'):
  101. ''' Read ground truth graphs.
  102. '''
  103. if not 'small' in dataset_name:
  104. hidden = 128
  105. else:
  106. hidden = 64
  107. if model_name=='Internal' or model_name=='Noise' or model_name=='B-A' or model_name=='E-R':
  108. fname_test = dir_input + 'GraphRNN_MLP' + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(
  109. hidden) + '_test_' + str(0) + '.dat'
  110. else:
  111. fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(
  112. hidden) + '_test_' + str(0) + '.dat'
  113. try:
  114. graph_test = utils.load_graph_list(fname_test,is_real=True)
  115. except:
  116. print('Not found: ' + fname_test)
  117. logging.warning('Not found: ' + fname_test)
  118. return None
  119. return graph_test
  120. def eval_single_list(graphs, dir_input, dataset_name):
  121. ''' Evaluate a list of graphs by comparing with graphs in directory dir_input.
  122. Args:
  123. dir_input: directory where ground truth graph list is stored
  124. dataset_name: name of the dataset (ground truth)
  125. '''
  126. graph_test = load_ground_truth(dir_input, dataset_name)
  127. graph_test_len = len(graph_test)
  128. graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set
  129. mmd_degree = eval.stats.degree_stats(graph_test, graphs)
  130. mmd_clustering = eval.stats.clustering_stats(graph_test, graphs)
  131. try:
  132. mmd_4orbits = eval.stats.orbit_stats_all(graph_test, graphs)
  133. except:
  134. mmd_4orbits = -1
  135. print('deg: ', mmd_degree)
  136. print('clustering: ', mmd_clustering)
  137. print('orbits: ', mmd_4orbits)
  138. def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is_clean=True, epoch_start=1000,epoch_end=3001,epoch_step=100):
  139. with open(fname_output, 'w+') as f:
  140. f.write('sample_time,epoch,degree_validate,clustering_validate,orbits4_validate,degree_test,clustering_test,orbits4_test\n')
  141. # TODO: Maybe refactor into a separate file/function that specifies THE naming convention
  142. # across main and evaluate
  143. if not 'small' in dataset_name:
  144. hidden = 128
  145. else:
  146. hidden = 64
  147. # read real graph
  148. if model_name=='Internal' or model_name=='Noise' or model_name=='B-A' or model_name=='E-R':
  149. fname_test = dir_input + 'GraphRNN_MLP' + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(
  150. hidden) + '_test_' + str(0) + '.dat'
  151. elif 'Baseline' in model_name:
  152. fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(64) + '_test_' + str(0) + '.dat'
  153. else:
  154. fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(
  155. hidden) + '_test_' + str(0) + '.dat'
  156. try:
  157. graph_test = utils.load_graph_list(fname_test,is_real=True)
  158. except:
  159. print('Not found: ' + fname_test)
  160. logging.warning('Not found: ' + fname_test)
  161. return None
  162. graph_test_len = len(graph_test)
  163. graph_train = graph_test[0:int(0.8 * graph_test_len)] # train
  164. graph_validate = graph_test[0:int(0.2 * graph_test_len)] # validate
  165. graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set
  166. graph_test_aver = 0
  167. for graph in graph_test:
  168. graph_test_aver+=graph.number_of_nodes()
  169. graph_test_aver /= len(graph_test)
  170. print('test average len',graph_test_aver)
  171. # get performance for proposed approaches
  172. if 'GraphRNN' in model_name:
  173. # read test graph
  174. for epoch in range(epoch_start,epoch_end,epoch_step):
  175. for sample_time in range(1,4):
  176. # get filename
  177. fname_pred = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(hidden) + '_pred_' + str(epoch) + '_' + str(sample_time) + '.dat'
  178. # load graphs
  179. try:
  180. graph_pred = utils.load_graph_list(fname_pred,is_real=False) # default False
  181. except:
  182. print('Not found: '+ fname_pred)
  183. logging.warning('Not found: '+ fname_pred)
  184. continue
  185. # clean graphs
  186. if is_clean:
  187. graph_test, graph_pred = clean_graphs(graph_test, graph_pred)
  188. else:
  189. shuffle(graph_pred)
  190. graph_pred = graph_pred[0:len(graph_test)]
  191. print('len graph_test', len(graph_test))
  192. print('len graph_validate', len(graph_validate))
  193. print('len graph_pred', len(graph_pred))
  194. graph_pred_aver = 0
  195. for graph in graph_pred:
  196. graph_pred_aver += graph.number_of_nodes()
  197. graph_pred_aver /= len(graph_pred)
  198. print('pred average len', graph_pred_aver)
  199. # evaluate MMD test
  200. mmd_degree = eval.stats.degree_stats(graph_test, graph_pred)
  201. mmd_clustering = eval.stats.clustering_stats(graph_test, graph_pred)
  202. try:
  203. mmd_4orbits = eval.stats.orbit_stats_all(graph_test, graph_pred)
  204. except:
  205. mmd_4orbits = -1
  206. # evaluate MMD validate
  207. mmd_degree_validate = eval.stats.degree_stats(graph_validate, graph_pred)
  208. mmd_clustering_validate = eval.stats.clustering_stats(graph_validate, graph_pred)
  209. try:
  210. mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_validate, graph_pred)
  211. except:
  212. mmd_4orbits_validate = -1
  213. # write results
  214. f.write(str(sample_time)+','+
  215. str(epoch)+','+
  216. str(mmd_degree_validate)+','+
  217. str(mmd_clustering_validate)+','+
  218. str(mmd_4orbits_validate)+','+
  219. str(mmd_degree)+','+
  220. str(mmd_clustering)+','+
  221. str(mmd_4orbits)+'\n')
  222. print('degree',mmd_degree,'clustering',mmd_clustering,'orbits',mmd_4orbits)
  223. # get internal MMD (MMD between ground truth validation and test sets)
  224. if model_name == 'Internal':
  225. mmd_degree_validate = eval.stats.degree_stats(graph_test, graph_validate)
  226. mmd_clustering_validate = eval.stats.clustering_stats(graph_test, graph_validate)
  227. try:
  228. mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_test, graph_validate)
  229. except:
  230. mmd_4orbits_validate = -1
  231. f.write(str(-1) + ',' + str(-1) + ',' + str(mmd_degree_validate) + ',' + str(
  232. mmd_clustering_validate) + ',' + str(mmd_4orbits_validate)
  233. + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + '\n')
  234. # get MMD between ground truth and its perturbed graphs
  235. if model_name == 'Noise':
  236. graph_validate_perturbed = perturb(graph_validate, 0.05)
  237. mmd_degree_validate = eval.stats.degree_stats(graph_test, graph_validate_perturbed)
  238. mmd_clustering_validate = eval.stats.clustering_stats(graph_test, graph_validate_perturbed)
  239. try:
  240. mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_test, graph_validate_perturbed)
  241. except:
  242. mmd_4orbits_validate = -1
  243. f.write(str(-1) + ',' + str(-1) + ',' + str(mmd_degree_validate) + ',' + str(
  244. mmd_clustering_validate) + ',' + str(mmd_4orbits_validate)
  245. + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + '\n')
  246. # get E-R MMD
  247. if model_name == 'E-R':
  248. graph_pred = Graph_generator_baseline(graph_train,generator='Gnp')
  249. # clean graphs
  250. if is_clean:
  251. graph_test, graph_pred = clean_graphs(graph_test, graph_pred)
  252. print('len graph_test', len(graph_test))
  253. print('len graph_pred', len(graph_pred))
  254. mmd_degree = eval.stats.degree_stats(graph_test, graph_pred)
  255. mmd_clustering = eval.stats.clustering_stats(graph_test, graph_pred)
  256. try:
  257. mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_test, graph_pred)
  258. except:
  259. mmd_4orbits_validate = -1
  260. f.write(str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1)
  261. + ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits_validate) + '\n')
  262. # get B-A MMD
  263. if model_name == 'B-A':
  264. graph_pred = Graph_generator_baseline(graph_train, generator='BA')
  265. # clean graphs
  266. if is_clean:
  267. graph_test, graph_pred = clean_graphs(graph_test, graph_pred)
  268. print('len graph_test', len(graph_test))
  269. print('len graph_pred', len(graph_pred))
  270. mmd_degree = eval.stats.degree_stats(graph_test, graph_pred)
  271. mmd_clustering = eval.stats.clustering_stats(graph_test, graph_pred)
  272. try:
  273. mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_test, graph_pred)
  274. except:
  275. mmd_4orbits_validate = -1
  276. f.write(str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1)
  277. + ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits_validate) + '\n')
  278. # get performance for baseline approaches
  279. if 'Baseline' in model_name:
  280. # read test graph
  281. for epoch in range(epoch_start, epoch_end, epoch_step):
  282. # get filename
  283. fname_pred = dir_input + model_name + '_' + dataset_name + '_' + str(
  284. 64) + '_pred_' + str(epoch) + '.dat'
  285. # load graphs
  286. try:
  287. graph_pred = utils.load_graph_list(fname_pred, is_real=True) # default False
  288. except:
  289. print('Not found: ' + fname_pred)
  290. logging.warning('Not found: ' + fname_pred)
  291. continue
  292. # clean graphs
  293. if is_clean:
  294. graph_test, graph_pred = clean_graphs(graph_test, graph_pred)
  295. else:
  296. shuffle(graph_pred)
  297. graph_pred = graph_pred[0:len(graph_test)]
  298. print('len graph_test', len(graph_test))
  299. print('len graph_validate', len(graph_validate))
  300. print('len graph_pred', len(graph_pred))
  301. graph_pred_aver = 0
  302. for graph in graph_pred:
  303. graph_pred_aver += graph.number_of_nodes()
  304. graph_pred_aver /= len(graph_pred)
  305. print('pred average len', graph_pred_aver)
  306. # evaluate MMD test
  307. mmd_degree = eval.stats.degree_stats(graph_test, graph_pred)
  308. mmd_clustering = eval.stats.clustering_stats(graph_test, graph_pred)
  309. try:
  310. mmd_4orbits = eval.stats.orbit_stats_all(graph_test, graph_pred)
  311. except:
  312. mmd_4orbits = -1
  313. # evaluate MMD validate
  314. mmd_degree_validate = eval.stats.degree_stats(graph_validate, graph_pred)
  315. mmd_clustering_validate = eval.stats.clustering_stats(graph_validate, graph_pred)
  316. try:
  317. mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_validate, graph_pred)
  318. except:
  319. mmd_4orbits_validate = -1
  320. # write results
  321. f.write(str(-1) + ',' + str(epoch) + ',' + str(mmd_degree_validate) + ',' + str(
  322. mmd_clustering_validate) + ',' + str(mmd_4orbits_validate)
  323. + ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits) + '\n')
  324. print('degree', mmd_degree, 'clustering', mmd_clustering, 'orbits', mmd_4orbits)
  325. return True
  326. def evaluation(args_evaluate,dir_input, dir_output, model_name_all, dataset_name_all, args, overwrite = True):
  327. ''' Evaluate the performance of a set of models on a set of datasets.
  328. '''
  329. for model_name in model_name_all:
  330. for dataset_name in dataset_name_all:
  331. # check output exist
  332. fname_output = dir_output+model_name+'_'+dataset_name+'.csv'
  333. print('processing: '+dir_output + model_name + '_' + dataset_name + '.csv')
  334. logging.info('processing: '+dir_output + model_name + '_' + dataset_name + '.csv')
  335. if overwrite==False and os.path.isfile(fname_output):
  336. print(dir_output+model_name+'_'+dataset_name+'.csv exists!')
  337. logging.info(dir_output+model_name+'_'+dataset_name+'.csv exists!')
  338. continue
  339. evaluation_epoch(dir_input,fname_output,model_name,dataset_name,args,is_clean=True, epoch_start=args_evaluate.epoch_start,epoch_end=args_evaluate.epoch_end,epoch_step=args_evaluate.epoch_step)
  340. def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines,
  341. eval_every, epoch_range=None, out_file_prefix=None):
  342. ''' Evaluate list of predicted graphs compared to ground truth, stored in files.
  343. Args:
  344. baselines: dict mapping name of the baseline to list of generated graphs.
  345. '''
  346. if out_file_prefix is not None:
  347. out_files = {
  348. 'train': open(out_file_prefix + '_train.txt', 'w+'),
  349. 'compare': open(out_file_prefix + '_compare.txt', 'w+')
  350. }
  351. out_files['train'].write('degree,clustering,orbits4\n')
  352. line = 'metric,real,ours,perturbed'
  353. for bl in baselines:
  354. line += ',' + bl
  355. line += '\n'
  356. out_files['compare'].write(line)
  357. results = {
  358. 'deg': {
  359. 'real': 0,
  360. 'ours': 100, # take min over all training epochs
  361. 'perturbed': 0,
  362. 'kron': 0},
  363. 'clustering': {
  364. 'real': 0,
  365. 'ours': 100,
  366. 'perturbed': 0,
  367. 'kron': 0},
  368. 'orbits4': {
  369. 'real': 0,
  370. 'ours': 100,
  371. 'perturbed': 0,
  372. 'kron': 0}
  373. }
  374. num_evals = len(pred_graphs_filename)
  375. if epoch_range is None:
  376. epoch_range = [i * eval_every for i in range(num_evals)]
  377. for i in range(num_evals):
  378. real_g_list = utils.load_graph_list(real_graph_filename)
  379. #pred_g_list = utils.load_graph_list(pred_graphs_filename[i])
  380. # contains all predicted G
  381. pred_g_list_raw = utils.load_graph_list(pred_graphs_filename[i])
  382. if len(real_g_list)>200:
  383. real_g_list = real_g_list[0:200]
  384. shuffle(real_g_list)
  385. shuffle(pred_g_list_raw)
  386. # get length
  387. real_g_len_list = np.array([len(real_g_list[i]) for i in range(len(real_g_list))])
  388. pred_g_len_list_raw = np.array([len(pred_g_list_raw[i]) for i in range(len(pred_g_list_raw))])
  389. # get perturb real
  390. #perturbed_g_list_001 = perturb(real_g_list, 0.01)
  391. perturbed_g_list_005 = perturb(real_g_list, 0.05)
  392. #perturbed_g_list_010 = perturb(real_g_list, 0.10)
  393. # select pred samples
  394. # The number of nodes are sampled from the similar distribution as the training set
  395. pred_g_list = []
  396. pred_g_len_list = []
  397. for value in real_g_len_list:
  398. pred_idx = find_nearest_idx(pred_g_len_list_raw, value)
  399. pred_g_list.append(pred_g_list_raw[pred_idx])
  400. pred_g_len_list.append(pred_g_len_list_raw[pred_idx])
  401. # delete
  402. pred_g_len_list_raw = np.delete(pred_g_len_list_raw, pred_idx)
  403. del pred_g_list_raw[pred_idx]
  404. if len(pred_g_list) == len(real_g_list):
  405. break
  406. # pred_g_len_list = np.array(pred_g_len_list)
  407. print('################## epoch {} ##################'.format(epoch_range[i]))
  408. # info about graph size
  409. print('real average nodes',
  410. sum([real_g_list[i].number_of_nodes() for i in range(len(real_g_list))]) / len(real_g_list))
  411. print('pred average nodes',
  412. sum([pred_g_list[i].number_of_nodes() for i in range(len(pred_g_list))]) / len(pred_g_list))
  413. print('num of real graphs', len(real_g_list))
  414. print('num of pred graphs', len(pred_g_list))
  415. # ========================================
  416. # Evaluation
  417. # ========================================
  418. mid = len(real_g_list) // 2
  419. dist_degree, dist_clustering = compute_basic_stats(real_g_list[:mid], real_g_list[mid:])
  420. #dist_4cycle = eval.stats.motif_stats(real_g_list[:mid], real_g_list[mid:])
  421. dist_4orbits = eval.stats.orbit_stats_all(real_g_list[:mid], real_g_list[mid:])
  422. print('degree dist among real: ', dist_degree)
  423. print('clustering dist among real: ', dist_clustering)
  424. #print('4 cycle dist among real: ', dist_4cycle)
  425. print('orbits dist among real: ', dist_4orbits)
  426. results['deg']['real'] += dist_degree
  427. results['clustering']['real'] += dist_clustering
  428. results['orbits4']['real'] += dist_4orbits
  429. dist_degree, dist_clustering = compute_basic_stats(real_g_list, pred_g_list)
  430. #dist_4cycle = eval.stats.motif_stats(real_g_list, pred_g_list)
  431. dist_4orbits = eval.stats.orbit_stats_all(real_g_list, pred_g_list)
  432. print('degree dist between real and pred at epoch ', epoch_range[i], ': ', dist_degree)
  433. print('clustering dist between real and pred at epoch ', epoch_range[i], ': ', dist_clustering)
  434. #print('4 cycle dist between real and pred at epoch: ', epoch_range[i], dist_4cycle)
  435. print('orbits dist between real and pred at epoch ', epoch_range[i], ': ', dist_4orbits)
  436. results['deg']['ours'] = min(dist_degree, results['deg']['ours'])
  437. results['clustering']['ours'] = min(dist_clustering, results['clustering']['ours'])
  438. results['orbits4']['ours'] = min(dist_4orbits, results['orbits4']['ours'])
  439. # performance at training time
  440. out_files['train'].write(str(dist_degree) + ',')
  441. out_files['train'].write(str(dist_clustering) + ',')
  442. out_files['train'].write(str(dist_4orbits) + ',')
  443. dist_degree, dist_clustering = compute_basic_stats(real_g_list, perturbed_g_list_005)
  444. #dist_4cycle = eval.stats.motif_stats(real_g_list, perturbed_g_list_005)
  445. dist_4orbits = eval.stats.orbit_stats_all(real_g_list, perturbed_g_list_005)
  446. print('degree dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_degree)
  447. print('clustering dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_clustering)
  448. #print('4 cycle dist between real and perturbed at epoch: ', epoch_range[i], dist_4cycle)
  449. print('orbits dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_4orbits)
  450. results['deg']['perturbed'] += dist_degree
  451. results['clustering']['perturbed'] += dist_clustering
  452. results['orbits4']['perturbed'] += dist_4orbits
  453. if i == 0:
  454. # Baselines
  455. for baseline in baselines:
  456. dist_degree, dist_clustering = compute_basic_stats(real_g_list, baselines[baseline])
  457. dist_4orbits = eval.stats.orbit_stats_all(real_g_list, baselines[baseline])
  458. results['deg'][baseline] = dist_degree
  459. results['clustering'][baseline] = dist_clustering
  460. results['orbits4'][baseline] = dist_4orbits
  461. print('Kron: deg=', dist_degree, ', clustering=', dist_clustering,
  462. ', orbits4=', dist_4orbits)
  463. out_files['train'].write('\n')
  464. for metric, methods in results.items():
  465. methods['real'] /= num_evals
  466. methods['perturbed'] /= num_evals
  467. # Write results
  468. for metric, methods in results.items():
  469. line = metric+','+ \
  470. str(methods['real'])+','+ \
  471. str(methods['ours'])+','+ \
  472. str(methods['perturbed'])
  473. for baseline in baselines:
  474. line += ',' + str(methods[baseline])
  475. line += '\n'
  476. out_files['compare'].write(line)
  477. for _, out_f in out_files.items():
  478. out_f.close()
  479. def eval_performance(datadir, prefix=None, args=None, eval_every=200, out_file_prefix=None,
  480. sample_time = 2, baselines={}):
  481. if args is None:
  482. real_graphs_filename = [datadir + f for f in os.listdir(datadir)
  483. if re.match(prefix + '.*real.*\.dat', f)]
  484. pred_graphs_filename = [datadir + f for f in os.listdir(datadir)
  485. if re.match(prefix + '.*pred.*\.dat', f)]
  486. eval_list(real_graphs_filename, pred_graphs_filename, prefix, 200)
  487. else:
  488. # # for vanilla graphrnn
  489. # real_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
  490. # str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(args.bptt_len) + '.dat' for epoch in range(0,50001,eval_every)]
  491. # pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
  492. # str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(args.bptt_len) + '.dat' for epoch in range(0,50001,eval_every)]
  493. real_graph_filename = datadir+args.graph_save_path + args.fname_test + '0.dat'
  494. # for proposed model
  495. end_epoch = 3001
  496. epoch_range = range(eval_every, end_epoch, eval_every)
  497. pred_graphs_filename = [datadir+args.graph_save_path + args.fname_pred+str(epoch)+'_'+str(sample_time)+'.dat'
  498. for epoch in epoch_range]
  499. # for baseline model
  500. #pred_graphs_filename = [datadir+args.fname_baseline+'.dat']
  501. #real_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
  502. # str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(
  503. # args.bptt_len) + '_' + str(args.gumbel) + '.dat' for epoch in range(10000, 50001, eval_every)]
  504. #pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
  505. # str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(
  506. # args.bptt_len) + '_' + str(args.gumbel) + '.dat' for epoch in range(10000, 50001, eval_every)]
  507. eval_list_fname(real_graph_filename, pred_graphs_filename, baselines,
  508. epoch_range=epoch_range,
  509. eval_every=eval_every,
  510. out_file_prefix=out_file_prefix)
  511. def process_kron(kron_dir):
  512. txt_files = []
  513. for f in os.listdir(kron_dir):
  514. filename = os.fsdecode(f)
  515. if filename.endswith('.txt'):
  516. txt_files.append(filename)
  517. elif filename.endswith('.dat'):
  518. return utils.load_graph_list(os.path.join(kron_dir, filename))
  519. G_list = []
  520. for filename in txt_files:
  521. G_list.append(utils.snap_txt_output_to_nx(os.path.join(kron_dir, filename)))
  522. return G_list
  523. if __name__ == '__main__':
  524. args = Args()
  525. args_evaluate = Args_evaluate()
  526. parser = argparse.ArgumentParser(description='Evaluation arguments.')
  527. feature_parser = parser.add_mutually_exclusive_group(required=False)
  528. feature_parser.add_argument('--export-real', dest='export', action='store_true')
  529. feature_parser.add_argument('--no-export-real', dest='export', action='store_false')
  530. feature_parser.add_argument('--kron-dir', dest='kron_dir',
  531. help='Directory where graphs generated by kronecker method is stored.')
  532. parser.add_argument('--testfile', dest='test_file',
  533. help='The file that stores list of graphs to be evaluated. Only used when 1 list of '
  534. 'graphs is to be evaluated.')
  535. parser.add_argument('--dir-prefix', dest='dir_prefix',
  536. help='The file that stores list of graphs to be evaluated. Can be used when evaluating multiple'
  537. 'models on multiple datasets.')
  538. parser.add_argument('--graph-type', dest='graph_type',
  539. help='Type of graphs / dataset.')
  540. parser.set_defaults(export=False, kron_dir='', test_file='',
  541. dir_prefix='',
  542. graph_type=args.graph_type)
  543. prog_args = parser.parse_args()
  544. # dir_prefix = prog_args.dir_prefix
  545. # dir_prefix = "/dfs/scratch0/jiaxuany0/"
  546. dir_prefix = args.dir_input
  547. time_now = strftime("%Y-%m-%d %H:%M:%S", gmtime())
  548. if not os.path.isdir('logs/'):
  549. os.makedirs('logs/')
  550. logging.basicConfig(filename='logs/evaluate' + time_now + '.log', level=logging.INFO)
  551. if prog_args.export:
  552. if not os.path.isdir('eval_results'):
  553. os.makedirs('eval_results')
  554. if not os.path.isdir('eval_results/ground_truth'):
  555. os.makedirs('eval_results/ground_truth')
  556. out_dir = os.path.join('eval_results/ground_truth', prog_args.graph_type)
  557. if not os.path.isdir(out_dir):
  558. os.makedirs(out_dir)
  559. output_prefix = os.path.join(out_dir, prog_args.graph_type)
  560. print('Export ground truth to prefix: ', output_prefix)
  561. if prog_args.graph_type == 'grid':
  562. graphs = []
  563. for i in range(10,20):
  564. for j in range(10,20):
  565. graphs.append(nx.grid_2d_graph(i,j))
  566. utils.export_graphs_to_txt(graphs, output_prefix)
  567. elif prog_args.graph_type == 'caveman':
  568. graphs = []
  569. for i in range(2, 3):
  570. for j in range(30, 81):
  571. for k in range(10):
  572. graphs.append(caveman_special(i,j, p_edge=0.3))
  573. utils.export_graphs_to_txt(graphs, output_prefix)
  574. elif prog_args.graph_type == 'citeseer':
  575. graphs = utils.citeseer_ego()
  576. utils.export_graphs_to_txt(graphs, output_prefix)
  577. else:
  578. # load from directory
  579. input_path = dir_prefix + real_graph_filename
  580. g_list = utils.load_graph_list(input_path)
  581. utils.export_graphs_to_txt(g_list, output_prefix)
  582. elif not prog_args.kron_dir == '':
  583. kron_g_list = process_kron(prog_args.kron_dir)
  584. fname = os.path.join(prog_args.kron_dir, prog_args.graph_type + '.dat')
  585. print([g.number_of_nodes() for g in kron_g_list])
  586. utils.save_graph_list(kron_g_list, fname)
  587. elif not prog_args.test_file == '':
  588. # evaluate single .dat file containing list of test graphs (networkx format)
  589. graphs = utils.load_graph_list(prog_args.test_file)
  590. eval_single_list(graphs, dir_input=dir_prefix+'graphs/', dataset_name='grid')
  591. ## if you don't try kronecker, only the following part is needed
  592. else:
  593. if not os.path.isdir(dir_prefix+'eval_results'):
  594. os.makedirs(dir_prefix+'eval_results')
  595. evaluation(args_evaluate,dir_input=dir_prefix+"graphs/", dir_output=dir_prefix+"eval_results/",
  596. model_name_all=args_evaluate.model_name_all,dataset_name_all=args_evaluate.dataset_name_all,args=args,overwrite=True)