You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

create_graphs.py 6.0KB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import networkx as nx
  2. import numpy as np
  3. from baselines.graphvae.util import load_data
  4. from utils import *
  5. from data import *
  6. def create(args):
  7. ### load datasets
  8. graphs = []
  9. # synthetic graphs
  10. if args.graph_type == 'ladder':
  11. graphs = []
  12. for i in range(100, 201):
  13. graphs.append(nx.ladder_graph(i))
  14. args.max_prev_node = 10
  15. elif args.graph_type == 'ladder_small':
  16. graphs = []
  17. for i in range(2, 11):
  18. graphs.append(nx.ladder_graph(i))
  19. args.max_prev_node = 10
  20. elif args.graph_type == 'tree':
  21. graphs = []
  22. for i in range(2, 5):
  23. for j in range(3, 5):
  24. graphs.append(nx.balanced_tree(i, j))
  25. args.max_prev_node = 256
  26. elif args.graph_type == 'caveman':
  27. # graphs = []
  28. # for i in range(5,10):
  29. # for j in range(5,25):
  30. # for k in range(5):
  31. # graphs.append(nx.relaxed_caveman_graph(i, j, p=0.1))
  32. graphs = []
  33. for i in range(2, 3):
  34. for j in range(30, 81):
  35. for k in range(10):
  36. graphs.append(caveman_special(i, j, p_edge=0.3))
  37. args.max_prev_node = 100
  38. elif args.graph_type == 'caveman_small':
  39. # graphs = []
  40. # for i in range(2,5):
  41. # for j in range(2,6):
  42. # for k in range(10):
  43. # graphs.append(nx.relaxed_caveman_graph(i, j, p=0.1))
  44. graphs = []
  45. for i in range(2, 3):
  46. for j in range(6, 11):
  47. for k in range(20):
  48. graphs.append(caveman_special(i, j, p_edge=0.8)) # default 0.8
  49. args.max_prev_node = 20
  50. elif args.graph_type == 'caveman_small_single':
  51. # graphs = []
  52. # for i in range(2,5):
  53. # for j in range(2,6):
  54. # for k in range(10):
  55. # graphs.append(nx.relaxed_caveman_graph(i, j, p=0.1))
  56. graphs = []
  57. for i in range(2, 3):
  58. for j in range(8, 9):
  59. for k in range(100):
  60. graphs.append(caveman_special(i, j, p_edge=0.5))
  61. args.max_prev_node = 20
  62. elif args.graph_type.startswith('community'):
  63. num_communities = int(args.graph_type[-1])
  64. print('Creating dataset with ', num_communities, ' communities')
  65. c_sizes = np.random.choice([12, 13, 14, 15, 16, 17], num_communities)
  66. # c_sizes = [15] * num_communities
  67. for k in range(3000):
  68. graphs.append(n_community(c_sizes, p_inter=0.01))
  69. args.max_prev_node = 80
  70. elif args.graph_type == 'grid':
  71. graphs = []
  72. for i in range(10, 20):
  73. for j in range(10, 20):
  74. graphs.append(nx.grid_2d_graph(i, j))
  75. args.max_prev_node = 40
  76. elif args.graph_type == 'grid_small':
  77. graphs = []
  78. for i in range(2, 5):
  79. for j in range(2, 5):
  80. graphs.append(nx.grid_2d_graph(i, j))
  81. args.max_prev_node = 15
  82. elif args.graph_type == 'barabasi':
  83. graphs = []
  84. for i in range(100, 200):
  85. for j in range(4, 5):
  86. for k in range(5):
  87. graphs.append(nx.barabasi_albert_graph(i, j))
  88. args.max_prev_node = 130
  89. elif args.graph_type == 'barabasi_small':
  90. graphs = []
  91. for i in range(4, 21):
  92. for j in range(3, 4):
  93. for k in range(10):
  94. graphs.append(nx.barabasi_albert_graph(i, j))
  95. args.max_prev_node = 20
  96. elif args.graph_type == 'grid_big':
  97. graphs = []
  98. for i in range(36, 46):
  99. for j in range(36, 46):
  100. graphs.append(nx.grid_2d_graph(i, j))
  101. args.max_prev_node = 90
  102. elif 'barabasi_noise' in args.graph_type:
  103. graphs = []
  104. for i in range(100, 101):
  105. for j in range(4, 5):
  106. for k in range(500):
  107. graphs.append(nx.barabasi_albert_graph(i, j))
  108. graphs = perturb_new(graphs, p=args.noise / 10.0)
  109. args.max_prev_node = 99
  110. # real graphs
  111. elif args.graph_type == 'enzymes':
  112. graphs = Graph_load_batch(min_num_nodes=10, name='ENZYMES')
  113. args.max_prev_node = 25
  114. elif args.graph_type == 'enzymes_small':
  115. graphs_raw = Graph_load_batch(min_num_nodes=10, name='ENZYMES')
  116. graphs = []
  117. for G in graphs_raw:
  118. if G.number_of_nodes() <= 20:
  119. graphs.append(G)
  120. args.max_prev_node = 15
  121. elif args.graph_type == 'protein':
  122. graphs = Graph_load_batch(min_num_nodes=20, name='PROTEINS_full')
  123. args.max_prev_node = 80
  124. elif args.graph_type == 'DD':
  125. graphs = Graph_load_batch(min_num_nodes=100, max_num_nodes=500, name='DD', node_attributes=False,
  126. graph_labels=True)
  127. args.max_prev_node = 230
  128. elif args.graph_type == 'citeseer':
  129. _, _, G = Graph_load(dataset='citeseer')
  130. G = max(nx.connected_component_subgraphs(G), key=len)
  131. G = nx.convert_node_labels_to_integers(G)
  132. graphs = []
  133. for i in range(G.number_of_nodes()):
  134. G_ego = nx.ego_graph(G, i, radius=3)
  135. if G_ego.number_of_nodes() >= 50 and (G_ego.number_of_nodes() <= 400):
  136. graphs.append(G_ego)
  137. args.max_prev_node = 250
  138. elif args.graph_type == 'citeseer_small':
  139. _, _, G = Graph_load(dataset='citeseer')
  140. G = max(nx.connected_component_subgraphs(G), key=len)
  141. G = nx.convert_node_labels_to_integers(G)
  142. graphs = []
  143. for i in range(G.number_of_nodes()):
  144. G_ego = nx.ego_graph(G, i, radius=1)
  145. if (G_ego.number_of_nodes() >= 4) and (G_ego.number_of_nodes() <= 20):
  146. graphs.append(G_ego)
  147. shuffle(graphs)
  148. graphs = graphs[0:200]
  149. args.max_prev_node = 15
  150. elif args.graph_type == 'COLLAB' or args.graph_type == 'IMDBBINARY' or args.graph_type == 'IMDBMULTI':
  151. graphs, num_classes = load_data(args.graph_type, True)
  152. args.max_prev_node = 40
  153. return graphs