You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

baseline_simple.py 10.0KB

4 years ago

  1. from main import *
  2. from scipy.linalg import toeplitz
  3. import pyemd
  4. import scipy.optimize as opt
  5. def Graph_generator_baseline_train_rulebased(graphs,generator='BA'):
  6. graph_nodes = [graphs[i].number_of_nodes() for i in range(len(graphs))]
  7. graph_edges = [graphs[i].number_of_edges() for i in range(len(graphs))]
  8. parameter = {}
  9. for i in range(len(graph_nodes)):
  10. nodes = graph_nodes[i]
  11. edges = graph_edges[i]
  12. # based on rule, calculate optimal parameter
  13. if generator=='BA':
  14. # BA optimal: nodes = n; edges = (n-m)*m
  15. n = nodes
  16. m = (n - np.sqrt(n**2-4*edges))/2
  17. parameter_temp = [n,m,1]
  18. if generator=='Gnp':
  19. # Gnp optimal: nodes = n; edges = ((n-1)*n/2)*p
  20. n = nodes
  21. p = float(edges)/((n-1)*n/2)
  22. parameter_temp = [n,p,1]
  23. # update parameter list
  24. if nodes not in parameter.keys():
  25. parameter[nodes] = parameter_temp
  26. else:
  27. count = parameter[nodes][-1]
  28. parameter[nodes] = [(parameter[nodes][i]*count+parameter_temp[i])/(count+1) for i in range(len(parameter[nodes]))]
  29. parameter[nodes][-1] = count+1
  30. # print(parameter)
  31. return parameter
  32. def Graph_generator_baseline(graph_train, pred_num=1000, generator='BA'):
  33. graph_nodes = [graph_train[i].number_of_nodes() for i in range(len(graph_train))]
  34. graph_edges = [graph_train[i].number_of_edges() for i in range(len(graph_train))]
  35. repeat = pred_num//len(graph_train)
  36. graph_pred = []
  37. for i in range(len(graph_nodes)):
  38. nodes = graph_nodes[i]
  39. edges = graph_edges[i]
  40. # based on rule, calculate optimal parameter
  41. if generator=='BA':
  42. # BA optimal: nodes = n; edges = (n-m)*m
  43. n = nodes
  44. m = int((n - np.sqrt(n**2-4*edges))/2)
  45. for j in range(repeat):
  46. graph_pred.append(nx.barabasi_albert_graph(n,m))
  47. if generator=='Gnp':
  48. # Gnp optimal: nodes = n; edges = ((n-1)*n/2)*p
  49. n = nodes
  50. p = float(edges)/((n-1)*n/2)
  51. for j in range(repeat):
  52. graph_pred.append(nx.fast_gnp_random_graph(n, p))
  53. return graph_pred
  54. def emd_distance(x, y, distance_scaling=1.0):
  55. support_size = max(len(x), len(y))
  56. d_mat = toeplitz(range(support_size)).astype(np.float)
  57. distance_mat = d_mat / distance_scaling
  58. # convert histogram values x and y to float, and make them equal len
  59. x = x.astype(np.float)
  60. y = y.astype(np.float)
  61. if len(x) < len(y):
  62. x = np.hstack((x, [0.0] * (support_size - len(x))))
  63. elif len(y) < len(x):
  64. y = np.hstack((y, [0.0] * (support_size - len(y))))
  65. emd = pyemd.emd(x, y, distance_mat)
  66. return emd
  67. # def Loss(x,args):
  68. # '''
  69. #
  70. # :param x: 1-D array, parameters to be optimized
  71. # :param args: tuple (n, G, generator, metric).
  72. # n: n for pred graph;
  73. # G: real graph in networkx format;
  74. # generator: 'BA', 'Gnp', 'Powerlaw';
  75. # metric: 'degree', 'clustering'
  76. # :return: Loss: emd distance
  77. # '''
  78. # # get argument
  79. # generator = args[2]
  80. # metric = args[3]
  81. #
  82. # # get real and pred graphs
  83. # G_real = args[1]
  84. # if generator=='BA':
  85. # G_pred = nx.barabasi_albert_graph(args[0],int(np.rint(x)))
  86. # if generator=='Gnp':
  87. # G_pred = nx.fast_gnp_random_graph(args[0],x)
  88. #
  89. # # define metric
  90. # if metric == 'degree':
  91. # G_real_hist = np.array(nx.degree_histogram(G_real))
  92. # G_real_hist = G_real_hist / np.sum(G_real_hist)
  93. # G_pred_hist = np.array(nx.degree_histogram(G_pred))
  94. # G_pred_hist = G_pred_hist/np.sum(G_pred_hist)
  95. # if metric == 'clustering':
  96. # G_real_hist, _ = np.histogram(
  97. # np.array(list(nx.clustering(G_real).values())), bins=50, range=(0.0, 1.0), density=False)
  98. # G_real_hist = G_real_hist / np.sum(G_real_hist)
  99. # G_pred_hist, _ = np.histogram(
  100. # np.array(list(nx.clustering(G_pred).values())), bins=50, range=(0.0, 1.0), density=False)
  101. # G_pred_hist = G_pred_hist / np.sum(G_pred_hist)
  102. #
  103. # loss = emd_distance(G_real_hist,G_pred_hist)
  104. # return loss
  105. def Loss(x,n,G_real,generator,metric):
  106. '''
  107. :param x: 1-D array, parameters to be optimized
  108. :param
  109. n: n for pred graph;
  110. G: real graph in networkx format;
  111. generator: 'BA', 'Gnp', 'Powerlaw';
  112. metric: 'degree', 'clustering'
  113. :return: Loss: emd distance
  114. '''
  115. # get argument
  116. # get real and pred graphs
  117. if generator=='BA':
  118. G_pred = nx.barabasi_albert_graph(n,int(np.rint(x)))
  119. if generator=='Gnp':
  120. G_pred = nx.fast_gnp_random_graph(n,x)
  121. # define metric
  122. if metric == 'degree':
  123. G_real_hist = np.array(nx.degree_histogram(G_real))
  124. G_real_hist = G_real_hist / np.sum(G_real_hist)
  125. G_pred_hist = np.array(nx.degree_histogram(G_pred))
  126. G_pred_hist = G_pred_hist/np.sum(G_pred_hist)
  127. if metric == 'clustering':
  128. G_real_hist, _ = np.histogram(
  129. np.array(list(nx.clustering(G_real).values())), bins=50, range=(0.0, 1.0), density=False)
  130. G_real_hist = G_real_hist / np.sum(G_real_hist)
  131. G_pred_hist, _ = np.histogram(
  132. np.array(list(nx.clustering(G_pred).values())), bins=50, range=(0.0, 1.0), density=False)
  133. G_pred_hist = G_pred_hist / np.sum(G_pred_hist)
  134. loss = emd_distance(G_real_hist,G_pred_hist)
  135. return loss
  136. def optimizer_brute(x_min, x_max, x_step, n, G_real, generator, metric):
  137. loss_all = []
  138. x_list = np.arange(x_min,x_max,x_step)
  139. for x_test in x_list:
  140. loss_all.append(Loss(x_test,n,G_real,generator,metric))
  141. x_optim = x_list[np.argmin(np.array(loss_all))]
  142. return x_optim
  143. def Graph_generator_baseline_train_optimizationbased(graphs,generator='BA',metric='degree'):
  144. graph_nodes = [graphs[i].number_of_nodes() for i in range(len(graphs))]
  145. parameter = {}
  146. for i in range(len(graph_nodes)):
  147. print('graph ',i)
  148. nodes = graph_nodes[i]
  149. if generator=='BA':
  150. n = nodes
  151. m = optimizer_brute(1,10,1, nodes, graphs[i], generator, metric)
  152. parameter_temp = [n,m,1]
  153. elif generator=='Gnp':
  154. n = nodes
  155. p = optimizer_brute(1e-6,1,0.01, nodes, graphs[i], generator, metric)
  156. ## if use evolution
  157. # result = opt.differential_evolution(Loss,bounds=[(0,1)],args=(nodes, graphs[i], generator, metric),maxiter=1000)
  158. # p = result.x
  159. parameter_temp = [n, p, 1]
  160. # update parameter list
  161. if nodes not in parameter.keys():
  162. parameter[nodes] = parameter_temp
  163. else:
  164. count = parameter[nodes][2]
  165. parameter[nodes] = [(parameter[nodes][i]*count+parameter_temp[i])/(count+1) for i in range(len(parameter[nodes]))]
  166. parameter[nodes][2] = count+1
  167. print(parameter)
  168. return parameter
  169. def Graph_generator_baseline_test(graph_nodes, parameter, generator='BA'):
  170. graphs = []
  171. for i in range(len(graph_nodes)):
  172. nodes = graph_nodes[i]
  173. if not nodes in parameter.keys():
  174. nodes = min(parameter.keys(), key=lambda k: abs(k - nodes))
  175. if generator=='BA':
  176. n = int(parameter[nodes][0])
  177. m = int(np.rint(parameter[nodes][1]))
  178. print(n,m)
  179. graph = nx.barabasi_albert_graph(n,m)
  180. if generator=='Gnp':
  181. n = int(parameter[nodes][0])
  182. p = parameter[nodes][1]
  183. print(n,p)
  184. graph = nx.fast_gnp_random_graph(n,p)
  185. graphs.append(graph)
  186. return graphs
  187. if __name__ == '__main__':
  188. args = Args()
  189. print('File name prefix', args.fname)
  190. ### load datasets
  191. graphs = []
  192. # synthetic graphs
  193. if args.graph_type=='ladder':
  194. graphs = []
  195. for i in range(100, 201):
  196. graphs.append(nx.ladder_graph(i))
  197. args.max_prev_node = 10
  198. if args.graph_type=='tree':
  199. graphs = []
  200. for i in range(2,5):
  201. for j in range(3,5):
  202. graphs.append(nx.balanced_tree(i,j))
  203. args.max_prev_node = 256
  204. if args.graph_type=='caveman':
  205. graphs = []
  206. for i in range(5,10):
  207. for j in range(5,25):
  208. graphs.append(nx.connected_caveman_graph(i, j))
  209. args.max_prev_node = 50
  210. if args.graph_type=='grid':
  211. graphs = []
  212. for i in range(10,20):
  213. for j in range(10,20):
  214. graphs.append(nx.grid_2d_graph(i,j))
  215. args.max_prev_node = 40
  216. if args.graph_type=='barabasi':
  217. graphs = []
  218. for i in range(100,200):
  219. graphs.append(nx.barabasi_albert_graph(i,2))
  220. args.max_prev_node = 130
  221. # real graphs
  222. if args.graph_type == 'enzymes':
  223. graphs= Graph_load_batch(min_num_nodes=10, name='ENZYMES')
  224. args.max_prev_node = 25
  225. if args.graph_type == 'protein':
  226. graphs = Graph_load_batch(min_num_nodes=20, name='PROTEINS_full')
  227. args.max_prev_node = 80
  228. if args.graph_type == 'DD':
  229. graphs = Graph_load_batch(min_num_nodes=100, max_num_nodes=500, name='DD',node_attributes=False,graph_labels=True)
  230. args.max_prev_node = 230
  231. graph_nodes = [graphs[i].number_of_nodes() for i in range(len(graphs))]
  232. graph_edges = [graphs[i].number_of_edges() for i in range(len(graphs))]
  233. args.max_num_node = max(graph_nodes)
  234. # show graphs statistics
  235. print('total graph num: {}'.format(len(graphs)))
  236. print('max number node: {}'.format(args.max_num_node))
  237. print('max previous node: {}'.format(args.max_prev_node))
  238. # start baseline generation method
  239. generator = args.generator_baseline
  240. metric = args.metric_baseline
  241. print(args.fname_baseline + '.dat')
  242. if metric=='general':
  243. parameter = Graph_generator_baseline_train_rulebased(graphs,generator=generator)
  244. else:
  245. parameter = Graph_generator_baseline_train_optimizationbased(graphs,generator=generator,metric=metric)
  246. graphs_generated = Graph_generator_baseline_test(graph_nodes, parameter,generator)
  247. save_graph_list(graphs_generated,args.fname_baseline + '.dat')