You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluate2.py 4.9KB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import numpy as np
  2. import torch
  3. import random
  4. import networkx as nx
  5. from statistics import mean
  6. from model import sample_sigmoid
  7. from sklearn.metrics import mean_absolute_error, roc_auc_score, average_precision_score
  8. def getitem(graph, max_num_nodes):
  9. adj = nx.to_numpy_matrix(graph)
  10. num_nodes = adj.shape[0]
  11. adj_padded = np.zeros((max_num_nodes, max_num_nodes))
  12. adj_padded[:num_nodes, :num_nodes] = adj
  13. return {'adj': adj_padded,
  14. 'num_nodes': num_nodes,
  15. 'features': np.identity(max_num_nodes)}
  16. def evaluate(test_graph_list, model, graph_show_mode=False):
  17. model.eval()
  18. mae_list = []
  19. # tp_div_pos_list = []
  20. roc_score_list = []
  21. ap_score_list = []
  22. precision_list = []
  23. recall_list = []
  24. F_Measure_list = []
  25. for graph in test_graph_list:
  26. data = getitem(graph, model.max_num_nodes)
  27. input_features = data['features']
  28. input_features = torch.tensor(input_features, device='cuda:0').unsqueeze(0).float()
  29. adj_input = data['adj']
  30. num_nodes = data['num_nodes']
  31. numpy_adj = adj_input
  32. # print("**** numpy_adj input")
  33. # print(numpy_adj)
  34. if model.vae_args.GRAN:
  35. gran_lable = np.zeros((model.max_num_nodes - 1))
  36. if model.vae_args.GRAN_arbitrary_node:
  37. random_index = np.random.randint(num_nodes)
  38. gran_lable[:random_index] = \
  39. numpy_adj[:random_index, random_index]
  40. gran_lable[random_index:num_nodes - 1] = \
  41. numpy_adj[random_index + 1:num_nodes, random_index]
  42. numpy_adj[:num_nodes, random_index] = 1
  43. numpy_adj[random_index, :num_nodes] = 1
  44. else:
  45. random_index = num_nodes - 1
  46. gran_lable[:num_nodes - 1] = numpy_adj[:num_nodes - 1, num_nodes - 1]
  47. numpy_adj[:num_nodes, num_nodes - 1] = 1
  48. numpy_adj[num_nodes - 1, :num_nodes] = 1
  49. gran_lable = torch.tensor(gran_lable).float()
  50. incomplete_adj = torch.tensor(numpy_adj, device='cuda:0').unsqueeze(0).float()
  51. # print("*** random_index")
  52. # print(random_index)
  53. # print("*** gran_lable")
  54. # print(gran_lable)
  55. # print("*** incomplete_adj")
  56. # print(incomplete_adj)
  57. x = model.conv1(input_features, incomplete_adj)
  58. x = model.act(x)
  59. x = model.bn1(x)
  60. x = model.conv2(x, incomplete_adj)
  61. num_nodes = torch.tensor(num_nodes, device='cuda:0').unsqueeze(0)
  62. preds = model.gran(x, num_nodes, [random_index], model.vae_args.GRAN_arbitrary_node).squeeze()
  63. sample = sample_sigmoid(preds.unsqueeze(0).unsqueeze(0), sample=False, thresh=0.5, sample_time=10)
  64. sample = sample.squeeze()
  65. mae = mean_absolute_error(gran_lable.numpy(), sample.cpu())
  66. mae_list.append(mae)
  67. label = gran_lable.numpy()
  68. result = sample.cpu().numpy()
  69. # print("*** label")
  70. # print(label)
  71. # print("*** result")
  72. # print(result)
  73. # print(sum(result))
  74. # print(num_nodes)
  75. part1 = label[result == 1]
  76. part2 = part1[part1 == 1]
  77. part3 = part1[part1 == 0]
  78. part4 = label[result == 0]
  79. part5 = part4[part4 == 1]
  80. tp = len(part2)
  81. fp = len(part3)
  82. fn = part5.sum()
  83. if tp + fp > 0:
  84. precision = tp / (tp + fp)
  85. else:
  86. precision = 0
  87. recall = tp / (tp + fn)
  88. # F_Measure = 2 * precision * recall / (precision + recall)
  89. precision_list.append(precision)
  90. recall_list.append(recall)
  91. # F_Measure_list.append(F_Measure)
  92. # tp_div_pos = part2.sum() / label.sum()
  93. # tp_div_pos_list.append(tp_div_pos)
  94. result = preds.cpu().detach().numpy()
  95. positive = result[label == 1]
  96. if len(positive) <= len(list(result[label == 0])):
  97. negative = random.sample(list(result[label == 0]), len(positive))
  98. else:
  99. negative = result[label == 0]
  100. positive = random.sample(list(result[label == 1]), len(negative))
  101. preds_all = np.hstack([positive, negative])
  102. labels_all = np.hstack([np.ones(len(positive)), np.zeros(len(positive))])
  103. if len(labels_all) > 0:
  104. roc_score = roc_auc_score(labels_all, preds_all)
  105. ap_score = average_precision_score(labels_all, preds_all)
  106. roc_score_list.append(roc_score)
  107. ap_score_list.append(ap_score)
  108. # print("*********************************")
  109. # print(precision_list)
  110. # print(
  111. # "*** MAE - roc_score - ap_score - precision - recall - F_Measure : " + str(mean(mae_list)) + " _ "
  112. # + str(mean(roc_score_list)) + " _ " + str(mean(ap_score_list)) + " _ "
  113. # + str(mean(precision_list)) + " _ " + str(mean(recall_list)) + " _ "
  114. # + str(mean(F_Measure_list)) + " _ ")
  115. return mean(mae_list), mean(roc_score_list), mean(ap_score_list), mean(precision_list), mean(recall_list)