import numpy as np import scipy.optimize import torch import torch.nn as nn from torch.autograd import Variable from torch import optim import torch.nn.functional as F import torch.nn.init as init import networkx as nx import matplotlib.pyplot as plt import model import baselines.graphvae.diffpoolClassesAndFunctions as diffpool from model import sample_sigmoid from baselines.graphvae import args from random import random def graph_show(G, title, colors): pos = nx.spring_layout(G, scale=2) nx.draw(G, pos, node_color=colors) fig = plt.gcf() fig.canvas.set_window_title(title) plt.show() plt.savefig('foo.png') class GraphConv(nn.Module): def __init__(self, input_dim, output_dim, add_self=False, normalize_embedding=False, dropout=0.0, bias=True): super(GraphConv, self).__init__() self.add_self = add_self self.dropout = dropout if dropout > 0.001: self.dropout_layer = nn.Dropout(p=dropout) self.normalize_embedding = normalize_embedding self.input_dim = input_dim self.output_dim = output_dim self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).cuda()) if bias: self.bias = nn.Parameter(torch.FloatTensor(output_dim).cuda()) else: self.bias = None def forward(self, x, adj): if self.dropout > 0.001: x = self.dropout_layer(x) y = torch.matmul(adj, x) if self.add_self: y += x y = torch.matmul(y, self.weight) if self.bias is not None: y = y + self.bias if self.normalize_embedding: y = F.normalize(y, p=2, dim=2) # print(y[0][0]) return y class InnerProductDecoder(nn.Module): """Decoder for using inner product for prediction.""" def __init__(self, dropout, act=torch.sigmoid): super(InnerProductDecoder, self).__init__() self.dropout = dropout self.act = act def forward(self, z): batch_size = z.size()[0] adj_size = z.size()[1] z = F.dropout(z, self.dropout, training=self.training) adj = torch.zeros(batch_size, adj_size, adj_size) for i in range(batch_size): adj[i] = self.act(torch.mm(z[i], z[i].t())) return adj class GraphVAE(nn.Module): def __init__(self, input_dim, hidden_dim, latent_dim, max_num_nodes, vae_args, pool='sum'): ''' Args: input_dim: input feature dimension for node. hidden_dim: hidden dim for 2-layer gcn. latent_dim: dimension of the latent representation of graph. ''' self.hidden_dim = hidden_dim self.vae_args = vae_args super(GraphVAE, self).__init__() self.conv1 = model.GraphConv(input_dim=input_dim, output_dim=32) self.bn1 = nn.BatchNorm1d(input_dim) self.conv2 = model.GraphConv(input_dim=32, output_dim=hidden_dim) self.bn2 = nn.BatchNorm1d(input_dim) self.act = nn.ReLU() self.linear = nn.Linear(input_dim * hidden_dim, 128) self.dropout = 0 self.dc = InnerProductDecoder(self.dropout, act=lambda x: x) self.embedding_dim = 5 self.label_dim = 2 self.num_layers = 3 self.DiffpoolGcnEncoderGraph = diffpool.GcnEncoderGraph(input_dim, 5, self.embedding_dim, self.label_dim, self.num_layers).cuda() if vae_args.completion_mode_small_parameter_size: output_dim = vae_args.number_of_missing_nodes * max_num_nodes - \ (vae_args.number_of_missing_nodes * vae_args.number_of_missing_nodes - (vae_args.number_of_missing_nodes * (vae_args.number_of_missing_nodes + 1) // 2)) self.number_of_incomplete_nodes = max_num_nodes - vae_args.number_of_missing_nodes self.output_dim = output_dim elif vae_args.GRAN_linear: output_dim = max_num_nodes - 1 else: output_dim = max_num_nodes * (max_num_nodes + 1) // 2 # self.vae = model.MLP_VAE_plain(hidden_dim, latent_dim, output_dim) if self.vae_args.graph_pooling_mode: self.vae = model.MLP_VAE_plain(hidden_dim, hidden_dim, output_dim) else: self.vae = model.MLP_VAE_plain(input_dim * hidden_dim, hidden_dim, output_dim) # self.feature_mlp = model.MLP_plain(latent_dim, latent_dim, output_dim) self.max_num_nodes = max_num_nodes self.gran = model.GRAN(self.max_num_nodes) for m in self.modules(): if isinstance(m, model.GraphConv): m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu')) elif isinstance(m, nn.BatchNorm1d): m.weight.data.fill_(1) m.bias.data.zero_() self.pool = pool ############################################################# self.bn = True self.num_layers = 3 self.num_aggs = 1 self.bias = True in_dim = input_dim hidden_dim1 = 20 embedding_dim1 = 20 self.conv_first, self.conv_block, self.conv_last = self.build_conv_layers( in_dim, hidden_dim1, embedding_dim1, self.num_layers, False, normalize=True, dropout=0.0) ############################################################# def recover_adj_lower(self, l): # NOTE: Assumes 1 per minibatch batch_size = l.size()[0] adj = torch.zeros(batch_size, self.max_num_nodes, self.max_num_nodes) adj[:, torch.triu(torch.ones(self.max_num_nodes, self.max_num_nodes)) == 1] = l return adj def recover_full_adj_from_lower(self, lower): batch_size = lower.size()[0] diag = torch.zeros(batch_size, lower.size()[1], lower.size()[1]) transpose = torch.zeros(batch_size, lower.size()[1], lower.size()[1]) i = 0 for mat in lower: diag[i, :, :] = torch.diag(torch.diag(mat)) transpose[i, :, :] = torch.transpose(mat, 0, 1) i += 1 # diag = torch.diag(torch.diag(lower, 0)) return lower + transpose - diag def edge_similarity_matrix(self, adj, adj_recon, matching_features, matching_features_recon, sim_func): S = torch.zeros(self.max_num_nodes, self.max_num_nodes, self.max_num_nodes, self.max_num_nodes) for i in range(self.max_num_nodes): for j in range(self.max_num_nodes): if i == j: for a in range(self.max_num_nodes): S[i, i, a, a] = adj[i, i] * adj_recon[a, a] * \ sim_func(matching_features[i], matching_features_recon[a]) # with feature not implemented # if input_features is not None: else: for a in range(self.max_num_nodes): for b in range(self.max_num_nodes): if b == a: continue S[i, j, a, b] = adj[i, j] * adj[i, i] * adj[j, j] * \ adj_recon[a, b] * adj_recon[a, a] * adj_recon[b, b] return S def mpm(self, x_init, S, max_iters=50): x = x_init for it in range(max_iters): x_new = torch.zeros(self.max_num_nodes, self.max_num_nodes) for i in range(self.max_num_nodes): for a in range(self.max_num_nodes): x_new[i, a] = x[i, a] * S[i, i, a, a] pooled = [torch.max(x[j, :] * S[i, j, a, :]) for j in range(self.max_num_nodes) if j != i] neigh_sim = sum(pooled) x_new[i, a] += neigh_sim norm = torch.norm(x_new) x = x_new / norm return x def deg_feature_similarity(self, f1, f2): return 1 / (abs(f1 - f2) + 1) def permute_adj(self, adj, curr_ind, target_ind): ''' Permute adjacency matrix. The target_ind (connectivity) should be permuted to the curr_ind position. ''' # order curr_ind according to target ind ind = np.zeros(self.max_num_nodes, dtype=np.int) ind[target_ind] = curr_ind adj_permuted = torch.zeros((self.max_num_nodes, self.max_num_nodes)) adj_permuted[:, :] = adj[ind, :] adj_permuted[:, :] = adj_permuted[:, ind] return adj_permuted def pool_graph(self, x): if self.pool == 'max': out, _ = torch.max(x, dim=1, keepdim=False) elif self.pool == 'sum': out = torch.sum(x, dim=1, keepdim=False) return out def apply_bn(self, x): ''' Batch normalization of 3D tensor x ''' bn_module = nn.BatchNorm1d(x.size()[1]).cuda() return bn_module(x) def build_conv_layers(self, input_dim, hidden_dim, embedding_dim, num_layers, add_self, normalize=False, dropout=0.0): conv_first = GraphConv(input_dim=input_dim, output_dim=hidden_dim, add_self=add_self, normalize_embedding=normalize, bias=self.bias) conv_block = nn.ModuleList( [GraphConv(input_dim=hidden_dim, output_dim=hidden_dim, add_self=add_self, normalize_embedding=normalize, dropout=dropout, bias=self.bias) for i in range(num_layers - 2)]) conv_last = GraphConv(input_dim=hidden_dim, output_dim=embedding_dim, add_self=add_self, normalize_embedding=normalize, bias=self.bias) return conv_first, conv_block, conv_last def forward(self, input_features, adj, batch_num_nodes): # embedding_mask = self.DiffpoolGcnEncoderGraph.construct_mask(self.max_num_nodes, batch_num_nodes) numpy_adj = adj.cpu().numpy() # print("**** numpy_adj") # print(numpy_adj) if self.vae_args.completion_mode_small_parameter_size: incomplete_adj = torch.tensor( numpy_adj[:, :self.number_of_incomplete_nodes, :self.number_of_incomplete_nodes], device='cuda:0') input_features = torch.tensor( input_features[:, :self.number_of_incomplete_nodes, :self.number_of_incomplete_nodes], device='cuda:0') elif self.vae_args.GRAN: batch_size = numpy_adj.shape[0] gran_lable = np.zeros((batch_size, self.max_num_nodes - 1)) random_index_list = [] # print("*** numpy_adj before change") # print(numpy_adj) for i in range(batch_size): num_nodes = batch_num_nodes[i] if self.vae_args.GRAN_arbitrary_node: random_index_list.append(np.random.randint(num_nodes)) gran_lable[i, :random_index_list[i]] = \ numpy_adj[i, :random_index_list[i], random_index_list[i]] gran_lable[i, random_index_list[i]:num_nodes-1] = \ numpy_adj[i, random_index_list[i]+1:num_nodes, random_index_list[i]] numpy_adj[i, :num_nodes, random_index_list[i]] = 1 numpy_adj[i, random_index_list[i], :num_nodes] = 1 # print("*** random_index_list") # print(random_index_list) # print("*** gran_lable") # print(gran_lable) # print("*** numpy_adj") # print(numpy_adj) else: gran_lable[i, :num_nodes - 1] = numpy_adj[i, :num_nodes - 1, num_nodes - 1] numpy_adj[i, :num_nodes, num_nodes - 1] = 1 numpy_adj[i, num_nodes - 1, :num_nodes] = 1 gran_lable = torch.tensor(gran_lable).float() incomplete_adj = torch.tensor(numpy_adj, device='cuda:0') # print("*** in model : random_index_list") # print(random_index_list) # print("*** random_index_list") # print(random_index_list) # print("*** gran_lable") # print(gran_lable) # print("*** numpy_adj after change") # print(numpy_adj) else: index_list = [] for i in range(self.vae_args.number_of_missing_nodes): random_index = np.random.randint(self.max_num_nodes) while random_index in index_list: random_index = np.random.randint(self.max_num_nodes) index_list.append(random_index) numpy_adj[:, :, random_index] = 0 numpy_adj[:, random_index, :] = 0 incomplete_adj = torch.tensor(numpy_adj, device='cuda:0') if self.vae_args.diffpoolGCN: x = self.DiffpoolGcnEncoderGraph.gcn_forward(input_features, incomplete_adj, self.DiffpoolGcnEncoderGraph.conv_first, self.DiffpoolGcnEncoderGraph.conv_block, self.DiffpoolGcnEncoderGraph.conv_last) else: ################################################################### embedding x = self.conv1(input_features, incomplete_adj) x = self.act(x) x = self.bn1(x) x = self.conv2(x, incomplete_adj) # x = self.bn2(x) # x = self.act(x) if (self.vae_args.GRAN): # x = self.conv_first(input_features, incomplete_adj) # x = self.act(x) # if self.bn: # x = self.apply_bn(x) # out_all = [] # out, _ = torch.max(x, dim=1) # out_all.append(out) # for i in range(self.num_layers - 2): # x = self.conv_block[i](x, adj) # x = self.act(x) # if self.bn: # x = self.apply_bn(x) # out, _ = torch.max(x, dim=1) # out_all.append(out) # if self.num_aggs == 2: # out = torch.sum(x, dim=1) # out_all.append(out) # x = self.conv_last(x, adj) ################################################################### end of embedding if self.vae_args.GRAN_linear: x = x.view(-1, self.max_num_nodes * self.hidden_dim) h_decode, z_mu, z_lsgms = self.vae(x) pos_weight = torch.ones(batch_size) for i in range(batch_size): # / adj[i].shape[0] pos_weight[i] = torch.Tensor( [1 * float(adj[i].shape[0] * adj[i].shape[0] - adj[i].sum()) / adj[i].sum()]) pos_weight = pos_weight.unsqueeze(1) return F.binary_cross_entropy_with_logits(h_decode.cpu(), gran_lable, pos_weight=pos_weight) else: pos_weight = torch.ones(batch_size) for i in range(batch_size): # / adj[i].shape[0] pos_weight[i] = torch.Tensor( [float(adj[i].shape[0] * adj[i].shape[0] - adj[i].sum()) / adj[i].sum()]) pos_weight = pos_weight.unsqueeze(1) # print("************************") # print(pos_weight.size()) # print(self.gran(x, batch_num_nodes).squeeze().cpu().size()) # print(gran_lable.size()) # print("***** in model *****") # print("*** model result : ") # print(self.gran(x, batch_num_nodes, random_index_list, self.vae_args.GRAN_arbitrary_node).squeeze().cpu()) # print("***** gran_lable") # print(gran_lable) return F.binary_cross_entropy_with_logits( self.gran(x, batch_num_nodes, random_index_list, self.vae_args.GRAN_arbitrary_node).squeeze().cpu(), gran_lable, pos_weight=pos_weight) if self.vae_args.reconstruction: loss = F.binary_cross_entropy_with_logits(self.dc(x), incomplete_adj.cpu()) return loss if self.vae_args.completion_mode_small_parameter_size: x = x.view(-1, self.number_of_incomplete_nodes * self.hidden_dim) else: if self.vae_args.graph_pooling_mode: x = self.pool_graph(x) else: x = x.view(-1, self.max_num_nodes * self.hidden_dim) # vae h_decode, z_mu, z_lsgms = self.vae(x) out = F.sigmoid(h_decode) out_tensor = out.cpu().data if self.vae_args.graph_matching_mode: recon_adj_lower = self.recover_adj_lower(out_tensor) recon_adj_tensor = self.recover_full_adj_from_lower(recon_adj_lower) out_features = torch.sum(recon_adj_tensor, 1) adj_data = adj.cpu().data adj_features = torch.sum(adj_data, 1) batch_size = adj_data.size(0) adj_permuted = torch.zeros(adj_data.size(0), adj_data.size(1), adj_data.size(2)) for i in range(batch_size): S = self.edge_similarity_matrix(adj_data[i].squeeze(), recon_adj_tensor[i].squeeze(), adj_features[i].squeeze(), out_features[i].squeeze(), self.deg_feature_similarity) # initialization strategies init_corr = 1 / self.max_num_nodes init_assignment = torch.ones(self.max_num_nodes, self.max_num_nodes) * init_corr assignment = self.mpm(init_assignment, S) # matching # use negative of the assignment score since the alg finds min cost flow row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy()) # order row index according to col index adj_permuted[i] = self.permute_adj(adj_data[i].squeeze(), row_ind, col_ind) adj_permuted[i] = adj_permuted[i].unsqueeze(0) else: adj_data = adj.cpu().data adj_permuted = adj_data if self.vae_args.completion_mode_small_parameter_size: adj_vectorized = torch.zeros(out.size()) adj_vectorized_index = 0 for i in range(self.vae_args.number_of_missing_nodes): # iterates over columns for j in range(self.max_num_nodes - i): adj_vectorized[:, adj_vectorized_index] = adj_permuted[:, self.max_num_nodes - j - i - 1, self.max_num_nodes - i - 1] adj_vectorized_index += 1 else: adj_vectorized = adj_permuted[:, torch.triu(torch.ones(self.max_num_nodes, self.max_num_nodes)) == 1] adj_vectorized_var = Variable(adj_vectorized).cuda() adj_recon_loss = self.adj_recon_loss(adj_vectorized_var, out) # x2 = h_decode[0].unsqueeze(0) # x3 = x2.unsqueeze(0) # sample = sample_sigmoid(x3, sample=True, sample_time=1) # y = torch.zeros(self.max_num_nodes, self.max_num_nodes) # y[torch.triu(torch.ones(self.max_num_nodes, self.max_num_nodes)) == 1] = sample.data.cpu() loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp()) loss_kl /= self.max_num_nodes * self.max_num_nodes # normalize print('kl: ', loss_kl) loss = adj_recon_loss return loss def forward_test(self, input_features, adj): self.max_num_nodes = 4 adj_data = torch.zeros(self.max_num_nodes, self.max_num_nodes) adj_data[:4, :4] = torch.FloatTensor([[1, 1, 0, 0], [1, 1, 1, 0], [0, 1, 1, 1], [0, 0, 1, 1]]) adj_features = torch.Tensor([2, 3, 3, 2]) adj_data1 = torch.zeros(self.max_num_nodes, self.max_num_nodes) adj_data1 = torch.FloatTensor([[1, 1, 1, 0], [1, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]]) adj_features1 = torch.Tensor([3, 3, 2, 2]) S = self.edge_similarity_matrix(adj_data, adj_data1, adj_features, adj_features1, self.deg_feature_similarity) # initialization strategies init_corr = 1 / self.max_num_nodes init_assignment = torch.ones(self.max_num_nodes, self.max_num_nodes) * init_corr # init_assignment = torch.FloatTensor(4, 4) # init.uniform(init_assignment) assignment = self.mpm(init_assignment, S) # print('Assignment: ', assignment) # matching row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy()) # print('row: ', row_ind) # print('col: ', col_ind) permuted_adj = self.permute_adj(adj_data, row_ind, col_ind) # print('permuted: ', permuted_adj) adj_recon_loss = self.adj_recon_loss(permuted_adj, adj_data1) # print(adj_data1) # print('diff: ', adj_recon_loss) def adj_recon_loss(self, adj_truth, adj_pred): return F.binary_cross_entropy(adj_pred, adj_truth)