You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 8.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import numpy as np
  2. import scipy.optimize
  3. import torch
  4. import torch.nn as nn
  5. from torch.autograd import Variable
  6. from torch import optim
  7. import torch.nn.functional as F
  8. import torch.nn.init as init
  9. import model
  10. class GraphVAE(nn.Module):
  11. def __init__(self, input_dim, hidden_dim, latent_dim, max_num_nodes, pool='sum'):
  12. '''
  13. Args:
  14. input_dim: input feature dimension for node.
  15. hidden_dim: hidden dim for 2-layer gcn.
  16. latent_dim: dimension of the latent representation of graph.
  17. '''
  18. super(GraphVAE, self).__init__()
  19. self.conv1 = model.GraphConv(input_dim=input_dim, output_dim=hidden_dim)
  20. self.bn1 = nn.BatchNorm1d(input_dim)
  21. self.conv2 = model.GraphConv(input_dim=hidden_dim, output_dim=hidden_dim)
  22. self.bn2 = nn.BatchNorm1d(input_dim)
  23. self.act = nn.ReLU()
  24. output_dim = max_num_nodes * (max_num_nodes + 1) // 2
  25. #self.vae = model.MLP_VAE_plain(hidden_dim, latent_dim, output_dim)
  26. self.vae = model.MLP_VAE_plain(input_dim * input_dim, latent_dim, output_dim)
  27. #self.feature_mlp = model.MLP_plain(latent_dim, latent_dim, output_dim)
  28. self.max_num_nodes = max_num_nodes
  29. for m in self.modules():
  30. if isinstance(m, model.GraphConv):
  31. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  32. elif isinstance(m, nn.BatchNorm1d):
  33. m.weight.data.fill_(1)
  34. m.bias.data.zero_()
  35. self.pool = pool
  36. def recover_adj_lower(self, l):
  37. # NOTE: Assumes 1 per minibatch
  38. adj = torch.zeros(self.max_num_nodes, self.max_num_nodes)
  39. adj[torch.triu(torch.ones(self.max_num_nodes, self.max_num_nodes)) == 1] = l
  40. return adj
  41. def recover_full_adj_from_lower(self, lower):
  42. diag = torch.diag(torch.diag(lower, 0))
  43. return lower + torch.transpose(lower, 0, 1) - diag
  44. def edge_similarity_matrix(self, adj, adj_recon, matching_features,
  45. matching_features_recon, sim_func):
  46. S = torch.zeros(self.max_num_nodes, self.max_num_nodes,
  47. self.max_num_nodes, self.max_num_nodes)
  48. for i in range(self.max_num_nodes):
  49. for j in range(self.max_num_nodes):
  50. if i == j:
  51. for a in range(self.max_num_nodes):
  52. S[i, i, a, a] = adj[i, i] * adj_recon[a, a] * \
  53. sim_func(matching_features[i], matching_features_recon[a])
  54. # with feature not implemented
  55. # if input_features is not None:
  56. else:
  57. for a in range(self.max_num_nodes):
  58. for b in range(self.max_num_nodes):
  59. if b == a:
  60. continue
  61. S[i, j, a, b] = adj[i, j] * adj[i, i] * adj[j, j] * \
  62. adj_recon[a, b] * adj_recon[a, a] * adj_recon[b, b]
  63. return S
  64. def mpm(self, x_init, S, max_iters=50):
  65. x = x_init
  66. for it in range(max_iters):
  67. x_new = torch.zeros(self.max_num_nodes, self.max_num_nodes)
  68. for i in range(self.max_num_nodes):
  69. for a in range(self.max_num_nodes):
  70. x_new[i, a] = x[i, a] * S[i, i, a, a]
  71. pooled = [torch.max(x[j, :] * S[i, j, a, :])
  72. for j in range(self.max_num_nodes) if j != i]
  73. neigh_sim = sum(pooled)
  74. x_new[i, a] += neigh_sim
  75. norm = torch.norm(x_new)
  76. x = x_new / norm
  77. return x
  78. def deg_feature_similarity(self, f1, f2):
  79. return 1 / (abs(f1 - f2) + 1)
  80. def permute_adj(self, adj, curr_ind, target_ind):
  81. ''' Permute adjacency matrix.
  82. The target_ind (connectivity) should be permuted to the curr_ind position.
  83. '''
  84. # order curr_ind according to target ind
  85. ind = np.zeros(self.max_num_nodes, dtype=np.int)
  86. ind[target_ind] = curr_ind
  87. adj_permuted = torch.zeros((self.max_num_nodes, self.max_num_nodes))
  88. adj_permuted[:, :] = adj[ind, :]
  89. adj_permuted[:, :] = adj_permuted[:, ind]
  90. return adj_permuted
  91. def pool_graph(self, x):
  92. if self.pool == 'max':
  93. out, _ = torch.max(x, dim=1, keepdim=False)
  94. elif self.pool == 'sum':
  95. out = torch.sum(x, dim=1, keepdim=False)
  96. return out
  97. def forward(self, input_features, adj):
  98. #x = self.conv1(input_features, adj)
  99. #x = self.bn1(x)
  100. #x = self.act(x)
  101. #x = self.conv2(x, adj)
  102. #x = self.bn2(x)
  103. # pool over all nodes
  104. #graph_h = self.pool_graph(x)
  105. graph_h = input_features.view(-1, self.max_num_nodes * self.max_num_nodes)
  106. # vae
  107. h_decode, z_mu, z_lsgms = self.vae(graph_h)
  108. out = F.sigmoid(h_decode)
  109. out_tensor = out.cpu().data
  110. recon_adj_lower = self.recover_adj_lower(out_tensor)
  111. recon_adj_tensor = self.recover_full_adj_from_lower(recon_adj_lower)
  112. # set matching features be degree
  113. out_features = torch.sum(recon_adj_tensor, 1)
  114. adj_data = adj.cpu().data[0]
  115. adj_features = torch.sum(adj_data, 1)
  116. S = self.edge_similarity_matrix(adj_data, recon_adj_tensor, adj_features, out_features,
  117. self.deg_feature_similarity)
  118. # initialization strategies
  119. init_corr = 1 / self.max_num_nodes
  120. init_assignment = torch.ones(self.max_num_nodes, self.max_num_nodes) * init_corr
  121. #init_assignment = torch.FloatTensor(4, 4)
  122. #init.uniform(init_assignment)
  123. assignment = self.mpm(init_assignment, S)
  124. #print('Assignment: ', assignment)
  125. # matching
  126. # use negative of the assignment score since the alg finds min cost flow
  127. row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy())
  128. print('row: ', row_ind)
  129. print('col: ', col_ind)
  130. # order row index according to col index
  131. #adj_permuted = self.permute_adj(adj_data, row_ind, col_ind)
  132. adj_permuted = adj_data
  133. adj_vectorized = adj_permuted[torch.triu(torch.ones(self.max_num_nodes,self.max_num_nodes) )== 1].squeeze_()
  134. adj_vectorized_var = Variable(adj_vectorized).cuda()
  135. #print(adj)
  136. #print('permuted: ', adj_permuted)
  137. #print('recon: ', recon_adj_tensor)
  138. adj_recon_loss = self.adj_recon_loss(adj_vectorized_var, out[0])
  139. print('recon: ', adj_recon_loss)
  140. print(adj_vectorized_var)
  141. print(out[0])
  142. loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp())
  143. loss_kl /= self.max_num_nodes * self.max_num_nodes # normalize
  144. print('kl: ', loss_kl)
  145. loss = adj_recon_loss + loss_kl
  146. return loss
  147. def forward_test(self, input_features, adj):
  148. self.max_num_nodes = 4
  149. adj_data = torch.zeros(self.max_num_nodes, self.max_num_nodes)
  150. adj_data[:4, :4] = torch.FloatTensor([[1,1,0,0], [1,1,1,0], [0,1,1,1], [0,0,1,1]])
  151. adj_features = torch.Tensor([2,3,3,2])
  152. adj_data1 = torch.zeros(self.max_num_nodes, self.max_num_nodes)
  153. adj_data1 = torch.FloatTensor([[1,1,1,0], [1,1,0,1], [1,0,1,0], [0,1,0,1]])
  154. adj_features1 = torch.Tensor([3,3,2,2])
  155. S = self.edge_similarity_matrix(adj_data, adj_data1, adj_features, adj_features1,
  156. self.deg_feature_similarity)
  157. # initialization strategies
  158. init_corr = 1 / self.max_num_nodes
  159. init_assignment = torch.ones(self.max_num_nodes, self.max_num_nodes) * init_corr
  160. #init_assignment = torch.FloatTensor(4, 4)
  161. #init.uniform(init_assignment)
  162. assignment = self.mpm(init_assignment, S)
  163. #print('Assignment: ', assignment)
  164. # matching
  165. row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy())
  166. print('row: ', row_ind)
  167. print('col: ', col_ind)
  168. permuted_adj = self.permute_adj(adj_data, row_ind, col_ind)
  169. print('permuted: ', permuted_adj)
  170. adj_recon_loss = self.adj_recon_loss(permuted_adj, adj_data1)
  171. print(adj_data1)
  172. print('diff: ', adj_recon_loss)
  173. def adj_recon_loss(self, adj_truth, adj_pred):
  174. return F.binary_cross_entropy(adj_truth, adj_pred)