You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

stats.py 8.2KB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import concurrent.futures
  2. from datetime import datetime
  3. from functools import partial
  4. import numpy as np
  5. import networkx as nx
  6. import os
  7. import pickle as pkl
  8. import subprocess as sp
  9. import time
  10. import eval.mmd as mmd
  11. PRINT_TIME = False
  12. def degree_worker(G):
  13. return np.array(nx.degree_histogram(G))
  14. def add_tensor(x,y):
  15. support_size = max(len(x), len(y))
  16. if len(x) < len(y):
  17. x = np.hstack((x, [0.0] * (support_size - len(x))))
  18. elif len(y) < len(x):
  19. y = np.hstack((y, [0.0] * (support_size - len(y))))
  20. return x+y
  21. def degree_stats(graph_ref_list, graph_pred_list, is_parallel=False):
  22. ''' Compute the distance between the degree distributions of two unordered sets of graphs.
  23. Args:
  24. graph_ref_list, graph_target_list: two lists of networkx graphs to be evaluated
  25. '''
  26. sample_ref = []
  27. sample_pred = []
  28. # in case an empty graph is generated
  29. graph_pred_list_remove_empty = [G for G in graph_pred_list if not G.number_of_nodes() == 0]
  30. prev = datetime.now()
  31. if is_parallel:
  32. with concurrent.futures.ProcessPoolExecutor() as executor:
  33. for deg_hist in executor.map(degree_worker, graph_ref_list):
  34. sample_ref.append(deg_hist)
  35. with concurrent.futures.ProcessPoolExecutor() as executor:
  36. for deg_hist in executor.map(degree_worker, graph_pred_list_remove_empty):
  37. sample_pred.append(deg_hist)
  38. else:
  39. for i in range(len(graph_ref_list)):
  40. degree_temp = np.array(nx.degree_histogram(graph_ref_list[i]))
  41. sample_ref.append(degree_temp)
  42. for i in range(len(graph_pred_list_remove_empty)):
  43. degree_temp = np.array(nx.degree_histogram(graph_pred_list_remove_empty[i]))
  44. sample_pred.append(degree_temp)
  45. print(len(sample_ref),len(sample_pred))
  46. mmd_dist = mmd.compute_mmd(sample_ref, sample_pred, kernel=mmd.gaussian_emd)
  47. elapsed = datetime.now() - prev
  48. if PRINT_TIME:
  49. print('Time computing degree mmd: ', elapsed)
  50. return mmd_dist
  51. def clustering_worker(param):
  52. G, bins = param
  53. clustering_coeffs_list = list(nx.clustering(G).values())
  54. hist, _ = np.histogram(
  55. clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False)
  56. return hist
  57. def clustering_stats(graph_ref_list, graph_pred_list, bins=100, is_parallel=True):
  58. sample_ref = []
  59. sample_pred = []
  60. graph_pred_list_remove_empty = [G for G in graph_pred_list if not G.number_of_nodes() == 0]
  61. prev = datetime.now()
  62. if is_parallel:
  63. with concurrent.futures.ProcessPoolExecutor() as executor:
  64. for clustering_hist in executor.map(clustering_worker,
  65. [(G, bins) for G in graph_ref_list]):
  66. sample_ref.append(clustering_hist)
  67. with concurrent.futures.ProcessPoolExecutor() as executor:
  68. for clustering_hist in executor.map(clustering_worker,
  69. [(G, bins) for G in graph_pred_list_remove_empty]):
  70. sample_pred.append(clustering_hist)
  71. # check non-zero elements in hist
  72. #total = 0
  73. #for i in range(len(sample_pred)):
  74. # nz = np.nonzero(sample_pred[i])[0].shape[0]
  75. # total += nz
  76. #print(total)
  77. else:
  78. for i in range(len(graph_ref_list)):
  79. clustering_coeffs_list = list(nx.clustering(graph_ref_list[i]).values())
  80. hist, _ = np.histogram(
  81. clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False)
  82. sample_ref.append(hist)
  83. for i in range(len(graph_pred_list_remove_empty)):
  84. clustering_coeffs_list = list(nx.clustering(graph_pred_list_remove_empty[i]).values())
  85. hist, _ = np.histogram(
  86. clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False)
  87. sample_pred.append(hist)
  88. mmd_dist = mmd.compute_mmd(sample_ref, sample_pred, kernel=mmd.gaussian_emd,
  89. sigma=1.0/10, distance_scaling=bins)
  90. elapsed = datetime.now() - prev
  91. if PRINT_TIME:
  92. print('Time computing clustering mmd: ', elapsed)
  93. return mmd_dist
  94. # maps motif/orbit name string to its corresponding list of indices from orca output
  95. motif_to_indices = {
  96. '3path' : [1, 2],
  97. '4cycle' : [8],
  98. }
  99. COUNT_START_STR = 'orbit counts: \n'
  100. def edge_list_reindexed(G):
  101. idx = 0
  102. id2idx = dict()
  103. for u in G.nodes():
  104. id2idx[str(u)] = idx
  105. idx += 1
  106. edges = []
  107. for (u, v) in G.edges():
  108. edges.append((id2idx[str(u)], id2idx[str(v)]))
  109. return edges
  110. def orca(graph):
  111. tmp_fname = 'eval/orca/tmp.txt'
  112. f = open(tmp_fname, 'w')
  113. f.write(str(graph.number_of_nodes()) + ' ' + str(graph.number_of_edges()) + '\n')
  114. for (u, v) in edge_list_reindexed(graph):
  115. f.write(str(u) + ' ' + str(v) + '\n')
  116. f.close()
  117. output = sp.check_output(['./eval/orca/orca', 'node', '4', 'eval/orca/tmp.txt', 'std'])
  118. output = output.decode('utf8').strip()
  119. idx = output.find(COUNT_START_STR) + len(COUNT_START_STR)
  120. output = output[idx:]
  121. node_orbit_counts = np.array([list(map(int, node_cnts.strip().split(' ') ))
  122. for node_cnts in output.strip('\n').split('\n')])
  123. try:
  124. os.remove(tmp_fname)
  125. except OSError:
  126. pass
  127. return node_orbit_counts
  128. def motif_stats(graph_ref_list, graph_pred_list, motif_type='4cycle', ground_truth_match=None, bins=100):
  129. # graph motif counts (int for each graph)
  130. # normalized by graph size
  131. total_counts_ref = []
  132. total_counts_pred = []
  133. num_matches_ref = []
  134. num_matches_pred = []
  135. graph_pred_list_remove_empty = [G for G in graph_pred_list if not G.number_of_nodes() == 0]
  136. indices = motif_to_indices[motif_type]
  137. for G in graph_ref_list:
  138. orbit_counts = orca(G)
  139. motif_counts = np.sum(orbit_counts[:, indices], axis=1)
  140. if ground_truth_match is not None:
  141. match_cnt = 0
  142. for elem in motif_counts:
  143. if elem == ground_truth_match:
  144. match_cnt += 1
  145. num_matches_ref.append(match_cnt / G.number_of_nodes())
  146. #hist, _ = np.histogram(
  147. # motif_counts, bins=bins, density=False)
  148. motif_temp = np.sum(motif_counts) / G.number_of_nodes()
  149. total_counts_ref.append(motif_temp)
  150. for G in graph_pred_list_remove_empty:
  151. orbit_counts = orca(G)
  152. motif_counts = np.sum(orbit_counts[:, indices], axis=1)
  153. if ground_truth_match is not None:
  154. match_cnt = 0
  155. for elem in motif_counts:
  156. if elem == ground_truth_match:
  157. match_cnt += 1
  158. num_matches_pred.append(match_cnt / G.number_of_nodes())
  159. motif_temp = np.sum(motif_counts) / G.number_of_nodes()
  160. total_counts_pred.append(motif_temp)
  161. mmd_dist = mmd.compute_mmd(total_counts_ref, total_counts_pred, kernel=mmd.gaussian,
  162. is_hist=False)
  163. #print('-------------------------')
  164. #print(np.sum(total_counts_ref) / len(total_counts_ref))
  165. #print('...')
  166. #print(np.sum(total_counts_pred) / len(total_counts_pred))
  167. #print('-------------------------')
  168. return mmd_dist
  169. def orbit_stats_all(graph_ref_list, graph_pred_list):
  170. total_counts_ref = []
  171. total_counts_pred = []
  172. graph_pred_list_remove_empty = [G for G in graph_pred_list if not G.number_of_nodes() == 0]
  173. for G in graph_ref_list:
  174. try:
  175. orbit_counts = orca(G)
  176. except:
  177. continue
  178. orbit_counts_graph = np.sum(orbit_counts, axis=0) / G.number_of_nodes()
  179. total_counts_ref.append(orbit_counts_graph)
  180. for G in graph_pred_list:
  181. try:
  182. orbit_counts = orca(G)
  183. except:
  184. continue
  185. orbit_counts_graph = np.sum(orbit_counts, axis=0) / G.number_of_nodes()
  186. total_counts_pred.append(orbit_counts_graph)
  187. total_counts_ref = np.array(total_counts_ref)
  188. total_counts_pred = np.array(total_counts_pred)
  189. mmd_dist = mmd.compute_mmd(total_counts_ref, total_counts_pred, kernel=mmd.gaussian,
  190. is_hist=False, sigma=30.0)
  191. print('-------------------------')
  192. print(np.sum(total_counts_ref, axis=0) / len(total_counts_ref))
  193. print('...')
  194. print(np.sum(total_counts_pred, axis=0) / len(total_counts_pred))
  195. print('-------------------------')
  196. return mmd_dist