| @@ -751,7 +751,7 @@ if __name__ == '__main__': | |||
| print('max previous node: {}'.format(args.max_prev_node)) | |||
| ### train | |||
| train_DGMG(args, graphs_train, model) | |||
| # train_DGMG(args, graphs_train, model) | |||
| fname = args.model_save_path + args.fname + 'model_' + str(args.load_epoch) + '.dat' | |||
| model.load_state_dict(torch.load(fname)) | |||
| @@ -760,9 +760,10 @@ if __name__ == '__main__': | |||
| all_ret_test = list() | |||
| for test_graph in graphs_test: | |||
| test_graph = nx.convert_node_labels_to_integers(test_graph) | |||
| original_graph = test_graph.copy() | |||
| test_graph.remove_node(test_graph.nodes()[len(test_graph.nodes()) - 1]) | |||
| ret_test = test_DGMG_2(args, model, test_graph) | |||
| all_tests.append(test_graph) | |||
| all_tests.append(original_graph) | |||
| all_ret_test.append(ret_test) | |||
| labels, results = calc_lable_result(test_graph, ret_test) | |||
| labels, results = calc_lable_result(original_graph, ret_test) | |||
| evaluate(labels, results) | |||