123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464 |
- import numpy as np
- import scipy.optimize
- import torch
- import torch.nn as nn
- from torch.autograd import Variable
- from torch import optim
- import torch.nn.functional as F
- import torch.nn.init as init
- import networkx as nx
- import matplotlib.pyplot as plt
- import model
- import baselines.graphvae.diffpoolClassesAndFunctions as diffpool
- from model import sample_sigmoid
- from baselines.graphvae import args
- from random import random
- def graph_show(G, title, colors):
- pos = nx.spring_layout(G, scale=2)
- nx.draw(G, pos, node_color=colors)
- fig = plt.gcf()
- fig.canvas.set_window_title(title)
- plt.show()
- plt.savefig('foo.png')
- class GraphConv(nn.Module):
- def __init__(self, input_dim, output_dim, add_self=False, normalize_embedding=False,
- dropout=0.0, bias=True):
- super(GraphConv, self).__init__()
- self.add_self = add_self
- self.dropout = dropout
- if dropout > 0.001:
- self.dropout_layer = nn.Dropout(p=dropout)
- self.normalize_embedding = normalize_embedding
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).cuda())
- if bias:
- self.bias = nn.Parameter(torch.FloatTensor(output_dim).cuda())
- else:
- self.bias = None
- def forward(self, x, adj):
- if self.dropout > 0.001:
- x = self.dropout_layer(x)
- y = torch.matmul(adj, x)
- if self.add_self:
- y += x
- y = torch.matmul(y, self.weight)
- if self.bias is not None:
- y = y + self.bias
- if self.normalize_embedding:
- y = F.normalize(y, p=2, dim=2)
- # print(y[0][0])
- return y
- class InnerProductDecoder(nn.Module):
- """Decoder for using inner product for prediction."""
- def __init__(self, dropout, act=torch.sigmoid):
- super(InnerProductDecoder, self).__init__()
- self.dropout = dropout
- self.act = act
- def forward(self, z):
- batch_size = z.size()[0]
- adj_size = z.size()[1]
- z = F.dropout(z, self.dropout, training=self.training)
- adj = torch.zeros(batch_size, adj_size, adj_size)
- for i in range(batch_size):
- adj[i] = self.act(torch.mm(z[i], z[i].t()))
- return adj
- class GraphVAE(nn.Module):
- def __init__(self, input_dim, hidden_dim, latent_dim, max_num_nodes, vae_args, pool='sum'):
- '''
- Args:
- input_dim: input feature dimension for node.
- hidden_dim: hidden dim for 2-layer gcn.
- latent_dim: dimension of the latent representation of graph.
- '''
- self.hidden_dim = hidden_dim
- self.vae_args = vae_args
- super(GraphVAE, self).__init__()
- self.conv1 = model.GraphConv(input_dim=input_dim, output_dim=32)
- self.bn1 = nn.BatchNorm1d(input_dim)
- self.conv2 = model.GraphConv(input_dim=32, output_dim=hidden_dim)
- self.bn2 = nn.BatchNorm1d(input_dim)
- self.act = nn.ReLU()
- self.linear = nn.Linear(input_dim * hidden_dim, 128)
- self.dropout = 0
- self.dc = InnerProductDecoder(self.dropout, act=lambda x: x)
- self.embedding_dim = 5
- self.label_dim = 2
- self.num_layers = 3
- self.DiffpoolGcnEncoderGraph = diffpool.GcnEncoderGraph(input_dim, 5, self.embedding_dim, self.label_dim,
- self.num_layers).cuda()
- if vae_args.completion_mode_small_parameter_size:
- output_dim = vae_args.number_of_missing_nodes * max_num_nodes - \
- (vae_args.number_of_missing_nodes * vae_args.number_of_missing_nodes -
- (vae_args.number_of_missing_nodes * (vae_args.number_of_missing_nodes + 1) // 2))
- self.number_of_incomplete_nodes = max_num_nodes - vae_args.number_of_missing_nodes
- self.output_dim = output_dim
- elif vae_args.GRAN_linear:
- output_dim = max_num_nodes - 1
- else:
- output_dim = max_num_nodes * (max_num_nodes + 1) // 2
- # self.vae = model.MLP_VAE_plain(hidden_dim, latent_dim, output_dim)
- if self.vae_args.graph_pooling_mode:
- self.vae = model.MLP_VAE_plain(hidden_dim, hidden_dim, output_dim)
- else:
- self.vae = model.MLP_VAE_plain(input_dim * hidden_dim, hidden_dim, output_dim)
- # self.feature_mlp = model.MLP_plain(latent_dim, latent_dim, output_dim)
- self.max_num_nodes = max_num_nodes
- self.gran = model.GRAN(self.max_num_nodes)
- for m in self.modules():
- if isinstance(m, model.GraphConv):
- m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
- elif isinstance(m, nn.BatchNorm1d):
- m.weight.data.fill_(1)
- m.bias.data.zero_()
- self.pool = pool
- #############################################################
- self.bn = True
- self.num_layers = 3
- self.num_aggs = 1
- self.bias = True
- in_dim = input_dim
- hidden_dim1 = 20
- embedding_dim1 = 20
- self.conv_first, self.conv_block, self.conv_last = self.build_conv_layers(
- in_dim, hidden_dim1, embedding_dim1, self.num_layers,
- False, normalize=True, dropout=0.0)
- #############################################################
- def recover_adj_lower(self, l):
- # NOTE: Assumes 1 per minibatch
- batch_size = l.size()[0]
- adj = torch.zeros(batch_size, self.max_num_nodes, self.max_num_nodes)
- adj[:, torch.triu(torch.ones(self.max_num_nodes, self.max_num_nodes)) == 1] = l
- return adj
- def recover_full_adj_from_lower(self, lower):
- batch_size = lower.size()[0]
- diag = torch.zeros(batch_size, lower.size()[1], lower.size()[1])
- transpose = torch.zeros(batch_size, lower.size()[1], lower.size()[1])
- i = 0
- for mat in lower:
- diag[i, :, :] = torch.diag(torch.diag(mat))
- transpose[i, :, :] = torch.transpose(mat, 0, 1)
- i += 1
- # diag = torch.diag(torch.diag(lower, 0))
- return lower + transpose - diag
- def edge_similarity_matrix(self, adj, adj_recon, matching_features,
- matching_features_recon, sim_func):
- S = torch.zeros(self.max_num_nodes, self.max_num_nodes,
- self.max_num_nodes, self.max_num_nodes)
- for i in range(self.max_num_nodes):
- for j in range(self.max_num_nodes):
- if i == j:
- for a in range(self.max_num_nodes):
- S[i, i, a, a] = adj[i, i] * adj_recon[a, a] * \
- sim_func(matching_features[i], matching_features_recon[a])
- # with feature not implemented
- # if input_features is not None:
- else:
- for a in range(self.max_num_nodes):
- for b in range(self.max_num_nodes):
- if b == a:
- continue
- S[i, j, a, b] = adj[i, j] * adj[i, i] * adj[j, j] * \
- adj_recon[a, b] * adj_recon[a, a] * adj_recon[b, b]
- return S
- def mpm(self, x_init, S, max_iters=50):
- x = x_init
- for it in range(max_iters):
- x_new = torch.zeros(self.max_num_nodes, self.max_num_nodes)
- for i in range(self.max_num_nodes):
- for a in range(self.max_num_nodes):
- x_new[i, a] = x[i, a] * S[i, i, a, a]
- pooled = [torch.max(x[j, :] * S[i, j, a, :])
- for j in range(self.max_num_nodes) if j != i]
- neigh_sim = sum(pooled)
- x_new[i, a] += neigh_sim
- norm = torch.norm(x_new)
- x = x_new / norm
- return x
- def deg_feature_similarity(self, f1, f2):
- return 1 / (abs(f1 - f2) + 1)
- def permute_adj(self, adj, curr_ind, target_ind):
- ''' Permute adjacency matrix.
- The target_ind (connectivity) should be permuted to the curr_ind position.
- '''
- # order curr_ind according to target ind
- ind = np.zeros(self.max_num_nodes, dtype=np.int)
- ind[target_ind] = curr_ind
- adj_permuted = torch.zeros((self.max_num_nodes, self.max_num_nodes))
- adj_permuted[:, :] = adj[ind, :]
- adj_permuted[:, :] = adj_permuted[:, ind]
- return adj_permuted
- def pool_graph(self, x):
- if self.pool == 'max':
- out, _ = torch.max(x, dim=1, keepdim=False)
- elif self.pool == 'sum':
- out = torch.sum(x, dim=1, keepdim=False)
- return out
- def apply_bn(self, x):
- ''' Batch normalization of 3D tensor x
- '''
- bn_module = nn.BatchNorm1d(x.size()[1]).cuda()
- return bn_module(x)
- def build_conv_layers(self, input_dim, hidden_dim, embedding_dim, num_layers, add_self,
- normalize=False, dropout=0.0):
- conv_first = GraphConv(input_dim=input_dim, output_dim=hidden_dim, add_self=add_self,
- normalize_embedding=normalize, bias=self.bias)
- conv_block = nn.ModuleList(
- [GraphConv(input_dim=hidden_dim, output_dim=hidden_dim, add_self=add_self,
- normalize_embedding=normalize, dropout=dropout, bias=self.bias)
- for i in range(num_layers - 2)])
- conv_last = GraphConv(input_dim=hidden_dim, output_dim=embedding_dim, add_self=add_self,
- normalize_embedding=normalize, bias=self.bias)
- return conv_first, conv_block, conv_last
- def forward(self, input_features, adj, batch_num_nodes):
- # embedding_mask = self.DiffpoolGcnEncoderGraph.construct_mask(self.max_num_nodes, batch_num_nodes)
- numpy_adj = adj.cpu().numpy()
- # print("**** numpy_adj")
- # print(numpy_adj)
- if self.vae_args.completion_mode_small_parameter_size:
- incomplete_adj = torch.tensor(
- numpy_adj[:, :self.number_of_incomplete_nodes, :self.number_of_incomplete_nodes],
- device='cuda:0')
- input_features = torch.tensor(
- input_features[:, :self.number_of_incomplete_nodes, :self.number_of_incomplete_nodes],
- device='cuda:0')
- elif self.vae_args.GRAN:
- batch_size = numpy_adj.shape[0]
- gran_lable = np.zeros((batch_size, self.max_num_nodes - 1))
- random_index_list = []
- # print("*** numpy_adj before change")
- # print(numpy_adj)
- for i in range(batch_size):
- num_nodes = batch_num_nodes[i]
- if self.vae_args.GRAN_arbitrary_node:
- random_index_list.append(np.random.randint(num_nodes))
- gran_lable[i, :random_index_list[i]] = \
- numpy_adj[i, :random_index_list[i], random_index_list[i]]
- gran_lable[i, random_index_list[i]:num_nodes-1] = \
- numpy_adj[i, random_index_list[i]+1:num_nodes, random_index_list[i]]
- numpy_adj[i, :num_nodes, random_index_list[i]] = 1
- numpy_adj[i, random_index_list[i], :num_nodes] = 1
- # print("*** random_index_list")
- # print(random_index_list)
- # print("*** gran_lable")
- # print(gran_lable)
- # print("*** numpy_adj")
- # print(numpy_adj)
- else:
- gran_lable[i, :num_nodes - 1] = numpy_adj[i, :num_nodes - 1, num_nodes - 1]
- numpy_adj[i, :num_nodes, num_nodes - 1] = 1
- numpy_adj[i, num_nodes - 1, :num_nodes] = 1
- gran_lable = torch.tensor(gran_lable).float()
- incomplete_adj = torch.tensor(numpy_adj, device='cuda:0')
- # print("*** in model : random_index_list")
- # print(random_index_list)
- # print("*** random_index_list")
- # print(random_index_list)
- # print("*** gran_lable")
- # print(gran_lable)
- # print("*** numpy_adj after change")
- # print(numpy_adj)
- else:
- index_list = []
- for i in range(self.vae_args.number_of_missing_nodes):
- random_index = np.random.randint(self.max_num_nodes)
- while random_index in index_list:
- random_index = np.random.randint(self.max_num_nodes)
- index_list.append(random_index)
- numpy_adj[:, :, random_index] = 0
- numpy_adj[:, random_index, :] = 0
- incomplete_adj = torch.tensor(numpy_adj, device='cuda:0')
- if self.vae_args.diffpoolGCN:
- x = self.DiffpoolGcnEncoderGraph.gcn_forward(input_features, incomplete_adj,
- self.DiffpoolGcnEncoderGraph.conv_first,
- self.DiffpoolGcnEncoderGraph.conv_block,
- self.DiffpoolGcnEncoderGraph.conv_last)
- else:
- ################################################################### embedding
- x = self.conv1(input_features, incomplete_adj)
- x = self.act(x)
- x = self.bn1(x)
- x = self.conv2(x, incomplete_adj)
- # x = self.bn2(x)
- # x = self.act(x)
- if (self.vae_args.GRAN):
- # x = self.conv_first(input_features, incomplete_adj)
- # x = self.act(x)
- # if self.bn:
- # x = self.apply_bn(x)
- # out_all = []
- # out, _ = torch.max(x, dim=1)
- # out_all.append(out)
- # for i in range(self.num_layers - 2):
- # x = self.conv_block[i](x, adj)
- # x = self.act(x)
- # if self.bn:
- # x = self.apply_bn(x)
- # out, _ = torch.max(x, dim=1)
- # out_all.append(out)
- # if self.num_aggs == 2:
- # out = torch.sum(x, dim=1)
- # out_all.append(out)
- # x = self.conv_last(x, adj)
- ################################################################### end of embedding
- if self.vae_args.GRAN_linear:
- x = x.view(-1, self.max_num_nodes * self.hidden_dim)
- h_decode, z_mu, z_lsgms = self.vae(x)
- pos_weight = torch.ones(batch_size)
- for i in range(batch_size):
- # / adj[i].shape[0]
- pos_weight[i] = torch.Tensor(
- [1 * float(adj[i].shape[0] * adj[i].shape[0] - adj[i].sum()) / adj[i].sum()])
- pos_weight = pos_weight.unsqueeze(1)
- return F.binary_cross_entropy_with_logits(h_decode.cpu(), gran_lable,
- pos_weight=pos_weight)
- else:
- pos_weight = torch.ones(batch_size)
- for i in range(batch_size):
- # / adj[i].shape[0]
- pos_weight[i] = torch.Tensor(
- [float(adj[i].shape[0] * adj[i].shape[0] - adj[i].sum()) / adj[i].sum()])
- pos_weight = pos_weight.unsqueeze(1)
- # print("************************")
- # print(pos_weight.size())
- # print(self.gran(x, batch_num_nodes).squeeze().cpu().size())
- # print(gran_lable.size())
- # print("***** in model *****")
- # print("*** model result : ")
- # print(self.gran(x, batch_num_nodes, random_index_list, self.vae_args.GRAN_arbitrary_node).squeeze().cpu())
- # print("***** gran_lable")
- # print(gran_lable)
- return F.binary_cross_entropy_with_logits(
- self.gran(x, batch_num_nodes, random_index_list, self.vae_args.GRAN_arbitrary_node).squeeze().cpu(), gran_lable,
- pos_weight=pos_weight)
- if self.vae_args.reconstruction:
- loss = F.binary_cross_entropy_with_logits(self.dc(x), incomplete_adj.cpu())
- return loss
- if self.vae_args.completion_mode_small_parameter_size:
- x = x.view(-1, self.number_of_incomplete_nodes * self.hidden_dim)
- else:
- if self.vae_args.graph_pooling_mode:
- x = self.pool_graph(x)
- else:
- x = x.view(-1, self.max_num_nodes * self.hidden_dim)
- # vae
- h_decode, z_mu, z_lsgms = self.vae(x)
- out = F.sigmoid(h_decode)
- out_tensor = out.cpu().data
- if self.vae_args.graph_matching_mode:
- recon_adj_lower = self.recover_adj_lower(out_tensor)
- recon_adj_tensor = self.recover_full_adj_from_lower(recon_adj_lower)
- out_features = torch.sum(recon_adj_tensor, 1)
- adj_data = adj.cpu().data
- adj_features = torch.sum(adj_data, 1)
- batch_size = adj_data.size(0)
- adj_permuted = torch.zeros(adj_data.size(0), adj_data.size(1), adj_data.size(2))
- for i in range(batch_size):
- S = self.edge_similarity_matrix(adj_data[i].squeeze(), recon_adj_tensor[i].squeeze(),
- adj_features[i].squeeze(), out_features[i].squeeze(),
- self.deg_feature_similarity)
- # initialization strategies
- init_corr = 1 / self.max_num_nodes
- init_assignment = torch.ones(self.max_num_nodes, self.max_num_nodes) * init_corr
- assignment = self.mpm(init_assignment, S)
- # matching
- # use negative of the assignment score since the alg finds min cost flow
- row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy())
- # order row index according to col index
- adj_permuted[i] = self.permute_adj(adj_data[i].squeeze(), row_ind, col_ind)
- adj_permuted[i] = adj_permuted[i].unsqueeze(0)
- else:
- adj_data = adj.cpu().data
- adj_permuted = adj_data
- if self.vae_args.completion_mode_small_parameter_size:
- adj_vectorized = torch.zeros(out.size())
- adj_vectorized_index = 0
- for i in range(self.vae_args.number_of_missing_nodes): # iterates over columns
- for j in range(self.max_num_nodes - i):
- adj_vectorized[:, adj_vectorized_index] = adj_permuted[:, self.max_num_nodes - j - i - 1,
- self.max_num_nodes - i - 1]
- adj_vectorized_index += 1
- else:
- adj_vectorized = adj_permuted[:, torch.triu(torch.ones(self.max_num_nodes, self.max_num_nodes)) == 1]
- adj_vectorized_var = Variable(adj_vectorized).cuda()
- adj_recon_loss = self.adj_recon_loss(adj_vectorized_var, out)
- # x2 = h_decode[0].unsqueeze(0)
- # x3 = x2.unsqueeze(0)
- # sample = sample_sigmoid(x3, sample=True, sample_time=1)
- # y = torch.zeros(self.max_num_nodes, self.max_num_nodes)
- # y[torch.triu(torch.ones(self.max_num_nodes, self.max_num_nodes)) == 1] = sample.data.cpu()
- loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp())
- loss_kl /= self.max_num_nodes * self.max_num_nodes # normalize
- print('kl: ', loss_kl)
- loss = adj_recon_loss
- return loss
- def forward_test(self, input_features, adj):
- self.max_num_nodes = 4
- adj_data = torch.zeros(self.max_num_nodes, self.max_num_nodes)
- adj_data[:4, :4] = torch.FloatTensor([[1, 1, 0, 0], [1, 1, 1, 0], [0, 1, 1, 1], [0, 0, 1, 1]])
- adj_features = torch.Tensor([2, 3, 3, 2])
- adj_data1 = torch.zeros(self.max_num_nodes, self.max_num_nodes)
- adj_data1 = torch.FloatTensor([[1, 1, 1, 0], [1, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]])
- adj_features1 = torch.Tensor([3, 3, 2, 2])
- S = self.edge_similarity_matrix(adj_data, adj_data1, adj_features, adj_features1,
- self.deg_feature_similarity)
- # initialization strategies
- init_corr = 1 / self.max_num_nodes
- init_assignment = torch.ones(self.max_num_nodes, self.max_num_nodes) * init_corr
- # init_assignment = torch.FloatTensor(4, 4)
- # init.uniform(init_assignment)
- assignment = self.mpm(init_assignment, S)
- # print('Assignment: ', assignment)
- # matching
- row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy())
- # print('row: ', row_ind)
- # print('col: ', col_ind)
- permuted_adj = self.permute_adj(adj_data, row_ind, col_ind)
- # print('permuted: ', permuted_adj)
- adj_recon_loss = self.adj_recon_loss(permuted_adj, adj_data1)
- # print(adj_data1)
- # print('diff: ', adj_recon_loss)
- def adj_recon_loss(self, adj_truth, adj_pred):
- return F.binary_cross_entropy(adj_pred, adj_truth)