import numpy as np import torch import random import networkx as nx from statistics import mean from model import sample_sigmoid from sklearn.metrics import mean_absolute_error, roc_auc_score, average_precision_score def getitem(graph, max_num_nodes): adj = nx.to_numpy_matrix(graph) num_nodes = adj.shape[0] adj_padded = np.zeros((max_num_nodes, max_num_nodes)) adj_padded[:num_nodes, :num_nodes] = adj return {'adj': adj_padded, 'num_nodes': num_nodes, 'features': np.identity(max_num_nodes)} def evaluate(test_graph_list, model, graph_show_mode=False): model.eval() mae_list = [] # tp_div_pos_list = [] roc_score_list = [] ap_score_list = [] precision_list = [] recall_list = [] F_Measure_list = [] for graph in test_graph_list: data = getitem(graph, model.max_num_nodes) input_features = data['features'] input_features = torch.tensor(input_features, device='cuda:0').unsqueeze(0).float() adj_input = data['adj'] num_nodes = data['num_nodes'] numpy_adj = adj_input # print("**** numpy_adj input") # print(numpy_adj) if model.vae_args.GRAN: gran_lable = np.zeros((model.max_num_nodes - 1)) if model.vae_args.GRAN_arbitrary_node: random_index = np.random.randint(num_nodes) gran_lable[:random_index] = \ numpy_adj[:random_index, random_index] gran_lable[random_index:num_nodes - 1] = \ numpy_adj[random_index + 1:num_nodes, random_index] numpy_adj[:num_nodes, random_index] = 1 numpy_adj[random_index, :num_nodes] = 1 else: random_index = num_nodes - 1 gran_lable[:num_nodes - 1] = numpy_adj[:num_nodes - 1, num_nodes - 1] numpy_adj[:num_nodes, num_nodes - 1] = 1 numpy_adj[num_nodes - 1, :num_nodes] = 1 gran_lable = torch.tensor(gran_lable).float() incomplete_adj = torch.tensor(numpy_adj, device='cuda:0').unsqueeze(0).float() # print("*** random_index") # print(random_index) # print("*** gran_lable") # print(gran_lable) # print("*** incomplete_adj") # print(incomplete_adj) x = model.conv1(input_features, incomplete_adj) x = model.act(x) x = model.bn1(x) x = model.conv2(x, incomplete_adj) num_nodes = torch.tensor(num_nodes, device='cuda:0').unsqueeze(0) preds = model.gran(x, num_nodes, [random_index], model.vae_args.GRAN_arbitrary_node).squeeze() sample = sample_sigmoid(preds.unsqueeze(0).unsqueeze(0), sample=False, thresh=0.5, sample_time=10) sample = sample.squeeze() mae = mean_absolute_error(gran_lable.numpy(), sample.cpu()) mae_list.append(mae) label = gran_lable.numpy() result = sample.cpu().numpy() # print("*** label") # print(label) # print("*** result") # print(result) # print(sum(result)) # print(num_nodes) part1 = label[result == 1] part2 = part1[part1 == 1] part3 = part1[part1 == 0] part4 = label[result == 0] part5 = part4[part4 == 1] tp = len(part2) fp = len(part3) fn = part5.sum() if tp + fp > 0: precision = tp / (tp + fp) else: precision = 0 recall = tp / (tp + fn) # F_Measure = 2 * precision * recall / (precision + recall) precision_list.append(precision) recall_list.append(recall) # F_Measure_list.append(F_Measure) # tp_div_pos = part2.sum() / label.sum() # tp_div_pos_list.append(tp_div_pos) result = preds.cpu().detach().numpy() positive = result[label == 1] if len(positive) <= len(list(result[label == 0])): negative = random.sample(list(result[label == 0]), len(positive)) else: negative = result[label == 0] positive = random.sample(list(result[label == 1]), len(negative)) preds_all = np.hstack([positive, negative]) labels_all = np.hstack([np.ones(len(positive)), np.zeros(len(positive))]) if len(labels_all) > 0: roc_score = roc_auc_score(labels_all, preds_all) ap_score = average_precision_score(labels_all, preds_all) roc_score_list.append(roc_score) ap_score_list.append(ap_score) # print("*********************************") # print(precision_list) # print( # "*** MAE - roc_score - ap_score - precision - recall - F_Measure : " + str(mean(mae_list)) + " _ " # + str(mean(roc_score_list)) + " _ " + str(mean(ap_score_list)) + " _ " # + str(mean(precision_list)) + " _ " + str(mean(recall_list)) + " _ " # + str(mean(F_Measure_list)) + " _ ") return mean(mae_list), mean(roc_score_list), mean(ap_score_list), mean(precision_list), mean(recall_list)