You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphvae_data.py 2.5KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import networkx as nx
  2. import numpy as np
  3. import torch
  4. class GraphAdjSampler(torch.utils.data.Dataset):
  5. def __init__(self, G_list, max_num_nodes, features='id'):
  6. self.max_num_nodes = max_num_nodes
  7. self.adj_all = []
  8. self.len_all = []
  9. self.feature_all = []
  10. for G in G_list:
  11. adj = nx.to_numpy_matrix(G)
  12. # the diagonal entries are 1 since they denote node probability
  13. self.adj_all.append(
  14. np.asarray(adj) + np.identity(G.number_of_nodes()))
  15. self.len_all.append(G.number_of_nodes())
  16. if features == 'id':
  17. self.feature_all.append(np.identity(max_num_nodes))
  18. elif features == 'deg':
  19. degs = np.sum(np.array(adj), 1)
  20. degs = np.expand_dims(np.pad(degs, [0, max_num_nodes - G.number_of_nodes()], 0),
  21. axis=1)
  22. self.feature_all.append(degs)
  23. elif features == 'struct':
  24. degs = np.sum(np.array(adj), 1)
  25. degs = np.expand_dims(np.pad(degs, [0, max_num_nodes - G.number_of_nodes()],
  26. 'constant'),
  27. axis=1)
  28. clusterings = np.array(list(nx.clustering(G).values()))
  29. clusterings = np.expand_dims(np.pad(clusterings,
  30. [0, max_num_nodes - G.number_of_nodes()],
  31. 'constant'),
  32. axis=1)
  33. self.feature_all.append(np.hstack([degs, clusterings]))
  34. def __len__(self):
  35. return len(self.adj_all)
  36. def __getitem__(self, idx):
  37. adj = self.adj_all[idx]
  38. num_nodes = adj.shape[0]
  39. adj_padded = np.zeros((self.max_num_nodes, self.max_num_nodes))
  40. adj_padded[:num_nodes, :num_nodes] = adj
  41. adj_decoded = np.zeros(self.max_num_nodes * (self.max_num_nodes + 1) // 2)
  42. node_idx = 0
  43. adj_vectorized = adj_padded[np.triu(np.ones((self.max_num_nodes,self.max_num_nodes)) ) == 1]
  44. # the following 2 lines recover the upper triangle of the adj matrix
  45. #recovered = np.zeros((self.max_num_nodes, self.max_num_nodes))
  46. #recovered[np.triu(np.ones((self.max_num_nodes, self.max_num_nodes)) ) == 1] = adj_vectorized
  47. #print(recovered)
  48. return {'adj':adj_padded,
  49. 'adj_decoded':adj_vectorized,
  50. 'features':self.feature_all[idx].copy()}