You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphvae_data.py 3.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import networkx as nx
  2. import numpy as np
  3. import torch
  4. from data import bfs_seq
  5. class GraphAdjSampler(torch.utils.data.Dataset):
  6. def __init__(self, G_list, max_num_nodes, permutation_mode, bfs_mode, features='id'):
  7. self.max_num_nodes = max_num_nodes
  8. self.adj_all = []
  9. self.len_all = []
  10. self.feature_all = []
  11. self.count = 0
  12. self.permutation_mode = permutation_mode
  13. self.bfs_mode = bfs_mode
  14. for G in G_list:
  15. adj = nx.to_numpy_matrix(G)
  16. # the diagonal entries are 1 since they denote node probability
  17. self.adj_all.append(
  18. np.asarray(adj) + np.identity(G.number_of_nodes()))
  19. self.len_all.append(G.number_of_nodes())
  20. if features == 'id':
  21. self.feature_all.append(np.identity(max_num_nodes))
  22. elif features == 'deg':
  23. degs = np.sum(np.array(adj), 1)
  24. degs = np.expand_dims(np.pad(degs, [0, max_num_nodes - G.number_of_nodes()], 0),
  25. axis=1)
  26. self.feature_all.append(degs)
  27. elif features == 'struct':
  28. degs = np.sum(np.array(adj), 1)
  29. degs = np.expand_dims(np.pad(degs, [0, max_num_nodes - G.number_of_nodes()],
  30. 'constant'),
  31. axis=1)
  32. clusterings = np.array(list(nx.clustering(G).values()))
  33. clusterings = np.expand_dims(np.pad(clusterings,
  34. [0, max_num_nodes - G.number_of_nodes()],
  35. 'constant'),
  36. axis=1)
  37. self.feature_all.append(np.hstack([degs, clusterings]))
  38. def __len__(self):
  39. return len(self.adj_all)
  40. def __getitem__(self, idx):
  41. adj = self.adj_all[idx]
  42. if self.permutation_mode:
  43. x_idx = np.random.permutation(adj.shape[0])
  44. # print("*** count = " + str(self.count))
  45. # print(x_idx)
  46. self.count += 1
  47. adj = adj[np.ix_(x_idx, x_idx)]
  48. adj = np.asmatrix(adj)
  49. if self.bfs_mode:
  50. G = nx.from_numpy_matrix(adj)
  51. # then do bfs in the permuted G
  52. start_idx = np.random.randint(adj.shape[0])
  53. x_idx = np.array(bfs_seq(G, start_idx))
  54. adj = adj[np.ix_(x_idx, x_idx)]
  55. num_nodes = adj.shape[0]
  56. adj_padded = np.zeros((self.max_num_nodes, self.max_num_nodes))
  57. adj_padded[:num_nodes, :num_nodes] = adj
  58. adj_decoded = np.zeros(self.max_num_nodes * (self.max_num_nodes + 1) // 2)
  59. node_idx = 0
  60. adj_vectorized = adj_padded[np.triu(np.ones((self.max_num_nodes,self.max_num_nodes)) ) == 1]
  61. # the following 2 lines recover the upper triangle of the adj matrix
  62. #recovered = np.zeros((self.max_num_nodes, self.max_num_nodes))
  63. #recovered[np.triu(np.ones((self.max_num_nodes, self.max_num_nodes)) ) == 1] = adj_vectorized
  64. #print(recovered)
  65. return {'adj':adj_padded,
  66. 'adj_decoded':adj_vectorized,
  67. 'features':self.feature_all[idx].copy()}