You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphvae_model.py 21KB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. import numpy as np
  2. import scipy.optimize
  3. import torch
  4. import torch.nn as nn
  5. from torch.autograd import Variable
  6. from torch import optim
  7. import torch.nn.functional as F
  8. import torch.nn.init as init
  9. import networkx as nx
  10. import matplotlib.pyplot as plt
  11. import model
  12. import baselines.graphvae.diffpoolClassesAndFunctions as diffpool
  13. from model import sample_sigmoid
  14. from baselines.graphvae import args
  15. from random import random
  16. def graph_show(G, title, colors):
  17. pos = nx.spring_layout(G, scale=2)
  18. nx.draw(G, pos, node_color=colors)
  19. fig = plt.gcf()
  20. fig.canvas.set_window_title(title)
  21. plt.show()
  22. plt.savefig('foo.png')
  23. class GraphConv(nn.Module):
  24. def __init__(self, input_dim, output_dim, add_self=False, normalize_embedding=False,
  25. dropout=0.0, bias=True):
  26. super(GraphConv, self).__init__()
  27. self.add_self = add_self
  28. self.dropout = dropout
  29. if dropout > 0.001:
  30. self.dropout_layer = nn.Dropout(p=dropout)
  31. self.normalize_embedding = normalize_embedding
  32. self.input_dim = input_dim
  33. self.output_dim = output_dim
  34. self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).cuda())
  35. if bias:
  36. self.bias = nn.Parameter(torch.FloatTensor(output_dim).cuda())
  37. else:
  38. self.bias = None
  39. def forward(self, x, adj):
  40. if self.dropout > 0.001:
  41. x = self.dropout_layer(x)
  42. y = torch.matmul(adj, x)
  43. if self.add_self:
  44. y += x
  45. y = torch.matmul(y, self.weight)
  46. if self.bias is not None:
  47. y = y + self.bias
  48. if self.normalize_embedding:
  49. y = F.normalize(y, p=2, dim=2)
  50. # print(y[0][0])
  51. return y
  52. class InnerProductDecoder(nn.Module):
  53. """Decoder for using inner product for prediction."""
  54. def __init__(self, dropout, act=torch.sigmoid):
  55. super(InnerProductDecoder, self).__init__()
  56. self.dropout = dropout
  57. self.act = act
  58. def forward(self, z):
  59. batch_size = z.size()[0]
  60. adj_size = z.size()[1]
  61. z = F.dropout(z, self.dropout, training=self.training)
  62. adj = torch.zeros(batch_size, adj_size, adj_size)
  63. for i in range(batch_size):
  64. adj[i] = self.act(torch.mm(z[i], z[i].t()))
  65. return adj
  66. class GraphVAE(nn.Module):
  67. def __init__(self, input_dim, hidden_dim, latent_dim, max_num_nodes, vae_args, pool='sum'):
  68. '''
  69. Args:
  70. input_dim: input feature dimension for node.
  71. hidden_dim: hidden dim for 2-layer gcn.
  72. latent_dim: dimension of the latent representation of graph.
  73. '''
  74. self.hidden_dim = hidden_dim
  75. self.vae_args = vae_args
  76. super(GraphVAE, self).__init__()
  77. self.conv1 = model.GraphConv(input_dim=input_dim, output_dim=32)
  78. self.bn1 = nn.BatchNorm1d(input_dim)
  79. self.conv2 = model.GraphConv(input_dim=32, output_dim=hidden_dim)
  80. self.bn2 = nn.BatchNorm1d(input_dim)
  81. self.act = nn.ReLU()
  82. self.linear = nn.Linear(input_dim * hidden_dim, 128)
  83. self.dropout = 0
  84. self.dc = InnerProductDecoder(self.dropout, act=lambda x: x)
  85. self.embedding_dim = 5
  86. self.label_dim = 2
  87. self.num_layers = 3
  88. self.DiffpoolGcnEncoderGraph = diffpool.GcnEncoderGraph(input_dim, 5, self.embedding_dim, self.label_dim,
  89. self.num_layers).cuda()
  90. if vae_args.completion_mode_small_parameter_size:
  91. output_dim = vae_args.number_of_missing_nodes * max_num_nodes - \
  92. (vae_args.number_of_missing_nodes * vae_args.number_of_missing_nodes -
  93. (vae_args.number_of_missing_nodes * (vae_args.number_of_missing_nodes + 1) // 2))
  94. self.number_of_incomplete_nodes = max_num_nodes - vae_args.number_of_missing_nodes
  95. self.output_dim = output_dim
  96. elif vae_args.GRAN_linear:
  97. output_dim = max_num_nodes - 1
  98. else:
  99. output_dim = max_num_nodes * (max_num_nodes + 1) // 2
  100. # self.vae = model.MLP_VAE_plain(hidden_dim, latent_dim, output_dim)
  101. if self.vae_args.graph_pooling_mode:
  102. self.vae = model.MLP_VAE_plain(hidden_dim, hidden_dim, output_dim)
  103. else:
  104. self.vae = model.MLP_VAE_plain(input_dim * hidden_dim, hidden_dim, output_dim)
  105. # self.feature_mlp = model.MLP_plain(latent_dim, latent_dim, output_dim)
  106. self.max_num_nodes = max_num_nodes
  107. self.gran = model.GRAN(self.max_num_nodes)
  108. for m in self.modules():
  109. if isinstance(m, model.GraphConv):
  110. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  111. elif isinstance(m, nn.BatchNorm1d):
  112. m.weight.data.fill_(1)
  113. m.bias.data.zero_()
  114. self.pool = pool
  115. #############################################################
  116. self.bn = True
  117. self.num_layers = 3
  118. self.num_aggs = 1
  119. self.bias = True
  120. in_dim = input_dim
  121. hidden_dim1 = 20
  122. embedding_dim1 = 20
  123. self.conv_first, self.conv_block, self.conv_last = self.build_conv_layers(
  124. in_dim, hidden_dim1, embedding_dim1, self.num_layers,
  125. False, normalize=True, dropout=0.0)
  126. #############################################################
  127. def recover_adj_lower(self, l):
  128. # NOTE: Assumes 1 per minibatch
  129. batch_size = l.size()[0]
  130. adj = torch.zeros(batch_size, self.max_num_nodes, self.max_num_nodes)
  131. adj[:, torch.triu(torch.ones(self.max_num_nodes, self.max_num_nodes)) == 1] = l
  132. return adj
  133. def recover_full_adj_from_lower(self, lower):
  134. batch_size = lower.size()[0]
  135. diag = torch.zeros(batch_size, lower.size()[1], lower.size()[1])
  136. transpose = torch.zeros(batch_size, lower.size()[1], lower.size()[1])
  137. i = 0
  138. for mat in lower:
  139. diag[i, :, :] = torch.diag(torch.diag(mat))
  140. transpose[i, :, :] = torch.transpose(mat, 0, 1)
  141. i += 1
  142. # diag = torch.diag(torch.diag(lower, 0))
  143. return lower + transpose - diag
  144. def edge_similarity_matrix(self, adj, adj_recon, matching_features,
  145. matching_features_recon, sim_func):
  146. S = torch.zeros(self.max_num_nodes, self.max_num_nodes,
  147. self.max_num_nodes, self.max_num_nodes)
  148. for i in range(self.max_num_nodes):
  149. for j in range(self.max_num_nodes):
  150. if i == j:
  151. for a in range(self.max_num_nodes):
  152. S[i, i, a, a] = adj[i, i] * adj_recon[a, a] * \
  153. sim_func(matching_features[i], matching_features_recon[a])
  154. # with feature not implemented
  155. # if input_features is not None:
  156. else:
  157. for a in range(self.max_num_nodes):
  158. for b in range(self.max_num_nodes):
  159. if b == a:
  160. continue
  161. S[i, j, a, b] = adj[i, j] * adj[i, i] * adj[j, j] * \
  162. adj_recon[a, b] * adj_recon[a, a] * adj_recon[b, b]
  163. return S
  164. def mpm(self, x_init, S, max_iters=50):
  165. x = x_init
  166. for it in range(max_iters):
  167. x_new = torch.zeros(self.max_num_nodes, self.max_num_nodes)
  168. for i in range(self.max_num_nodes):
  169. for a in range(self.max_num_nodes):
  170. x_new[i, a] = x[i, a] * S[i, i, a, a]
  171. pooled = [torch.max(x[j, :] * S[i, j, a, :])
  172. for j in range(self.max_num_nodes) if j != i]
  173. neigh_sim = sum(pooled)
  174. x_new[i, a] += neigh_sim
  175. norm = torch.norm(x_new)
  176. x = x_new / norm
  177. return x
  178. def deg_feature_similarity(self, f1, f2):
  179. return 1 / (abs(f1 - f2) + 1)
  180. def permute_adj(self, adj, curr_ind, target_ind):
  181. ''' Permute adjacency matrix.
  182. The target_ind (connectivity) should be permuted to the curr_ind position.
  183. '''
  184. # order curr_ind according to target ind
  185. ind = np.zeros(self.max_num_nodes, dtype=np.int)
  186. ind[target_ind] = curr_ind
  187. adj_permuted = torch.zeros((self.max_num_nodes, self.max_num_nodes))
  188. adj_permuted[:, :] = adj[ind, :]
  189. adj_permuted[:, :] = adj_permuted[:, ind]
  190. return adj_permuted
  191. def pool_graph(self, x):
  192. if self.pool == 'max':
  193. out, _ = torch.max(x, dim=1, keepdim=False)
  194. elif self.pool == 'sum':
  195. out = torch.sum(x, dim=1, keepdim=False)
  196. return out
  197. def apply_bn(self, x):
  198. ''' Batch normalization of 3D tensor x
  199. '''
  200. bn_module = nn.BatchNorm1d(x.size()[1]).cuda()
  201. return bn_module(x)
  202. def build_conv_layers(self, input_dim, hidden_dim, embedding_dim, num_layers, add_self,
  203. normalize=False, dropout=0.0):
  204. conv_first = GraphConv(input_dim=input_dim, output_dim=hidden_dim, add_self=add_self,
  205. normalize_embedding=normalize, bias=self.bias)
  206. conv_block = nn.ModuleList(
  207. [GraphConv(input_dim=hidden_dim, output_dim=hidden_dim, add_self=add_self,
  208. normalize_embedding=normalize, dropout=dropout, bias=self.bias)
  209. for i in range(num_layers - 2)])
  210. conv_last = GraphConv(input_dim=hidden_dim, output_dim=embedding_dim, add_self=add_self,
  211. normalize_embedding=normalize, bias=self.bias)
  212. return conv_first, conv_block, conv_last
  213. def forward(self, input_features, adj, batch_num_nodes):
  214. # embedding_mask = self.DiffpoolGcnEncoderGraph.construct_mask(self.max_num_nodes, batch_num_nodes)
  215. numpy_adj = adj.cpu().numpy()
  216. # print("**** numpy_adj")
  217. # print(numpy_adj)
  218. if self.vae_args.completion_mode_small_parameter_size:
  219. incomplete_adj = torch.tensor(
  220. numpy_adj[:, :self.number_of_incomplete_nodes, :self.number_of_incomplete_nodes],
  221. device='cuda:0')
  222. input_features = torch.tensor(
  223. input_features[:, :self.number_of_incomplete_nodes, :self.number_of_incomplete_nodes],
  224. device='cuda:0')
  225. elif self.vae_args.GRAN:
  226. batch_size = numpy_adj.shape[0]
  227. gran_lable = np.zeros((batch_size, self.max_num_nodes - 1))
  228. random_index_list = []
  229. # print("*** numpy_adj before change")
  230. # print(numpy_adj)
  231. for i in range(batch_size):
  232. num_nodes = batch_num_nodes[i]
  233. if self.vae_args.GRAN_arbitrary_node:
  234. random_index_list.append(np.random.randint(num_nodes))
  235. gran_lable[i, :random_index_list[i]] = \
  236. numpy_adj[i, :random_index_list[i], random_index_list[i]]
  237. gran_lable[i, random_index_list[i]:num_nodes-1] = \
  238. numpy_adj[i, random_index_list[i]+1:num_nodes, random_index_list[i]]
  239. numpy_adj[i, :num_nodes, random_index_list[i]] = 1
  240. numpy_adj[i, random_index_list[i], :num_nodes] = 1
  241. # print("*** random_index_list")
  242. # print(random_index_list)
  243. # print("*** gran_lable")
  244. # print(gran_lable)
  245. # print("*** numpy_adj")
  246. # print(numpy_adj)
  247. else:
  248. gran_lable[i, :num_nodes - 1] = numpy_adj[i, :num_nodes - 1, num_nodes - 1]
  249. numpy_adj[i, :num_nodes, num_nodes - 1] = 1
  250. numpy_adj[i, num_nodes - 1, :num_nodes] = 1
  251. gran_lable = torch.tensor(gran_lable).float()
  252. incomplete_adj = torch.tensor(numpy_adj, device='cuda:0')
  253. # print("*** in model : random_index_list")
  254. # print(random_index_list)
  255. # print("*** random_index_list")
  256. # print(random_index_list)
  257. # print("*** gran_lable")
  258. # print(gran_lable)
  259. # print("*** numpy_adj after change")
  260. # print(numpy_adj)
  261. else:
  262. index_list = []
  263. for i in range(self.vae_args.number_of_missing_nodes):
  264. random_index = np.random.randint(self.max_num_nodes)
  265. while random_index in index_list:
  266. random_index = np.random.randint(self.max_num_nodes)
  267. index_list.append(random_index)
  268. numpy_adj[:, :, random_index] = 0
  269. numpy_adj[:, random_index, :] = 0
  270. incomplete_adj = torch.tensor(numpy_adj, device='cuda:0')
  271. if self.vae_args.diffpoolGCN:
  272. x = self.DiffpoolGcnEncoderGraph.gcn_forward(input_features, incomplete_adj,
  273. self.DiffpoolGcnEncoderGraph.conv_first,
  274. self.DiffpoolGcnEncoderGraph.conv_block,
  275. self.DiffpoolGcnEncoderGraph.conv_last)
  276. else:
  277. ################################################################### embedding
  278. x = self.conv1(input_features, incomplete_adj)
  279. x = self.act(x)
  280. x = self.bn1(x)
  281. x = self.conv2(x, incomplete_adj)
  282. # x = self.bn2(x)
  283. # x = self.act(x)
  284. if (self.vae_args.GRAN):
  285. # x = self.conv_first(input_features, incomplete_adj)
  286. # x = self.act(x)
  287. # if self.bn:
  288. # x = self.apply_bn(x)
  289. # out_all = []
  290. # out, _ = torch.max(x, dim=1)
  291. # out_all.append(out)
  292. # for i in range(self.num_layers - 2):
  293. # x = self.conv_block[i](x, adj)
  294. # x = self.act(x)
  295. # if self.bn:
  296. # x = self.apply_bn(x)
  297. # out, _ = torch.max(x, dim=1)
  298. # out_all.append(out)
  299. # if self.num_aggs == 2:
  300. # out = torch.sum(x, dim=1)
  301. # out_all.append(out)
  302. # x = self.conv_last(x, adj)
  303. ################################################################### end of embedding
  304. if self.vae_args.GRAN_linear:
  305. x = x.view(-1, self.max_num_nodes * self.hidden_dim)
  306. h_decode, z_mu, z_lsgms = self.vae(x)
  307. pos_weight = torch.ones(batch_size)
  308. for i in range(batch_size):
  309. # / adj[i].shape[0]
  310. pos_weight[i] = torch.Tensor(
  311. [1 * float(adj[i].shape[0] * adj[i].shape[0] - adj[i].sum()) / adj[i].sum()])
  312. pos_weight = pos_weight.unsqueeze(1)
  313. return F.binary_cross_entropy_with_logits(h_decode.cpu(), gran_lable,
  314. pos_weight=pos_weight)
  315. else:
  316. pos_weight = torch.ones(batch_size)
  317. for i in range(batch_size):
  318. # / adj[i].shape[0]
  319. pos_weight[i] = torch.Tensor(
  320. [float(adj[i].shape[0] * adj[i].shape[0] - adj[i].sum()) / adj[i].sum()])
  321. pos_weight = pos_weight.unsqueeze(1)
  322. # print("************************")
  323. # print(pos_weight.size())
  324. # print(self.gran(x, batch_num_nodes).squeeze().cpu().size())
  325. # print(gran_lable.size())
  326. # print("***** in model *****")
  327. # print("*** model result : ")
  328. # print(self.gran(x, batch_num_nodes, random_index_list, self.vae_args.GRAN_arbitrary_node).squeeze().cpu())
  329. # print("***** gran_lable")
  330. # print(gran_lable)
  331. return F.binary_cross_entropy_with_logits(
  332. self.gran(x, batch_num_nodes, random_index_list, self.vae_args.GRAN_arbitrary_node).squeeze().cpu(), gran_lable,
  333. pos_weight=pos_weight)
  334. if self.vae_args.reconstruction:
  335. loss = F.binary_cross_entropy_with_logits(self.dc(x), incomplete_adj.cpu())
  336. return loss
  337. if self.vae_args.completion_mode_small_parameter_size:
  338. x = x.view(-1, self.number_of_incomplete_nodes * self.hidden_dim)
  339. else:
  340. if self.vae_args.graph_pooling_mode:
  341. x = self.pool_graph(x)
  342. else:
  343. x = x.view(-1, self.max_num_nodes * self.hidden_dim)
  344. # vae
  345. h_decode, z_mu, z_lsgms = self.vae(x)
  346. out = F.sigmoid(h_decode)
  347. out_tensor = out.cpu().data
  348. if self.vae_args.graph_matching_mode:
  349. recon_adj_lower = self.recover_adj_lower(out_tensor)
  350. recon_adj_tensor = self.recover_full_adj_from_lower(recon_adj_lower)
  351. out_features = torch.sum(recon_adj_tensor, 1)
  352. adj_data = adj.cpu().data
  353. adj_features = torch.sum(adj_data, 1)
  354. batch_size = adj_data.size(0)
  355. adj_permuted = torch.zeros(adj_data.size(0), adj_data.size(1), adj_data.size(2))
  356. for i in range(batch_size):
  357. S = self.edge_similarity_matrix(adj_data[i].squeeze(), recon_adj_tensor[i].squeeze(),
  358. adj_features[i].squeeze(), out_features[i].squeeze(),
  359. self.deg_feature_similarity)
  360. # initialization strategies
  361. init_corr = 1 / self.max_num_nodes
  362. init_assignment = torch.ones(self.max_num_nodes, self.max_num_nodes) * init_corr
  363. assignment = self.mpm(init_assignment, S)
  364. # matching
  365. # use negative of the assignment score since the alg finds min cost flow
  366. row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy())
  367. # order row index according to col index
  368. adj_permuted[i] = self.permute_adj(adj_data[i].squeeze(), row_ind, col_ind)
  369. adj_permuted[i] = adj_permuted[i].unsqueeze(0)
  370. else:
  371. adj_data = adj.cpu().data
  372. adj_permuted = adj_data
  373. if self.vae_args.completion_mode_small_parameter_size:
  374. adj_vectorized = torch.zeros(out.size())
  375. adj_vectorized_index = 0
  376. for i in range(self.vae_args.number_of_missing_nodes): # iterates over columns
  377. for j in range(self.max_num_nodes - i):
  378. adj_vectorized[:, adj_vectorized_index] = adj_permuted[:, self.max_num_nodes - j - i - 1,
  379. self.max_num_nodes - i - 1]
  380. adj_vectorized_index += 1
  381. else:
  382. adj_vectorized = adj_permuted[:, torch.triu(torch.ones(self.max_num_nodes, self.max_num_nodes)) == 1]
  383. adj_vectorized_var = Variable(adj_vectorized).cuda()
  384. adj_recon_loss = self.adj_recon_loss(adj_vectorized_var, out)
  385. # x2 = h_decode[0].unsqueeze(0)
  386. # x3 = x2.unsqueeze(0)
  387. # sample = sample_sigmoid(x3, sample=True, sample_time=1)
  388. # y = torch.zeros(self.max_num_nodes, self.max_num_nodes)
  389. # y[torch.triu(torch.ones(self.max_num_nodes, self.max_num_nodes)) == 1] = sample.data.cpu()
  390. loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp())
  391. loss_kl /= self.max_num_nodes * self.max_num_nodes # normalize
  392. print('kl: ', loss_kl)
  393. loss = adj_recon_loss
  394. return loss
  395. def forward_test(self, input_features, adj):
  396. self.max_num_nodes = 4
  397. adj_data = torch.zeros(self.max_num_nodes, self.max_num_nodes)
  398. adj_data[:4, :4] = torch.FloatTensor([[1, 1, 0, 0], [1, 1, 1, 0], [0, 1, 1, 1], [0, 0, 1, 1]])
  399. adj_features = torch.Tensor([2, 3, 3, 2])
  400. adj_data1 = torch.zeros(self.max_num_nodes, self.max_num_nodes)
  401. adj_data1 = torch.FloatTensor([[1, 1, 1, 0], [1, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]])
  402. adj_features1 = torch.Tensor([3, 3, 2, 2])
  403. S = self.edge_similarity_matrix(adj_data, adj_data1, adj_features, adj_features1,
  404. self.deg_feature_similarity)
  405. # initialization strategies
  406. init_corr = 1 / self.max_num_nodes
  407. init_assignment = torch.ones(self.max_num_nodes, self.max_num_nodes) * init_corr
  408. # init_assignment = torch.FloatTensor(4, 4)
  409. # init.uniform(init_assignment)
  410. assignment = self.mpm(init_assignment, S)
  411. # print('Assignment: ', assignment)
  412. # matching
  413. row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy())
  414. # print('row: ', row_ind)
  415. # print('col: ', col_ind)
  416. permuted_adj = self.permute_adj(adj_data, row_ind, col_ind)
  417. # print('permuted: ', permuted_adj)
  418. adj_recon_loss = self.adj_recon_loss(permuted_adj, adj_data1)
  419. # print(adj_data1)
  420. # print('diff: ', adj_recon_loss)
  421. def adj_recon_loss(self, adj_truth, adj_pred):
  422. return F.binary_cross_entropy(adj_pred, adj_truth)