You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main_DeepGMG.py 22KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. # an implementation for "Learning Deep Generative Models of Graphs"
  2. import os
  3. import random
  4. from statistics import mean
  5. import networkx as nx
  6. import numpy as np
  7. from sklearn.metrics import roc_auc_score, average_precision_score
  8. from main import *
  9. class Args_DGMG():
  10. def __init__(self):
  11. ### CUDA
  12. self.cuda = 0
  13. ### model type
  14. self.note = 'Baseline_DGMG' # do GCN after adding each edge
  15. # self.note = 'Baseline_DGMG_fast' # do GCN only after adding each node
  16. ### data config
  17. # self.graph_type = 'caveman_small'
  18. self.graph_type = 'grid_small'
  19. # self.graph_type = 'ladder_small'
  20. # self.graph_type = 'enzymes_small'
  21. # self.graph_type = 'barabasi_small'
  22. # self.graph_type = 'citeseer_small'
  23. self.max_num_node = 20
  24. ### network config
  25. self.node_embedding_size = 64
  26. self.test_graph_num = 200
  27. ### training config
  28. self.epochs = 100 # now one epoch means self.batch_ratio x batch_size
  29. self.load_epoch = 100
  30. self.epochs_test_start = 10
  31. self.epochs_test = 10
  32. self.epochs_log = 10
  33. self.epochs_save = 10
  34. if 'fast' in self.note:
  35. self.is_fast = True
  36. else:
  37. self.is_fast = False
  38. self.lr = 0.001
  39. self.milestones = [300, 600, 1000]
  40. self.lr_rate = 0.3
  41. ### output config
  42. self.model_save_path = 'model_save/'
  43. self.graph_save_path = 'graphs/'
  44. self.figure_save_path = 'figures/'
  45. self.timing_save_path = 'timing/'
  46. self.figure_prediction_save_path = 'figures_prediction/'
  47. self.nll_save_path = 'nll/'
  48. self.fname = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size)
  49. self.fname_pred = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size) + '_pred_'
  50. self.fname_train = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size) + '_train_'
  51. self.fname_test = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size) + '_test_'
  52. self.load = False
  53. self.save = True
  54. def train_DGMG_epoch(epoch, args, model, dataset, optimizer, scheduler, is_fast=False):
  55. model.train()
  56. graph_num = len(dataset)
  57. order = list(range(graph_num))
  58. shuffle(order)
  59. loss_addnode = 0
  60. loss_addedge = 0
  61. loss_node = 0
  62. for i in order:
  63. model.zero_grad()
  64. graph = dataset[i]
  65. # do random ordering: relabel nodes
  66. node_order = list(range(graph.number_of_nodes()))
  67. shuffle(node_order)
  68. order_mapping = dict(zip(graph.nodes(), node_order))
  69. graph = nx.relabel_nodes(graph, order_mapping, copy=True)
  70. # NOTE: when starting loop, we assume a node has already been generated
  71. node_count = 1
  72. node_embedding = [
  73. Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden
  74. loss = 0
  75. while node_count <= graph.number_of_nodes():
  76. node_neighbor = graph.subgraph(
  77. list(range(node_count))).adjacency_list() # list of lists (first node is zero)
  78. node_neighbor_new = graph.subgraph(list(range(node_count + 1))).adjacency_list()[
  79. -1] # list of new node's neighbors
  80. # 1 message passing
  81. # do 2 times message passing
  82. node_embedding = message_passing(node_neighbor, node_embedding, model)
  83. # 2 graph embedding and new node embedding
  84. node_embedding_cat = torch.cat(node_embedding, dim=0)
  85. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  86. init_embedding = calc_init_embedding(node_embedding_cat, model)
  87. # 3 f_addnode
  88. p_addnode = model.f_an(graph_embedding)
  89. if node_count < graph.number_of_nodes():
  90. # add node
  91. node_neighbor.append([])
  92. node_embedding.append(init_embedding)
  93. if is_fast:
  94. node_embedding_cat = torch.cat(node_embedding, dim=0)
  95. # calc loss
  96. loss_addnode_step = F.binary_cross_entropy(p_addnode, Variable(torch.ones((1, 1))).cuda())
  97. # loss_addnode_step.backward(retain_graph=True)
  98. loss += loss_addnode_step
  99. loss_addnode += loss_addnode_step.data
  100. else:
  101. # calc loss
  102. loss_addnode_step = F.binary_cross_entropy(p_addnode, Variable(torch.zeros((1, 1))).cuda())
  103. # loss_addnode_step.backward(retain_graph=True)
  104. loss += loss_addnode_step
  105. loss_addnode += loss_addnode_step.data
  106. break
  107. edge_count = 0
  108. while edge_count <= len(node_neighbor_new):
  109. if not is_fast:
  110. node_embedding = message_passing(node_neighbor, node_embedding, model)
  111. node_embedding_cat = torch.cat(node_embedding, dim=0)
  112. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  113. # 4 f_addedge
  114. p_addedge = model.f_ae(graph_embedding)
  115. if edge_count < len(node_neighbor_new):
  116. # calc loss
  117. loss_addedge_step = F.binary_cross_entropy(p_addedge, Variable(torch.ones((1, 1))).cuda())
  118. # loss_addedge_step.backward(retain_graph=True)
  119. loss += loss_addedge_step
  120. loss_addedge += loss_addedge_step.data
  121. # 5 f_nodes
  122. # excluding the last node (which is the new node)
  123. node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1,
  124. node_embedding_cat.size(1))
  125. s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1))
  126. p_node = F.softmax(s_node.permute(1, 0))
  127. # get ground truth
  128. a_node = torch.zeros((1, p_node.size(1)))
  129. # print('node_neighbor_new',node_neighbor_new, edge_count)
  130. a_node[0, node_neighbor_new[edge_count]] = 1
  131. a_node = Variable(a_node).cuda()
  132. # add edge
  133. node_neighbor[-1].append(node_neighbor_new[edge_count])
  134. node_neighbor[node_neighbor_new[edge_count]].append(len(node_neighbor) - 1)
  135. # calc loss
  136. loss_node_step = F.binary_cross_entropy(p_node, a_node)
  137. # loss_node_step.backward(retain_graph=True)
  138. loss += loss_node_step
  139. loss_node += loss_node_step.data
  140. else:
  141. # calc loss
  142. loss_addedge_step = F.binary_cross_entropy(p_addedge, Variable(torch.zeros((1, 1))).cuda())
  143. # loss_addedge_step.backward(retain_graph=True)
  144. loss += loss_addedge_step
  145. loss_addedge += loss_addedge_step.data
  146. break
  147. edge_count += 1
  148. node_count += 1
  149. # update deterministic and lstm
  150. loss.backward()
  151. optimizer.step()
  152. scheduler.step()
  153. loss_all = loss_addnode + loss_addedge + loss_node
  154. if epoch % args.epochs_log == 0:
  155. print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, hidden: {}'.format(
  156. epoch, args.epochs, loss_all, args.graph_type, args.node_embedding_size))
  157. # loss_sum += loss.data[0]*x.size(0)
  158. # return loss_sum
  159. def test_DGMG_epoch(args, model, is_fast=False):
  160. model.eval()
  161. graph_num = args.test_graph_num
  162. graphs_generated = []
  163. for i in range(graph_num):
  164. # NOTE: when starting loop, we assume a node has already been generated
  165. node_neighbor = [[]] # list of lists (first node is zero)
  166. node_embedding = [
  167. Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden
  168. node_count = 1
  169. while node_count <= args.max_num_node:
  170. # 1 message passing
  171. # do 2 times message passing
  172. node_embedding = message_passing(node_neighbor, node_embedding, model)
  173. # 2 graph embedding and new node embedding
  174. node_embedding_cat = torch.cat(node_embedding, dim=0)
  175. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  176. init_embedding = calc_init_embedding(node_embedding_cat, model)
  177. # 3 f_addnode
  178. p_addnode = model.f_an(graph_embedding)
  179. a_addnode = sample_tensor(p_addnode)
  180. # print(a_addnode.data[0][0])
  181. if a_addnode.data[0][0] == 1:
  182. # print('add node')
  183. # add node
  184. node_neighbor.append([])
  185. node_embedding.append(init_embedding)
  186. if is_fast:
  187. node_embedding_cat = torch.cat(node_embedding, dim=0)
  188. else:
  189. break
  190. edge_count = 0
  191. while edge_count < args.max_num_node:
  192. if not is_fast:
  193. node_embedding = message_passing(node_neighbor, node_embedding, model)
  194. node_embedding_cat = torch.cat(node_embedding, dim=0)
  195. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  196. # 4 f_addedge
  197. p_addedge = model.f_ae(graph_embedding)
  198. a_addedge = sample_tensor(p_addedge)
  199. # print(a_addedge.data[0][0])
  200. if a_addedge.data[0][0] == 1:
  201. # print('add edge')
  202. # 5 f_nodes
  203. # excluding the last node (which is the new node)
  204. node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1,
  205. node_embedding_cat.size(1))
  206. s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1))
  207. p_node = F.softmax(s_node.permute(1, 0))
  208. a_node = gumbel_softmax(p_node, temperature=0.01)
  209. _, a_node_id = a_node.topk(1)
  210. a_node_id = int(a_node_id.data[0][0])
  211. # add edge
  212. node_neighbor[-1].append(a_node_id)
  213. node_neighbor[a_node_id].append(len(node_neighbor) - 1)
  214. else:
  215. break
  216. edge_count += 1
  217. node_count += 1
  218. # save graph
  219. node_neighbor_dict = dict(zip(list(range(len(node_neighbor))), node_neighbor))
  220. graph = nx.from_dict_of_lists(node_neighbor_dict)
  221. graphs_generated.append(graph)
  222. return graphs_generated
  223. def test_DGMG_2(args, model, test_graph, is_fast=False):
  224. model.eval()
  225. graph_num = args.test_graph_num
  226. graphs_generated = []
  227. # for i in range(graph_num):
  228. # NOTE: when starting loop, we assume a node has already been generated
  229. node_neighbor = [[]] # list of lists (first node is zero)
  230. node_embedding = [
  231. Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden
  232. node_max = len(test_graph.nodes())
  233. node_count = 1
  234. while node_count <= node_max:
  235. # 1 message passing
  236. # do 2 times message passing
  237. node_embedding = message_passing(node_neighbor, node_embedding, model)
  238. # 2 graph embedding and new node embedding
  239. node_embedding_cat = torch.cat(node_embedding, dim=0)
  240. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  241. init_embedding = calc_init_embedding(node_embedding_cat, model)
  242. # 3 f_addnode
  243. p_addnode = model.f_an(graph_embedding)
  244. a_addnode = sample_tensor(p_addnode)
  245. if a_addnode.data[0][0] == 1:
  246. # add node
  247. node_neighbor.append([])
  248. node_embedding.append(init_embedding)
  249. if is_fast:
  250. node_embedding_cat = torch.cat(node_embedding, dim=0)
  251. else:
  252. break
  253. edge_count = 0
  254. while edge_count < args.max_num_node:
  255. if not is_fast:
  256. node_embedding = message_passing(node_neighbor, node_embedding, model)
  257. node_embedding_cat = torch.cat(node_embedding, dim=0)
  258. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  259. # 4 f_addedge
  260. p_addedge = model.f_ae(graph_embedding)
  261. a_addedge = sample_tensor(p_addedge)
  262. if a_addedge.data[0][0] == 1:
  263. # 5 f_nodes
  264. # excluding the last node (which is the new node)
  265. node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1,
  266. node_embedding_cat.size(1))
  267. s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1))
  268. p_node = F.softmax(s_node.permute(1, 0))
  269. a_node = gumbel_softmax(p_node, temperature=0.01)
  270. _, a_node_id = a_node.topk(1)
  271. a_node_id = int(a_node_id.data[0][0])
  272. # add edge
  273. node_neighbor[-1].append(a_node_id)
  274. node_neighbor[a_node_id].append(len(node_neighbor) - 1)
  275. else:
  276. break
  277. edge_count += 1
  278. node_count += 1
  279. # clear node_neighbor and build it again
  280. node_neighbor = []
  281. for n in range(node_max):
  282. temp_neighbor = [k for k in test_graph.edge[n]]
  283. node_neighbor.append(temp_neighbor)
  284. # now add the last node for real
  285. # 1 message passing
  286. # do 2 times message passing
  287. try:
  288. node_embedding = message_passing(node_neighbor, node_embedding, model)
  289. # 2 graph embedding and new node embedding
  290. node_embedding_cat = torch.cat(node_embedding, dim=0)
  291. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  292. init_embedding = calc_init_embedding(node_embedding_cat, model)
  293. # 3 f_addnode
  294. p_addnode = model.f_an(graph_embedding)
  295. a_addnode = sample_tensor(p_addnode)
  296. if a_addnode.data[0][0] == 1:
  297. # add node
  298. node_neighbor.append([])
  299. node_embedding.append(init_embedding)
  300. if is_fast:
  301. node_embedding_cat = torch.cat(node_embedding, dim=0)
  302. edge_count = 0
  303. while edge_count < args.max_num_node:
  304. if not is_fast:
  305. node_embedding = message_passing(node_neighbor, node_embedding, model)
  306. node_embedding_cat = torch.cat(node_embedding, dim=0)
  307. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  308. # 4 f_addedge
  309. p_addedge = model.f_ae(graph_embedding)
  310. a_addedge = sample_tensor(p_addedge)
  311. if a_addedge.data[0][0] == 1:
  312. # 5 f_nodes
  313. # excluding the last node (which is the new node)
  314. node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1,
  315. node_embedding_cat.size(1))
  316. s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1))
  317. p_node = F.softmax(s_node.permute(1, 0))
  318. a_node = gumbel_softmax(p_node, temperature=0.01)
  319. _, a_node_id = a_node.topk(1)
  320. a_node_id = int(a_node_id.data[0][0])
  321. # add edge
  322. node_neighbor[-1].append(a_node_id)
  323. node_neighbor[a_node_id].append(len(node_neighbor) - 1)
  324. else:
  325. break
  326. edge_count += 1
  327. node_count += 1
  328. except:
  329. print('error')
  330. # save graph
  331. node_neighbor_dict = dict(zip(list(range(len(node_neighbor))), node_neighbor))
  332. graph = nx.from_dict_of_lists(node_neighbor_dict)
  333. graphs_generated.append(graph)
  334. return graphs_generated
  335. ########### train function for LSTM + VAE
  336. def train_DGMG(args, dataset_train, model):
  337. # check if load existing model
  338. if args.load:
  339. fname = args.model_save_path + args.fname + 'model_' + str(args.load_epoch) + '.dat'
  340. model.load_state_dict(torch.load(fname))
  341. args.lr = 0.00001
  342. epoch = args.load_epoch
  343. print('model loaded!, lr: {}'.format(args.lr))
  344. else:
  345. epoch = 1
  346. # initialize optimizer
  347. optimizer = optim.Adam(list(model.parameters()), lr=args.lr)
  348. scheduler = MultiStepLR(optimizer, milestones=args.milestones, gamma=args.lr_rate)
  349. # start main loop
  350. time_all = np.zeros(args.epochs)
  351. while epoch <= args.epochs:
  352. time_start = tm.time()
  353. # train
  354. train_DGMG_epoch(epoch, args, model, dataset_train, optimizer, scheduler, is_fast=args.is_fast)
  355. time_end = tm.time()
  356. time_all[epoch - 1] = time_end - time_start
  357. print('time used', time_all[epoch - 1])
  358. # test
  359. if epoch % args.epochs_test == 0 and epoch >= args.epochs_test_start:
  360. graphs = test_DGMG_epoch(args, model, is_fast=args.is_fast)
  361. fname = args.graph_save_path + args.fname_pred + str(epoch) + '.dat'
  362. save_graph_list(graphs, fname)
  363. # print('test done, graphs saved')
  364. # save model checkpoint
  365. if args.save:
  366. if epoch % args.epochs_save == 0:
  367. fname = args.model_save_path + args.fname + 'model_' + str(epoch) + '.dat'
  368. torch.save(model.state_dict(), fname)
  369. epoch += 1
  370. np.save(args.timing_save_path + args.fname, time_all)
  371. def neigh_to_mat(neigh, size):
  372. ret_list = np.zeros(size)
  373. for i in neigh:
  374. ret_list[i] = 1
  375. return ret_list
  376. def calc_lable_result(test_graphs, returned_graphs):
  377. labels = []
  378. results = []
  379. i = 0
  380. for test_graph in test_graphs:
  381. n = len(test_graph.nodes())
  382. returned_graph = returned_graphs[i]
  383. label = neigh_to_mat([k for k in test_graph.edge[n - 1]], n)
  384. try:
  385. result = neigh_to_mat([k for k in returned_graph.edge[n - 1]], n)
  386. except:
  387. result = np.zeros(n)
  388. labels.append(label)
  389. results.append(result)
  390. i += 1
  391. return labels, results
  392. def evaluate(labels, results):
  393. mae_list = []
  394. roc_score_list = []
  395. ap_score_list = []
  396. precision_list = []
  397. recall_list = []
  398. iter = 0
  399. for result in results:
  400. label = labels[iter]
  401. iter += 1
  402. part1 = label[result == 1]
  403. part2 = part1[part1 == 1]
  404. part3 = part1[part1 == 0]
  405. part4 = label[result == 0]
  406. part5 = part4[part4 == 1]
  407. tp = len(part2)
  408. fp = len(part3)
  409. fn = part5.sum()
  410. if tp + fp > 0:
  411. precision = tp / (tp + fp)
  412. else:
  413. precision = 0
  414. recall = tp / (tp + fn)
  415. precision_list.append(precision)
  416. recall_list.append(recall)
  417. positive = result[label == 1]
  418. if len(positive) <= len(list(result[label == 0])):
  419. negative = random.sample(list(result[label == 0]), len(positive))
  420. else:
  421. negative = result[label == 0]
  422. positive = random.sample(list(result[label == 1]), len(negative))
  423. preds_all = np.hstack([positive, negative])
  424. labels_all = np.hstack([np.ones(len(positive)), np.zeros(len(positive))])
  425. if len(labels_all) > 0:
  426. roc_score = roc_auc_score(labels_all, preds_all)
  427. ap_score = average_precision_score(labels_all, preds_all)
  428. roc_score_list.append(roc_score)
  429. ap_score_list.append(ap_score)
  430. mae = 0
  431. for x in range(len(result)):
  432. if result[x] != label[x]:
  433. mae += 1
  434. mae = mae / len(label)
  435. mae_list.append(mae)
  436. mean_roc = mean(roc_score_list)
  437. mean_ap = mean(ap_score_list)
  438. mean_precision = mean(precision_list)
  439. mean_recall = mean(recall_list)
  440. mean_mae = mean(mae_list)
  441. print('roc_score ' + str(mean_roc))
  442. print('ap_score ' + str(mean_ap))
  443. print('precision ' + str(mean_precision))
  444. print('recall ' + str(mean_recall))
  445. print('mae ' + str(mean_mae))
  446. return mean_roc, mean_ap, mean_precision, mean_recall
  447. if __name__ == '__main__':
  448. args = Args_DGMG()
  449. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
  450. print('CUDA', args.cuda)
  451. print('File name prefix', args.fname)
  452. graphs = []
  453. for i in range(4, 10):
  454. graphs.append(nx.ladder_graph(i))
  455. model = DGM_graphs(h_size=args.node_embedding_size).cuda()
  456. if args.graph_type == 'grid_small':
  457. graphs = []
  458. for i in range(2, 3):
  459. for j in range(2, 4):
  460. graphs.append(nx.grid_2d_graph(i, j))
  461. args.max_prev_node = 5
  462. # remove self loops
  463. for graph in graphs:
  464. edges_with_selfloops = graph.selfloop_edges()
  465. if len(edges_with_selfloops) > 0:
  466. graph.remove_edges_from(edges_with_selfloops)
  467. # split datasets
  468. random.seed(123)
  469. shuffle(graphs)
  470. # graphs_len = len(graphs)
  471. # graphs_test = graphs[int(0.8 * graphs_len):]
  472. # graphs_validate = graphs[int(0.7 * graphs_len):int(0.8 * graphs_len)]
  473. # graphs_train = graphs[0:int(0.7 * graphs_len)]
  474. args.max_num_node = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
  475. print('max number node: {}'.format(args.max_num_node))
  476. print('max previous node: {}'.format(args.max_prev_node))
  477. test_graph = nx.grid_2d_graph(2, 3)
  478. test_graph.remove_node(test_graph.nodes()[5])
  479. train_DGMG(args, graphs, model)
  480. test_graph = nx.convert_node_labels_to_integers(test_graph)
  481. test_DGMG_2(args, model, test_graph)
  482. labels, results = calc_lable_result(test_graphs, eval_graphs)
  483. evaluate(labels, results)