Browse Source

fix bug in evaluation

master
Ali Amiri 4 years ago
parent
commit
e0b6d40a17
1 changed files with 4 additions and 3 deletions
  1. 4
    3
      main_DeepGMG.py

+ 4
- 3
main_DeepGMG.py View File

@@ -751,7 +751,7 @@ if __name__ == '__main__':
print('max previous node: {}'.format(args.max_prev_node))

### train
train_DGMG(args, graphs_train, model)
# train_DGMG(args, graphs_train, model)

fname = args.model_save_path + args.fname + 'model_' + str(args.load_epoch) + '.dat'
model.load_state_dict(torch.load(fname))
@@ -760,9 +760,10 @@ if __name__ == '__main__':
all_ret_test = list()
for test_graph in graphs_test:
test_graph = nx.convert_node_labels_to_integers(test_graph)
original_graph = test_graph.copy()
test_graph.remove_node(test_graph.nodes()[len(test_graph.nodes()) - 1])
ret_test = test_DGMG_2(args, model, test_graph)
all_tests.append(test_graph)
all_tests.append(original_graph)
all_ret_test.append(ret_test)
labels, results = calc_lable_result(test_graph, ret_test)
labels, results = calc_lable_result(original_graph, ret_test)
evaluate(labels, results)

Loading…
Cancel
Save