123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- import numpy as np
- import torch
- import random
- import networkx as nx
- from statistics import mean
-
- from model import sample_sigmoid
- from sklearn.metrics import mean_absolute_error, roc_auc_score, average_precision_score
- from baselines.graphvae.graphvae_model import graph_show
-
-
- def getitem(graph, max_num_nodes):
- adj = nx.to_numpy_matrix(graph)
- num_nodes = adj.shape[0]
- adj_padded = np.zeros((max_num_nodes, max_num_nodes))
- adj_padded[:num_nodes, :num_nodes] = adj
-
- adj_decoded = np.zeros(max_num_nodes * (max_num_nodes + 1) // 2)
- node_idx = 0
-
- adj_vectorized = adj_padded[np.triu(np.ones((max_num_nodes, max_num_nodes))) == 1]
- return {'adj': adj_padded,
- 'adj_decoded': adj_vectorized,
- 'num_nodes': num_nodes}
-
-
- def evaluate(dataset, model, graph_show_mode=False):
- # print("******************** TEST *****************************")
- model.eval()
- mae_list = []
- tp_div_pos_list = []
- roc_score_list = []
- ap_score_list = []
- for batch_idx, data in enumerate(dataset):
- input_features = data['features'].float()
- input_features = torch.tensor(input_features, device='cuda:0')
- adj_input = data['adj'].float()
- batch_num_nodes = data['num_nodes'].int().numpy()
- numpy_adj = adj_input.cpu().numpy()
- # print("**** numpy_adj")
- # print(numpy_adj)
- # print("**** input_features")
- # print(input_features)
- batch_size = numpy_adj.shape[0]
- if model.vae_args.GRAN:
- gran_lable = np.zeros((batch_size, model.max_num_nodes - 1))
- random_index_list = []
- for i in range(batch_size):
- num_nodes = batch_num_nodes[i]
- if model.vae_args.GRAN_arbitrary_node:
- random_index_list.append(np.random.randint(num_nodes))
- gran_lable[i, :random_index_list[i]] = \
- numpy_adj[i, :random_index_list[i], random_index_list[i]]
- gran_lable[i, random_index_list[i]:num_nodes - 1] = \
- numpy_adj[i, random_index_list[i] + 1:num_nodes, random_index_list[i]]
-
- numpy_adj[i, :num_nodes, random_index_list[i]] = 1
- numpy_adj[i, random_index_list[i], :num_nodes] = 1
- # print("*** random_index_list")
- # print(random_index_list)
- # print("*** gran_lable")
- # print(gran_lable)
- # print("*** numpy_adj after fully connection")
- # print(numpy_adj)
- else:
- gran_lable[i, :num_nodes - 1] = numpy_adj[i, :num_nodes - 1, num_nodes - 1]
- numpy_adj[i, :num_nodes, num_nodes - 1] = 1
- numpy_adj[i, num_nodes - 1, :num_nodes] = 1
- gran_lable = torch.tensor(gran_lable).float()
- incomplete_adj = torch.tensor(numpy_adj, device='cuda:0')
- # print("*** gran_lable")
- # print(gran_lable)
- # print("*** numpy_adj after change")
- # print(numpy_adj)
- x = model.conv1(input_features, incomplete_adj)
- x = model.act(x)
- x = model.bn1(x)
- x = model.conv2(x, incomplete_adj)
- # x = model.bn2(x)
- # x = model.act(x)
- # print("*** embedding : ")
- # print(x)
- preds = model.gran(x, batch_num_nodes, random_index_list, model.vae_args.GRAN_arbitrary_node).squeeze()
- # print("*** preds")
- # print(preds)
- sample = sample_sigmoid(preds.unsqueeze(0), sample=False, thresh=0.5, sample_time=10)
- sample = sample.squeeze()
- # ind = 1
- # print("*** sample")
- # print(sample)
- # print("*** gran_lable")
- # print(gran_lable)
- if graph_show_mode:
- index = 0
- colors = [(0.7509279299037631, 0.021203049355839054, 0.24561203044115132)]
- incomplete = numpy_adj.copy()
- print("******** 1) ")
- print(incomplete[index])
- incomplete[index, :batch_num_nodes[index], batch_num_nodes[index] - 1] = 0
- incomplete[index, batch_num_nodes[index] - 1, :batch_num_nodes[index]] = 0
- print("******** 2) ")
- print(incomplete[index])
- graph_show(nx.from_numpy_matrix(incomplete[index]), "incomplete graph", colors)
- incomplete[index, :batch_num_nodes[index] - 1, batch_num_nodes[index] - 1] = \
- sample.cpu()[index, :batch_num_nodes[index] - 1]
- incomplete[index, batch_num_nodes[index] - 1, :batch_num_nodes[index] - 1] = \
- sample.cpu()[index, :batch_num_nodes[index] - 1]
- graph_show(nx.from_numpy_matrix(incomplete[index]), "complete graph", colors)
- return
- mae = mean_absolute_error(gran_lable.numpy(), sample.cpu())
- mae_list.append(mae)
- tp_div_pos = 0
- label = gran_lable.numpy()
- result = sample.cpu().numpy()
- for i in range(gran_lable.size()[0]):
- part1 = label[i][result[i] == 1]
- part2 = part1[part1 == 1]
- tp_div_pos += part2.sum() / label[i].sum()
- tp_div_pos /= gran_lable.size()[0]
- tp_div_pos_list.append(tp_div_pos)
- # #######################################
- result = preds.cpu().detach().numpy()
- roc_score = 0
- ap_score = 0
- for i in range(gran_lable.size()[0]):
- positive = result[i][label[i] == 1]
- # print("****************")
- # print(len(positive))
- # print(len(list(result[i][label[i] == 0])))
- # print(positive)
- if len(positive) <= len(list(result[i][label[i] == 0])):
- negative = random.sample(list(result[i][label[i] == 0]), len(positive))
- else:
- negative = result[i][label[i] == 0]
- positive = random.sample(list(result[i][label[i] == 1]), len(negative))
- preds_all = np.hstack([positive, negative])
- labels_all = np.hstack([np.ones(len(positive)), np.zeros(len(positive))])
- if len(labels_all) > 0:
- roc_score += roc_auc_score(labels_all, preds_all)
- ap_score += average_precision_score(labels_all, preds_all)
- roc_score /= batch_size
- ap_score /= batch_size
- roc_score_list.append(roc_score)
- ap_score_list.append(ap_score)
- # #######################################
- print("*** MAE - tp_div_pos - roc_score - ap_score : " + str(mean(mae_list)) + " _ "
- + str(mean(tp_div_pos_list)) + " _ " + str(mean(roc_score_list)) + " _ "
- + str(mean(ap_score_list)) + " _ ")
- # print(mae)
- # print("*** tp_div_pos : ")
- # print(tp_div_pos)
- # print("*** roc_score : ")
- # print(roc_score)
- # print("*** ap_score : ")
- # print(ap_score)
- return mae, tp_div_pos, roc_score, ap_score
|