| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464 | import numpy as np
import scipy.optimize
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
import torch.nn.init as init
import networkx as nx
import matplotlib.pyplot as plt
import model
import baselines.graphvae.diffpoolClassesAndFunctions as diffpool
from model import sample_sigmoid
from baselines.graphvae import args
from random import random
def graph_show(G, title, colors):
    pos = nx.spring_layout(G, scale=2)
    nx.draw(G, pos, node_color=colors)
    fig = plt.gcf()
    fig.canvas.set_window_title(title)
    plt.show()
    plt.savefig('foo.png')
class GraphConv(nn.Module):
    def __init__(self, input_dim, output_dim, add_self=False, normalize_embedding=False,
                 dropout=0.0, bias=True):
        super(GraphConv, self).__init__()
        self.add_self = add_self
        self.dropout = dropout
        if dropout > 0.001:
            self.dropout_layer = nn.Dropout(p=dropout)
        self.normalize_embedding = normalize_embedding
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).cuda())
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(output_dim).cuda())
        else:
            self.bias = None
    def forward(self, x, adj):
        if self.dropout > 0.001:
            x = self.dropout_layer(x)
        y = torch.matmul(adj, x)
        if self.add_self:
            y += x
        y = torch.matmul(y, self.weight)
        if self.bias is not None:
            y = y + self.bias
        if self.normalize_embedding:
            y = F.normalize(y, p=2, dim=2)
            # print(y[0][0])
        return y
class InnerProductDecoder(nn.Module):
    """Decoder for using inner product for prediction."""
    def __init__(self, dropout, act=torch.sigmoid):
        super(InnerProductDecoder, self).__init__()
        self.dropout = dropout
        self.act = act
    def forward(self, z):
        batch_size = z.size()[0]
        adj_size = z.size()[1]
        z = F.dropout(z, self.dropout, training=self.training)
        adj = torch.zeros(batch_size, adj_size, adj_size)
        for i in range(batch_size):
            adj[i] = self.act(torch.mm(z[i], z[i].t()))
        return adj
class GraphVAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, max_num_nodes, vae_args, pool='sum'):
        '''
        Args:
            input_dim: input feature dimension for node.
            hidden_dim: hidden dim for 2-layer gcn.
            latent_dim: dimension of the latent representation of graph.
        '''
        self.hidden_dim = hidden_dim
        self.vae_args = vae_args
        super(GraphVAE, self).__init__()
        self.conv1 = model.GraphConv(input_dim=input_dim, output_dim=32)
        self.bn1 = nn.BatchNorm1d(input_dim)
        self.conv2 = model.GraphConv(input_dim=32, output_dim=hidden_dim)
        self.bn2 = nn.BatchNorm1d(input_dim)
        self.act = nn.ReLU()
        self.linear = nn.Linear(input_dim * hidden_dim, 128)
        self.dropout = 0
        self.dc = InnerProductDecoder(self.dropout, act=lambda x: x)
        self.embedding_dim = 5
        self.label_dim = 2
        self.num_layers = 3
        self.DiffpoolGcnEncoderGraph = diffpool.GcnEncoderGraph(input_dim, 5, self.embedding_dim, self.label_dim,
                                                                self.num_layers).cuda()
        if vae_args.completion_mode_small_parameter_size:
            output_dim = vae_args.number_of_missing_nodes * max_num_nodes - \
                         (vae_args.number_of_missing_nodes * vae_args.number_of_missing_nodes -
                          (vae_args.number_of_missing_nodes * (vae_args.number_of_missing_nodes + 1) // 2))
            self.number_of_incomplete_nodes = max_num_nodes - vae_args.number_of_missing_nodes
            self.output_dim = output_dim
        elif vae_args.GRAN_linear:
            output_dim = max_num_nodes - 1
        else:
            output_dim = max_num_nodes * (max_num_nodes + 1) // 2
        # self.vae = model.MLP_VAE_plain(hidden_dim, latent_dim, output_dim)
        if self.vae_args.graph_pooling_mode:
            self.vae = model.MLP_VAE_plain(hidden_dim, hidden_dim, output_dim)
        else:
            self.vae = model.MLP_VAE_plain(input_dim * hidden_dim, hidden_dim, output_dim)
        # self.feature_mlp = model.MLP_plain(latent_dim, latent_dim, output_dim)
        self.max_num_nodes = max_num_nodes
        self.gran = model.GRAN(self.max_num_nodes)
        for m in self.modules():
            if isinstance(m, model.GraphConv):
                m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
        self.pool = pool
        #############################################################
        self.bn = True
        self.num_layers = 3
        self.num_aggs = 1
        self.bias = True
        in_dim = input_dim
        hidden_dim1 = 20
        embedding_dim1 = 20
        self.conv_first, self.conv_block, self.conv_last = self.build_conv_layers(
            in_dim, hidden_dim1, embedding_dim1, self.num_layers,
            False, normalize=True, dropout=0.0)
        #############################################################
    def recover_adj_lower(self, l):
        # NOTE: Assumes 1 per minibatch
        batch_size = l.size()[0]
        adj = torch.zeros(batch_size, self.max_num_nodes, self.max_num_nodes)
        adj[:, torch.triu(torch.ones(self.max_num_nodes, self.max_num_nodes)) == 1] = l
        return adj
    def recover_full_adj_from_lower(self, lower):
        batch_size = lower.size()[0]
        diag = torch.zeros(batch_size, lower.size()[1], lower.size()[1])
        transpose = torch.zeros(batch_size, lower.size()[1], lower.size()[1])
        i = 0
        for mat in lower:
            diag[i, :, :] = torch.diag(torch.diag(mat))
            transpose[i, :, :] = torch.transpose(mat, 0, 1)
            i += 1
        # diag = torch.diag(torch.diag(lower, 0))
        return lower + transpose - diag
    def edge_similarity_matrix(self, adj, adj_recon, matching_features,
                               matching_features_recon, sim_func):
        S = torch.zeros(self.max_num_nodes, self.max_num_nodes,
                        self.max_num_nodes, self.max_num_nodes)
        for i in range(self.max_num_nodes):
            for j in range(self.max_num_nodes):
                if i == j:
                    for a in range(self.max_num_nodes):
                        S[i, i, a, a] = adj[i, i] * adj_recon[a, a] * \
                                        sim_func(matching_features[i], matching_features_recon[a])
                        # with feature not implemented
                        # if input_features is not None:
                else:
                    for a in range(self.max_num_nodes):
                        for b in range(self.max_num_nodes):
                            if b == a:
                                continue
                            S[i, j, a, b] = adj[i, j] * adj[i, i] * adj[j, j] * \
                                            adj_recon[a, b] * adj_recon[a, a] * adj_recon[b, b]
        return S
    def mpm(self, x_init, S, max_iters=50):
        x = x_init
        for it in range(max_iters):
            x_new = torch.zeros(self.max_num_nodes, self.max_num_nodes)
            for i in range(self.max_num_nodes):
                for a in range(self.max_num_nodes):
                    x_new[i, a] = x[i, a] * S[i, i, a, a]
                    pooled = [torch.max(x[j, :] * S[i, j, a, :])
                              for j in range(self.max_num_nodes) if j != i]
                    neigh_sim = sum(pooled)
                    x_new[i, a] += neigh_sim
            norm = torch.norm(x_new)
            x = x_new / norm
        return x
    def deg_feature_similarity(self, f1, f2):
        return 1 / (abs(f1 - f2) + 1)
    def permute_adj(self, adj, curr_ind, target_ind):
        ''' Permute adjacency matrix.
          The target_ind (connectivity) should be permuted to the curr_ind position.
        '''
        # order curr_ind according to target ind
        ind = np.zeros(self.max_num_nodes, dtype=np.int)
        ind[target_ind] = curr_ind
        adj_permuted = torch.zeros((self.max_num_nodes, self.max_num_nodes))
        adj_permuted[:, :] = adj[ind, :]
        adj_permuted[:, :] = adj_permuted[:, ind]
        return adj_permuted
    def pool_graph(self, x):
        if self.pool == 'max':
            out, _ = torch.max(x, dim=1, keepdim=False)
        elif self.pool == 'sum':
            out = torch.sum(x, dim=1, keepdim=False)
        return out
    def apply_bn(self, x):
        ''' Batch normalization of 3D tensor x
        '''
        bn_module = nn.BatchNorm1d(x.size()[1]).cuda()
        return bn_module(x)
    def build_conv_layers(self, input_dim, hidden_dim, embedding_dim, num_layers, add_self,
                          normalize=False, dropout=0.0):
        conv_first = GraphConv(input_dim=input_dim, output_dim=hidden_dim, add_self=add_self,
                               normalize_embedding=normalize, bias=self.bias)
        conv_block = nn.ModuleList(
            [GraphConv(input_dim=hidden_dim, output_dim=hidden_dim, add_self=add_self,
                       normalize_embedding=normalize, dropout=dropout, bias=self.bias)
             for i in range(num_layers - 2)])
        conv_last = GraphConv(input_dim=hidden_dim, output_dim=embedding_dim, add_self=add_self,
                              normalize_embedding=normalize, bias=self.bias)
        return conv_first, conv_block, conv_last
    def forward(self, input_features, adj, batch_num_nodes):
        # embedding_mask = self.DiffpoolGcnEncoderGraph.construct_mask(self.max_num_nodes, batch_num_nodes)
        numpy_adj = adj.cpu().numpy()
        # print("**** numpy_adj")
        # print(numpy_adj)
        if self.vae_args.completion_mode_small_parameter_size:
            incomplete_adj = torch.tensor(
                numpy_adj[:, :self.number_of_incomplete_nodes, :self.number_of_incomplete_nodes],
                device='cuda:0')
            input_features = torch.tensor(
                input_features[:, :self.number_of_incomplete_nodes, :self.number_of_incomplete_nodes],
                device='cuda:0')
        elif self.vae_args.GRAN:
            batch_size = numpy_adj.shape[0]
            gran_lable = np.zeros((batch_size, self.max_num_nodes - 1))
            random_index_list = []
            # print("*** numpy_adj before change")
            # print(numpy_adj)
            for i in range(batch_size):
                num_nodes = batch_num_nodes[i]
                if self.vae_args.GRAN_arbitrary_node:
                    random_index_list.append(np.random.randint(num_nodes))
                    gran_lable[i, :random_index_list[i]] = \
                        numpy_adj[i, :random_index_list[i], random_index_list[i]]
                    gran_lable[i, random_index_list[i]:num_nodes-1] = \
                        numpy_adj[i, random_index_list[i]+1:num_nodes, random_index_list[i]]
                    numpy_adj[i, :num_nodes, random_index_list[i]] = 1
                    numpy_adj[i, random_index_list[i], :num_nodes] = 1
                    # print("***  random_index_list")
                    # print(random_index_list)
                    # print("***  gran_lable")
                    # print(gran_lable)
                    # print("***  numpy_adj")
                    # print(numpy_adj)
                else:
                    gran_lable[i, :num_nodes - 1] = numpy_adj[i, :num_nodes - 1, num_nodes - 1]
                    numpy_adj[i, :num_nodes, num_nodes - 1] = 1
                    numpy_adj[i, num_nodes - 1, :num_nodes] = 1
            gran_lable = torch.tensor(gran_lable).float()
            incomplete_adj = torch.tensor(numpy_adj, device='cuda:0')
            # print("*** in model : random_index_list")
            # print(random_index_list)
            # print("***  random_index_list")
            # print(random_index_list)
            # print("***  gran_lable")
            # print(gran_lable)
            # print("***  numpy_adj after change")
            # print(numpy_adj)
        else:
            index_list = []
            for i in range(self.vae_args.number_of_missing_nodes):
                random_index = np.random.randint(self.max_num_nodes)
                while random_index in index_list:
                    random_index = np.random.randint(self.max_num_nodes)
                index_list.append(random_index)
                numpy_adj[:, :, random_index] = 0
                numpy_adj[:, random_index, :] = 0
            incomplete_adj = torch.tensor(numpy_adj, device='cuda:0')
        if self.vae_args.diffpoolGCN:
            x = self.DiffpoolGcnEncoderGraph.gcn_forward(input_features, incomplete_adj,
                                                         self.DiffpoolGcnEncoderGraph.conv_first,
                                                         self.DiffpoolGcnEncoderGraph.conv_block,
                                                         self.DiffpoolGcnEncoderGraph.conv_last)
        else:
            ################################################################### embedding
            x = self.conv1(input_features, incomplete_adj)
            x = self.act(x)
            x = self.bn1(x)
            x = self.conv2(x, incomplete_adj)
            # x = self.bn2(x)
            # x = self.act(x)
        if (self.vae_args.GRAN):
            # x = self.conv_first(input_features, incomplete_adj)
            # x = self.act(x)
            # if self.bn:
            #     x = self.apply_bn(x)
            # out_all = []
            # out, _ = torch.max(x, dim=1)
            # out_all.append(out)
            # for i in range(self.num_layers - 2):
            #     x = self.conv_block[i](x, adj)
            #     x = self.act(x)
            #     if self.bn:
            #         x = self.apply_bn(x)
            #     out, _ = torch.max(x, dim=1)
            #     out_all.append(out)
            #     if self.num_aggs == 2:
            #         out = torch.sum(x, dim=1)
            #         out_all.append(out)
            # x = self.conv_last(x, adj)
            ################################################################### end of embedding
            if self.vae_args.GRAN_linear:
                x = x.view(-1, self.max_num_nodes * self.hidden_dim)
                h_decode, z_mu, z_lsgms = self.vae(x)
                pos_weight = torch.ones(batch_size)
                for i in range(batch_size):
                    # / adj[i].shape[0]
                    pos_weight[i] = torch.Tensor(
                        [1 * float(adj[i].shape[0] * adj[i].shape[0] - adj[i].sum()) / adj[i].sum()])
                pos_weight = pos_weight.unsqueeze(1)
                return F.binary_cross_entropy_with_logits(h_decode.cpu(), gran_lable,
                                                          pos_weight=pos_weight)
            else:
                pos_weight = torch.ones(batch_size)
                for i in range(batch_size):
                    # / adj[i].shape[0]
                    pos_weight[i] = torch.Tensor(
                        [float(adj[i].shape[0] * adj[i].shape[0] - adj[i].sum()) / adj[i].sum()])
                pos_weight = pos_weight.unsqueeze(1)
                # print("************************")
                # print(pos_weight.size())
                # print(self.gran(x, batch_num_nodes).squeeze().cpu().size())
                # print(gran_lable.size())
                # print("***** in model *****")
                # print("*** model result : ")
                # print(self.gran(x, batch_num_nodes, random_index_list, self.vae_args.GRAN_arbitrary_node).squeeze().cpu())
                # print("***** gran_lable")
                # print(gran_lable)
                return F.binary_cross_entropy_with_logits(
                    self.gran(x, batch_num_nodes, random_index_list, self.vae_args.GRAN_arbitrary_node).squeeze().cpu(), gran_lable,
                    pos_weight=pos_weight)
        if self.vae_args.reconstruction:
            loss = F.binary_cross_entropy_with_logits(self.dc(x), incomplete_adj.cpu())
            return loss
        if self.vae_args.completion_mode_small_parameter_size:
            x = x.view(-1, self.number_of_incomplete_nodes * self.hidden_dim)
        else:
            if self.vae_args.graph_pooling_mode:
                x = self.pool_graph(x)
            else:
                x = x.view(-1, self.max_num_nodes * self.hidden_dim)
        # vae
        h_decode, z_mu, z_lsgms = self.vae(x)
        out = F.sigmoid(h_decode)
        out_tensor = out.cpu().data
        if self.vae_args.graph_matching_mode:
            recon_adj_lower = self.recover_adj_lower(out_tensor)
            recon_adj_tensor = self.recover_full_adj_from_lower(recon_adj_lower)
            out_features = torch.sum(recon_adj_tensor, 1)
            adj_data = adj.cpu().data
            adj_features = torch.sum(adj_data, 1)
            batch_size = adj_data.size(0)
            adj_permuted = torch.zeros(adj_data.size(0), adj_data.size(1), adj_data.size(2))
            for i in range(batch_size):
                S = self.edge_similarity_matrix(adj_data[i].squeeze(), recon_adj_tensor[i].squeeze(),
                                                adj_features[i].squeeze(), out_features[i].squeeze(),
                                                self.deg_feature_similarity)
                # initialization strategies
                init_corr = 1 / self.max_num_nodes
                init_assignment = torch.ones(self.max_num_nodes, self.max_num_nodes) * init_corr
                assignment = self.mpm(init_assignment, S)
                # matching
                # use negative of the assignment score since the alg finds min cost flow
                row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy())
                # order row index according to col index
                adj_permuted[i] = self.permute_adj(adj_data[i].squeeze(), row_ind, col_ind)
                adj_permuted[i] = adj_permuted[i].unsqueeze(0)
        else:
            adj_data = adj.cpu().data
            adj_permuted = adj_data
        if self.vae_args.completion_mode_small_parameter_size:
            adj_vectorized = torch.zeros(out.size())
            adj_vectorized_index = 0
            for i in range(self.vae_args.number_of_missing_nodes):  # iterates over columns
                for j in range(self.max_num_nodes - i):
                    adj_vectorized[:, adj_vectorized_index] = adj_permuted[:, self.max_num_nodes - j - i - 1,
                                                              self.max_num_nodes - i - 1]
                    adj_vectorized_index += 1
        else:
            adj_vectorized = adj_permuted[:, torch.triu(torch.ones(self.max_num_nodes, self.max_num_nodes)) == 1]
        adj_vectorized_var = Variable(adj_vectorized).cuda()
        adj_recon_loss = self.adj_recon_loss(adj_vectorized_var, out)
        # x2 = h_decode[0].unsqueeze(0)
        # x3 = x2.unsqueeze(0)
        # sample = sample_sigmoid(x3, sample=True, sample_time=1)
        # y = torch.zeros(self.max_num_nodes, self.max_num_nodes)
        # y[torch.triu(torch.ones(self.max_num_nodes, self.max_num_nodes)) == 1] = sample.data.cpu()
        loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp())
        loss_kl /= self.max_num_nodes * self.max_num_nodes  # normalize
        print('kl: ', loss_kl)
        loss = adj_recon_loss
        return loss
    def forward_test(self, input_features, adj):
        self.max_num_nodes = 4
        adj_data = torch.zeros(self.max_num_nodes, self.max_num_nodes)
        adj_data[:4, :4] = torch.FloatTensor([[1, 1, 0, 0], [1, 1, 1, 0], [0, 1, 1, 1], [0, 0, 1, 1]])
        adj_features = torch.Tensor([2, 3, 3, 2])
        adj_data1 = torch.zeros(self.max_num_nodes, self.max_num_nodes)
        adj_data1 = torch.FloatTensor([[1, 1, 1, 0], [1, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]])
        adj_features1 = torch.Tensor([3, 3, 2, 2])
        S = self.edge_similarity_matrix(adj_data, adj_data1, adj_features, adj_features1,
                                        self.deg_feature_similarity)
        # initialization strategies
        init_corr = 1 / self.max_num_nodes
        init_assignment = torch.ones(self.max_num_nodes, self.max_num_nodes) * init_corr
        # init_assignment = torch.FloatTensor(4, 4)
        # init.uniform(init_assignment)
        assignment = self.mpm(init_assignment, S)
        # print('Assignment: ', assignment)
        # matching
        row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy())
        # print('row: ', row_ind)
        # print('col: ', col_ind)
        permuted_adj = self.permute_adj(adj_data, row_ind, col_ind)
        # print('permuted: ', permuted_adj)
        adj_recon_loss = self.adj_recon_loss(permuted_adj, adj_data1)
        # print(adj_data1)
        # print('diff: ', adj_recon_loss)
    def adj_recon_loss(self, adj_truth, adj_pred):
        return F.binary_cross_entropy(adj_pred, adj_truth)
 |