|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 |
- import numpy as np
- import torch
- import random
- import networkx as nx
- from statistics import mean
-
- from model import sample_sigmoid
- from sklearn.metrics import mean_absolute_error, roc_auc_score, average_precision_score
-
-
- def getitem(graph, max_num_nodes):
- adj = nx.to_numpy_matrix(graph)
- num_nodes = adj.shape[0]
- adj_padded = np.zeros((max_num_nodes, max_num_nodes))
- adj_padded[:num_nodes, :num_nodes] = adj
-
- return {'adj': adj_padded,
- 'num_nodes': num_nodes,
- 'features': np.identity(max_num_nodes)}
-
-
- def evaluate(test_graph_list, model, graph_show_mode=False):
- model.eval()
- mae_list = []
- # tp_div_pos_list = []
- roc_score_list = []
- ap_score_list = []
- precision_list = []
- recall_list = []
- F_Measure_list = []
- for graph in test_graph_list:
- data = getitem(graph, model.max_num_nodes)
- input_features = data['features']
- input_features = torch.tensor(input_features, device='cuda:0').unsqueeze(0).float()
- adj_input = data['adj']
- num_nodes = data['num_nodes']
- numpy_adj = adj_input
- # print("**** numpy_adj input")
- # print(numpy_adj)
- if model.vae_args.GRAN:
- gran_lable = np.zeros((model.max_num_nodes - 1))
- if model.vae_args.GRAN_arbitrary_node:
- random_index = np.random.randint(num_nodes)
- gran_lable[:random_index] = \
- numpy_adj[:random_index, random_index]
- gran_lable[random_index:num_nodes - 1] = \
- numpy_adj[random_index + 1:num_nodes, random_index]
-
- numpy_adj[:num_nodes, random_index] = 1
- numpy_adj[random_index, :num_nodes] = 1
- else:
- random_index = num_nodes - 1
- gran_lable[:num_nodes - 1] = numpy_adj[:num_nodes - 1, num_nodes - 1]
- numpy_adj[:num_nodes, num_nodes - 1] = 1
- numpy_adj[num_nodes - 1, :num_nodes] = 1
- gran_lable = torch.tensor(gran_lable).float()
- incomplete_adj = torch.tensor(numpy_adj, device='cuda:0').unsqueeze(0).float()
- # print("*** random_index")
- # print(random_index)
- # print("*** gran_lable")
- # print(gran_lable)
- # print("*** incomplete_adj")
- # print(incomplete_adj)
- x = model.conv1(input_features, incomplete_adj)
- x = model.act(x)
- x = model.bn1(x)
- x = model.conv2(x, incomplete_adj)
- num_nodes = torch.tensor(num_nodes, device='cuda:0').unsqueeze(0)
- preds = model.gran(x, num_nodes, [random_index], model.vae_args.GRAN_arbitrary_node).squeeze()
- sample = sample_sigmoid(preds.unsqueeze(0).unsqueeze(0), sample=False, thresh=0.5, sample_time=10)
- sample = sample.squeeze()
- mae = mean_absolute_error(gran_lable.numpy(), sample.cpu())
- mae_list.append(mae)
-
- label = gran_lable.numpy()
- result = sample.cpu().numpy()
- # print("*** label")
- # print(label)
- # print("*** result")
- # print(result)
- # print(sum(result))
- # print(num_nodes)
- part1 = label[result == 1]
- part2 = part1[part1 == 1]
- part3 = part1[part1 == 0]
- part4 = label[result == 0]
- part5 = part4[part4 == 1]
- tp = len(part2)
- fp = len(part3)
- fn = part5.sum()
- if tp + fp > 0:
- precision = tp / (tp + fp)
- else:
- precision = 0
- recall = tp / (tp + fn)
- # F_Measure = 2 * precision * recall / (precision + recall)
- precision_list.append(precision)
- recall_list.append(recall)
- # F_Measure_list.append(F_Measure)
- # tp_div_pos = part2.sum() / label.sum()
- # tp_div_pos_list.append(tp_div_pos)
-
- result = preds.cpu().detach().numpy()
-
- positive = result[label == 1]
- if len(positive) <= len(list(result[label == 0])):
- negative = random.sample(list(result[label == 0]), len(positive))
- else:
- negative = result[label == 0]
- positive = random.sample(list(result[label == 1]), len(negative))
- preds_all = np.hstack([positive, negative])
- labels_all = np.hstack([np.ones(len(positive)), np.zeros(len(positive))])
- if len(labels_all) > 0:
- roc_score = roc_auc_score(labels_all, preds_all)
- ap_score = average_precision_score(labels_all, preds_all)
- roc_score_list.append(roc_score)
- ap_score_list.append(ap_score)
- # print("*********************************")
- # print(precision_list)
- # print(
- # "*** MAE - roc_score - ap_score - precision - recall - F_Measure : " + str(mean(mae_list)) + " _ "
- # + str(mean(roc_score_list)) + " _ " + str(mean(ap_score_list)) + " _ "
- # + str(mean(precision_list)) + " _ " + str(mean(recall_list)) + " _ "
- # + str(mean(F_Measure_list)) + " _ ")
- return mean(mae_list), mean(roc_score_list), mean(ap_score_list), mean(precision_list), mean(recall_list)
|