You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluate.py 32KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690
  1. import argparse
  2. import logging
  3. import os
  4. from time import strftime, gmtime
  5. import eval.stats
  6. import utils
  7. # import main.Args
  8. from args import Args
  9. from baselines.baseline_simple import *
  10. class Args_evaluate():
  11. def __init__(self):
  12. # loop over the settings
  13. # self.model_name_all = ['GraphRNN_MLP','GraphRNN_RNN','Internal','Noise']
  14. # self.model_name_all = ['E-R', 'B-A']
  15. self.model_name_all = ['GraphRNN_RNN']
  16. # self.model_name_all = ['Baseline_DGMG']
  17. # list of dataset to evaluate
  18. # use a list of 1 element to evaluate a single dataset
  19. # self.dataset_name_all = ['caveman', 'grid', 'barabasi', 'citeseer', 'DD']
  20. self.dataset_name_all = ['caveman_small']
  21. # self.dataset_name_all = ['citeseer_small','caveman_small']
  22. # self.dataset_name_all = ['barabasi_noise0','barabasi_noise2','barabasi_noise4','barabasi_noise6','barabasi_noise8','barabasi_noise10']
  23. # self.dataset_name_all = ['caveman_small', 'ladder_small', 'grid_small', 'ladder_small', 'enzymes_small', 'barabasi_small','citeseer_small']
  24. self.epoch_start = 10
  25. self.epoch_end = 101
  26. self.epoch_step = 10
  27. def find_nearest_idx(array, value):
  28. idx = (np.abs(array - value)).argmin()
  29. return idx
  30. def extract_result_id_and_epoch(name, prefix, suffix):
  31. '''
  32. Args:
  33. eval_every: the number of epochs between consecutive evaluations
  34. suffix: real_ or pred_
  35. Returns:
  36. A tuple of (id, epoch number) extracted from the filename string
  37. '''
  38. pos = name.find(suffix) + len(suffix)
  39. end_pos = name.find('.dat')
  40. result_id = name[pos:end_pos]
  41. pos = name.find(prefix) + len(prefix)
  42. end_pos = name.find('_', pos)
  43. epochs = int(name[pos:end_pos])
  44. return result_id, epochs
  45. def eval_list(real_graphs_filename, pred_graphs_filename, prefix, eval_every):
  46. real_graphs_dict = {}
  47. pred_graphs_dict = {}
  48. for fname in real_graphs_filename:
  49. result_id, epochs = extract_result_id_and_epoch(fname, prefix, 'real_')
  50. if not epochs % eval_every == 0:
  51. continue
  52. if result_id not in real_graphs_dict:
  53. real_graphs_dict[result_id] = {}
  54. real_graphs_dict[result_id][epochs] = fname
  55. for fname in pred_graphs_filename:
  56. result_id, epochs = extract_result_id_and_epoch(fname, prefix, 'pred_')
  57. if not epochs % eval_every == 0:
  58. continue
  59. if result_id not in pred_graphs_dict:
  60. pred_graphs_dict[result_id] = {}
  61. pred_graphs_dict[result_id][epochs] = fname
  62. for result_id in real_graphs_dict.keys():
  63. for epochs in sorted(real_graphs_dict[result_id]):
  64. real_g_list = utils.load_graph_list(real_graphs_dict[result_id][epochs])
  65. pred_g_list = utils.load_graph_list(pred_graphs_dict[result_id][epochs])
  66. shuffle(real_g_list)
  67. shuffle(pred_g_list)
  68. perturbed_g_list = perturb(real_g_list, 0.05)
  69. # dist = eval.stats.degree_stats(real_g_list, pred_g_list)
  70. dist = eval.stats.clustering_stats(real_g_list, pred_g_list)
  71. print('dist between real and pred (', result_id, ') at epoch ', epochs, ': ', dist)
  72. # dist = eval.stats.degree_stats(real_g_list, perturbed_g_list)
  73. dist = eval.stats.clustering_stats(real_g_list, perturbed_g_list)
  74. print('dist between real and perturbed: ', dist)
  75. mid = len(real_g_list) // 2
  76. # dist = eval.stats.degree_stats(real_g_list[:mid], real_g_list[mid:])
  77. dist = eval.stats.clustering_stats(real_g_list[:mid], real_g_list[mid:])
  78. print('dist among real: ', dist)
  79. def compute_basic_stats(real_g_list, target_g_list):
  80. dist_degree = eval.stats.degree_stats(real_g_list, target_g_list)
  81. dist_clustering = eval.stats.clustering_stats(real_g_list, target_g_list)
  82. return dist_degree, dist_clustering
  83. def clean_graphs(graph_real, graph_pred):
  84. ''' Selecting graphs generated that have the similar sizes.
  85. It is usually necessary for GraphRNN-S version, but not the full GraphRNN model.
  86. '''
  87. shuffle(graph_real)
  88. shuffle(graph_pred)
  89. # get length
  90. real_graph_len = np.array([len(graph_real[i]) for i in range(len(graph_real))])
  91. pred_graph_len = np.array([len(graph_pred[i]) for i in range(len(graph_pred))])
  92. # select pred samples
  93. # The number of nodes are sampled from the similar distribution as the training set
  94. pred_graph_new = []
  95. pred_graph_len_new = []
  96. for value in real_graph_len:
  97. pred_idx = find_nearest_idx(pred_graph_len, value)
  98. pred_graph_new.append(graph_pred[pred_idx])
  99. pred_graph_len_new.append(pred_graph_len[pred_idx])
  100. return graph_real, pred_graph_new
  101. def load_ground_truth(dir_input, dataset_name, model_name='GraphRNN_RNN'):
  102. ''' Read ground truth graphs.
  103. '''
  104. if not 'small' in dataset_name:
  105. hidden = 128
  106. else:
  107. hidden = 64
  108. if model_name == 'Internal' or model_name == 'Noise' or model_name == 'B-A' or model_name == 'E-R':
  109. fname_test = dir_input + 'GraphRNN_MLP' + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(
  110. hidden) + '_test_' + str(0) + '.dat'
  111. else:
  112. fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(
  113. hidden) + '_test_' + str(0) + '.dat'
  114. try:
  115. graph_test = utils.load_graph_list(fname_test, is_real=True)
  116. except:
  117. print('Not found: ' + fname_test)
  118. logging.warning('Not found: ' + fname_test)
  119. return None
  120. return graph_test
  121. def eval_single_list(graphs, dir_input, dataset_name):
  122. ''' Evaluate a list of graphs by comparing with graphs in directory dir_input.
  123. Args:
  124. dir_input: directory where ground truth graph list is stored
  125. dataset_name: name of the dataset (ground truth)
  126. '''
  127. graph_test = load_ground_truth(dir_input, dataset_name)
  128. graph_test_len = len(graph_test)
  129. graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set
  130. mmd_degree = eval.stats.degree_stats(graph_test, graphs)
  131. mmd_clustering = eval.stats.clustering_stats(graph_test, graphs)
  132. try:
  133. mmd_4orbits = eval.stats.orbit_stats_all(graph_test, graphs)
  134. except:
  135. mmd_4orbits = -1
  136. print('deg: ', mmd_degree)
  137. print('clustering: ', mmd_clustering)
  138. print('orbits: ', mmd_4orbits)
  139. def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is_clean=True, epoch_start=1000,
  140. epoch_end=3001, epoch_step=100):
  141. with open(fname_output, 'w+') as f:
  142. f.write(
  143. 'sample_time,epoch,degree_validate,clustering_validate,orbits4_validate,degree_test,clustering_test,orbits4_test\n')
  144. # TODO: Maybe refactor into a separate file/function that specifies THE naming convention
  145. # across main and evaluate
  146. if not 'small' in dataset_name:
  147. hidden = 128
  148. else:
  149. hidden = 64
  150. # read real graph
  151. if model_name == 'Internal' or model_name == 'Noise' or model_name == 'B-A' or model_name == 'E-R':
  152. fname_test = dir_input + 'GraphRNN_MLP' + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(
  153. hidden) + '_test_' + str(0) + '.dat'
  154. elif 'Baseline' in model_name:
  155. fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(64) + '_test_' + str(0) + '.dat'
  156. else:
  157. fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(
  158. hidden) + '_test_' + str(0) + '.dat'
  159. try:
  160. graph_test = utils.load_graph_list(fname_test, is_real=True)
  161. except:
  162. print('Not found: ' + fname_test)
  163. logging.warning('Not found: ' + fname_test)
  164. return None
  165. graph_test_len = len(graph_test)
  166. graph_train = graph_test[0:int(0.8 * graph_test_len)] # train
  167. graph_validate = graph_test[0:int(0.2 * graph_test_len)] # validate
  168. graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set
  169. graph_test_aver = 0
  170. for graph in graph_test:
  171. graph_test_aver += graph.number_of_nodes()
  172. graph_test_aver /= len(graph_test)
  173. print('test average len', graph_test_aver)
  174. # get performance for proposed approaches
  175. if 'GraphRNN' in model_name:
  176. # read test graph
  177. for epoch in range(epoch_start, epoch_end, epoch_step):
  178. for sample_time in range(1, 4):
  179. # get filename
  180. fname_pred = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(
  181. hidden) + '_pred_' + str(epoch) + '_' + str(sample_time) + '.dat'
  182. # load graphs
  183. try:
  184. graph_pred = utils.load_graph_list(fname_pred, is_real=False) # default False
  185. except:
  186. print('Not found: ' + fname_pred)
  187. logging.warning('Not found: ' + fname_pred)
  188. continue
  189. # clean graphs
  190. if is_clean:
  191. graph_test, graph_pred = clean_graphs(graph_test, graph_pred)
  192. else:
  193. shuffle(graph_pred)
  194. graph_pred = graph_pred[0:len(graph_test)]
  195. print('len graph_test', len(graph_test))
  196. print('len graph_validate', len(graph_validate))
  197. print('len graph_pred', len(graph_pred))
  198. graph_pred_aver = 0
  199. for graph in graph_pred:
  200. graph_pred_aver += graph.number_of_nodes()
  201. graph_pred_aver /= len(graph_pred)
  202. print('pred average len', graph_pred_aver)
  203. # evaluate MMD test
  204. mmd_degree = eval.stats.degree_stats(graph_test, graph_pred)
  205. mmd_clustering = eval.stats.clustering_stats(graph_test, graph_pred)
  206. try:
  207. mmd_4orbits = eval.stats.orbit_stats_all(graph_test, graph_pred)
  208. except:
  209. mmd_4orbits = -1
  210. # evaluate MMD validate
  211. mmd_degree_validate = eval.stats.degree_stats(graph_validate, graph_pred)
  212. mmd_clustering_validate = eval.stats.clustering_stats(graph_validate, graph_pred)
  213. try:
  214. mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_validate, graph_pred)
  215. except:
  216. mmd_4orbits_validate = -1
  217. # write results
  218. f.write(str(sample_time) + ',' +
  219. str(epoch) + ',' +
  220. str(mmd_degree_validate) + ',' +
  221. str(mmd_clustering_validate) + ',' +
  222. str(mmd_4orbits_validate) + ',' +
  223. str(mmd_degree) + ',' +
  224. str(mmd_clustering) + ',' +
  225. str(mmd_4orbits) + '\n')
  226. print('degree', mmd_degree, 'clustering', mmd_clustering, 'orbits', mmd_4orbits)
  227. # get internal MMD (MMD between ground truth validation and test sets)
  228. if model_name == 'Internal':
  229. mmd_degree_validate = eval.stats.degree_stats(graph_test, graph_validate)
  230. mmd_clustering_validate = eval.stats.clustering_stats(graph_test, graph_validate)
  231. try:
  232. mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_test, graph_validate)
  233. except:
  234. mmd_4orbits_validate = -1
  235. f.write(str(-1) + ',' + str(-1) + ',' + str(mmd_degree_validate) + ',' + str(
  236. mmd_clustering_validate) + ',' + str(mmd_4orbits_validate)
  237. + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + '\n')
  238. # get MMD between ground truth and its perturbed graphs
  239. if model_name == 'Noise':
  240. graph_validate_perturbed = perturb(graph_validate, 0.05)
  241. mmd_degree_validate = eval.stats.degree_stats(graph_test, graph_validate_perturbed)
  242. mmd_clustering_validate = eval.stats.clustering_stats(graph_test, graph_validate_perturbed)
  243. try:
  244. mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_test, graph_validate_perturbed)
  245. except:
  246. mmd_4orbits_validate = -1
  247. f.write(str(-1) + ',' + str(-1) + ',' + str(mmd_degree_validate) + ',' + str(
  248. mmd_clustering_validate) + ',' + str(mmd_4orbits_validate)
  249. + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + '\n')
  250. # get E-R MMD
  251. if model_name == 'E-R':
  252. graph_pred = Graph_generator_baseline(graph_train, generator='Gnp')
  253. # clean graphs
  254. if is_clean:
  255. graph_test, graph_pred = clean_graphs(graph_test, graph_pred)
  256. print('len graph_test', len(graph_test))
  257. print('len graph_pred', len(graph_pred))
  258. mmd_degree = eval.stats.degree_stats(graph_test, graph_pred)
  259. mmd_clustering = eval.stats.clustering_stats(graph_test, graph_pred)
  260. try:
  261. mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_test, graph_pred)
  262. except:
  263. mmd_4orbits_validate = -1
  264. f.write(str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1)
  265. + ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits_validate) + '\n')
  266. # get B-A MMD
  267. if model_name == 'B-A':
  268. graph_pred = Graph_generator_baseline(graph_train, generator='BA')
  269. # clean graphs
  270. if is_clean:
  271. graph_test, graph_pred = clean_graphs(graph_test, graph_pred)
  272. print('len graph_test', len(graph_test))
  273. print('len graph_pred', len(graph_pred))
  274. mmd_degree = eval.stats.degree_stats(graph_test, graph_pred)
  275. mmd_clustering = eval.stats.clustering_stats(graph_test, graph_pred)
  276. try:
  277. mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_test, graph_pred)
  278. except:
  279. mmd_4orbits_validate = -1
  280. f.write(str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1)
  281. + ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits_validate) + '\n')
  282. # get performance for baseline approaches
  283. if 'Baseline' in model_name:
  284. # read test graph
  285. for epoch in range(epoch_start, epoch_end, epoch_step):
  286. # get filename
  287. fname_pred = dir_input + model_name + '_' + dataset_name + '_' + str(
  288. 64) + '_pred_' + str(epoch) + '.dat'
  289. # load graphs
  290. try:
  291. graph_pred = utils.load_graph_list(fname_pred, is_real=True) # default False
  292. except:
  293. print('Not found: ' + fname_pred)
  294. logging.warning('Not found: ' + fname_pred)
  295. continue
  296. # clean graphs
  297. if is_clean:
  298. graph_test, graph_pred = clean_graphs(graph_test, graph_pred)
  299. else:
  300. shuffle(graph_pred)
  301. graph_pred = graph_pred[0:len(graph_test)]
  302. print('len graph_test', len(graph_test))
  303. print('len graph_validate', len(graph_validate))
  304. print('len graph_pred', len(graph_pred))
  305. graph_pred_aver = 0
  306. for graph in graph_pred:
  307. graph_pred_aver += graph.number_of_nodes()
  308. graph_pred_aver /= len(graph_pred)
  309. print('pred average len', graph_pred_aver)
  310. # evaluate MMD test
  311. mmd_degree = eval.stats.degree_stats(graph_test, graph_pred)
  312. mmd_clustering = eval.stats.clustering_stats(graph_test, graph_pred)
  313. try:
  314. mmd_4orbits = eval.stats.orbit_stats_all(graph_test, graph_pred)
  315. except:
  316. mmd_4orbits = -1
  317. # evaluate MMD validate
  318. mmd_degree_validate = eval.stats.degree_stats(graph_validate, graph_pred)
  319. mmd_clustering_validate = eval.stats.clustering_stats(graph_validate, graph_pred)
  320. try:
  321. mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_validate, graph_pred)
  322. except:
  323. mmd_4orbits_validate = -1
  324. # write results
  325. f.write(str(-1) + ',' + str(epoch) + ',' + str(mmd_degree_validate) + ',' + str(
  326. mmd_clustering_validate) + ',' + str(mmd_4orbits_validate)
  327. + ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits) + '\n')
  328. print('degree', mmd_degree, 'clustering', mmd_clustering, 'orbits', mmd_4orbits)
  329. return True
  330. def evaluation(args_evaluate, dir_input, dir_output, model_name_all, dataset_name_all, args, overwrite=True):
  331. ''' Evaluate the performance of a set of models on a set of datasets.
  332. '''
  333. for model_name in model_name_all:
  334. for dataset_name in dataset_name_all:
  335. # check output exist
  336. fname_output = dir_output + model_name + '_' + dataset_name + '.csv'
  337. print('processing: ' + dir_output + model_name + '_' + dataset_name + '.csv')
  338. logging.info('processing: ' + dir_output + model_name + '_' + dataset_name + '.csv')
  339. if overwrite == False and os.path.isfile(fname_output):
  340. print(dir_output + model_name + '_' + dataset_name + '.csv exists!')
  341. logging.info(dir_output + model_name + '_' + dataset_name + '.csv exists!')
  342. continue
  343. evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is_clean=True,
  344. epoch_start=args_evaluate.epoch_start, epoch_end=args_evaluate.epoch_end,
  345. epoch_step=args_evaluate.epoch_step)
  346. def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines,
  347. eval_every, epoch_range=None, out_file_prefix=None):
  348. ''' Evaluate list of predicted graphs compared to ground truth, stored in files.
  349. Args:
  350. baselines: dict mapping name of the baseline to list of generated graphs.
  351. '''
  352. if out_file_prefix is not None:
  353. out_files = {
  354. 'train': open(out_file_prefix + '_train.txt', 'w+'),
  355. 'compare': open(out_file_prefix + '_compare.txt', 'w+')
  356. }
  357. out_files['train'].write('degree,clustering,orbits4\n')
  358. line = 'metric,real,ours,perturbed'
  359. for bl in baselines:
  360. line += ',' + bl
  361. line += '\n'
  362. out_files['compare'].write(line)
  363. results = {
  364. 'deg': {
  365. 'real': 0,
  366. 'ours': 100, # take min over all training epochs
  367. 'perturbed': 0,
  368. 'kron': 0},
  369. 'clustering': {
  370. 'real': 0,
  371. 'ours': 100,
  372. 'perturbed': 0,
  373. 'kron': 0},
  374. 'orbits4': {
  375. 'real': 0,
  376. 'ours': 100,
  377. 'perturbed': 0,
  378. 'kron': 0}
  379. }
  380. num_evals = len(pred_graphs_filename)
  381. if epoch_range is None:
  382. epoch_range = [i * eval_every for i in range(num_evals)]
  383. for i in range(num_evals):
  384. real_g_list = utils.load_graph_list(real_graph_filename)
  385. # pred_g_list = utils.load_graph_list(pred_graphs_filename[i])
  386. # contains all predicted G
  387. pred_g_list_raw = utils.load_graph_list(pred_graphs_filename[i])
  388. if len(real_g_list) > 200:
  389. real_g_list = real_g_list[0:200]
  390. shuffle(real_g_list)
  391. shuffle(pred_g_list_raw)
  392. # get length
  393. real_g_len_list = np.array([len(real_g_list[i]) for i in range(len(real_g_list))])
  394. pred_g_len_list_raw = np.array([len(pred_g_list_raw[i]) for i in range(len(pred_g_list_raw))])
  395. # get perturb real
  396. # perturbed_g_list_001 = perturb(real_g_list, 0.01)
  397. perturbed_g_list_005 = perturb(real_g_list, 0.05)
  398. # perturbed_g_list_010 = perturb(real_g_list, 0.10)
  399. # select pred samples
  400. # The number of nodes are sampled from the similar distribution as the training set
  401. pred_g_list = []
  402. pred_g_len_list = []
  403. for value in real_g_len_list:
  404. pred_idx = find_nearest_idx(pred_g_len_list_raw, value)
  405. pred_g_list.append(pred_g_list_raw[pred_idx])
  406. pred_g_len_list.append(pred_g_len_list_raw[pred_idx])
  407. # delete
  408. pred_g_len_list_raw = np.delete(pred_g_len_list_raw, pred_idx)
  409. del pred_g_list_raw[pred_idx]
  410. if len(pred_g_list) == len(real_g_list):
  411. break
  412. # pred_g_len_list = np.array(pred_g_len_list)
  413. print('################## epoch {} ##################'.format(epoch_range[i]))
  414. # info about graph size
  415. print('real average nodes',
  416. sum([real_g_list[i].number_of_nodes() for i in range(len(real_g_list))]) / len(real_g_list))
  417. print('pred average nodes',
  418. sum([pred_g_list[i].number_of_nodes() for i in range(len(pred_g_list))]) / len(pred_g_list))
  419. print('num of real graphs', len(real_g_list))
  420. print('num of pred graphs', len(pred_g_list))
  421. # ========================================
  422. # Evaluation
  423. # ========================================
  424. mid = len(real_g_list) // 2
  425. dist_degree, dist_clustering = compute_basic_stats(real_g_list[:mid], real_g_list[mid:])
  426. # dist_4cycle = eval.stats.motif_stats(real_g_list[:mid], real_g_list[mid:])
  427. dist_4orbits = eval.stats.orbit_stats_all(real_g_list[:mid], real_g_list[mid:])
  428. print('degree dist among real: ', dist_degree)
  429. print('clustering dist among real: ', dist_clustering)
  430. # print('4 cycle dist among real: ', dist_4cycle)
  431. print('orbits dist among real: ', dist_4orbits)
  432. results['deg']['real'] += dist_degree
  433. results['clustering']['real'] += dist_clustering
  434. results['orbits4']['real'] += dist_4orbits
  435. dist_degree, dist_clustering = compute_basic_stats(real_g_list, pred_g_list)
  436. # dist_4cycle = eval.stats.motif_stats(real_g_list, pred_g_list)
  437. dist_4orbits = eval.stats.orbit_stats_all(real_g_list, pred_g_list)
  438. print('degree dist between real and pred at epoch ', epoch_range[i], ': ', dist_degree)
  439. print('clustering dist between real and pred at epoch ', epoch_range[i], ': ', dist_clustering)
  440. # print('4 cycle dist between real and pred at epoch: ', epoch_range[i], dist_4cycle)
  441. print('orbits dist between real and pred at epoch ', epoch_range[i], ': ', dist_4orbits)
  442. results['deg']['ours'] = min(dist_degree, results['deg']['ours'])
  443. results['clustering']['ours'] = min(dist_clustering, results['clustering']['ours'])
  444. results['orbits4']['ours'] = min(dist_4orbits, results['orbits4']['ours'])
  445. # performance at training time
  446. out_files['train'].write(str(dist_degree) + ',')
  447. out_files['train'].write(str(dist_clustering) + ',')
  448. out_files['train'].write(str(dist_4orbits) + ',')
  449. dist_degree, dist_clustering = compute_basic_stats(real_g_list, perturbed_g_list_005)
  450. # dist_4cycle = eval.stats.motif_stats(real_g_list, perturbed_g_list_005)
  451. dist_4orbits = eval.stats.orbit_stats_all(real_g_list, perturbed_g_list_005)
  452. print('degree dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_degree)
  453. print('clustering dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_clustering)
  454. # print('4 cycle dist between real and perturbed at epoch: ', epoch_range[i], dist_4cycle)
  455. print('orbits dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_4orbits)
  456. results['deg']['perturbed'] += dist_degree
  457. results['clustering']['perturbed'] += dist_clustering
  458. results['orbits4']['perturbed'] += dist_4orbits
  459. if i == 0:
  460. # Baselines
  461. for baseline in baselines:
  462. dist_degree, dist_clustering = compute_basic_stats(real_g_list, baselines[baseline])
  463. dist_4orbits = eval.stats.orbit_stats_all(real_g_list, baselines[baseline])
  464. results['deg'][baseline] = dist_degree
  465. results['clustering'][baseline] = dist_clustering
  466. results['orbits4'][baseline] = dist_4orbits
  467. print('Kron: deg=', dist_degree, ', clustering=', dist_clustering,
  468. ', orbits4=', dist_4orbits)
  469. out_files['train'].write('\n')
  470. for metric, methods in results.items():
  471. methods['real'] /= num_evals
  472. methods['perturbed'] /= num_evals
  473. # Write results
  474. for metric, methods in results.items():
  475. line = metric + ',' + \
  476. str(methods['real']) + ',' + \
  477. str(methods['ours']) + ',' + \
  478. str(methods['perturbed'])
  479. for baseline in baselines:
  480. line += ',' + str(methods[baseline])
  481. line += '\n'
  482. out_files['compare'].write(line)
  483. for _, out_f in out_files.items():
  484. out_f.close()
  485. def eval_performance(datadir, prefix=None, args=None, eval_every=200, out_file_prefix=None,
  486. sample_time=2, baselines={}):
  487. if args is None:
  488. real_graphs_filename = [datadir + f for f in os.listdir(datadir)
  489. if re.match(prefix + '.*real.*\.dat', f)]
  490. pred_graphs_filename = [datadir + f for f in os.listdir(datadir)
  491. if re.match(prefix + '.*pred.*\.dat', f)]
  492. eval_list(real_graphs_filename, pred_graphs_filename, prefix, 200)
  493. else:
  494. # # for vanilla graphrnn
  495. # real_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
  496. # str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(args.bptt_len) + '.dat' for epoch in range(0,50001,eval_every)]
  497. # pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
  498. # str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(args.bptt_len) + '.dat' for epoch in range(0,50001,eval_every)]
  499. real_graph_filename = datadir + args.graph_save_path + args.fname_test + '0.dat'
  500. # for proposed model
  501. end_epoch = 3001
  502. epoch_range = range(eval_every, end_epoch, eval_every)
  503. pred_graphs_filename = [
  504. datadir + args.graph_save_path + args.fname_pred + str(epoch) + '_' + str(sample_time) + '.dat'
  505. for epoch in epoch_range]
  506. # for baseline model
  507. # pred_graphs_filename = [datadir+args.fname_baseline+'.dat']
  508. # real_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
  509. # str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(
  510. # args.bptt_len) + '_' + str(args.gumbel) + '.dat' for epoch in range(10000, 50001, eval_every)]
  511. # pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
  512. # str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(
  513. # args.bptt_len) + '_' + str(args.gumbel) + '.dat' for epoch in range(10000, 50001, eval_every)]
  514. eval_list_fname(real_graph_filename, pred_graphs_filename, baselines,
  515. epoch_range=epoch_range,
  516. eval_every=eval_every,
  517. out_file_prefix=out_file_prefix)
  518. def process_kron(kron_dir):
  519. txt_files = []
  520. for f in os.listdir(kron_dir):
  521. filename = os.fsdecode(f)
  522. if filename.endswith('.txt'):
  523. txt_files.append(filename)
  524. elif filename.endswith('.dat'):
  525. return utils.load_graph_list(os.path.join(kron_dir, filename))
  526. G_list = []
  527. for filename in txt_files:
  528. G_list.append(utils.snap_txt_output_to_nx(os.path.join(kron_dir, filename)))
  529. return G_list
  530. if __name__ == '__main__':
  531. args = Args()
  532. args_evaluate = Args_evaluate()
  533. parser = argparse.ArgumentParser(description='Evaluation arguments.')
  534. feature_parser = parser.add_mutually_exclusive_group(required=False)
  535. feature_parser.add_argument('--export-real', dest='export', action='store_true')
  536. feature_parser.add_argument('--no-export-real', dest='export', action='store_false')
  537. feature_parser.add_argument('--kron-dir', dest='kron_dir',
  538. help='Directory where graphs generated by kronecker method is stored.')
  539. parser.add_argument('--testfile', dest='test_file',
  540. help='The file that stores list of graphs to be evaluated. Only used when 1 list of '
  541. 'graphs is to be evaluated.')
  542. parser.add_argument('--dir-prefix', dest='dir_prefix',
  543. help='The file that stores list of graphs to be evaluated. Can be used when evaluating multiple'
  544. 'models on multiple datasets.')
  545. parser.add_argument('--graph-type', dest='graph_type',
  546. help='Type of graphs / dataset.')
  547. parser.set_defaults(export=False, kron_dir='', test_file='',
  548. dir_prefix='',
  549. graph_type=args.graph_type)
  550. prog_args = parser.parse_args()
  551. # dir_prefix = prog_args.dir_prefix
  552. # dir_prefix = "/dfs/scratch0/jiaxuany0/"
  553. dir_prefix = args.dir_input
  554. time_now = strftime("%Y-%m-%d %H:%M:%S", gmtime())
  555. if not os.path.isdir('logs/'):
  556. os.makedirs('logs/')
  557. logging.basicConfig(filename='logs/evaluate' + time_now + '.log', level=logging.INFO)
  558. if prog_args.export:
  559. if not os.path.isdir('eval_results'):
  560. os.makedirs('eval_results')
  561. if not os.path.isdir('eval_results/ground_truth'):
  562. os.makedirs('eval_results/ground_truth')
  563. out_dir = os.path.join('eval_results/ground_truth', prog_args.graph_type)
  564. if not os.path.isdir(out_dir):
  565. os.makedirs(out_dir)
  566. output_prefix = os.path.join(out_dir, prog_args.graph_type)
  567. print('Export ground truth to prefix: ', output_prefix)
  568. if prog_args.graph_type == 'grid':
  569. graphs = []
  570. for i in range(10, 20):
  571. for j in range(10, 20):
  572. graphs.append(nx.grid_2d_graph(i, j))
  573. utils.export_graphs_to_txt(graphs, output_prefix)
  574. elif prog_args.graph_type == 'caveman':
  575. graphs = []
  576. for i in range(2, 3):
  577. for j in range(30, 81):
  578. for k in range(10):
  579. graphs.append(caveman_special(i, j, p_edge=0.3))
  580. utils.export_graphs_to_txt(graphs, output_prefix)
  581. elif prog_args.graph_type == 'citeseer':
  582. graphs = utils.citeseer_ego()
  583. utils.export_graphs_to_txt(graphs, output_prefix)
  584. else:
  585. # load from directory
  586. input_path = dir_prefix + real_graph_filename
  587. g_list = utils.load_graph_list(input_path)
  588. utils.export_graphs_to_txt(g_list, output_prefix)
  589. elif not prog_args.kron_dir == '':
  590. kron_g_list = process_kron(prog_args.kron_dir)
  591. fname = os.path.join(prog_args.kron_dir, prog_args.graph_type + '.dat')
  592. print([g.number_of_nodes() for g in kron_g_list])
  593. utils.save_graph_list(kron_g_list, fname)
  594. elif not prog_args.test_file == '':
  595. # evaluate single .dat file containing list of test graphs (networkx format)
  596. graphs = utils.load_graph_list(prog_args.test_file)
  597. eval_single_list(graphs, dir_input=dir_prefix + 'graphs/', dataset_name='grid')
  598. ## if you don't try kronecker, only the following part is needed
  599. else:
  600. if not os.path.isdir(dir_prefix + 'eval_results'):
  601. os.makedirs(dir_prefix + 'eval_results')
  602. evaluation(args_evaluate, dir_input=dir_prefix + "graphs/", dir_output=dir_prefix + "eval_results/",
  603. model_name_all=args_evaluate.model_name_all, dataset_name_all=args_evaluate.dataset_name_all,
  604. args=args, overwrite=True)