| # an implementation for "Learning Deep Generative Models of Graphs" | # an implementation for "Learning Deep Generative Models of Graphs" | ||||
| import os | import os | ||||
| import random | |||||
| from statistics import mean | |||||
| import networkx as nx | |||||
| import numpy as np | |||||
| from sklearn.metrics import roc_auc_score, average_precision_score | |||||
| from main import * | from main import * | ||||
| graph_num = args.test_graph_num | graph_num = args.test_graph_num | ||||
| graphs_generated = [] | graphs_generated = [] | ||||
| for i in range(graph_num): | |||||
| # NOTE: when starting loop, we assume a node has already been generated | |||||
| node_neighbor = [[]] # list of lists (first node is zero) | |||||
| node_embedding = [ | |||||
| Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden | |||||
| node_max = len(test_graph.nodes()) | |||||
| node_count = 1 | |||||
| while node_count <= node_max: | |||||
| # 1 message passing | |||||
| # do 2 times message passing | |||||
| node_embedding = message_passing(node_neighbor, node_embedding, model) | |||||
| # 2 graph embedding and new node embedding | |||||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
| graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||||
| init_embedding = calc_init_embedding(node_embedding_cat, model) | |||||
| # 3 f_addnode | |||||
| p_addnode = model.f_an(graph_embedding) | |||||
| a_addnode = sample_tensor(p_addnode) | |||||
| if a_addnode.data[0][0] == 1: | |||||
| # add node | |||||
| node_neighbor.append([]) | |||||
| node_embedding.append(init_embedding) | |||||
| if is_fast: | |||||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
| else: | |||||
| break | |||||
| edge_count = 0 | |||||
| while edge_count < args.max_num_node: | |||||
| if not is_fast: | |||||
| node_embedding = message_passing(node_neighbor, node_embedding, model) | |||||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
| graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||||
| # 4 f_addedge | |||||
| p_addedge = model.f_ae(graph_embedding) | |||||
| a_addedge = sample_tensor(p_addedge) | |||||
| if a_addedge.data[0][0] == 1: | |||||
| # 5 f_nodes | |||||
| # excluding the last node (which is the new node) | |||||
| node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1, | |||||
| node_embedding_cat.size(1)) | |||||
| s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1)) | |||||
| p_node = F.softmax(s_node.permute(1, 0)) | |||||
| a_node = gumbel_softmax(p_node, temperature=0.01) | |||||
| _, a_node_id = a_node.topk(1) | |||||
| a_node_id = int(a_node_id.data[0][0]) | |||||
| # add edge | |||||
| node_neighbor[-1].append(a_node_id) | |||||
| node_neighbor[a_node_id].append(len(node_neighbor) - 1) | |||||
| else: | |||||
| break | |||||
| edge_count += 1 | |||||
| node_count += 1 | |||||
| # clear node_neighbor and build it again | |||||
| node_neighbor = [] | |||||
| for n in range(node_max): | |||||
| temp_neighbor = [k for k in test_graph.edge[n]] | |||||
| node_neighbor.append(temp_neighbor) | |||||
| # now add the last node for real | |||||
| # for i in range(graph_num): | |||||
| # NOTE: when starting loop, we assume a node has already been generated | |||||
| node_neighbor = [[]] # list of lists (first node is zero) | |||||
| node_embedding = [ | |||||
| Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden | |||||
| node_max = len(test_graph.nodes()) | |||||
| node_count = 1 | |||||
| while node_count <= node_max: | |||||
| # 1 message passing | # 1 message passing | ||||
| # do 2 times message passing | # do 2 times message passing | ||||
| try: | |||||
| node_embedding = message_passing(node_neighbor, node_embedding, model) | |||||
| # 2 graph embedding and new node embedding | |||||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
| graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||||
| init_embedding = calc_init_embedding(node_embedding_cat, model) | |||||
| # 3 f_addnode | |||||
| p_addnode = model.f_an(graph_embedding) | |||||
| a_addnode = sample_tensor(p_addnode) | |||||
| if a_addnode.data[0][0] == 1: | |||||
| # add node | |||||
| node_neighbor.append([]) | |||||
| node_embedding.append(init_embedding) | |||||
| if is_fast: | |||||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
| edge_count = 0 | |||||
| while edge_count < args.max_num_node: | |||||
| if not is_fast: | |||||
| node_embedding = message_passing(node_neighbor, node_embedding, model) | |||||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
| graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||||
| # 4 f_addedge | |||||
| p_addedge = model.f_ae(graph_embedding) | |||||
| a_addedge = sample_tensor(p_addedge) | |||||
| if a_addedge.data[0][0] == 1: | |||||
| # 5 f_nodes | |||||
| # excluding the last node (which is the new node) | |||||
| node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1, | |||||
| node_embedding_cat.size(1)) | |||||
| s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1)) | |||||
| p_node = F.softmax(s_node.permute(1, 0)) | |||||
| a_node = gumbel_softmax(p_node, temperature=0.01) | |||||
| _, a_node_id = a_node.topk(1) | |||||
| a_node_id = int(a_node_id.data[0][0]) | |||||
| # add edge | |||||
| node_embedding = message_passing(node_neighbor, node_embedding, model) | |||||
| # 2 graph embedding and new node embedding | |||||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
| graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||||
| init_embedding = calc_init_embedding(node_embedding_cat, model) | |||||
| # 3 f_addnode | |||||
| p_addnode = model.f_an(graph_embedding) | |||||
| a_addnode = sample_tensor(p_addnode) | |||||
| if a_addnode.data[0][0] == 1: | |||||
| # add node | |||||
| node_neighbor.append([]) | |||||
| node_embedding.append(init_embedding) | |||||
| if is_fast: | |||||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
| else: | |||||
| break | |||||
| edge_count = 0 | |||||
| while edge_count < args.max_num_node: | |||||
| if not is_fast: | |||||
| node_embedding = message_passing(node_neighbor, node_embedding, model) | |||||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
| graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||||
| # 4 f_addedge | |||||
| p_addedge = model.f_ae(graph_embedding) | |||||
| a_addedge = sample_tensor(p_addedge) | |||||
| if a_addedge.data[0][0] == 1: | |||||
| # 5 f_nodes | |||||
| # excluding the last node (which is the new node) | |||||
| node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1, | |||||
| node_embedding_cat.size(1)) | |||||
| s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1)) | |||||
| p_node = F.softmax(s_node.permute(1, 0)) | |||||
| a_node = gumbel_softmax(p_node, temperature=0.01) | |||||
| _, a_node_id = a_node.topk(1) | |||||
| a_node_id = int(a_node_id.data[0][0]) | |||||
| # add edge | |||||
| node_neighbor[-1].append(a_node_id) | |||||
| node_neighbor[a_node_id].append(len(node_neighbor) - 1) | |||||
| else: | |||||
| break | |||||
| node_neighbor[-1].append(a_node_id) | |||||
| node_neighbor[a_node_id].append(len(node_neighbor) - 1) | |||||
| else: | |||||
| break | |||||
| edge_count += 1 | |||||
| node_count += 1 | |||||
| # clear node_neighbor and build it again | |||||
| node_neighbor = [] | |||||
| for n in range(node_max): | |||||
| temp_neighbor = [k for k in test_graph.edge[n]] | |||||
| node_neighbor.append(temp_neighbor) | |||||
| # now add the last node for real | |||||
| # 1 message passing | |||||
| # do 2 times message passing | |||||
| try: | |||||
| node_embedding = message_passing(node_neighbor, node_embedding, model) | |||||
| # 2 graph embedding and new node embedding | |||||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
| graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||||
| init_embedding = calc_init_embedding(node_embedding_cat, model) | |||||
| # 3 f_addnode | |||||
| p_addnode = model.f_an(graph_embedding) | |||||
| a_addnode = sample_tensor(p_addnode) | |||||
| if a_addnode.data[0][0] == 1: | |||||
| # add node | |||||
| node_neighbor.append([]) | |||||
| node_embedding.append(init_embedding) | |||||
| if is_fast: | |||||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
| edge_count = 0 | |||||
| while edge_count < args.max_num_node: | |||||
| if not is_fast: | |||||
| node_embedding = message_passing(node_neighbor, node_embedding, model) | |||||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
| graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||||
| # 4 f_addedge | |||||
| p_addedge = model.f_ae(graph_embedding) | |||||
| a_addedge = sample_tensor(p_addedge) | |||||
| if a_addedge.data[0][0] == 1: | |||||
| # 5 f_nodes | |||||
| # excluding the last node (which is the new node) | |||||
| node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1, | |||||
| node_embedding_cat.size(1)) | |||||
| s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1)) | |||||
| p_node = F.softmax(s_node.permute(1, 0)) | |||||
| a_node = gumbel_softmax(p_node, temperature=0.01) | |||||
| _, a_node_id = a_node.topk(1) | |||||
| a_node_id = int(a_node_id.data[0][0]) | |||||
| # add edge | |||||
| node_neighbor[-1].append(a_node_id) | |||||
| node_neighbor[a_node_id].append(len(node_neighbor) - 1) | |||||
| else: | |||||
| break | |||||
| edge_count += 1 | |||||
| node_count += 1 | |||||
| except: | |||||
| print('error') | |||||
| # save graph | |||||
| node_neighbor_dict = dict(zip(list(range(len(node_neighbor))), node_neighbor)) | |||||
| graph = nx.from_dict_of_lists(node_neighbor_dict) | |||||
| graphs_generated.append(graph) | |||||
| edge_count += 1 | |||||
| node_count += 1 | |||||
| except: | |||||
| print('error') | |||||
| # save graph | |||||
| node_neighbor_dict = dict(zip(list(range(len(node_neighbor))), node_neighbor)) | |||||
| graph = nx.from_dict_of_lists(node_neighbor_dict) | |||||
| graphs_generated.append(graph) | |||||
| return graphs_generated | return graphs_generated | ||||
| np.save(args.timing_save_path + args.fname, time_all) | np.save(args.timing_save_path + args.fname, time_all) | ||||
| def neigh_to_mat(neigh, size): | |||||
| ret_list = np.zeros(size) | |||||
| for i in neigh: | |||||
| ret_list[i] = 1 | |||||
| return ret_list | |||||
| def calc_lable_result(test_graphs, returned_graphs): | |||||
| labels = [] | |||||
| results = [] | |||||
| i = 0 | |||||
| for test_graph in test_graphs: | |||||
| n = len(test_graph.nodes()) | |||||
| returned_graph = returned_graphs[i] | |||||
| label = neigh_to_mat([k for k in test_graph.edge[n - 1]], n) | |||||
| try: | |||||
| result = neigh_to_mat([k for k in returned_graph.edge[n - 1]], n) | |||||
| except: | |||||
| result = np.zeros(n) | |||||
| labels.append(label) | |||||
| results.append(result) | |||||
| i += 1 | |||||
| return labels, results | |||||
| def evaluate(labels, results): | |||||
| mae_list = [] | |||||
| roc_score_list = [] | |||||
| ap_score_list = [] | |||||
| precision_list = [] | |||||
| recall_list = [] | |||||
| iter = 0 | |||||
| for result in results: | |||||
| label = labels[iter] | |||||
| iter += 1 | |||||
| part1 = label[result == 1] | |||||
| part2 = part1[part1 == 1] | |||||
| part3 = part1[part1 == 0] | |||||
| part4 = label[result == 0] | |||||
| part5 = part4[part4 == 1] | |||||
| tp = len(part2) | |||||
| fp = len(part3) | |||||
| fn = part5.sum() | |||||
| if tp + fp > 0: | |||||
| precision = tp / (tp + fp) | |||||
| else: | |||||
| precision = 0 | |||||
| recall = tp / (tp + fn) | |||||
| precision_list.append(precision) | |||||
| recall_list.append(recall) | |||||
| positive = result[label == 1] | |||||
| if len(positive) <= len(list(result[label == 0])): | |||||
| negative = random.sample(list(result[label == 0]), len(positive)) | |||||
| else: | |||||
| negative = result[label == 0] | |||||
| positive = random.sample(list(result[label == 1]), len(negative)) | |||||
| preds_all = np.hstack([positive, negative]) | |||||
| labels_all = np.hstack([np.ones(len(positive)), np.zeros(len(positive))]) | |||||
| if len(labels_all) > 0: | |||||
| roc_score = roc_auc_score(labels_all, preds_all) | |||||
| ap_score = average_precision_score(labels_all, preds_all) | |||||
| roc_score_list.append(roc_score) | |||||
| ap_score_list.append(ap_score) | |||||
| mae = 0 | |||||
| for x in range(len(result)): | |||||
| if result[x] != label[x]: | |||||
| mae += 1 | |||||
| mae = mae / len(label) | |||||
| mae_list.append(mae) | |||||
| mean_roc = mean(roc_score_list) | |||||
| mean_ap = mean(ap_score_list) | |||||
| mean_precision = mean(precision_list) | |||||
| mean_recall = mean(recall_list) | |||||
| mean_mae = mean(mae_list) | |||||
| print('roc_score ' + str(mean_roc)) | |||||
| print('ap_score ' + str(mean_ap)) | |||||
| print('precision ' + str(mean_precision)) | |||||
| print('recall ' + str(mean_recall)) | |||||
| print('mae ' + str(mean_mae)) | |||||
| return mean_roc, mean_ap, mean_precision, mean_recall | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| args = Args_DGMG() | args = Args_DGMG() | ||||
| os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda) | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda) | ||||
| print('CUDA', args.cuda) | print('CUDA', args.cuda) | ||||
| for i in range(2, 3): | for i in range(2, 3): | ||||
| for j in range(2, 4): | for j in range(2, 4): | ||||
| graphs.append(nx.grid_2d_graph(i, j)) | graphs.append(nx.grid_2d_graph(i, j)) | ||||
| args.max_prev_node = 6 | |||||
| args.max_prev_node = 5 | |||||
| # remove self loops | # remove self loops | ||||
| for graph in graphs: | for graph in graphs: | ||||
| test_graph = nx.convert_node_labels_to_integers(test_graph) | test_graph = nx.convert_node_labels_to_integers(test_graph) | ||||
| test_DGMG_2(args, model, test_graph) | test_DGMG_2(args, model, test_graph) | ||||
| labels, results = calc_lable_result(test_graphs, eval_graphs) | |||||
| evaluate(labels, results) |