|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- import networkx as nx
- import numpy as np
- import torch
-
- class GraphAdjSampler(torch.utils.data.Dataset):
- def __init__(self, G_list, max_num_nodes, features='id'):
- self.max_num_nodes = max_num_nodes
- self.adj_all = []
- self.len_all = []
- self.feature_all = []
-
- for G in G_list:
- adj = nx.to_numpy_matrix(G)
- # the diagonal entries are 1 since they denote node probability
- self.adj_all.append(
- np.asarray(adj) + np.identity(G.number_of_nodes()))
- self.len_all.append(G.number_of_nodes())
- if features == 'id':
- self.feature_all.append(np.identity(max_num_nodes))
- elif features == 'deg':
- degs = np.sum(np.array(adj), 1)
- degs = np.expand_dims(np.pad(degs, [0, max_num_nodes - G.number_of_nodes()], 0),
- axis=1)
- self.feature_all.append(degs)
- elif features == 'struct':
- degs = np.sum(np.array(adj), 1)
- degs = np.expand_dims(np.pad(degs, [0, max_num_nodes - G.number_of_nodes()],
- 'constant'),
- axis=1)
- clusterings = np.array(list(nx.clustering(G).values()))
- clusterings = np.expand_dims(np.pad(clusterings,
- [0, max_num_nodes - G.number_of_nodes()],
- 'constant'),
- axis=1)
- self.feature_all.append(np.hstack([degs, clusterings]))
-
- def __len__(self):
- return len(self.adj_all)
-
- def __getitem__(self, idx):
- adj = self.adj_all[idx]
- num_nodes = adj.shape[0]
- adj_padded = np.zeros((self.max_num_nodes, self.max_num_nodes))
- adj_padded[:num_nodes, :num_nodes] = adj
-
- adj_decoded = np.zeros(self.max_num_nodes * (self.max_num_nodes + 1) // 2)
- node_idx = 0
-
- adj_vectorized = adj_padded[np.triu(np.ones((self.max_num_nodes,self.max_num_nodes)) ) == 1]
- # the following 2 lines recover the upper triangle of the adj matrix
- #recovered = np.zeros((self.max_num_nodes, self.max_num_nodes))
- #recovered[np.triu(np.ones((self.max_num_nodes, self.max_num_nodes)) ) == 1] = adj_vectorized
- #print(recovered)
-
- return {'adj':adj_padded,
- 'adj_decoded':adj_vectorized,
- 'features':self.feature_all[idx].copy()}
-
|