You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main_DeepGMG.py 19KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. # an implementation for "Learning Deep Generative Models of Graphs"
  2. import os
  3. from main import *
  4. class Args_DGMG():
  5. def __init__(self):
  6. ### CUDA
  7. self.cuda = 0
  8. ### model type
  9. self.note = 'Baseline_DGMG' # do GCN after adding each edge
  10. # self.note = 'Baseline_DGMG_fast' # do GCN only after adding each node
  11. ### data config
  12. # self.graph_type = 'caveman_small'
  13. self.graph_type = 'grid_small'
  14. # self.graph_type = 'ladder_small'
  15. # self.graph_type = 'enzymes_small'
  16. # self.graph_type = 'barabasi_small'
  17. # self.graph_type = 'citeseer_small'
  18. self.max_num_node = 20
  19. ### network config
  20. self.node_embedding_size = 64
  21. self.test_graph_num = 200
  22. ### training config
  23. self.epochs = 100 # now one epoch means self.batch_ratio x batch_size
  24. self.load_epoch = 100
  25. self.epochs_test_start = 10
  26. self.epochs_test = 10
  27. self.epochs_log = 10
  28. self.epochs_save = 10
  29. if 'fast' in self.note:
  30. self.is_fast = True
  31. else:
  32. self.is_fast = False
  33. self.lr = 0.001
  34. self.milestones = [300, 600, 1000]
  35. self.lr_rate = 0.3
  36. ### output config
  37. self.model_save_path = 'model_save/'
  38. self.graph_save_path = 'graphs/'
  39. self.figure_save_path = 'figures/'
  40. self.timing_save_path = 'timing/'
  41. self.figure_prediction_save_path = 'figures_prediction/'
  42. self.nll_save_path = 'nll/'
  43. self.fname = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size)
  44. self.fname_pred = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size) + '_pred_'
  45. self.fname_train = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size) + '_train_'
  46. self.fname_test = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size) + '_test_'
  47. self.load = False
  48. self.save = True
  49. def train_DGMG_epoch(epoch, args, model, dataset, optimizer, scheduler, is_fast=False):
  50. model.train()
  51. graph_num = len(dataset)
  52. order = list(range(graph_num))
  53. shuffle(order)
  54. loss_addnode = 0
  55. loss_addedge = 0
  56. loss_node = 0
  57. for i in order:
  58. model.zero_grad()
  59. graph = dataset[i]
  60. # do random ordering: relabel nodes
  61. node_order = list(range(graph.number_of_nodes()))
  62. shuffle(node_order)
  63. order_mapping = dict(zip(graph.nodes(), node_order))
  64. graph = nx.relabel_nodes(graph, order_mapping, copy=True)
  65. # NOTE: when starting loop, we assume a node has already been generated
  66. node_count = 1
  67. node_embedding = [
  68. Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden
  69. loss = 0
  70. while node_count <= graph.number_of_nodes():
  71. node_neighbor = graph.subgraph(
  72. list(range(node_count))).adjacency_list() # list of lists (first node is zero)
  73. node_neighbor_new = graph.subgraph(list(range(node_count + 1))).adjacency_list()[
  74. -1] # list of new node's neighbors
  75. # 1 message passing
  76. # do 2 times message passing
  77. node_embedding = message_passing(node_neighbor, node_embedding, model)
  78. # 2 graph embedding and new node embedding
  79. node_embedding_cat = torch.cat(node_embedding, dim=0)
  80. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  81. init_embedding = calc_init_embedding(node_embedding_cat, model)
  82. # 3 f_addnode
  83. p_addnode = model.f_an(graph_embedding)
  84. if node_count < graph.number_of_nodes():
  85. # add node
  86. node_neighbor.append([])
  87. node_embedding.append(init_embedding)
  88. if is_fast:
  89. node_embedding_cat = torch.cat(node_embedding, dim=0)
  90. # calc loss
  91. loss_addnode_step = F.binary_cross_entropy(p_addnode, Variable(torch.ones((1, 1))).cuda())
  92. # loss_addnode_step.backward(retain_graph=True)
  93. loss += loss_addnode_step
  94. loss_addnode += loss_addnode_step.data
  95. else:
  96. # calc loss
  97. loss_addnode_step = F.binary_cross_entropy(p_addnode, Variable(torch.zeros((1, 1))).cuda())
  98. # loss_addnode_step.backward(retain_graph=True)
  99. loss += loss_addnode_step
  100. loss_addnode += loss_addnode_step.data
  101. break
  102. edge_count = 0
  103. while edge_count <= len(node_neighbor_new):
  104. if not is_fast:
  105. node_embedding = message_passing(node_neighbor, node_embedding, model)
  106. node_embedding_cat = torch.cat(node_embedding, dim=0)
  107. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  108. # 4 f_addedge
  109. p_addedge = model.f_ae(graph_embedding)
  110. if edge_count < len(node_neighbor_new):
  111. # calc loss
  112. loss_addedge_step = F.binary_cross_entropy(p_addedge, Variable(torch.ones((1, 1))).cuda())
  113. # loss_addedge_step.backward(retain_graph=True)
  114. loss += loss_addedge_step
  115. loss_addedge += loss_addedge_step.data
  116. # 5 f_nodes
  117. # excluding the last node (which is the new node)
  118. node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1,
  119. node_embedding_cat.size(1))
  120. s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1))
  121. p_node = F.softmax(s_node.permute(1, 0))
  122. # get ground truth
  123. a_node = torch.zeros((1, p_node.size(1)))
  124. # print('node_neighbor_new',node_neighbor_new, edge_count)
  125. a_node[0, node_neighbor_new[edge_count]] = 1
  126. a_node = Variable(a_node).cuda()
  127. # add edge
  128. node_neighbor[-1].append(node_neighbor_new[edge_count])
  129. node_neighbor[node_neighbor_new[edge_count]].append(len(node_neighbor) - 1)
  130. # calc loss
  131. loss_node_step = F.binary_cross_entropy(p_node, a_node)
  132. # loss_node_step.backward(retain_graph=True)
  133. loss += loss_node_step
  134. loss_node += loss_node_step.data
  135. else:
  136. # calc loss
  137. loss_addedge_step = F.binary_cross_entropy(p_addedge, Variable(torch.zeros((1, 1))).cuda())
  138. # loss_addedge_step.backward(retain_graph=True)
  139. loss += loss_addedge_step
  140. loss_addedge += loss_addedge_step.data
  141. break
  142. edge_count += 1
  143. node_count += 1
  144. # update deterministic and lstm
  145. loss.backward()
  146. optimizer.step()
  147. scheduler.step()
  148. loss_all = loss_addnode + loss_addedge + loss_node
  149. if epoch % args.epochs_log == 0:
  150. print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, hidden: {}'.format(
  151. epoch, args.epochs, loss_all, args.graph_type, args.node_embedding_size))
  152. # loss_sum += loss.data[0]*x.size(0)
  153. # return loss_sum
  154. def test_DGMG_epoch(args, model, is_fast=False):
  155. model.eval()
  156. graph_num = args.test_graph_num
  157. graphs_generated = []
  158. for i in range(graph_num):
  159. # NOTE: when starting loop, we assume a node has already been generated
  160. node_neighbor = [[]] # list of lists (first node is zero)
  161. node_embedding = [
  162. Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden
  163. node_count = 1
  164. while node_count <= args.max_num_node:
  165. # 1 message passing
  166. # do 2 times message passing
  167. node_embedding = message_passing(node_neighbor, node_embedding, model)
  168. # 2 graph embedding and new node embedding
  169. node_embedding_cat = torch.cat(node_embedding, dim=0)
  170. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  171. init_embedding = calc_init_embedding(node_embedding_cat, model)
  172. # 3 f_addnode
  173. p_addnode = model.f_an(graph_embedding)
  174. a_addnode = sample_tensor(p_addnode)
  175. # print(a_addnode.data[0][0])
  176. if a_addnode.data[0][0] == 1:
  177. # print('add node')
  178. # add node
  179. node_neighbor.append([])
  180. node_embedding.append(init_embedding)
  181. if is_fast:
  182. node_embedding_cat = torch.cat(node_embedding, dim=0)
  183. else:
  184. break
  185. edge_count = 0
  186. while edge_count < args.max_num_node:
  187. if not is_fast:
  188. node_embedding = message_passing(node_neighbor, node_embedding, model)
  189. node_embedding_cat = torch.cat(node_embedding, dim=0)
  190. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  191. # 4 f_addedge
  192. p_addedge = model.f_ae(graph_embedding)
  193. a_addedge = sample_tensor(p_addedge)
  194. # print(a_addedge.data[0][0])
  195. if a_addedge.data[0][0] == 1:
  196. # print('add edge')
  197. # 5 f_nodes
  198. # excluding the last node (which is the new node)
  199. node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1,
  200. node_embedding_cat.size(1))
  201. s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1))
  202. p_node = F.softmax(s_node.permute(1, 0))
  203. a_node = gumbel_softmax(p_node, temperature=0.01)
  204. _, a_node_id = a_node.topk(1)
  205. a_node_id = int(a_node_id.data[0][0])
  206. # add edge
  207. node_neighbor[-1].append(a_node_id)
  208. node_neighbor[a_node_id].append(len(node_neighbor) - 1)
  209. else:
  210. break
  211. edge_count += 1
  212. node_count += 1
  213. # save graph
  214. node_neighbor_dict = dict(zip(list(range(len(node_neighbor))), node_neighbor))
  215. graph = nx.from_dict_of_lists(node_neighbor_dict)
  216. graphs_generated.append(graph)
  217. return graphs_generated
  218. def test_DGMG_2(args, model, test_graph, is_fast=False):
  219. model.eval()
  220. graph_num = args.test_graph_num
  221. graphs_generated = []
  222. for i in range(graph_num):
  223. # NOTE: when starting loop, we assume a node has already been generated
  224. node_neighbor = [[]] # list of lists (first node is zero)
  225. node_embedding = [
  226. Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden
  227. node_max = len(test_graph.nodes())
  228. node_count = 1
  229. while node_count <= node_max:
  230. # 1 message passing
  231. # do 2 times message passing
  232. node_embedding = message_passing(node_neighbor, node_embedding, model)
  233. # 2 graph embedding and new node embedding
  234. node_embedding_cat = torch.cat(node_embedding, dim=0)
  235. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  236. init_embedding = calc_init_embedding(node_embedding_cat, model)
  237. # 3 f_addnode
  238. p_addnode = model.f_an(graph_embedding)
  239. a_addnode = sample_tensor(p_addnode)
  240. if a_addnode.data[0][0] == 1:
  241. # add node
  242. node_neighbor.append([])
  243. node_embedding.append(init_embedding)
  244. if is_fast:
  245. node_embedding_cat = torch.cat(node_embedding, dim=0)
  246. else:
  247. break
  248. edge_count = 0
  249. while edge_count < args.max_num_node:
  250. if not is_fast:
  251. node_embedding = message_passing(node_neighbor, node_embedding, model)
  252. node_embedding_cat = torch.cat(node_embedding, dim=0)
  253. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  254. # 4 f_addedge
  255. p_addedge = model.f_ae(graph_embedding)
  256. a_addedge = sample_tensor(p_addedge)
  257. if a_addedge.data[0][0] == 1:
  258. # 5 f_nodes
  259. # excluding the last node (which is the new node)
  260. node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1,
  261. node_embedding_cat.size(1))
  262. s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1))
  263. p_node = F.softmax(s_node.permute(1, 0))
  264. a_node = gumbel_softmax(p_node, temperature=0.01)
  265. _, a_node_id = a_node.topk(1)
  266. a_node_id = int(a_node_id.data[0][0])
  267. # add edge
  268. node_neighbor[-1].append(a_node_id)
  269. node_neighbor[a_node_id].append(len(node_neighbor) - 1)
  270. else:
  271. break
  272. edge_count += 1
  273. node_count += 1
  274. # clear node_neighbor and build it again
  275. node_neighbor = []
  276. for n in range(node_max):
  277. temp_neighbor = [k for k in test_graph.edge[n]]
  278. node_neighbor.append(temp_neighbor)
  279. # now add the last node for real
  280. # 1 message passing
  281. # do 2 times message passing
  282. try:
  283. node_embedding = message_passing(node_neighbor, node_embedding, model)
  284. # 2 graph embedding and new node embedding
  285. node_embedding_cat = torch.cat(node_embedding, dim=0)
  286. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  287. init_embedding = calc_init_embedding(node_embedding_cat, model)
  288. # 3 f_addnode
  289. p_addnode = model.f_an(graph_embedding)
  290. a_addnode = sample_tensor(p_addnode)
  291. if a_addnode.data[0][0] == 1:
  292. # add node
  293. node_neighbor.append([])
  294. node_embedding.append(init_embedding)
  295. if is_fast:
  296. node_embedding_cat = torch.cat(node_embedding, dim=0)
  297. edge_count = 0
  298. while edge_count < args.max_num_node:
  299. if not is_fast:
  300. node_embedding = message_passing(node_neighbor, node_embedding, model)
  301. node_embedding_cat = torch.cat(node_embedding, dim=0)
  302. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  303. # 4 f_addedge
  304. p_addedge = model.f_ae(graph_embedding)
  305. a_addedge = sample_tensor(p_addedge)
  306. if a_addedge.data[0][0] == 1:
  307. # 5 f_nodes
  308. # excluding the last node (which is the new node)
  309. node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1,
  310. node_embedding_cat.size(1))
  311. s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1))
  312. p_node = F.softmax(s_node.permute(1, 0))
  313. a_node = gumbel_softmax(p_node, temperature=0.01)
  314. _, a_node_id = a_node.topk(1)
  315. a_node_id = int(a_node_id.data[0][0])
  316. # add edge
  317. node_neighbor[-1].append(a_node_id)
  318. node_neighbor[a_node_id].append(len(node_neighbor) - 1)
  319. else:
  320. break
  321. edge_count += 1
  322. node_count += 1
  323. except:
  324. print('error')
  325. # save graph
  326. node_neighbor_dict = dict(zip(list(range(len(node_neighbor))), node_neighbor))
  327. graph = nx.from_dict_of_lists(node_neighbor_dict)
  328. graphs_generated.append(graph)
  329. return graphs_generated
  330. ########### train function for LSTM + VAE
  331. def train_DGMG(args, dataset_train, model):
  332. # check if load existing model
  333. if args.load:
  334. fname = args.model_save_path + args.fname + 'model_' + str(args.load_epoch) + '.dat'
  335. model.load_state_dict(torch.load(fname))
  336. args.lr = 0.00001
  337. epoch = args.load_epoch
  338. print('model loaded!, lr: {}'.format(args.lr))
  339. else:
  340. epoch = 1
  341. # initialize optimizer
  342. optimizer = optim.Adam(list(model.parameters()), lr=args.lr)
  343. scheduler = MultiStepLR(optimizer, milestones=args.milestones, gamma=args.lr_rate)
  344. # start main loop
  345. time_all = np.zeros(args.epochs)
  346. while epoch <= args.epochs:
  347. time_start = tm.time()
  348. # train
  349. train_DGMG_epoch(epoch, args, model, dataset_train, optimizer, scheduler, is_fast=args.is_fast)
  350. time_end = tm.time()
  351. time_all[epoch - 1] = time_end - time_start
  352. print('time used', time_all[epoch - 1])
  353. # test
  354. if epoch % args.epochs_test == 0 and epoch >= args.epochs_test_start:
  355. graphs = test_DGMG_epoch(args, model, is_fast=args.is_fast)
  356. fname = args.graph_save_path + args.fname_pred + str(epoch) + '.dat'
  357. save_graph_list(graphs, fname)
  358. # print('test done, graphs saved')
  359. # save model checkpoint
  360. if args.save:
  361. if epoch % args.epochs_save == 0:
  362. fname = args.model_save_path + args.fname + 'model_' + str(epoch) + '.dat'
  363. torch.save(model.state_dict(), fname)
  364. epoch += 1
  365. np.save(args.timing_save_path + args.fname, time_all)
  366. if __name__ == '__main__':
  367. args = Args_DGMG()
  368. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
  369. print('CUDA', args.cuda)
  370. print('File name prefix', args.fname)
  371. graphs = []
  372. for i in range(4, 10):
  373. graphs.append(nx.ladder_graph(i))
  374. model = DGM_graphs(h_size=args.node_embedding_size).cuda()
  375. if args.graph_type == 'grid_small':
  376. graphs = []
  377. for i in range(2, 3):
  378. for j in range(2, 4):
  379. graphs.append(nx.grid_2d_graph(i, j))
  380. args.max_prev_node = 6
  381. # remove self loops
  382. for graph in graphs:
  383. edges_with_selfloops = graph.selfloop_edges()
  384. if len(edges_with_selfloops) > 0:
  385. graph.remove_edges_from(edges_with_selfloops)
  386. # split datasets
  387. random.seed(123)
  388. shuffle(graphs)
  389. # graphs_len = len(graphs)
  390. # graphs_test = graphs[int(0.8 * graphs_len):]
  391. # graphs_validate = graphs[int(0.7 * graphs_len):int(0.8 * graphs_len)]
  392. # graphs_train = graphs[0:int(0.7 * graphs_len)]
  393. args.max_num_node = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
  394. print('max number node: {}'.format(args.max_num_node))
  395. print('max previous node: {}'.format(args.max_prev_node))
  396. test_graph = nx.grid_2d_graph(2, 3)
  397. test_graph.remove_node(test_graph.nodes()[5])
  398. train_DGMG(args, graphs, model)
  399. test_graph = nx.convert_node_labels_to_integers(test_graph)
  400. test_DGMG_2(args, model, test_graph)