import networkx as nx import numpy as np import torch from data import bfs_seq class GraphAdjSampler(torch.utils.data.Dataset): def __init__(self, G_list, max_num_nodes, permutation_mode, bfs_mode, features='id'): self.max_num_nodes = max_num_nodes self.adj_all = [] self.len_all = [] self.feature_all = [] self.count = 0 self.permutation_mode = permutation_mode self.bfs_mode = bfs_mode for G in G_list: adj = nx.to_numpy_matrix(G) # the diagonal entries are 1 since they denote node probability self.adj_all.append( np.asarray(adj) + np.identity(G.number_of_nodes())) self.len_all.append(G.number_of_nodes()) if features == 'id': self.feature_all.append(np.identity(max_num_nodes)) elif features == 'deg': degs = np.sum(np.array(adj), 1) degs = np.expand_dims(np.pad(degs, [0, max_num_nodes - G.number_of_nodes()], 0), axis=1) self.feature_all.append(degs) elif features == 'struct': degs = np.sum(np.array(adj), 1) degs = np.expand_dims(np.pad(degs, [0, max_num_nodes - G.number_of_nodes()], 'constant'), axis=1) clusterings = np.array(list(nx.clustering(G).values())) clusterings = np.expand_dims(np.pad(clusterings, [0, max_num_nodes - G.number_of_nodes()], 'constant'), axis=1) self.feature_all.append(np.hstack([degs, clusterings])) def __len__(self): return len(self.adj_all) def __getitem__(self, idx): adj = self.adj_all[idx] if self.permutation_mode: x_idx = np.random.permutation(adj.shape[0]) # print("*** count = " + str(self.count)) # print(x_idx) self.count += 1 adj = adj[np.ix_(x_idx, x_idx)] adj = np.asmatrix(adj) if self.bfs_mode: G = nx.from_numpy_matrix(adj) # then do bfs in the permuted G start_idx = np.random.randint(adj.shape[0]) x_idx = np.array(bfs_seq(G, start_idx)) adj = adj[np.ix_(x_idx, x_idx)] num_nodes = adj.shape[0] adj_padded = np.zeros((self.max_num_nodes, self.max_num_nodes)) adj_padded[:num_nodes, :num_nodes] = adj adj_decoded = np.zeros(self.max_num_nodes * (self.max_num_nodes + 1) // 2) node_idx = 0 adj_vectorized = adj_padded[np.triu(np.ones((self.max_num_nodes,self.max_num_nodes)) ) == 1] # the following 2 lines recover the upper triangle of the adj matrix #recovered = np.zeros((self.max_num_nodes, self.max_num_nodes)) #recovered[np.triu(np.ones((self.max_num_nodes, self.max_num_nodes)) ) == 1] = adj_vectorized #print(recovered) return {'adj':adj_padded, 'adj_decoded':adj_vectorized, 'features':self.feature_all[idx].copy()}