| ### data config | ### data config | ||||
| # self.graph_type = 'caveman_small' | # self.graph_type = 'caveman_small' | ||||
| # self.graph_type = 'grid_small' | # self.graph_type = 'grid_small' | ||||
| self.graph_type = 'IMDBMULTI' | |||||
| # self.graph_type = 'ENZYMES' | |||||
| # self.graph_type = 'ladder_small' | # self.graph_type = 'ladder_small' | ||||
| # self.graph_type = 'enzymes_small' | |||||
| self.graph_type = 'enzymes_small' | |||||
| # self.graph_type = 'protein' | |||||
| # self.graph_type = 'barabasi_small' | # self.graph_type = 'barabasi_small' | ||||
| # self.graph_type = 'citeseer_small' | # self.graph_type = 'citeseer_small' | ||||
| def test_DGMG_2(args, model, test_graph, is_fast=False): | def test_DGMG_2(args, model, test_graph, is_fast=False): | ||||
| model.eval() | model.eval() | ||||
| graph_num = args.test_graph_num | |||||
| graphs_generated = [] | |||||
| # graphs_generated = [] | |||||
| # for i in range(graph_num): | # for i in range(graph_num): | ||||
| # NOTE: when starting loop, we assume a node has already been generated | # NOTE: when starting loop, we assume a node has already been generated | ||||
| node_neighbor = [[]] # list of lists (first node is zero) | node_neighbor = [[]] # list of lists (first node is zero) | ||||
| node_embedding = [ | |||||
| Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden | |||||
| try: | |||||
| node_embedding = [ | |||||
| Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden | |||||
| node_max = len(test_graph.nodes()) | |||||
| node_count = 1 | |||||
| while node_count <= node_max: | |||||
| # 1 message passing | |||||
| # do 2 times message passing | |||||
| node_embedding = message_passing(node_neighbor, node_embedding, model) | |||||
| node_max = len(test_graph.nodes()) | |||||
| node_count = 1 | |||||
| while node_count <= node_max: | |||||
| # 1 message passing | |||||
| # do 2 times message passing | |||||
| node_embedding = message_passing(node_neighbor, node_embedding, model) | |||||
| # 2 graph embedding and new node embedding | |||||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
| graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||||
| init_embedding = calc_init_embedding(node_embedding_cat, model) | |||||
| # 2 graph embedding and new node embedding | |||||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
| graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||||
| init_embedding = calc_init_embedding(node_embedding_cat, model) | |||||
| # 3 f_addnode | |||||
| p_addnode = model.f_an(graph_embedding) | |||||
| a_addnode = sample_tensor(p_addnode) | |||||
| # 3 f_addnode | |||||
| p_addnode = model.f_an(graph_embedding) | |||||
| a_addnode = sample_tensor(p_addnode) | |||||
| if a_addnode.data[0][0] == 1: | |||||
| # add node | |||||
| node_neighbor.append([]) | |||||
| node_embedding.append(init_embedding) | |||||
| if is_fast: | |||||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
| else: | |||||
| break | |||||
| if a_addnode.data[0][0] == 1: | |||||
| # add node | |||||
| node_neighbor.append([]) | |||||
| node_embedding.append(init_embedding) | |||||
| if is_fast: | |||||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
| else: | |||||
| break | |||||
| edge_count = 0 | |||||
| while edge_count < args.max_num_node: | |||||
| if not is_fast: | |||||
| node_embedding = message_passing(node_neighbor, node_embedding, model) | |||||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
| graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||||
| edge_count = 0 | |||||
| while edge_count < args.max_num_node: | |||||
| if not is_fast: | |||||
| node_embedding = message_passing(node_neighbor, node_embedding, model) | |||||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
| graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||||
| # 4 f_addedge | |||||
| p_addedge = model.f_ae(graph_embedding) | |||||
| a_addedge = sample_tensor(p_addedge) | |||||
| # 4 f_addedge | |||||
| p_addedge = model.f_ae(graph_embedding) | |||||
| a_addedge = sample_tensor(p_addedge) | |||||
| if a_addedge.data[0][0] == 1: | |||||
| # 5 f_nodes | |||||
| # excluding the last node (which is the new node) | |||||
| node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1, | |||||
| node_embedding_cat.size(1)) | |||||
| s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1)) | |||||
| p_node = F.softmax(s_node.permute(1, 0)) | |||||
| a_node = gumbel_softmax(p_node, temperature=0.01) | |||||
| _, a_node_id = a_node.topk(1) | |||||
| a_node_id = int(a_node_id.data[0][0]) | |||||
| # add edge | |||||
| if a_addedge.data[0][0] == 1: | |||||
| # 5 f_nodes | |||||
| # excluding the last node (which is the new node) | |||||
| node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1, | |||||
| node_embedding_cat.size(1)) | |||||
| s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1)) | |||||
| p_node = F.softmax(s_node.permute(1, 0)) | |||||
| a_node = gumbel_softmax(p_node, temperature=0.01) | |||||
| _, a_node_id = a_node.topk(1) | |||||
| a_node_id = int(a_node_id.data[0][0]) | |||||
| # add edge | |||||
| node_neighbor[-1].append(a_node_id) | |||||
| node_neighbor[a_node_id].append(len(node_neighbor) - 1) | |||||
| else: | |||||
| break | |||||
| node_neighbor[-1].append(a_node_id) | |||||
| node_neighbor[a_node_id].append(len(node_neighbor) - 1) | |||||
| else: | |||||
| break | |||||
| edge_count += 1 | |||||
| node_count += 1 | |||||
| edge_count += 1 | |||||
| node_count += 1 | |||||
| # clear node_neighbor and build it again | |||||
| node_neighbor = [] | |||||
| for n in range(node_max): | |||||
| temp_neighbor = [k for k in test_graph.edge[n]] | |||||
| node_neighbor.append(temp_neighbor) | |||||
| # clear node_neighbor and build it again | |||||
| node_neighbor = [] | |||||
| for n in range(node_max): | |||||
| temp_neighbor = [k for k in test_graph.edge[n]] | |||||
| node_neighbor.append(temp_neighbor) | |||||
| # now add the last node for real | |||||
| # 1 message passing | |||||
| # do 2 times message passing | |||||
| # now add the last node for real | |||||
| # 1 message passing | |||||
| # do 2 times message passing | |||||
| try: | |||||
| node_embedding = message_passing(node_neighbor, node_embedding, model) | node_embedding = message_passing(node_neighbor, node_embedding, model) | ||||
| # 2 graph embedding and new node embedding | # 2 graph embedding and new node embedding | ||||
| args.max_prev_node = 20 | args.max_prev_node = 20 | ||||
| if args.graph_type == 'enzymes_small': | if args.graph_type == 'enzymes_small': | ||||
| graphs_raw = Graph_load_batch(min_num_nodes=10, name='ENZYMES') | |||||
| graphs_raw = Graph_load_batch(min_num_nodes=1, name='ENZYMES') | |||||
| graphs = [] | graphs = [] | ||||
| for G in graphs_raw: | for G in graphs_raw: | ||||
| if G.number_of_nodes() <= 20: | if G.number_of_nodes() <= 20: | ||||
| graphs.append(G) | graphs.append(G) | ||||
| args.max_prev_node = 15 | |||||
| args.max_prev_node = 10 | |||||
| if args.graph_type == 'citeseer_small': | |||||
| _, _, G = Graph_load(dataset='citeseer') | |||||
| G = max(nx.connected_component_subgraphs(G), key=len) | |||||
| G = nx.convert_node_labels_to_integers(G) | |||||
| graphs = [] | |||||
| for i in range(G.number_of_nodes()): | |||||
| G_ego = nx.ego_graph(G, i, radius=1) | |||||
| if (G_ego.number_of_nodes() >= 4) and (G_ego.number_of_nodes() <= 20): | |||||
| graphs.append(G_ego) | |||||
| shuffle(graphs) | |||||
| graphs = graphs[0:200] | |||||
| args.max_prev_node = 15 | |||||
| else: | else: | ||||
| graphs, num_classes = load_data(args.graph_type, True) | |||||
| small_graphs = [] | |||||
| for i in range(len(graphs)): | |||||
| if graphs[i].number_of_nodes() < 13: | |||||
| small_graphs.append(graphs[i]) | |||||
| graphs = small_graphs | |||||
| args.max_prev_node = 12 | |||||
| if args.graph_type == 'protein': | |||||
| graphs_raw = Graph_load_batch(min_num_nodes=1, name='PROTEINS_full') | |||||
| graphs = [] | |||||
| for G in graphs_raw: | |||||
| if G.number_of_nodes() <= 15: | |||||
| graphs.append(G) | |||||
| args.max_prev_node = 10 | |||||
| else: | |||||
| if args.graph_type == 'citeseer_small': | |||||
| _, _, G = Graph_load(dataset='citeseer') | |||||
| G = max(nx.connected_component_subgraphs(G), key=len) | |||||
| G = nx.convert_node_labels_to_integers(G) | |||||
| graphs = [] | |||||
| for i in range(G.number_of_nodes()): | |||||
| G_ego = nx.ego_graph(G, i, radius=1) | |||||
| if (G_ego.number_of_nodes() >= 4) and (G_ego.number_of_nodes() <= 20): | |||||
| graphs.append(G_ego) | |||||
| shuffle(graphs) | |||||
| graphs = graphs[0:200] | |||||
| args.max_prev_node = 15 | |||||
| else: | |||||
| graphs, num_classes = load_data(args.graph_type, True) | |||||
| small_graphs = [] | |||||
| for i in range(len(graphs)): | |||||
| if graphs[i].number_of_nodes() < 16: | |||||
| small_graphs.append(graphs[i]) | |||||
| graphs = small_graphs | |||||
| args.max_prev_node = 21 | |||||
| # remove self loops | # remove self loops | ||||
| for graph in graphs: | for graph in graphs: | ||||
| random.seed(123) | random.seed(123) | ||||
| shuffle(graphs) | shuffle(graphs) | ||||
| graphs_len = len(graphs) | graphs_len = len(graphs) | ||||
| graph_statistics(graphs) | |||||
| graphs_test = graphs[int(0.8 * graphs_len):] | graphs_test = graphs[int(0.8 * graphs_len):] | ||||
| graphs_train = graphs[0:int(0.8 * graphs_len)] | graphs_train = graphs[0:int(0.8 * graphs_len)] | ||||
| print('max previous node: {}'.format(args.max_prev_node)) | print('max previous node: {}'.format(args.max_prev_node)) | ||||
| ### train | ### train | ||||
| # train_DGMG(args, graphs_train, model) | |||||
| fname = args.model_save_path + args.fname + 'model_' + str(args.load_epoch) + '.dat' | |||||
| model.load_state_dict(torch.load(fname)) | |||||
| all_tests = list() | |||||
| all_ret_test = list() | |||||
| for test_graph in graphs_test: | |||||
| test_graph = nx.convert_node_labels_to_integers(test_graph) | |||||
| original_graph = test_graph.copy() | |||||
| test_graph.remove_node(test_graph.nodes()[len(test_graph.nodes()) - 1]) | |||||
| ret_test = test_DGMG_2(args, model, test_graph) | |||||
| all_tests.append(original_graph) | |||||
| all_ret_test.append(ret_test) | |||||
| labels, results = calc_lable_result(original_graph, ret_test) | |||||
| evaluate(labels, results) | |||||
| train_DGMG(args, graphs_train, model) | |||||
| # fname = args.model_save_path + args.fname + 'model_' + str(9) + '.dat' | |||||
| # model.load_state_dict(torch.load(fname)) | |||||
| # | |||||
| # all_tests = list() | |||||
| # all_ret_test = list() | |||||
| # iter_count = 0 | |||||
| # for test_graph in graphs_test: | |||||
| # test_graph = nx.convert_node_labels_to_integers(test_graph) | |||||
| # original_graph = test_graph.copy() | |||||
| # test_graph.remove_node(test_graph.nodes()[len(test_graph.nodes()) - 1]) | |||||
| # ret_test = test_DGMG_2(args, model, test_graph) | |||||
| # all_tests.append(original_graph) | |||||
| # all_ret_test.append(ret_test) | |||||
| # iter_count += 1 | |||||
| # print(iter_count) | |||||
| # labels, results = calc_lable_result(all_tests, all_ret_test) | |||||
| # evaluate(labels, results) |