You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphvae_model.py 8.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import numpy as np
  2. import scipy.optimize
  3. import torch
  4. import torch.nn as nn
  5. from torch.autograd import Variable
  6. from torch import optim
  7. import torch.nn.functional as F
  8. import torch.nn.init as init
  9. import model
  10. # from baselines.graphvae.faez_test import graph_show
  11. # import networkx as nx
  12. class GraphVAE(nn.Module):
  13. def __init__(self, input_dim, hidden_dim, latent_dim, max_num_nodes, pool='sum'):
  14. '''
  15. Args:
  16. input_dim: input feature dimension for node.
  17. hidden_dim: hidden dim for 2-layer gcn.
  18. latent_dim: dimension of the latent representation of graph.
  19. '''
  20. super(GraphVAE, self).__init__()
  21. self.hidden_dim = hidden_dim
  22. self.conv1 = model.GraphConv(input_dim=input_dim, output_dim=hidden_dim)
  23. self.bn1 = nn.BatchNorm1d(input_dim)
  24. self.conv2 = model.GraphConv(input_dim=hidden_dim, output_dim=hidden_dim)
  25. self.bn2 = nn.BatchNorm1d(input_dim)
  26. self.act = nn.ReLU()
  27. output_dim = max_num_nodes * (max_num_nodes + 1) // 2
  28. # self.vae = model.Main_MLP_VAE_plain(hidden_dim, latent_dim, output_dim)
  29. self.vae = model.Main_MLP_VAE_plain(input_dim * hidden_dim, latent_dim, output_dim)
  30. # self.feature_mlp = model.MLP_plain(latent_dim, latent_dim, output_dim)
  31. self.max_num_nodes = max_num_nodes
  32. for m in self.modules():
  33. if isinstance(m, model.GraphConv):
  34. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  35. elif isinstance(m, nn.BatchNorm1d):
  36. m.weight.data.fill_(1)
  37. m.bias.data.zero_()
  38. self.pool = pool
  39. def recover_adj_lower(self, l):
  40. # NOTE: Assumes 1 per minibatch
  41. adj = torch.zeros(self.max_num_nodes, self.max_num_nodes)
  42. adj[torch.triu(torch.ones(self.max_num_nodes, self.max_num_nodes)) == 1] = l
  43. return adj
  44. def recover_full_adj_from_lower(self, lower):
  45. diag = torch.diag(torch.diag(lower, 0))
  46. return lower + torch.transpose(lower, 0, 1) - diag
  47. def edge_similarity_matrix(self, adj, adj_recon, matching_features,
  48. matching_features_recon, sim_func):
  49. S = torch.zeros(self.max_num_nodes, self.max_num_nodes,
  50. self.max_num_nodes, self.max_num_nodes)
  51. for i in range(self.max_num_nodes):
  52. for j in range(self.max_num_nodes):
  53. if i == j:
  54. for a in range(self.max_num_nodes):
  55. S[i, i, a, a] = adj[i, i] * adj_recon[a, a] * \
  56. sim_func(matching_features[i], matching_features_recon[a])
  57. # print("***")
  58. # print(S[i, i, a, a])
  59. # with feature not implemented
  60. # if input_features is not None:
  61. else:
  62. for a in range(self.max_num_nodes):
  63. for b in range(self.max_num_nodes):
  64. if b == a:
  65. continue
  66. S[i, j, a, b] = adj[i, j] * adj[i, i] * adj[j, j] * \
  67. adj_recon[a, b] * adj_recon[a, a] * adj_recon[b, b]
  68. # print("^^^")
  69. # print(S[i, j, a, b])
  70. # print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")
  71. return S
  72. def mpm(self, x_init, S, max_iters=50):
  73. x = x_init
  74. for it in range(max_iters):
  75. x_new = torch.zeros(self.max_num_nodes, self.max_num_nodes)
  76. for i in range(self.max_num_nodes):
  77. for a in range(self.max_num_nodes):
  78. x_new[i, a] = x[i, a] * S[i, i, a, a]
  79. pooled = [torch.max(x[j, :] * S[i, j, a, :])
  80. for j in range(self.max_num_nodes) if j != i]
  81. neigh_sim = sum(pooled)
  82. x_new[i, a] += neigh_sim
  83. norm = torch.norm(x_new)
  84. x = x_new / norm
  85. return x
  86. def deg_feature_similarity(self, f1, f2):
  87. return 1 / (abs(f1 - f2) + 1)
  88. def permute_adj(self, adj, curr_ind, target_ind):
  89. ''' Permute adjacency matrix.
  90. The target_ind (connectivity) should be permuted to the curr_ind position.
  91. '''
  92. # order curr_ind according to target ind
  93. ind = np.zeros(self.max_num_nodes, dtype=np.int)
  94. ind[target_ind] = curr_ind
  95. adj_permuted = torch.zeros((self.max_num_nodes, self.max_num_nodes))
  96. adj_permuted[:, :] = adj[ind, :]
  97. adj_permuted[:, :] = adj_permuted[:, ind]
  98. return adj_permuted
  99. def pool_graph(self, x):
  100. if self.pool == 'max':
  101. out, _ = torch.max(x, dim=1, keepdim=False)
  102. elif self.pool == 'sum':
  103. out = torch.sum(x, dim=1, keepdim=False)
  104. return out
  105. def forward(self, input_features, adj):
  106. # cpu_numpy = adj.squeeze(0).cpu().numpy()
  107. # print("********** 0)")
  108. # print(cpu_numpy)
  109. #
  110. # graph_show(nx.from_numpy_matrix(cpu_numpy), "1")
  111. x = self.conv1(input_features, adj)
  112. x = self.act(x)
  113. x = self.bn1(x)
  114. x = self.conv2(x, adj)
  115. # x = self.bn2(x)
  116. # pool over all nodes
  117. # graph_h = self.pool_graph(x)
  118. graph_h = x.view(-1, self.max_num_nodes * self.hidden_dim)
  119. # vaemax_num_nodes
  120. h_decode, z_mu, z_lsgms = self.vae(graph_h)
  121. out = F.sigmoid(h_decode)
  122. out_tensor = out.cpu().data
  123. print("*** h_decode")
  124. print(h_decode)
  125. print("*** out_tensor")
  126. print(out_tensor)
  127. recon_adj_lower = self.recover_adj_lower(out_tensor)
  128. recon_adj_tensor = self.recover_full_adj_from_lower(recon_adj_lower)
  129. # set matching features be degree
  130. out_features = torch.sum(recon_adj_tensor, 1)
  131. adj_data = adj.cpu().data[0]
  132. adj_features = torch.sum(adj_data, 1)
  133. S = self.edge_similarity_matrix(adj_data, recon_adj_tensor, adj_features, out_features,
  134. self.deg_feature_similarity)
  135. # initialization strategies
  136. init_corr = 1 / self.max_num_nodes
  137. init_assignment = torch.ones(self.max_num_nodes, self.max_num_nodes) * init_corr
  138. assignment = self.mpm(init_assignment, S)
  139. # matching
  140. # use negative of the assignment score since the alg finds min cost flow
  141. row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy())
  142. # order row index according to col index
  143. adj_permuted = self.permute_adj(adj_data, row_ind, col_ind)
  144. adj_vectorized = adj_permuted[torch.triu(torch.ones(self.max_num_nodes, self.max_num_nodes)) == 1].squeeze_()
  145. adj_vectorized_var = Variable(adj_vectorized).cuda()
  146. adj_recon_loss = self.adj_recon_loss(adj_vectorized_var, out[0])
  147. loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp())
  148. loss_kl /= self.max_num_nodes * self.max_num_nodes # normalize
  149. # print('kl: ', loss_kl)
  150. loss = adj_recon_loss + loss_kl
  151. return loss
  152. def forward_test(self, input_features, adj):
  153. self.max_num_nodes = 4
  154. adj_data = torch.zeros(self.max_num_nodes, self.max_num_nodes)
  155. adj_data[:4, :4] = torch.FloatTensor([[1, 1, 0, 0], [1, 1, 1, 0], [0, 1, 1, 1], [0, 0, 1, 1]])
  156. adj_features = torch.Tensor([2, 3, 3, 2])
  157. adj_data1 = torch.zeros(self.max_num_nodes, self.max_num_nodes)
  158. adj_data1 = torch.FloatTensor([[1, 1, 1, 0], [1, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]])
  159. adj_features1 = torch.Tensor([3, 3, 2, 2])
  160. S = self.edge_similarity_matrix(adj_data, adj_data1, adj_features, adj_features1,
  161. self.deg_feature_similarity)
  162. # initialization strategies
  163. init_corr = 1 / self.max_num_nodes
  164. init_assignment = torch.ones(self.max_num_nodes, self.max_num_nodes) * init_corr
  165. # init_assignment = torch.FloatTensor(4, 4)
  166. # init.uniform(init_assignment)
  167. assignment = self.mpm(init_assignment, S)
  168. # print('Assignment: ', assignment)
  169. # matching
  170. row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy())
  171. print('row: ', row_ind)
  172. print('col: ', col_ind)
  173. permuted_adj = self.permute_adj(adj_data, row_ind, col_ind)
  174. print('permuted: ', permuted_adj)
  175. adj_recon_loss = self.adj_recon_loss(permuted_adj, adj_data1)
  176. print(adj_data1)
  177. print('diff: ', adj_recon_loss)
  178. def adj_recon_loss(self, adj_truth, adj_pred):
  179. return F.binary_cross_entropy(adj_pred, adj_truth)