You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main_DeepGMG.py 31KB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786
  1. # an implementation for "Learning Deep Generative Models of Graphs"
  2. from baselines.graphvae.util import load_data
  3. from main import *
  4. from statistics import mean
  5. import networkx as nx
  6. import numpy as np
  7. from sklearn.metrics import roc_auc_score, average_precision_score
  8. class Args_DGMG():
  9. def __init__(self):
  10. ### CUDA
  11. self.cuda = 0
  12. ### model type
  13. self.note = 'Baseline_DGMG' # do GCN after adding each edge
  14. # self.note = 'Baseline_DGMG_fast' # do GCN only after adding each node
  15. ### data config
  16. # self.graph_type = 'caveman_small'
  17. # self.graph_type = 'grid_small'
  18. # self.graph_type = 'ENZYMES'
  19. # self.graph_type = 'ladder_small'
  20. self.graph_type = 'enzymes_small'
  21. # self.graph_type = 'protein'
  22. # self.graph_type = 'barabasi_small'
  23. # self.graph_type = 'citeseer_small'
  24. self.max_num_node = 20
  25. ### network config
  26. self.node_embedding_size = 64
  27. self.test_graph_num = 200
  28. ### training config
  29. self.epochs = 2000 # now one epoch means self.batch_ratio x batch_size
  30. self.load_epoch = 2000
  31. self.epochs_test_start = 100
  32. self.epochs_test = 100
  33. self.epochs_log = 2
  34. self.epochs_save = 100
  35. if 'fast' in self.note:
  36. self.is_fast = True
  37. else:
  38. self.is_fast = False
  39. self.lr = 0.001
  40. self.milestones = [300, 600, 1000]
  41. self.lr_rate = 0.3
  42. ### output config
  43. self.model_save_path = 'model_save/'
  44. self.graph_save_path = 'graphs/'
  45. self.figure_save_path = 'figures/'
  46. self.timing_save_path = 'timing/'
  47. self.figure_prediction_save_path = 'figures_prediction/'
  48. self.nll_save_path = 'nll/'
  49. self.fname = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size)
  50. self.fname_pred = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size) + '_pred_'
  51. self.fname_train = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size) + '_train_'
  52. self.fname_test = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size) + '_test_'
  53. self.load = False
  54. self.save = True
  55. def train_DGMG_epoch(epoch, args, model, dataset, optimizer, scheduler, is_fast=False):
  56. model.train()
  57. graph_num = len(dataset)
  58. order = list(range(graph_num))
  59. shuffle(order)
  60. loss_addnode = 0
  61. loss_addedge = 0
  62. loss_node = 0
  63. for i in order:
  64. model.zero_grad()
  65. graph = dataset[i]
  66. # do random ordering: relabel nodes
  67. node_order = list(range(graph.number_of_nodes()))
  68. shuffle(node_order)
  69. order_mapping = dict(zip(graph.nodes(), node_order))
  70. graph = nx.relabel_nodes(graph, order_mapping, copy=True)
  71. # NOTE: when starting loop, we assume a node has already been generated
  72. node_count = 1
  73. node_embedding = [
  74. Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden
  75. loss = 0
  76. while node_count <= graph.number_of_nodes():
  77. node_neighbor = graph.subgraph(
  78. list(range(node_count))).adjacency_list() # list of lists (first node is zero)
  79. node_neighbor_new = graph.subgraph(list(range(node_count + 1))).adjacency_list()[
  80. -1] # list of new node's neighbors
  81. # 1 message passing
  82. # do 2 times message passing
  83. node_embedding = message_passing(node_neighbor, node_embedding, model)
  84. # 2 graph embedding and new node embedding
  85. node_embedding_cat = torch.cat(node_embedding, dim=0)
  86. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  87. init_embedding = calc_init_embedding(node_embedding_cat, model)
  88. # 3 f_addnode
  89. p_addnode = model.f_an(graph_embedding)
  90. if node_count < graph.number_of_nodes():
  91. # add node
  92. node_neighbor.append([])
  93. node_embedding.append(init_embedding)
  94. if is_fast:
  95. node_embedding_cat = torch.cat(node_embedding, dim=0)
  96. # calc loss
  97. loss_addnode_step = F.binary_cross_entropy(p_addnode, Variable(torch.ones((1, 1))).cuda())
  98. # loss_addnode_step.backward(retain_graph=True)
  99. loss += loss_addnode_step
  100. loss_addnode += loss_addnode_step.data
  101. else:
  102. # calc loss
  103. loss_addnode_step = F.binary_cross_entropy(p_addnode, Variable(torch.zeros((1, 1))).cuda())
  104. # loss_addnode_step.backward(retain_graph=True)
  105. loss += loss_addnode_step
  106. loss_addnode += loss_addnode_step.data
  107. break
  108. edge_count = 0
  109. while edge_count <= len(node_neighbor_new):
  110. if not is_fast:
  111. node_embedding = message_passing(node_neighbor, node_embedding, model)
  112. node_embedding_cat = torch.cat(node_embedding, dim=0)
  113. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  114. # 4 f_addedge
  115. p_addedge = model.f_ae(graph_embedding)
  116. if edge_count < len(node_neighbor_new):
  117. # calc loss
  118. loss_addedge_step = F.binary_cross_entropy(p_addedge, Variable(torch.ones((1, 1))).cuda())
  119. # loss_addedge_step.backward(retain_graph=True)
  120. loss += loss_addedge_step
  121. loss_addedge += loss_addedge_step.data
  122. # 5 f_nodes
  123. # excluding the last node (which is the new node)
  124. node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1,
  125. node_embedding_cat.size(1))
  126. s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1))
  127. p_node = F.softmax(s_node.permute(1, 0))
  128. # get ground truth
  129. a_node = torch.zeros((1, p_node.size(1)))
  130. # print('node_neighbor_new',node_neighbor_new, edge_count)
  131. a_node[0, node_neighbor_new[edge_count]] = 1
  132. a_node = Variable(a_node).cuda()
  133. # add edge
  134. node_neighbor[-1].append(node_neighbor_new[edge_count])
  135. node_neighbor[node_neighbor_new[edge_count]].append(len(node_neighbor) - 1)
  136. # calc loss
  137. loss_node_step = F.binary_cross_entropy(p_node, a_node)
  138. # loss_node_step.backward(retain_graph=True)
  139. loss += loss_node_step
  140. loss_node += loss_node_step.data
  141. else:
  142. # calc loss
  143. loss_addedge_step = F.binary_cross_entropy(p_addedge, Variable(torch.zeros((1, 1))).cuda())
  144. # loss_addedge_step.backward(retain_graph=True)
  145. loss += loss_addedge_step
  146. loss_addedge += loss_addedge_step.data
  147. break
  148. edge_count += 1
  149. node_count += 1
  150. # update deterministic and lstm
  151. loss.backward()
  152. optimizer.step()
  153. scheduler.step()
  154. loss_all = loss_addnode + loss_addedge + loss_node
  155. if epoch % args.epochs_log == 0:
  156. print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, hidden: {}'.format(
  157. epoch, args.epochs, loss_all.item(), args.graph_type, args.node_embedding_size))
  158. # loss_sum += loss.data[0]*x.size(0)
  159. # return loss_sum
  160. def train_DGMG_forward_epoch(args, model, dataset, is_fast=False):
  161. model.train()
  162. graph_num = len(dataset)
  163. order = list(range(graph_num))
  164. shuffle(order)
  165. loss_addnode = 0
  166. loss_addedge = 0
  167. loss_node = 0
  168. for i in order:
  169. model.zero_grad()
  170. graph = dataset[i]
  171. # do random ordering: relabel nodes
  172. node_order = list(range(graph.number_of_nodes()))
  173. shuffle(node_order)
  174. order_mapping = dict(zip(graph.nodes(), node_order))
  175. graph = nx.relabel_nodes(graph, order_mapping, copy=True)
  176. # NOTE: when starting loop, we assume a node has already been generated
  177. node_count = 1
  178. node_embedding = [
  179. Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden
  180. loss = 0
  181. while node_count <= graph.number_of_nodes():
  182. node_neighbor = graph.subgraph(
  183. list(range(node_count))).adjacency_list() # list of lists (first node is zero)
  184. node_neighbor_new = graph.subgraph(list(range(node_count + 1))).adjacency_list()[
  185. -1] # list of new node's neighbors
  186. # 1 message passing
  187. # do 2 times message passing
  188. node_embedding = message_passing(node_neighbor, node_embedding, model)
  189. # 2 graph embedding and new node embedding
  190. node_embedding_cat = torch.cat(node_embedding, dim=0)
  191. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  192. init_embedding = calc_init_embedding(node_embedding_cat, model)
  193. # 3 f_addnode
  194. p_addnode = model.f_an(graph_embedding)
  195. if node_count < graph.number_of_nodes():
  196. # add node
  197. node_neighbor.append([])
  198. node_embedding.append(init_embedding)
  199. if is_fast:
  200. node_embedding_cat = torch.cat(node_embedding, dim=0)
  201. # calc loss
  202. loss_addnode_step = F.binary_cross_entropy(p_addnode, Variable(torch.ones((1, 1))).cuda())
  203. # loss_addnode_step.backward(retain_graph=True)
  204. loss += loss_addnode_step
  205. loss_addnode += loss_addnode_step.data
  206. else:
  207. # calc loss
  208. loss_addnode_step = F.binary_cross_entropy(p_addnode, Variable(torch.zeros((1, 1))).cuda())
  209. # loss_addnode_step.backward(retain_graph=True)
  210. loss += loss_addnode_step
  211. loss_addnode += loss_addnode_step.data
  212. break
  213. edge_count = 0
  214. while edge_count <= len(node_neighbor_new):
  215. if not is_fast:
  216. node_embedding = message_passing(node_neighbor, node_embedding, model)
  217. node_embedding_cat = torch.cat(node_embedding, dim=0)
  218. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  219. # 4 f_addedge
  220. p_addedge = model.f_ae(graph_embedding)
  221. if edge_count < len(node_neighbor_new):
  222. # calc loss
  223. loss_addedge_step = F.binary_cross_entropy(p_addedge, Variable(torch.ones((1, 1))).cuda())
  224. # loss_addedge_step.backward(retain_graph=True)
  225. loss += loss_addedge_step
  226. loss_addedge += loss_addedge_step.data
  227. # 5 f_nodes
  228. # excluding the last node (which is the new node)
  229. node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1,
  230. node_embedding_cat.size(1))
  231. s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1))
  232. p_node = F.softmax(s_node.permute(1, 0))
  233. # get ground truth
  234. a_node = torch.zeros((1, p_node.size(1)))
  235. # print('node_neighbor_new',node_neighbor_new, edge_count)
  236. a_node[0, node_neighbor_new[edge_count]] = 1
  237. a_node = Variable(a_node).cuda()
  238. # add edge
  239. node_neighbor[-1].append(node_neighbor_new[edge_count])
  240. node_neighbor[node_neighbor_new[edge_count]].append(len(node_neighbor) - 1)
  241. # calc loss
  242. loss_node_step = F.binary_cross_entropy(p_node, a_node)
  243. # loss_node_step.backward(retain_graph=True)
  244. loss += loss_node_step
  245. loss_node += loss_node_step.data * p_node.size(1)
  246. else:
  247. # calc loss
  248. loss_addedge_step = F.binary_cross_entropy(p_addedge, Variable(torch.zeros((1, 1))).cuda())
  249. # loss_addedge_step.backward(retain_graph=True)
  250. loss += loss_addedge_step
  251. loss_addedge += loss_addedge_step.data
  252. break
  253. edge_count += 1
  254. node_count += 1
  255. loss_all = loss_addnode + loss_addedge + loss_node
  256. # if epoch % args.epochs_log==0:
  257. # print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, hidden: {}'.format(
  258. # epoch, args.epochs,loss_all[0], args.graph_type, args.node_embedding_size))
  259. return loss_all[0] / len(dataset)
  260. def test_DGMG_epoch(args, model, is_fast=False):
  261. model.eval()
  262. graph_num = args.test_graph_num
  263. graphs_generated = []
  264. for i in range(graph_num):
  265. # NOTE: when starting loop, we assume a node has already been generated
  266. node_neighbor = [[]] # list of lists (first node is zero)
  267. node_embedding = [
  268. Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden
  269. node_count = 1
  270. while node_count <= args.max_num_node:
  271. # 1 message passing
  272. # do 2 times message passing
  273. node_embedding = message_passing(node_neighbor, node_embedding, model)
  274. # 2 graph embedding and new node embedding
  275. node_embedding_cat = torch.cat(node_embedding, dim=0)
  276. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  277. init_embedding = calc_init_embedding(node_embedding_cat, model)
  278. # 3 f_addnode
  279. p_addnode = model.f_an(graph_embedding)
  280. a_addnode = sample_tensor(p_addnode)
  281. # print(a_addnode.data[0][0])
  282. if a_addnode.data[0][0] == 1:
  283. # print('add node')
  284. # add node
  285. node_neighbor.append([])
  286. node_embedding.append(init_embedding)
  287. if is_fast:
  288. node_embedding_cat = torch.cat(node_embedding, dim=0)
  289. else:
  290. break
  291. edge_count = 0
  292. while edge_count < args.max_num_node:
  293. if not is_fast:
  294. node_embedding = message_passing(node_neighbor, node_embedding, model)
  295. node_embedding_cat = torch.cat(node_embedding, dim=0)
  296. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  297. # 4 f_addedge
  298. p_addedge = model.f_ae(graph_embedding)
  299. a_addedge = sample_tensor(p_addedge)
  300. # print(a_addedge.data[0][0])
  301. if a_addedge.data[0][0] == 1:
  302. # print('add edge')
  303. # 5 f_nodes
  304. # excluding the last node (which is the new node)
  305. node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1,
  306. node_embedding_cat.size(1))
  307. s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1))
  308. p_node = F.softmax(s_node.permute(1, 0))
  309. a_node = gumbel_softmax(p_node, temperature=0.01)
  310. _, a_node_id = a_node.topk(1)
  311. a_node_id = int(a_node_id.data[0][0])
  312. # add edge
  313. node_neighbor[-1].append(a_node_id)
  314. node_neighbor[a_node_id].append(len(node_neighbor) - 1)
  315. else:
  316. break
  317. edge_count += 1
  318. node_count += 1
  319. # save graph
  320. node_neighbor_dict = dict(zip(list(range(len(node_neighbor))), node_neighbor))
  321. graph = nx.from_dict_of_lists(node_neighbor_dict)
  322. graphs_generated.append(graph)
  323. return graphs_generated
  324. ########### train function for LSTM + VAE
  325. def train_DGMG(args, dataset_train, model):
  326. # check if load existing model
  327. if args.load:
  328. fname = args.model_save_path + args.fname + 'model_' + str(args.load_epoch) + '.dat'
  329. model.load_state_dict(torch.load(fname))
  330. args.lr = 0.00001
  331. epoch = args.load_epoch
  332. print('model loaded!, lr: {}'.format(args.lr))
  333. else:
  334. epoch = 1
  335. # initialize optimizer
  336. optimizer = optim.Adam(list(model.parameters()), lr=args.lr)
  337. scheduler = MultiStepLR(optimizer, milestones=args.milestones, gamma=args.lr_rate)
  338. # start main loop
  339. time_all = np.zeros(args.epochs)
  340. while epoch <= args.epochs:
  341. time_start = tm.time()
  342. # train
  343. train_DGMG_epoch(epoch, args, model, dataset_train, optimizer, scheduler, is_fast=args.is_fast)
  344. time_end = tm.time()
  345. time_all[epoch - 1] = time_end - time_start
  346. # print('time used',time_all[epoch - 1])
  347. # test
  348. if epoch % args.epochs_test == 0 and epoch >= args.epochs_test_start:
  349. graphs = test_DGMG_epoch(args, model, is_fast=args.is_fast)
  350. fname = args.graph_save_path + args.fname_pred + str(epoch) + '.dat'
  351. save_graph_list(graphs, fname)
  352. # print('test done, graphs saved')
  353. # save model checkpoint
  354. if args.save:
  355. if epoch % args.epochs_save == 0:
  356. fname = args.model_save_path + args.fname + 'model_' + str(epoch) + '.dat'
  357. torch.save(model.state_dict(), fname)
  358. epoch += 1
  359. np.save(args.timing_save_path + args.fname, time_all)
  360. ########### train function for LSTM + VAE
  361. def train_DGMG_nll(args, dataset_train, dataset_test, model, max_iter=1000):
  362. # check if load existing model
  363. fname = args.model_save_path + args.fname + 'model_' + str(args.load_epoch) + '.dat'
  364. model.load_state_dict(torch.load(fname))
  365. fname_output = args.nll_save_path + args.note + '_' + args.graph_type + '.csv'
  366. with open(fname_output, 'w+') as f:
  367. f.write('train,test\n')
  368. # start main loop
  369. for iter in range(max_iter):
  370. nll_train = train_DGMG_forward_epoch(args, model, dataset_train, is_fast=args.is_fast)
  371. nll_test = train_DGMG_forward_epoch(args, model, dataset_test, is_fast=args.is_fast)
  372. print('train', nll_train, 'test', nll_test)
  373. f.write(str(nll_train) + ',' + str(nll_test) + '\n')
  374. def test_DGMG_2(args, model, test_graph, is_fast=False):
  375. model.eval()
  376. # graphs_generated = []
  377. # for i in range(graph_num):
  378. # NOTE: when starting loop, we assume a node has already been generated
  379. node_neighbor = [[]] # list of lists (first node is zero)
  380. try:
  381. node_embedding = [
  382. Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden
  383. node_max = len(test_graph.nodes())
  384. node_count = 1
  385. while node_count <= node_max:
  386. # 1 message passing
  387. # do 2 times message passing
  388. node_embedding = message_passing(node_neighbor, node_embedding, model)
  389. # 2 graph embedding and new node embedding
  390. node_embedding_cat = torch.cat(node_embedding, dim=0)
  391. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  392. init_embedding = calc_init_embedding(node_embedding_cat, model)
  393. # 3 f_addnode
  394. p_addnode = model.f_an(graph_embedding)
  395. a_addnode = sample_tensor(p_addnode)
  396. if a_addnode.data[0][0] == 1:
  397. # add node
  398. node_neighbor.append([])
  399. node_embedding.append(init_embedding)
  400. if is_fast:
  401. node_embedding_cat = torch.cat(node_embedding, dim=0)
  402. else:
  403. break
  404. edge_count = 0
  405. while edge_count < args.max_num_node:
  406. if not is_fast:
  407. node_embedding = message_passing(node_neighbor, node_embedding, model)
  408. node_embedding_cat = torch.cat(node_embedding, dim=0)
  409. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  410. # 4 f_addedge
  411. p_addedge = model.f_ae(graph_embedding)
  412. a_addedge = sample_tensor(p_addedge)
  413. if a_addedge.data[0][0] == 1:
  414. # 5 f_nodes
  415. # excluding the last node (which is the new node)
  416. node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1,
  417. node_embedding_cat.size(1))
  418. s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1))
  419. p_node = F.softmax(s_node.permute(1, 0))
  420. a_node = gumbel_softmax(p_node, temperature=0.01)
  421. _, a_node_id = a_node.topk(1)
  422. a_node_id = int(a_node_id.data[0][0])
  423. # add edge
  424. node_neighbor[-1].append(a_node_id)
  425. node_neighbor[a_node_id].append(len(node_neighbor) - 1)
  426. else:
  427. break
  428. edge_count += 1
  429. node_count += 1
  430. # clear node_neighbor and build it again
  431. node_neighbor = []
  432. for n in range(node_max):
  433. temp_neighbor = [k for k in test_graph.edge[n]]
  434. node_neighbor.append(temp_neighbor)
  435. # now add the last node for real
  436. # 1 message passing
  437. # do 2 times message passing
  438. node_embedding = message_passing(node_neighbor, node_embedding, model)
  439. # 2 graph embedding and new node embedding
  440. node_embedding_cat = torch.cat(node_embedding, dim=0)
  441. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  442. init_embedding = calc_init_embedding(node_embedding_cat, model)
  443. # 3 f_addnode
  444. p_addnode = model.f_an(graph_embedding)
  445. a_addnode = sample_tensor(p_addnode)
  446. if a_addnode.data[0][0] == 1:
  447. # add node
  448. node_neighbor.append([])
  449. node_embedding.append(init_embedding)
  450. if is_fast:
  451. node_embedding_cat = torch.cat(node_embedding, dim=0)
  452. edge_count = 0
  453. while edge_count < args.max_num_node:
  454. if not is_fast:
  455. node_embedding = message_passing(node_neighbor, node_embedding, model)
  456. node_embedding_cat = torch.cat(node_embedding, dim=0)
  457. graph_embedding = calc_graph_embedding(node_embedding_cat, model)
  458. # 4 f_addedge
  459. p_addedge = model.f_ae(graph_embedding)
  460. a_addedge = sample_tensor(p_addedge)
  461. if a_addedge.data[0][0] == 1:
  462. # 5 f_nodes
  463. # excluding the last node (which is the new node)
  464. node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1,
  465. node_embedding_cat.size(1))
  466. s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1))
  467. p_node = F.softmax(s_node.permute(1, 0))
  468. a_node = gumbel_softmax(p_node, temperature=0.01)
  469. _, a_node_id = a_node.topk(1)
  470. a_node_id = int(a_node_id.data[0][0])
  471. # add edge
  472. node_neighbor[-1].append(a_node_id)
  473. node_neighbor[a_node_id].append(len(node_neighbor) - 1)
  474. else:
  475. break
  476. edge_count += 1
  477. node_count += 1
  478. except:
  479. print('error')
  480. # save graph
  481. node_neighbor_dict = dict(zip(list(range(len(node_neighbor))), node_neighbor))
  482. graph = nx.from_dict_of_lists(node_neighbor_dict)
  483. return graph
  484. def neigh_to_mat(neigh, size):
  485. ret_list = np.zeros(size)
  486. for i in neigh:
  487. ret_list[i] = 1
  488. return ret_list
  489. def calc_lable_result(test_graphs, returned_graphs):
  490. labels = []
  491. results = []
  492. i = 0
  493. for test_graph in test_graphs:
  494. n = len(test_graph.nodes())
  495. returned_graph = returned_graphs[i]
  496. label = neigh_to_mat([k for k in test_graph.edge[n - 1]], n)
  497. try:
  498. result = neigh_to_mat([k for k in returned_graph.edge[n - 1]], n)
  499. except:
  500. result = np.zeros(n)
  501. labels.append(label)
  502. results.append(result)
  503. i += 1
  504. return labels, results
  505. def evaluate(labels, results):
  506. mae_list = []
  507. roc_score_list = []
  508. ap_score_list = []
  509. precision_list = []
  510. recall_list = []
  511. iter = 0
  512. for result in results:
  513. label = labels[iter]
  514. iter += 1
  515. part1 = label[result == 1]
  516. part2 = part1[part1 == 1]
  517. part3 = part1[part1 == 0]
  518. part4 = label[result == 0]
  519. part5 = part4[part4 == 1]
  520. tp = len(part2)
  521. fp = len(part3)
  522. fn = part5.sum()
  523. if tp + fp > 0:
  524. precision = tp / (tp + fp)
  525. else:
  526. precision = 0
  527. recall = tp / (tp + fn)
  528. precision_list.append(precision)
  529. recall_list.append(recall)
  530. positive = result[label == 1]
  531. if len(positive) <= len(list(result[label == 0])):
  532. negative = random.sample(list(result[label == 0]), len(positive))
  533. else:
  534. negative = result[label == 0]
  535. positive = random.sample(list(result[label == 1]), len(negative))
  536. preds_all = np.hstack([positive, negative])
  537. labels_all = np.hstack([np.ones(len(positive)), np.zeros(len(positive))])
  538. if len(labels_all) > 0:
  539. roc_score = roc_auc_score(labels_all, preds_all)
  540. ap_score = average_precision_score(labels_all, preds_all)
  541. roc_score_list.append(roc_score)
  542. ap_score_list.append(ap_score)
  543. mae = 0
  544. for x in range(len(result)):
  545. if result[x] != label[x]:
  546. mae += 1
  547. mae = mae / len(label)
  548. mae_list.append(mae)
  549. mean_roc = mean(roc_score_list)
  550. mean_ap = mean(ap_score_list)
  551. mean_precision = mean(precision_list)
  552. mean_recall = mean(recall_list)
  553. mean_mae = mean(mae_list)
  554. print('roc_score ' + str(mean_roc))
  555. print('ap_score ' + str(mean_ap))
  556. print('precision ' + str(mean_precision))
  557. print('recall ' + str(mean_recall))
  558. print('mae ' + str(mean_mae))
  559. return mean_roc, mean_ap, mean_precision, mean_recall
  560. if __name__ == '__main__':
  561. args = Args_DGMG()
  562. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
  563. print('CUDA', args.cuda)
  564. print('File name prefix', args.fname)
  565. graphs = []
  566. for i in range(4, 10):
  567. graphs.append(nx.ladder_graph(i))
  568. model = DGM_graphs(h_size=args.node_embedding_size).cuda()
  569. if args.graph_type == 'ladder_small':
  570. graphs = []
  571. for i in range(2, 11):
  572. graphs.append(nx.ladder_graph(i))
  573. args.max_prev_node = 10
  574. if args.graph_type == 'caveman_small':
  575. graphs = []
  576. for i in range(2, 3):
  577. for j in range(6, 11):
  578. for k in range(20):
  579. graphs.append(caveman_special(i, j, p_edge=0.8))
  580. args.max_prev_node = 20
  581. if args.graph_type == 'grid_small':
  582. graphs = []
  583. for i in range(2, 4):
  584. for j in range(2, 4):
  585. graphs.append(nx.grid_2d_graph(i, j))
  586. args.max_prev_node = 15
  587. if args.graph_type == 'barabasi_small':
  588. graphs = []
  589. for i in range(4, 21):
  590. for j in range(3, 4):
  591. for k in range(10):
  592. graphs.append(nx.barabasi_albert_graph(i, j))
  593. args.max_prev_node = 20
  594. if args.graph_type == 'enzymes_small':
  595. graphs_raw = Graph_load_batch(min_num_nodes=1, name='ENZYMES')
  596. graphs = []
  597. for G in graphs_raw:
  598. if G.number_of_nodes() <= 20:
  599. graphs.append(G)
  600. args.max_prev_node = 10
  601. else:
  602. if args.graph_type == 'protein':
  603. graphs_raw = Graph_load_batch(min_num_nodes=1, name='PROTEINS_full')
  604. graphs = []
  605. for G in graphs_raw:
  606. if G.number_of_nodes() <= 15:
  607. graphs.append(G)
  608. args.max_prev_node = 10
  609. else:
  610. if args.graph_type == 'citeseer_small':
  611. _, _, G = Graph_load(dataset='citeseer')
  612. G = max(nx.connected_component_subgraphs(G), key=len)
  613. G = nx.convert_node_labels_to_integers(G)
  614. graphs = []
  615. for i in range(G.number_of_nodes()):
  616. G_ego = nx.ego_graph(G, i, radius=1)
  617. if (G_ego.number_of_nodes() >= 4) and (G_ego.number_of_nodes() <= 20):
  618. graphs.append(G_ego)
  619. shuffle(graphs)
  620. graphs = graphs[0:200]
  621. args.max_prev_node = 15
  622. else:
  623. graphs, num_classes = load_data(args.graph_type, True)
  624. small_graphs = []
  625. for i in range(len(graphs)):
  626. if graphs[i].number_of_nodes() < 16:
  627. small_graphs.append(graphs[i])
  628. graphs = small_graphs
  629. args.max_prev_node = 21
  630. # remove self loops
  631. for graph in graphs:
  632. edges_with_selfloops = graph.selfloop_edges()
  633. if len(edges_with_selfloops) > 0:
  634. graph.remove_edges_from(edges_with_selfloops)
  635. # split datasets
  636. random.seed(123)
  637. shuffle(graphs)
  638. graphs_len = len(graphs)
  639. graph_statistics(graphs)
  640. graphs_test = graphs[int(0.8 * graphs_len):]
  641. graphs_train = graphs[0:int(0.8 * graphs_len)]
  642. args.max_num_node = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
  643. # show graphs statistics
  644. print('total graph num: {}, training set: {}'.format(len(graphs), len(graphs_train)))
  645. print('max number node: {}'.format(args.max_num_node))
  646. print('max previous node: {}'.format(args.max_prev_node))
  647. ### train
  648. train_DGMG(args, graphs_train, model)
  649. # fname = args.model_save_path + args.fname + 'model_' + str(9) + '.dat'
  650. # model.load_state_dict(torch.load(fname))
  651. #
  652. # all_tests = list()
  653. # all_ret_test = list()
  654. # iter_count = 0
  655. # for test_graph in graphs_test:
  656. # test_graph = nx.convert_node_labels_to_integers(test_graph)
  657. # original_graph = test_graph.copy()
  658. # test_graph.remove_node(test_graph.nodes()[len(test_graph.nodes()) - 1])
  659. # ret_test = test_DGMG_2(args, model, test_graph)
  660. # all_tests.append(original_graph)
  661. # all_ret_test.append(ret_test)
  662. # iter_count += 1
  663. # print(iter_count)
  664. # labels, results = calc_lable_result(all_tests, all_ret_test)
  665. # evaluate(labels, results)