You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluate.py 6.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import numpy as np
  2. import torch
  3. import random
  4. import networkx as nx
  5. from statistics import mean
  6. from model import sample_sigmoid
  7. from sklearn.metrics import mean_absolute_error, roc_auc_score, average_precision_score
  8. from baselines.graphvae.graphvae_model import graph_show
  9. def getitem(graph, max_num_nodes):
  10. adj = nx.to_numpy_matrix(graph)
  11. num_nodes = adj.shape[0]
  12. adj_padded = np.zeros((max_num_nodes, max_num_nodes))
  13. adj_padded[:num_nodes, :num_nodes] = adj
  14. adj_decoded = np.zeros(max_num_nodes * (max_num_nodes + 1) // 2)
  15. node_idx = 0
  16. adj_vectorized = adj_padded[np.triu(np.ones((max_num_nodes, max_num_nodes))) == 1]
  17. return {'adj': adj_padded,
  18. 'adj_decoded': adj_vectorized,
  19. 'num_nodes': num_nodes}
  20. def evaluate(dataset, model, graph_show_mode=False):
  21. # print("******************** TEST *****************************")
  22. model.eval()
  23. mae_list = []
  24. tp_div_pos_list = []
  25. roc_score_list = []
  26. ap_score_list = []
  27. for batch_idx, data in enumerate(dataset):
  28. input_features = data['features'].float()
  29. input_features = torch.tensor(input_features, device='cuda:0')
  30. adj_input = data['adj'].float()
  31. batch_num_nodes = data['num_nodes'].int().numpy()
  32. numpy_adj = adj_input.cpu().numpy()
  33. # print("**** numpy_adj")
  34. # print(numpy_adj)
  35. # print("**** input_features")
  36. # print(input_features)
  37. batch_size = numpy_adj.shape[0]
  38. if model.vae_args.GRAN:
  39. gran_lable = np.zeros((batch_size, model.max_num_nodes - 1))
  40. random_index_list = []
  41. for i in range(batch_size):
  42. num_nodes = batch_num_nodes[i]
  43. if model.vae_args.GRAN_arbitrary_node:
  44. random_index_list.append(np.random.randint(num_nodes))
  45. gran_lable[i, :random_index_list[i]] = \
  46. numpy_adj[i, :random_index_list[i], random_index_list[i]]
  47. gran_lable[i, random_index_list[i]:num_nodes - 1] = \
  48. numpy_adj[i, random_index_list[i] + 1:num_nodes, random_index_list[i]]
  49. numpy_adj[i, :num_nodes, random_index_list[i]] = 1
  50. numpy_adj[i, random_index_list[i], :num_nodes] = 1
  51. # print("*** random_index_list")
  52. # print(random_index_list)
  53. # print("*** gran_lable")
  54. # print(gran_lable)
  55. # print("*** numpy_adj after fully connection")
  56. # print(numpy_adj)
  57. else:
  58. gran_lable[i, :num_nodes - 1] = numpy_adj[i, :num_nodes - 1, num_nodes - 1]
  59. numpy_adj[i, :num_nodes, num_nodes - 1] = 1
  60. numpy_adj[i, num_nodes - 1, :num_nodes] = 1
  61. gran_lable = torch.tensor(gran_lable).float()
  62. incomplete_adj = torch.tensor(numpy_adj, device='cuda:0')
  63. # print("*** gran_lable")
  64. # print(gran_lable)
  65. # print("*** numpy_adj after change")
  66. # print(numpy_adj)
  67. x = model.conv1(input_features, incomplete_adj)
  68. x = model.act(x)
  69. x = model.bn1(x)
  70. x = model.conv2(x, incomplete_adj)
  71. # x = model.bn2(x)
  72. # x = model.act(x)
  73. # print("*** embedding : ")
  74. # print(x)
  75. preds = model.gran(x, batch_num_nodes, random_index_list, model.vae_args.GRAN_arbitrary_node).squeeze()
  76. # print("*** preds")
  77. # print(preds)
  78. sample = sample_sigmoid(preds.unsqueeze(0), sample=False, thresh=0.5, sample_time=10)
  79. sample = sample.squeeze()
  80. # ind = 1
  81. # print("*** sample")
  82. # print(sample)
  83. # print("*** gran_lable")
  84. # print(gran_lable)
  85. if graph_show_mode:
  86. index = 0
  87. colors = [(0.7509279299037631, 0.021203049355839054, 0.24561203044115132)]
  88. incomplete = numpy_adj.copy()
  89. print("******** 1) ")
  90. print(incomplete[index])
  91. incomplete[index, :batch_num_nodes[index], batch_num_nodes[index] - 1] = 0
  92. incomplete[index, batch_num_nodes[index] - 1, :batch_num_nodes[index]] = 0
  93. print("******** 2) ")
  94. print(incomplete[index])
  95. graph_show(nx.from_numpy_matrix(incomplete[index]), "incomplete graph", colors)
  96. incomplete[index, :batch_num_nodes[index] - 1, batch_num_nodes[index] - 1] = \
  97. sample.cpu()[index, :batch_num_nodes[index] - 1]
  98. incomplete[index, batch_num_nodes[index] - 1, :batch_num_nodes[index] - 1] = \
  99. sample.cpu()[index, :batch_num_nodes[index] - 1]
  100. graph_show(nx.from_numpy_matrix(incomplete[index]), "complete graph", colors)
  101. return
  102. mae = mean_absolute_error(gran_lable.numpy(), sample.cpu())
  103. mae_list.append(mae)
  104. tp_div_pos = 0
  105. label = gran_lable.numpy()
  106. result = sample.cpu().numpy()
  107. for i in range(gran_lable.size()[0]):
  108. part1 = label[i][result[i] == 1]
  109. part2 = part1[part1 == 1]
  110. tp_div_pos += part2.sum() / label[i].sum()
  111. tp_div_pos /= gran_lable.size()[0]
  112. tp_div_pos_list.append(tp_div_pos)
  113. # #######################################
  114. result = preds.cpu().detach().numpy()
  115. roc_score = 0
  116. ap_score = 0
  117. for i in range(gran_lable.size()[0]):
  118. positive = result[i][label[i] == 1]
  119. # print("****************")
  120. # print(len(positive))
  121. # print(len(list(result[i][label[i] == 0])))
  122. # print(positive)
  123. if len(positive) <= len(list(result[i][label[i] == 0])):
  124. negative = random.sample(list(result[i][label[i] == 0]), len(positive))
  125. else:
  126. negative = result[i][label[i] == 0]
  127. positive = random.sample(list(result[i][label[i] == 1]), len(negative))
  128. preds_all = np.hstack([positive, negative])
  129. labels_all = np.hstack([np.ones(len(positive)), np.zeros(len(positive))])
  130. if len(labels_all) > 0:
  131. roc_score += roc_auc_score(labels_all, preds_all)
  132. ap_score += average_precision_score(labels_all, preds_all)
  133. roc_score /= batch_size
  134. ap_score /= batch_size
  135. roc_score_list.append(roc_score)
  136. ap_score_list.append(ap_score)
  137. # #######################################
  138. print("*** MAE - tp_div_pos - roc_score - ap_score : " + str(mean(mae_list)) + " _ "
  139. + str(mean(tp_div_pos_list)) + " _ " + str(mean(roc_score_list)) + " _ "
  140. + str(mean(ap_score_list)) + " _ ")
  141. # print(mae)
  142. # print("*** tp_div_pos : ")
  143. # print(tp_div_pos)
  144. # print("*** roc_score : ")
  145. # print(roc_score)
  146. # print("*** ap_score : ")
  147. # print(ap_score)
  148. return mae, tp_div_pos, roc_score, ap_score