You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. import networkx as nx
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.init as init
  6. from torch.autograd import Variable
  7. import matplotlib.pyplot as plt
  8. import torch.nn.functional as F
  9. from torch import optim
  10. from torch.optim.lr_scheduler import MultiStepLR
  11. # import node2vec.src.main as nv
  12. from sklearn.decomposition import PCA
  13. import community
  14. import pickle
  15. import re
  16. import data
  17. def citeseer_ego():
  18. _, _, G = data.Graph_load(dataset='citeseer')
  19. G = max(nx.connected_component_subgraphs(G), key=len)
  20. G = nx.convert_node_labels_to_integers(G)
  21. graphs = []
  22. for i in range(G.number_of_nodes()):
  23. G_ego = nx.ego_graph(G, i, radius=3)
  24. if G_ego.number_of_nodes() >= 50 and (G_ego.number_of_nodes() <= 400):
  25. graphs.append(G_ego)
  26. return graphs
  27. def caveman_special(c=2,k=20,p_path=0.1,p_edge=0.3):
  28. p = p_path
  29. path_count = max(int(np.ceil(p * k)),1)
  30. G = nx.caveman_graph(c, k)
  31. # remove 50% edges
  32. p = 1-p_edge
  33. for (u, v) in list(G.edges()):
  34. if np.random.rand() < p and ((u < k and v < k) or (u >= k and v >= k)):
  35. G.remove_edge(u, v)
  36. # add path_count links
  37. for i in range(path_count):
  38. u = np.random.randint(0, k)
  39. v = np.random.randint(k, k * 2)
  40. G.add_edge(u, v)
  41. G = max(nx.connected_component_subgraphs(G), key=len)
  42. return G
  43. def n_community(c_sizes, p_inter=0.01):
  44. graphs = [nx.gnp_random_graph(c_sizes[i], 0.7, seed=i) for i in range(len(c_sizes))]
  45. G = nx.disjoint_union_all(graphs)
  46. communities = list(nx.connected_component_subgraphs(G))
  47. for i in range(len(communities)):
  48. subG1 = communities[i]
  49. nodes1 = list(subG1.nodes())
  50. for j in range(i+1, len(communities)):
  51. subG2 = communities[j]
  52. nodes2 = list(subG2.nodes())
  53. has_inter_edge = False
  54. for n1 in nodes1:
  55. for n2 in nodes2:
  56. if np.random.rand() < p_inter:
  57. G.add_edge(n1, n2)
  58. has_inter_edge = True
  59. if not has_inter_edge:
  60. G.add_edge(nodes1[0], nodes2[0])
  61. #print('connected comp: ', len(list(nx.connected_component_subgraphs(G))))
  62. return G
  63. def perturb(graph_list, p_del, p_add=None):
  64. ''' Perturb the list of graphs by adding/removing edges.
  65. Args:
  66. p_add: probability of adding edges. If None, estimate it according to graph density,
  67. such that the expected number of added edges is equal to that of deleted edges.
  68. p_del: probability of removing edges
  69. Returns:
  70. A list of graphs that are perturbed from the original graphs
  71. '''
  72. perturbed_graph_list = []
  73. for G_original in graph_list:
  74. G = G_original.copy()
  75. trials = np.random.binomial(1, p_del, size=G.number_of_edges())
  76. edges = list(G.edges())
  77. i = 0
  78. for (u, v) in edges:
  79. if trials[i] == 1:
  80. G.remove_edge(u, v)
  81. i += 1
  82. if p_add is None:
  83. num_nodes = G.number_of_nodes()
  84. p_add_est = np.sum(trials) / (num_nodes * (num_nodes - 1) / 2 -
  85. G.number_of_edges())
  86. else:
  87. p_add_est = p_add
  88. nodes = list(G.nodes())
  89. tmp = 0
  90. for i in range(len(nodes)):
  91. u = nodes[i]
  92. trials = np.random.binomial(1, p_add_est, size=G.number_of_nodes())
  93. j = 0
  94. for j in range(i+1, len(nodes)):
  95. v = nodes[j]
  96. if trials[j] == 1:
  97. tmp += 1
  98. G.add_edge(u, v)
  99. j += 1
  100. perturbed_graph_list.append(G)
  101. return perturbed_graph_list
  102. def perturb_new(graph_list, p):
  103. ''' Perturb the list of graphs by adding/removing edges.
  104. Args:
  105. p_add: probability of adding edges. If None, estimate it according to graph density,
  106. such that the expected number of added edges is equal to that of deleted edges.
  107. p_del: probability of removing edges
  108. Returns:
  109. A list of graphs that are perturbed from the original graphs
  110. '''
  111. perturbed_graph_list = []
  112. for G_original in graph_list:
  113. G = G_original.copy()
  114. edge_remove_count = 0
  115. for (u, v) in list(G.edges()):
  116. if np.random.rand()<p:
  117. G.remove_edge(u, v)
  118. edge_remove_count += 1
  119. # randomly add the edges back
  120. for i in range(edge_remove_count):
  121. while True:
  122. u = np.random.randint(0, G.number_of_nodes())
  123. v = np.random.randint(0, G.number_of_nodes())
  124. if (not G.has_edge(u,v)) and (u!=v):
  125. break
  126. G.add_edge(u, v)
  127. perturbed_graph_list.append(G)
  128. return perturbed_graph_list
  129. def imsave(fname, arr, vmin=None, vmax=None, cmap=None, format=None, origin=None):
  130. from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
  131. from matplotlib.figure import Figure
  132. fig = Figure(figsize=arr.shape[::-1], dpi=1, frameon=False)
  133. canvas = FigureCanvas(fig)
  134. fig.figimage(arr, cmap=cmap, vmin=vmin, vmax=vmax, origin=origin)
  135. fig.savefig(fname, dpi=1, format=format)
  136. def save_prediction_histogram(y_pred_data, fname_pred, max_num_node, bin_n=20):
  137. bin_edge = np.linspace(1e-6, 1, bin_n + 1)
  138. output_pred = np.zeros((bin_n, max_num_node))
  139. for i in range(max_num_node):
  140. output_pred[:, i], _ = np.histogram(y_pred_data[:, i, :], bins=bin_edge, density=False)
  141. # normalize
  142. output_pred[:, i] /= np.sum(output_pred[:, i])
  143. imsave(fname=fname_pred, arr=output_pred, origin='upper', cmap='Greys_r', vmin=0.0, vmax=3.0 / bin_n)
  144. # draw a single graph G
  145. def draw_graph(G, prefix = 'test'):
  146. parts = community.best_partition(G)
  147. values = [parts.get(node) for node in G.nodes()]
  148. colors = []
  149. for i in range(len(values)):
  150. if values[i] == 0:
  151. colors.append('red')
  152. if values[i] == 1:
  153. colors.append('green')
  154. if values[i] == 2:
  155. colors.append('blue')
  156. if values[i] == 3:
  157. colors.append('yellow')
  158. if values[i] == 4:
  159. colors.append('orange')
  160. if values[i] == 5:
  161. colors.append('pink')
  162. if values[i] == 6:
  163. colors.append('black')
  164. # spring_pos = nx.spring_layout(G)
  165. plt.switch_backend('agg')
  166. plt.axis("off")
  167. pos = nx.spring_layout(G)
  168. nx.draw_networkx(G, with_labels=True, node_size=35, node_color=colors,pos=pos)
  169. # plt.switch_backend('agg')
  170. # options = {
  171. # 'node_color': 'black',
  172. # 'node_size': 10,
  173. # 'width': 1
  174. # }
  175. # plt.figure()
  176. # plt.subplot()
  177. # nx.draw_networkx(G, **options)
  178. plt.savefig('figures/graph_view_'+prefix+'.png', dpi=200)
  179. plt.close()
  180. plt.switch_backend('agg')
  181. G_deg = nx.degree_histogram(G)
  182. G_deg = np.array(G_deg)
  183. # plt.plot(range(len(G_deg)), G_deg, 'r', linewidth = 2)
  184. plt.loglog(np.arange(len(G_deg))[G_deg>0], G_deg[G_deg>0], 'r', linewidth=2)
  185. plt.savefig('figures/degree_view_' + prefix + '.png', dpi=200)
  186. plt.close()
  187. # degree_sequence = sorted(nx.degree(G).values(), reverse=True) # degree sequence
  188. # plt.loglog(degree_sequence, 'b-', marker='o')
  189. # plt.title("Degree rank plot")
  190. # plt.ylabel("degree")
  191. # plt.xlabel("rank")
  192. # plt.savefig('figures/degree_view_' + prefix + '.png', dpi=200)
  193. # plt.close()
  194. # G = nx.grid_2d_graph(8,8)
  195. # G = nx.karate_club_graph()
  196. # draw_graph(G)
  197. # draw a list of graphs [G]
  198. def draw_graph_list(G_list, row, col, fname = 'figures/test', layout='spring', is_single=False,k=1,node_size=55,alpha=1,width=1.3):
  199. # # draw graph view
  200. # from pylab import rcParams
  201. # rcParams['figure.figsize'] = 12,3
  202. plt.switch_backend('agg')
  203. for i,G in enumerate(G_list):
  204. plt.subplot(row,col,i+1)
  205. plt.subplots_adjust(left=0, bottom=0, right=1, top=1,
  206. wspace=0, hspace=0)
  207. # if i%2==0:
  208. # plt.title('real nodes: '+str(G.number_of_nodes()), fontsize = 4)
  209. # else:
  210. # plt.title('pred nodes: '+str(G.number_of_nodes()), fontsize = 4)
  211. # plt.title('num of nodes: '+str(G.number_of_nodes()), fontsize = 4)
  212. # parts = community.best_partition(G)
  213. # values = [parts.get(node) for node in G.nodes()]
  214. # colors = []
  215. # for i in range(len(values)):
  216. # if values[i] == 0:
  217. # colors.append('red')
  218. # if values[i] == 1:
  219. # colors.append('green')
  220. # if values[i] == 2:
  221. # colors.append('blue')
  222. # if values[i] == 3:
  223. # colors.append('yellow')
  224. # if values[i] == 4:
  225. # colors.append('orange')
  226. # if values[i] == 5:
  227. # colors.append('pink')
  228. # if values[i] == 6:
  229. # colors.append('black')
  230. plt.axis("off")
  231. if layout=='spring':
  232. pos = nx.spring_layout(G,k=k/np.sqrt(G.number_of_nodes()),iterations=100)
  233. # pos = nx.spring_layout(G)
  234. elif layout=='spectral':
  235. pos = nx.spectral_layout(G)
  236. # # nx.draw_networkx(G, with_labels=True, node_size=2, width=0.15, font_size = 1.5, node_color=colors,pos=pos)
  237. # nx.draw_networkx(G, with_labels=False, node_size=1.5, width=0.2, font_size = 1.5, linewidths=0.2, node_color = 'k',pos=pos,alpha=0.2)
  238. if is_single:
  239. # node_size default 60, edge_width default 1.5
  240. nx.draw_networkx_nodes(G, pos, node_size=node_size, node_color='#336699', alpha=1, linewidths=0, font_size=0)
  241. nx.draw_networkx_edges(G, pos, alpha=alpha, width=width)
  242. else:
  243. nx.draw_networkx_nodes(G, pos, node_size=1.5, node_color='#336699',alpha=1, linewidths=0.2, font_size = 1.5)
  244. nx.draw_networkx_edges(G, pos, alpha=0.3,width=0.2)
  245. # plt.axis('off')
  246. # plt.title('Complete Graph of Odd-degree Nodes')
  247. # plt.show()
  248. plt.tight_layout()
  249. plt.savefig(fname+'.png', dpi=600)
  250. plt.close()
  251. # # draw degree distribution
  252. # plt.switch_backend('agg')
  253. # for i, G in enumerate(G_list):
  254. # plt.subplot(row, col, i + 1)
  255. # G_deg = np.array(list(G.degree(G.nodes()).values()))
  256. # bins = np.arange(20)
  257. # plt.hist(np.array(G_deg), bins=bins, align='left')
  258. # plt.xlabel('degree', fontsize = 3)
  259. # plt.ylabel('count', fontsize = 3)
  260. # G_deg_mean = 2*G.number_of_edges()/float(G.number_of_nodes())
  261. # # if i % 2 == 0:
  262. # # plt.title('real average degree: {:.2f}'.format(G_deg_mean), fontsize=4)
  263. # # else:
  264. # # plt.title('pred average degree: {:.2f}'.format(G_deg_mean), fontsize=4)
  265. # plt.title('average degree: {:.2f}'.format(G_deg_mean), fontsize=4)
  266. # plt.tick_params(axis='both', which='major', labelsize=3)
  267. # plt.tick_params(axis='both', which='minor', labelsize=3)
  268. # plt.tight_layout()
  269. # plt.savefig(fname+'_degree.png', dpi=600)
  270. # plt.close()
  271. #
  272. # # draw clustering distribution
  273. # plt.switch_backend('agg')
  274. # for i, G in enumerate(G_list):
  275. # plt.subplot(row, col, i + 1)
  276. # G_cluster = list(nx.clustering(G).values())
  277. # bins = np.linspace(0,1,20)
  278. # plt.hist(np.array(G_cluster), bins=bins, align='left')
  279. # plt.xlabel('clustering coefficient', fontsize=3)
  280. # plt.ylabel('count', fontsize=3)
  281. # G_cluster_mean = sum(G_cluster) / len(G_cluster)
  282. # # if i % 2 == 0:
  283. # # plt.title('real average clustering: {:.4f}'.format(G_cluster_mean), fontsize=4)
  284. # # else:
  285. # # plt.title('pred average clustering: {:.4f}'.format(G_cluster_mean), fontsize=4)
  286. # plt.title('average clustering: {:.4f}'.format(G_cluster_mean), fontsize=4)
  287. # plt.tick_params(axis='both', which='major', labelsize=3)
  288. # plt.tick_params(axis='both', which='minor', labelsize=3)
  289. # plt.tight_layout()
  290. # plt.savefig(fname+'_clustering.png', dpi=600)
  291. # plt.close()
  292. #
  293. # # draw circle distribution
  294. # plt.switch_backend('agg')
  295. # for i, G in enumerate(G_list):
  296. # plt.subplot(row, col, i + 1)
  297. # cycle_len = []
  298. # cycle_all = nx.cycle_basis(G)
  299. # for item in cycle_all:
  300. # cycle_len.append(len(item))
  301. #
  302. # bins = np.arange(20)
  303. # plt.hist(np.array(cycle_len), bins=bins, align='left')
  304. # plt.xlabel('cycle length', fontsize=3)
  305. # plt.ylabel('count', fontsize=3)
  306. # G_cycle_mean = 0
  307. # if len(cycle_len)>0:
  308. # G_cycle_mean = sum(cycle_len) / len(cycle_len)
  309. # # if i % 2 == 0:
  310. # # plt.title('real average cycle: {:.4f}'.format(G_cycle_mean), fontsize=4)
  311. # # else:
  312. # # plt.title('pred average cycle: {:.4f}'.format(G_cycle_mean), fontsize=4)
  313. # plt.title('average cycle: {:.4f}'.format(G_cycle_mean), fontsize=4)
  314. # plt.tick_params(axis='both', which='major', labelsize=3)
  315. # plt.tick_params(axis='both', which='minor', labelsize=3)
  316. # plt.tight_layout()
  317. # plt.savefig(fname+'_cycle.png', dpi=600)
  318. # plt.close()
  319. #
  320. # # draw community distribution
  321. # plt.switch_backend('agg')
  322. # for i, G in enumerate(G_list):
  323. # plt.subplot(row, col, i + 1)
  324. # parts = community.best_partition(G)
  325. # values = np.array([parts.get(node) for node in G.nodes()])
  326. # counts = np.sort(np.bincount(values)[::-1])
  327. # pos = np.arange(len(counts))
  328. # plt.bar(pos,counts,align = 'edge')
  329. # plt.xlabel('community ID', fontsize=3)
  330. # plt.ylabel('count', fontsize=3)
  331. # G_community_count = len(counts)
  332. # # if i % 2 == 0:
  333. # # plt.title('real average clustering: {}'.format(G_community_count), fontsize=4)
  334. # # else:
  335. # # plt.title('pred average clustering: {}'.format(G_community_count), fontsize=4)
  336. # plt.title('average clustering: {}'.format(G_community_count), fontsize=4)
  337. # plt.tick_params(axis='both', which='major', labelsize=3)
  338. # plt.tick_params(axis='both', which='minor', labelsize=3)
  339. # plt.tight_layout()
  340. # plt.savefig(fname+'_community.png', dpi=600)
  341. # plt.close()
  342. # plt.switch_backend('agg')
  343. # G_deg = nx.degree_histogram(G)
  344. # G_deg = np.array(G_deg)
  345. # # plt.plot(range(len(G_deg)), G_deg, 'r', linewidth = 2)
  346. # plt.loglog(np.arange(len(G_deg))[G_deg>0], G_deg[G_deg>0], 'r', linewidth=2)
  347. # plt.savefig('figures/degree_view_' + prefix + '.png', dpi=200)
  348. # plt.close()
  349. # degree_sequence = sorted(nx.degree(G).values(), reverse=True) # degree sequence
  350. # plt.loglog(degree_sequence, 'b-', marker='o')
  351. # plt.title("Degree rank plot")
  352. # plt.ylabel("degree")
  353. # plt.xlabel("rank")
  354. # plt.savefig('figures/degree_view_' + prefix + '.png', dpi=200)
  355. # plt.close()
  356. # directly get graph statistics from adj, obsoleted
  357. def decode_graph(adj, prefix):
  358. adj = np.asmatrix(adj)
  359. G = nx.from_numpy_matrix(adj)
  360. # G.remove_nodes_from(nx.isolates(G))
  361. print('num of nodes: {}'.format(G.number_of_nodes()))
  362. print('num of edges: {}'.format(G.number_of_edges()))
  363. G_deg = nx.degree_histogram(G)
  364. G_deg_sum = [a * b for a, b in zip(G_deg, range(0, len(G_deg)))]
  365. print('average degree: {}'.format(sum(G_deg_sum) / G.number_of_nodes()))
  366. if nx.is_connected(G):
  367. print('average path length: {}'.format(nx.average_shortest_path_length(G)))
  368. print('average diameter: {}'.format(nx.diameter(G)))
  369. G_cluster = sorted(list(nx.clustering(G).values()))
  370. print('average clustering coefficient: {}'.format(sum(G_cluster) / len(G_cluster)))
  371. cycle_len = []
  372. cycle_all = nx.cycle_basis(G, 0)
  373. for item in cycle_all:
  374. cycle_len.append(len(item))
  375. print('cycles', cycle_len)
  376. print('cycle count', len(cycle_len))
  377. draw_graph(G, prefix=prefix)
  378. def get_graph(adj):
  379. '''
  380. get a graph from zero-padded adj
  381. :param adj:
  382. :return:
  383. '''
  384. # remove all zeros rows and columns
  385. adj = adj[~np.all(adj == 0, axis=1)]
  386. adj = adj[:, ~np.all(adj == 0, axis=0)]
  387. adj = np.asmatrix(adj)
  388. G = nx.from_numpy_matrix(adj)
  389. return G
  390. # save a list of graphs
  391. def save_graph_list(G_list, fname):
  392. with open(fname, "wb") as f:
  393. pickle.dump(G_list, f)
  394. # pick the first connected component
  395. def pick_connected_component(G):
  396. node_list = nx.node_connected_component(G,0)
  397. return G.subgraph(node_list)
  398. def pick_connected_component_new(G):
  399. adj_list = G.adjacency_list()
  400. for id,adj in enumerate(adj_list):
  401. id_min = min(adj)
  402. if id<id_min and id>=1:
  403. # if id<id_min and id>=4:
  404. break
  405. node_list = list(range(id)) # only include node prior than node "id"
  406. G = G.subgraph(node_list)
  407. G = max(nx.connected_component_subgraphs(G), key=len)
  408. return G
  409. # load a list of graphs
  410. def load_graph_list(fname,is_real=True):
  411. with open(fname, "rb") as f:
  412. graph_list = pickle.load(f)
  413. for i in range(len(graph_list)):
  414. edges_with_selfloops = graph_list[i].selfloop_edges()
  415. if len(edges_with_selfloops)>0:
  416. graph_list[i].remove_edges_from(edges_with_selfloops)
  417. if is_real:
  418. graph_list[i] = max(nx.connected_component_subgraphs(graph_list[i]), key=len)
  419. graph_list[i] = nx.convert_node_labels_to_integers(graph_list[i])
  420. else:
  421. graph_list[i] = pick_connected_component_new(graph_list[i])
  422. return graph_list
  423. def export_graphs_to_txt(g_list, output_filename_prefix):
  424. i = 0
  425. for G in g_list:
  426. f = open(output_filename_prefix + '_' + str(i) + '.txt', 'w+')
  427. for (u, v) in G.edges():
  428. idx_u = G.nodes().index(u)
  429. idx_v = G.nodes().index(v)
  430. f.write(str(idx_u) + '\t' + str(idx_v) + '\n')
  431. i += 1
  432. def snap_txt_output_to_nx(in_fname):
  433. G = nx.Graph()
  434. with open(in_fname, 'r') as f:
  435. for line in f:
  436. if not line[0] == '#':
  437. splitted = re.split('[ \t]', line)
  438. # self loop might be generated, but should be removed
  439. u = int(splitted[0])
  440. v = int(splitted[1])
  441. if not u == v:
  442. G.add_edge(int(u), int(v))
  443. return G
  444. def test_perturbed():
  445. graphs = []
  446. for i in range(100,101):
  447. for j in range(4,5):
  448. for k in range(500):
  449. graphs.append(nx.barabasi_albert_graph(i,j))
  450. g_perturbed = perturb(graphs, 0.9)
  451. print([g.number_of_edges() for g in graphs])
  452. print([g.number_of_edges() for g in g_perturbed])
  453. if __name__ == '__main__':
  454. #test_perturbed()
  455. #graphs = load_graph_list('graphs/' + 'GraphRNN_RNN_community4_4_128_train_0.dat')
  456. #graphs = load_graph_list('graphs/' + 'GraphRNN_RNN_community4_4_128_pred_2500_1.dat')
  457. graphs = load_graph_list('eval_results/mmsb/' + 'community41.dat')
  458. for i in range(0, 160, 16):
  459. draw_graph_list(graphs[i:i+16], 4, 4, fname='figures/community4_' + str(i))