|
|
|
|
|
|
|
|
print('max previous node: {}'.format(args.max_prev_node)) |
|
|
print('max previous node: {}'.format(args.max_prev_node)) |
|
|
|
|
|
|
|
|
### train |
|
|
### train |
|
|
train_DGMG(args, graphs_train, model) |
|
|
|
|
|
|
|
|
# train_DGMG(args, graphs_train, model) |
|
|
|
|
|
|
|
|
fname = args.model_save_path + args.fname + 'model_' + str(args.load_epoch) + '.dat' |
|
|
fname = args.model_save_path + args.fname + 'model_' + str(args.load_epoch) + '.dat' |
|
|
model.load_state_dict(torch.load(fname)) |
|
|
model.load_state_dict(torch.load(fname)) |
|
|
|
|
|
|
|
|
all_ret_test = list() |
|
|
all_ret_test = list() |
|
|
for test_graph in graphs_test: |
|
|
for test_graph in graphs_test: |
|
|
test_graph = nx.convert_node_labels_to_integers(test_graph) |
|
|
test_graph = nx.convert_node_labels_to_integers(test_graph) |
|
|
|
|
|
original_graph = test_graph.copy() |
|
|
test_graph.remove_node(test_graph.nodes()[len(test_graph.nodes()) - 1]) |
|
|
test_graph.remove_node(test_graph.nodes()[len(test_graph.nodes()) - 1]) |
|
|
ret_test = test_DGMG_2(args, model, test_graph) |
|
|
ret_test = test_DGMG_2(args, model, test_graph) |
|
|
all_tests.append(test_graph) |
|
|
|
|
|
|
|
|
all_tests.append(original_graph) |
|
|
all_ret_test.append(ret_test) |
|
|
all_ret_test.append(ret_test) |
|
|
labels, results = calc_lable_result(test_graph, ret_test) |
|
|
|
|
|
|
|
|
labels, results = calc_lable_result(original_graph, ret_test) |
|
|
evaluate(labels, results) |
|
|
evaluate(labels, results) |