| print('max previous node: {}'.format(args.max_prev_node)) | print('max previous node: {}'.format(args.max_prev_node)) | ||||
| ### train | ### train | ||||
| train_DGMG(args, graphs_train, model) | |||||
| # train_DGMG(args, graphs_train, model) | |||||
| fname = args.model_save_path + args.fname + 'model_' + str(args.load_epoch) + '.dat' | fname = args.model_save_path + args.fname + 'model_' + str(args.load_epoch) + '.dat' | ||||
| model.load_state_dict(torch.load(fname)) | model.load_state_dict(torch.load(fname)) | ||||
| all_ret_test = list() | all_ret_test = list() | ||||
| for test_graph in graphs_test: | for test_graph in graphs_test: | ||||
| test_graph = nx.convert_node_labels_to_integers(test_graph) | test_graph = nx.convert_node_labels_to_integers(test_graph) | ||||
| original_graph = test_graph.copy() | |||||
| test_graph.remove_node(test_graph.nodes()[len(test_graph.nodes()) - 1]) | test_graph.remove_node(test_graph.nodes()[len(test_graph.nodes()) - 1]) | ||||
| ret_test = test_DGMG_2(args, model, test_graph) | ret_test = test_DGMG_2(args, model, test_graph) | ||||
| all_tests.append(test_graph) | |||||
| all_tests.append(original_graph) | |||||
| all_ret_test.append(ret_test) | all_ret_test.append(ret_test) | ||||
| labels, results = calc_lable_result(test_graph, ret_test) | |||||
| labels, results = calc_lable_result(original_graph, ret_test) | |||||
| evaluate(labels, results) | evaluate(labels, results) |