| @@ -21,9 +21,10 @@ class Args_DGMG(): | |||
| ### data config | |||
| # self.graph_type = 'caveman_small' | |||
| # self.graph_type = 'grid_small' | |||
| self.graph_type = 'IMDBMULTI' | |||
| # self.graph_type = 'ENZYMES' | |||
| # self.graph_type = 'ladder_small' | |||
| # self.graph_type = 'enzymes_small' | |||
| self.graph_type = 'enzymes_small' | |||
| # self.graph_type = 'protein' | |||
| # self.graph_type = 'barabasi_small' | |||
| # self.graph_type = 'citeseer_small' | |||
| @@ -445,81 +446,81 @@ def train_DGMG_nll(args, dataset_train, dataset_test, model, max_iter=1000): | |||
| def test_DGMG_2(args, model, test_graph, is_fast=False): | |||
| model.eval() | |||
| graph_num = args.test_graph_num | |||
| graphs_generated = [] | |||
| # graphs_generated = [] | |||
| # for i in range(graph_num): | |||
| # NOTE: when starting loop, we assume a node has already been generated | |||
| node_neighbor = [[]] # list of lists (first node is zero) | |||
| node_embedding = [ | |||
| Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden | |||
| try: | |||
| node_embedding = [ | |||
| Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden | |||
| node_max = len(test_graph.nodes()) | |||
| node_count = 1 | |||
| while node_count <= node_max: | |||
| # 1 message passing | |||
| # do 2 times message passing | |||
| node_embedding = message_passing(node_neighbor, node_embedding, model) | |||
| node_max = len(test_graph.nodes()) | |||
| node_count = 1 | |||
| while node_count <= node_max: | |||
| # 1 message passing | |||
| # do 2 times message passing | |||
| node_embedding = message_passing(node_neighbor, node_embedding, model) | |||
| # 2 graph embedding and new node embedding | |||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||
| graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||
| init_embedding = calc_init_embedding(node_embedding_cat, model) | |||
| # 2 graph embedding and new node embedding | |||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||
| graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||
| init_embedding = calc_init_embedding(node_embedding_cat, model) | |||
| # 3 f_addnode | |||
| p_addnode = model.f_an(graph_embedding) | |||
| a_addnode = sample_tensor(p_addnode) | |||
| # 3 f_addnode | |||
| p_addnode = model.f_an(graph_embedding) | |||
| a_addnode = sample_tensor(p_addnode) | |||
| if a_addnode.data[0][0] == 1: | |||
| # add node | |||
| node_neighbor.append([]) | |||
| node_embedding.append(init_embedding) | |||
| if is_fast: | |||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||
| else: | |||
| break | |||
| if a_addnode.data[0][0] == 1: | |||
| # add node | |||
| node_neighbor.append([]) | |||
| node_embedding.append(init_embedding) | |||
| if is_fast: | |||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||
| else: | |||
| break | |||
| edge_count = 0 | |||
| while edge_count < args.max_num_node: | |||
| if not is_fast: | |||
| node_embedding = message_passing(node_neighbor, node_embedding, model) | |||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||
| graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||
| edge_count = 0 | |||
| while edge_count < args.max_num_node: | |||
| if not is_fast: | |||
| node_embedding = message_passing(node_neighbor, node_embedding, model) | |||
| node_embedding_cat = torch.cat(node_embedding, dim=0) | |||
| graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||
| # 4 f_addedge | |||
| p_addedge = model.f_ae(graph_embedding) | |||
| a_addedge = sample_tensor(p_addedge) | |||
| # 4 f_addedge | |||
| p_addedge = model.f_ae(graph_embedding) | |||
| a_addedge = sample_tensor(p_addedge) | |||
| if a_addedge.data[0][0] == 1: | |||
| # 5 f_nodes | |||
| # excluding the last node (which is the new node) | |||
| node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1, | |||
| node_embedding_cat.size(1)) | |||
| s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1)) | |||
| p_node = F.softmax(s_node.permute(1, 0)) | |||
| a_node = gumbel_softmax(p_node, temperature=0.01) | |||
| _, a_node_id = a_node.topk(1) | |||
| a_node_id = int(a_node_id.data[0][0]) | |||
| # add edge | |||
| if a_addedge.data[0][0] == 1: | |||
| # 5 f_nodes | |||
| # excluding the last node (which is the new node) | |||
| node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1, | |||
| node_embedding_cat.size(1)) | |||
| s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1)) | |||
| p_node = F.softmax(s_node.permute(1, 0)) | |||
| a_node = gumbel_softmax(p_node, temperature=0.01) | |||
| _, a_node_id = a_node.topk(1) | |||
| a_node_id = int(a_node_id.data[0][0]) | |||
| # add edge | |||
| node_neighbor[-1].append(a_node_id) | |||
| node_neighbor[a_node_id].append(len(node_neighbor) - 1) | |||
| else: | |||
| break | |||
| node_neighbor[-1].append(a_node_id) | |||
| node_neighbor[a_node_id].append(len(node_neighbor) - 1) | |||
| else: | |||
| break | |||
| edge_count += 1 | |||
| node_count += 1 | |||
| edge_count += 1 | |||
| node_count += 1 | |||
| # clear node_neighbor and build it again | |||
| node_neighbor = [] | |||
| for n in range(node_max): | |||
| temp_neighbor = [k for k in test_graph.edge[n]] | |||
| node_neighbor.append(temp_neighbor) | |||
| # clear node_neighbor and build it again | |||
| node_neighbor = [] | |||
| for n in range(node_max): | |||
| temp_neighbor = [k for k in test_graph.edge[n]] | |||
| node_neighbor.append(temp_neighbor) | |||
| # now add the last node for real | |||
| # 1 message passing | |||
| # do 2 times message passing | |||
| # now add the last node for real | |||
| # 1 message passing | |||
| # do 2 times message passing | |||
| try: | |||
| node_embedding = message_passing(node_neighbor, node_embedding, model) | |||
| # 2 graph embedding and new node embedding | |||
| @@ -703,33 +704,43 @@ if __name__ == '__main__': | |||
| args.max_prev_node = 20 | |||
| if args.graph_type == 'enzymes_small': | |||
| graphs_raw = Graph_load_batch(min_num_nodes=10, name='ENZYMES') | |||
| graphs_raw = Graph_load_batch(min_num_nodes=1, name='ENZYMES') | |||
| graphs = [] | |||
| for G in graphs_raw: | |||
| if G.number_of_nodes() <= 20: | |||
| graphs.append(G) | |||
| args.max_prev_node = 15 | |||
| args.max_prev_node = 10 | |||
| if args.graph_type == 'citeseer_small': | |||
| _, _, G = Graph_load(dataset='citeseer') | |||
| G = max(nx.connected_component_subgraphs(G), key=len) | |||
| G = nx.convert_node_labels_to_integers(G) | |||
| graphs = [] | |||
| for i in range(G.number_of_nodes()): | |||
| G_ego = nx.ego_graph(G, i, radius=1) | |||
| if (G_ego.number_of_nodes() >= 4) and (G_ego.number_of_nodes() <= 20): | |||
| graphs.append(G_ego) | |||
| shuffle(graphs) | |||
| graphs = graphs[0:200] | |||
| args.max_prev_node = 15 | |||
| else: | |||
| graphs, num_classes = load_data(args.graph_type, True) | |||
| small_graphs = [] | |||
| for i in range(len(graphs)): | |||
| if graphs[i].number_of_nodes() < 13: | |||
| small_graphs.append(graphs[i]) | |||
| graphs = small_graphs | |||
| args.max_prev_node = 12 | |||
| if args.graph_type == 'protein': | |||
| graphs_raw = Graph_load_batch(min_num_nodes=1, name='PROTEINS_full') | |||
| graphs = [] | |||
| for G in graphs_raw: | |||
| if G.number_of_nodes() <= 15: | |||
| graphs.append(G) | |||
| args.max_prev_node = 10 | |||
| else: | |||
| if args.graph_type == 'citeseer_small': | |||
| _, _, G = Graph_load(dataset='citeseer') | |||
| G = max(nx.connected_component_subgraphs(G), key=len) | |||
| G = nx.convert_node_labels_to_integers(G) | |||
| graphs = [] | |||
| for i in range(G.number_of_nodes()): | |||
| G_ego = nx.ego_graph(G, i, radius=1) | |||
| if (G_ego.number_of_nodes() >= 4) and (G_ego.number_of_nodes() <= 20): | |||
| graphs.append(G_ego) | |||
| shuffle(graphs) | |||
| graphs = graphs[0:200] | |||
| args.max_prev_node = 15 | |||
| else: | |||
| graphs, num_classes = load_data(args.graph_type, True) | |||
| small_graphs = [] | |||
| for i in range(len(graphs)): | |||
| if graphs[i].number_of_nodes() < 16: | |||
| small_graphs.append(graphs[i]) | |||
| graphs = small_graphs | |||
| args.max_prev_node = 21 | |||
| # remove self loops | |||
| for graph in graphs: | |||
| @@ -741,6 +752,9 @@ if __name__ == '__main__': | |||
| random.seed(123) | |||
| shuffle(graphs) | |||
| graphs_len = len(graphs) | |||
| graph_statistics(graphs) | |||
| graphs_test = graphs[int(0.8 * graphs_len):] | |||
| graphs_train = graphs[0:int(0.8 * graphs_len)] | |||
| @@ -751,19 +765,22 @@ if __name__ == '__main__': | |||
| print('max previous node: {}'.format(args.max_prev_node)) | |||
| ### train | |||
| # train_DGMG(args, graphs_train, model) | |||
| fname = args.model_save_path + args.fname + 'model_' + str(args.load_epoch) + '.dat' | |||
| model.load_state_dict(torch.load(fname)) | |||
| all_tests = list() | |||
| all_ret_test = list() | |||
| for test_graph in graphs_test: | |||
| test_graph = nx.convert_node_labels_to_integers(test_graph) | |||
| original_graph = test_graph.copy() | |||
| test_graph.remove_node(test_graph.nodes()[len(test_graph.nodes()) - 1]) | |||
| ret_test = test_DGMG_2(args, model, test_graph) | |||
| all_tests.append(original_graph) | |||
| all_ret_test.append(ret_test) | |||
| labels, results = calc_lable_result(original_graph, ret_test) | |||
| evaluate(labels, results) | |||
| train_DGMG(args, graphs_train, model) | |||
| # fname = args.model_save_path + args.fname + 'model_' + str(9) + '.dat' | |||
| # model.load_state_dict(torch.load(fname)) | |||
| # | |||
| # all_tests = list() | |||
| # all_ret_test = list() | |||
| # iter_count = 0 | |||
| # for test_graph in graphs_test: | |||
| # test_graph = nx.convert_node_labels_to_integers(test_graph) | |||
| # original_graph = test_graph.copy() | |||
| # test_graph.remove_node(test_graph.nodes()[len(test_graph.nodes()) - 1]) | |||
| # ret_test = test_DGMG_2(args, model, test_graph) | |||
| # all_tests.append(original_graph) | |||
| # all_ret_test.append(ret_test) | |||
| # iter_count += 1 | |||
| # print(iter_count) | |||
| # labels, results = calc_lable_result(all_tests, all_ret_test) | |||
| # evaluate(labels, results) | |||