You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphvae_data.py 4.9KB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import networkx as nx
  2. import numpy as np
  3. import torch
  4. from data import bfs_seq
  5. class GraphAdjSampler(torch.utils.data.Dataset):
  6. def __init__(self, G_list, max_num_nodes, permutation_mode, bfs_mode, bfs_mode_with_arbitrary_node_deleted,
  7. features='id'):
  8. self.max_num_nodes = max_num_nodes
  9. self.adj_all = []
  10. self.len_all = []
  11. self.feature_all = []
  12. self.count = 0
  13. self.permutation_mode = permutation_mode
  14. self.bfs_mode = bfs_mode
  15. self.bfs_mode_with_arbitrary_node_deleted = bfs_mode_with_arbitrary_node_deleted
  16. for G in G_list:
  17. adj = nx.to_numpy_matrix(G)
  18. # the diagonal entries are 1 since they denote node probability
  19. self.adj_all.append(
  20. np.asarray(adj) + np.identity(G.number_of_nodes()))
  21. self.len_all.append(G.number_of_nodes())
  22. if features == 'id':
  23. self.feature_all.append(np.identity(max_num_nodes))
  24. elif features == 'deg':
  25. degs = np.sum(np.array(adj), 1)
  26. degs = np.expand_dims(np.pad(degs, [0, max_num_nodes - G.number_of_nodes()], 0),
  27. axis=1)
  28. self.feature_all.append(degs)
  29. elif features == 'struct':
  30. degs = np.sum(np.array(adj), 1)
  31. degs = np.expand_dims(np.pad(degs, [0, max_num_nodes - G.number_of_nodes()],
  32. 'constant'),
  33. axis=1)
  34. clusterings = np.array(list(nx.clustering(G).values()))
  35. clusterings = np.expand_dims(np.pad(clusterings,
  36. [0, max_num_nodes - G.number_of_nodes()],
  37. 'constant'),
  38. axis=1)
  39. self.feature_all.append(np.hstack([degs, clusterings]))
  40. def __len__(self):
  41. return len(self.adj_all)
  42. def __getitem__(self, idx):
  43. adj = self.adj_all[idx]
  44. if self.permutation_mode:
  45. x_idx = np.random.permutation(adj.shape[0])
  46. # self.count += 1
  47. adj = adj[np.ix_(x_idx, x_idx)]
  48. adj = np.asmatrix(adj)
  49. if self.bfs_mode:
  50. if self.bfs_mode_with_arbitrary_node_deleted:
  51. random_idx_for_delete = np.random.randint(adj.shape[0])
  52. deleted_node = adj[:, random_idx_for_delete].copy()
  53. for i in range(deleted_node.__len__()):
  54. if i >= random_idx_for_delete and i < deleted_node.__len__() - 1:
  55. deleted_node[i] = deleted_node[i + 1]
  56. elif i == deleted_node.__len__() - 1:
  57. deleted_node[i] = 0
  58. adj[:, random_idx_for_delete:adj.shape[0] - 1] = adj[:, random_idx_for_delete + 1:adj.shape[0]]
  59. adj[random_idx_for_delete:adj.shape[0] - 1, :] = adj[random_idx_for_delete + 1:adj.shape[0], :]
  60. adj = np.delete(adj, -1, axis=1)
  61. adj = np.delete(adj, -1, axis=0)
  62. G = nx.from_numpy_matrix(adj)
  63. # then do bfs in the permuted G
  64. degree_arr = np.sum(adj, axis=0)
  65. start_idx = np.argmax(degree_arr)
  66. # start_idx = np.random.randint(adj.shape[0])
  67. x_idx = np.array(bfs_seq(G, start_idx))
  68. adj = adj[np.ix_(x_idx, x_idx)]
  69. x_idx = np.insert(x_idx, x_idx.size, x_idx.size)
  70. deleted_node = deleted_node[np.ix_(x_idx)]
  71. adj = np.append(adj, deleted_node[:-1], axis=1)
  72. deleted_node = deleted_node.reshape(1, -1)
  73. adj = np.vstack([adj, deleted_node])
  74. else:
  75. G = nx.from_numpy_matrix(adj)
  76. # then do bfs in the permuted G
  77. start_idx = np.random.randint(adj.shape[0])
  78. x_idx = np.array(bfs_seq(G, start_idx))
  79. adj = adj[np.ix_(x_idx, x_idx)]
  80. num_nodes = adj.shape[0]
  81. adj_padded = np.zeros((self.max_num_nodes, self.max_num_nodes))
  82. adj_padded[:num_nodes, :num_nodes] = adj
  83. adj_decoded = np.zeros(self.max_num_nodes * (self.max_num_nodes + 1) // 2)
  84. node_idx = 0
  85. adj_vectorized = adj_padded[np.triu(np.ones((self.max_num_nodes, self.max_num_nodes))) == 1]
  86. # the following 2 lines recover the upper triangle of the adj matrix
  87. # recovered = np.zeros((self.max_num_nodes, self.max_num_nodes))
  88. # recovered[np.triu(np.ones((self.max_num_nodes, self.max_num_nodes)) ) == 1] = adj_vectorized
  89. # print(recovered)
  90. return {'adj': adj_padded,
  91. 'adj_decoded': adj_vectorized,
  92. 'num_nodes': num_nodes,
  93. 'features': self.feature_all[idx].copy()}