You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphvae_model.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. import numpy as np
  2. import scipy.optimize
  3. import torch
  4. import torch.nn as nn
  5. from torch.autograd import Variable
  6. from torch import optim
  7. import torch.nn.functional as F
  8. import torch.nn.init as init
  9. import networkx as nx
  10. import matplotlib.pyplot as plt
  11. import model
  12. from model import sample_sigmoid
  13. from baselines.graphvae import args
  14. from random import random
  15. def graph_show(G, title, colors):
  16. pos = nx.spring_layout(G, scale=2)
  17. nx.draw(G, pos, node_color=colors)
  18. fig = plt.gcf()
  19. fig.canvas.set_window_title(title)
  20. plt.show()
  21. plt.savefig('foo.png')
  22. class GraphVAE(nn.Module):
  23. def __init__(self, input_dim, hidden_dim, latent_dim, max_num_nodes,
  24. number_of_missing_nodes, completion_mode, pool='sum'):
  25. '''
  26. Args:
  27. input_dim: input feature dimension for node.
  28. hidden_dim: hidden dim for 2-layer gcn.
  29. latent_dim: dimension of the latent representation of graph.
  30. '''
  31. self.hidden_dim = hidden_dim
  32. self.completion_mode = completion_mode
  33. self.number_of_missing_nodes = number_of_missing_nodes
  34. super(GraphVAE, self).__init__()
  35. self.conv1 = model.GraphConv(input_dim=input_dim, output_dim=32)
  36. self.bn1 = nn.BatchNorm1d(input_dim)
  37. self.conv2 = model.GraphConv(input_dim=32, output_dim=hidden_dim)
  38. self.bn2 = nn.BatchNorm1d(input_dim)
  39. self.act = nn.ReLU()
  40. self.linear = nn.Linear(input_dim * hidden_dim, 128)
  41. if completion_mode:
  42. output_dim = number_of_missing_nodes * max_num_nodes - \
  43. (number_of_missing_nodes * number_of_missing_nodes -
  44. (number_of_missing_nodes * (number_of_missing_nodes + 1) // 2))
  45. self.number_of_incomplete_nodes = max_num_nodes - number_of_missing_nodes
  46. else:
  47. output_dim = max_num_nodes * (max_num_nodes + 1) // 2
  48. self.max_num_nodes = max_num_nodes
  49. # self.vae = model.MLP_VAE_plain(hidden_dim, latent_dim, output_dim)
  50. self.vae = model.MLP_VAE_plain(input_dim * hidden_dim, hidden_dim, output_dim)
  51. # self.feature_mlp = model.MLP_plain(latent_dim, latent_dim, output_dim)
  52. for m in self.modules():
  53. if isinstance(m, model.GraphConv):
  54. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  55. elif isinstance(m, nn.BatchNorm1d):
  56. m.weight.data.fill_(1)
  57. m.bias.data.zero_()
  58. self.pool = pool
  59. def recover_adj_lower(self, l):
  60. # NOTE: Assumes 1 per minibatch
  61. batch_size = l.size()[0]
  62. adj = torch.zeros(batch_size, self.max_num_nodes, self.max_num_nodes)
  63. if self.completion_mode:
  64. index = torch.zeros(self.max_num_nodes, self.max_num_nodes, dtype=torch.uint8)
  65. for i in range(self.number_of_missing_nodes):
  66. for j in range(self.max_num_nodes -i):
  67. index[j, self.max_num_nodes-i-1] = 1
  68. adj[:, index] = l
  69. else:
  70. adj[:, torch.triu(torch.ones(self.max_num_nodes, self.max_num_nodes)) == 1] = l
  71. return adj
  72. def recover_full_adj_from_lower(self, lower):
  73. batch_size = lower.size()[0]
  74. diag = torch.zeros(batch_size, lower.size()[1], lower.size()[1])
  75. transpose = torch.zeros(batch_size, lower.size()[1], lower.size()[1])
  76. i = 0
  77. for mat in lower:
  78. diag[i, :, :] = torch.diag(torch.diag(mat))
  79. transpose[i, :, :] = torch.transpose(mat, 0, 1)
  80. i += 1
  81. # diag = torch.diag(torch.diag(lower, 0))
  82. return lower + transpose - diag
  83. def edge_similarity_matrix(self, adj, adj_recon, matching_features,
  84. matching_features_recon, sim_func):
  85. S = torch.zeros(self.max_num_nodes, self.max_num_nodes,
  86. self.max_num_nodes, self.max_num_nodes)
  87. for i in range(self.max_num_nodes):
  88. for j in range(self.max_num_nodes):
  89. if i == j:
  90. for a in range(self.max_num_nodes):
  91. S[i, i, a, a] = adj[i, i] * adj_recon[a, a] * \
  92. sim_func(matching_features[i], matching_features_recon[a])
  93. # with feature not implemented
  94. # if input_features is not None:
  95. else:
  96. for a in range(self.max_num_nodes):
  97. for b in range(self.max_num_nodes):
  98. if b == a:
  99. continue
  100. S[i, j, a, b] = adj[i, j] * adj[i, i] * adj[j, j] * \
  101. adj_recon[a, b] * adj_recon[a, a] * adj_recon[b, b]
  102. return S
  103. def mpm(self, x_init, S, max_iters=50):
  104. x = x_init
  105. for it in range(max_iters):
  106. x_new = torch.zeros(self.max_num_nodes, self.max_num_nodes)
  107. for i in range(self.max_num_nodes):
  108. for a in range(self.max_num_nodes):
  109. x_new[i, a] = x[i, a] * S[i, i, a, a]
  110. pooled = [torch.max(x[j, :] * S[i, j, a, :])
  111. for j in range(self.max_num_nodes) if j != i]
  112. neigh_sim = sum(pooled)
  113. x_new[i, a] += neigh_sim
  114. norm = torch.norm(x_new)
  115. x = x_new / norm
  116. return x
  117. def deg_feature_similarity(self, f1, f2):
  118. return 1 / (abs(f1 - f2) + 1)
  119. def permute_adj(self, adj, curr_ind, target_ind):
  120. ''' Permute adjacency matrix.
  121. The target_ind (connectivity) should be permuted to the curr_ind position.
  122. '''
  123. # order curr_ind according to target ind
  124. ind = np.zeros(self.max_num_nodes, dtype=np.int)
  125. ind[target_ind] = curr_ind
  126. adj_permuted = torch.zeros((self.max_num_nodes, self.max_num_nodes))
  127. adj_permuted[:, :] = adj[ind, :]
  128. adj_permuted[:, :] = adj_permuted[:, ind]
  129. return adj_permuted
  130. def pool_graph(self, x):
  131. if self.pool == 'max':
  132. out, _ = torch.max(x, dim=1, keepdim=False)
  133. elif self.pool == 'sum':
  134. out = torch.sum(x, dim=1, keepdim=False)
  135. return out
  136. def forward(self, input_features, adj):
  137. arg = args.GraphVAE_Args();
  138. graph_size = adj.size()[1]
  139. numpy_adj = adj.cpu().numpy()
  140. incomplete_matrix_size = adj.size()[1]-arg.number_of_missing_nodes
  141. for j in range(adj.size()[0]):
  142. index_list = []
  143. if arg.completion_mode:
  144. print()
  145. # incomplete_numpy_adj[j] = numpy_adj[j, :incomplete_matrix_size, :incomplete_matrix_size]
  146. else:
  147. for i in range(arg.number_of_missing_nodes):
  148. random_index = np.random.randint(graph_size)
  149. while random_index in index_list:
  150. random_index = np.random.randint(graph_size)
  151. index_list.append(random_index)
  152. numpy_adj[j, :, random_index] = 0
  153. numpy_adj[j, random_index, :] = 0
  154. # print("********************************************************")
  155. # print(numpy_adj[0])
  156. if arg.completion_mode:
  157. incomplete_adj = torch.tensor(numpy_adj[:, :incomplete_matrix_size, :incomplete_matrix_size], device='cuda:0')
  158. input_features = torch.tensor(input_features[:, :incomplete_matrix_size, :incomplete_matrix_size], device='cuda:0')
  159. else:
  160. incomplete_adj = torch.tensor(numpy_adj, device='cuda:0')
  161. # colors = [(0.7509279299037631, 0.021203049355839054, 0.24561203044115132)]
  162. # graph_show(nx.from_numpy_matrix(numpy_adj[0]), "input", colors)
  163. x = self.conv1(input_features, incomplete_adj)
  164. x = self.bn1(x)
  165. x = self.act(x)
  166. x = self.conv2(x, incomplete_adj)
  167. x = self.bn2(x)
  168. if self.completion_mode:
  169. x = x.view(-1, self.number_of_incomplete_nodes * self.hidden_dim)
  170. else:
  171. x = x.view(-1, self.max_num_nodes * self.hidden_dim)
  172. # vae
  173. h_decode, z_mu, z_lsgms = self.vae(x)
  174. out = F.sigmoid(h_decode)
  175. out_tensor = out.cpu().data
  176. recon_adj_lower = self.recover_adj_lower(out_tensor)
  177. recon_adj_tensor = self.recover_full_adj_from_lower(recon_adj_lower)
  178. if arg.graph_matching_mode:
  179. out_features = torch.sum(recon_adj_tensor, 1)
  180. adj_data = adj.cpu().data
  181. adj_features = torch.sum(adj_data, 1)
  182. batch_size = adj_data.size(0)
  183. adj_permuted = torch.zeros(adj_data.size(0), adj_data.size(1), adj_data.size(2))
  184. # print(adj_data)
  185. # print("*************************")
  186. # print(adj_permuted)
  187. for i in range(batch_size):
  188. S = self.edge_similarity_matrix(adj_data[i].squeeze(), recon_adj_tensor[i].squeeze(),
  189. adj_features[i].squeeze(), out_features[i].squeeze(),
  190. self.deg_feature_similarity)
  191. # initialization strategies
  192. init_corr = 1 / self.max_num_nodes
  193. init_assignment = torch.ones(self.max_num_nodes, self.max_num_nodes) * init_corr
  194. assignment = self.mpm(init_assignment, S)
  195. # matching
  196. # use negative of the assignment score since the alg finds min cost flow
  197. row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy())
  198. # order row index according to col index
  199. adj_permuted[i] = self.permute_adj(adj_data[i].squeeze(), row_ind, col_ind)
  200. adj_permuted[i] = adj_permuted[i].unsqueeze(0)
  201. else:
  202. adj_data = adj.cpu().data
  203. adj_permuted = adj_data
  204. adj_vectorized = adj_permuted[:, torch.triu(torch.ones(self.max_num_nodes, self.max_num_nodes)) == 1]
  205. if self.completion_mode:
  206. adj_vectorized = torch.zeros(self.output_dim)
  207. # index = torch.zeros(self.max_num_nodes, self.max_num_nodes, dtype=torch.uint8)
  208. for i in range(self.number_of_missing_nodes):
  209. for j in range(self.max_num_nodes - i):
  210. adj_vectorized[:, i+j] = adj_permuted[:, self.max_num_nodes-j-i-1, self.max_num_nodes - i - 1]
  211. adj_vectorized_var = Variable(adj_vectorized).cuda()
  212. adj_recon_loss = self.adj_recon_loss(adj_vectorized_var, out)
  213. # x2 = h_decode[0].unsqueeze(0)
  214. # x3 = x2.unsqueeze(0)
  215. # sample = sample_sigmoid(x3, sample=True, sample_time=1)
  216. # y = torch.zeros(self.max_num_nodes, self.max_num_nodes)
  217. # y[torch.triu(torch.ones(self.max_num_nodes, self.max_num_nodes)) == 1] = sample.data.cpu()
  218. loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp())
  219. loss_kl /= self.max_num_nodes * self.max_num_nodes # normalize
  220. print('kl: ', loss_kl)
  221. loss = adj_recon_loss
  222. return loss
  223. def forward_test(self, input_features, adj):
  224. self.max_num_nodes = 4
  225. adj_data = torch.zeros(self.max_num_nodes, self.max_num_nodes)
  226. adj_data[:4, :4] = torch.FloatTensor([[1, 1, 0, 0], [1, 1, 1, 0], [0, 1, 1, 1], [0, 0, 1, 1]])
  227. adj_features = torch.Tensor([2, 3, 3, 2])
  228. adj_data1 = torch.zeros(self.max_num_nodes, self.max_num_nodes)
  229. adj_data1 = torch.FloatTensor([[1, 1, 1, 0], [1, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]])
  230. adj_features1 = torch.Tensor([3, 3, 2, 2])
  231. S = self.edge_similarity_matrix(adj_data, adj_data1, adj_features, adj_features1,
  232. self.deg_feature_similarity)
  233. # initialization strategies
  234. init_corr = 1 / self.max_num_nodes
  235. init_assignment = torch.ones(self.max_num_nodes, self.max_num_nodes) * init_corr
  236. # init_assignment = torch.FloatTensor(4, 4)
  237. # init.uniform(init_assignment)
  238. assignment = self.mpm(init_assignment, S)
  239. # print('Assignment: ', assignment)
  240. # matching
  241. row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy())
  242. # print('row: ', row_ind)
  243. # print('col: ', col_ind)
  244. permuted_adj = self.permute_adj(adj_data, row_ind, col_ind)
  245. # print('permuted: ', permuted_adj)
  246. adj_recon_loss = self.adj_recon_loss(permuted_adj, adj_data1)
  247. # print(adj_data1)
  248. # print('diff: ', adj_recon_loss)
  249. def adj_recon_loss(self, adj_truth, adj_pred):
  250. return F.binary_cross_entropy(adj_pred, adj_truth)