Browse Source

fix bug in evaluation

master
Ali Amiri 4 years ago
parent
commit
e0b6d40a17
1 changed files with 4 additions and 3 deletions
  1. 4
    3
      main_DeepGMG.py

+ 4
- 3
main_DeepGMG.py View File

print('max previous node: {}'.format(args.max_prev_node)) print('max previous node: {}'.format(args.max_prev_node))


### train ### train
train_DGMG(args, graphs_train, model)
# train_DGMG(args, graphs_train, model)


fname = args.model_save_path + args.fname + 'model_' + str(args.load_epoch) + '.dat' fname = args.model_save_path + args.fname + 'model_' + str(args.load_epoch) + '.dat'
model.load_state_dict(torch.load(fname)) model.load_state_dict(torch.load(fname))
all_ret_test = list() all_ret_test = list()
for test_graph in graphs_test: for test_graph in graphs_test:
test_graph = nx.convert_node_labels_to_integers(test_graph) test_graph = nx.convert_node_labels_to_integers(test_graph)
original_graph = test_graph.copy()
test_graph.remove_node(test_graph.nodes()[len(test_graph.nodes()) - 1]) test_graph.remove_node(test_graph.nodes()[len(test_graph.nodes()) - 1])
ret_test = test_DGMG_2(args, model, test_graph) ret_test = test_DGMG_2(args, model, test_graph)
all_tests.append(test_graph)
all_tests.append(original_graph)
all_ret_test.append(ret_test) all_ret_test.append(ret_test)
labels, results = calc_lable_result(test_graph, ret_test)
labels, results = calc_lable_result(original_graph, ret_test)
evaluate(labels, results) evaluate(labels, results)

Loading…
Cancel
Save