| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 | import numpy as np
import torch
import random
import networkx as nx
from statistics import mean
from model import sample_sigmoid
from sklearn.metrics import mean_absolute_error, roc_auc_score, average_precision_score
from baselines.graphvae.graphvae_model import graph_show
def getitem(graph, max_num_nodes):
    adj = nx.to_numpy_matrix(graph)
    num_nodes = adj.shape[0]
    adj_padded = np.zeros((max_num_nodes, max_num_nodes))
    adj_padded[:num_nodes, :num_nodes] = adj
    adj_decoded = np.zeros(max_num_nodes * (max_num_nodes + 1) // 2)
    node_idx = 0
    adj_vectorized = adj_padded[np.triu(np.ones((max_num_nodes, max_num_nodes))) == 1]
    return {'adj': adj_padded,
            'adj_decoded': adj_vectorized,
            'num_nodes': num_nodes}
def evaluate(dataset, model, graph_show_mode=False):
    # print("********************   TEST *****************************")
    model.eval()
    mae_list = []
    tp_div_pos_list = []
    roc_score_list = []
    ap_score_list = []
    for batch_idx, data in enumerate(dataset):
        input_features = data['features'].float()
        input_features = torch.tensor(input_features, device='cuda:0')
        adj_input = data['adj'].float()
        batch_num_nodes = data['num_nodes'].int().numpy()
        numpy_adj = adj_input.cpu().numpy()
        # print("**** numpy_adj")
        # print(numpy_adj)
        # print("**** input_features")
        # print(input_features)
        batch_size = numpy_adj.shape[0]
        if model.vae_args.GRAN:
            gran_lable = np.zeros((batch_size, model.max_num_nodes - 1))
            random_index_list = []
            for i in range(batch_size):
                num_nodes = batch_num_nodes[i]
                if model.vae_args.GRAN_arbitrary_node:
                    random_index_list.append(np.random.randint(num_nodes))
                    gran_lable[i, :random_index_list[i]] = \
                        numpy_adj[i, :random_index_list[i], random_index_list[i]]
                    gran_lable[i, random_index_list[i]:num_nodes - 1] = \
                        numpy_adj[i, random_index_list[i] + 1:num_nodes, random_index_list[i]]
                    numpy_adj[i, :num_nodes, random_index_list[i]] = 1
                    numpy_adj[i, random_index_list[i], :num_nodes] = 1
                    # print("***  random_index_list")
                    # print(random_index_list)
                    # print("***  gran_lable")
                    # print(gran_lable)
                    # print("***  numpy_adj after fully connection")
                    # print(numpy_adj)
                else:
                    gran_lable[i, :num_nodes - 1] = numpy_adj[i, :num_nodes - 1, num_nodes - 1]
                    numpy_adj[i, :num_nodes, num_nodes - 1] = 1
                    numpy_adj[i, num_nodes - 1, :num_nodes] = 1
            gran_lable = torch.tensor(gran_lable).float()
            incomplete_adj = torch.tensor(numpy_adj, device='cuda:0')
            # print("***  gran_lable")
            # print(gran_lable)
            # print("***  numpy_adj after change")
            # print(numpy_adj)
        x = model.conv1(input_features, incomplete_adj)
        x = model.act(x)
        x = model.bn1(x)
        x = model.conv2(x, incomplete_adj)
        # x = model.bn2(x)
        # x = model.act(x)
        # print("***   embedding : ")
        # print(x)
        preds = model.gran(x, batch_num_nodes, random_index_list, model.vae_args.GRAN_arbitrary_node).squeeze()
        # print("***  preds")
        # print(preds)
        sample = sample_sigmoid(preds.unsqueeze(0), sample=False, thresh=0.5, sample_time=10)
        sample = sample.squeeze()
        # ind = 1
        # print("***  sample")
        # print(sample)
        # print("***  gran_lable")
        # print(gran_lable)
        if graph_show_mode:
            index = 0
            colors = [(0.7509279299037631, 0.021203049355839054, 0.24561203044115132)]
            incomplete = numpy_adj.copy()
            print("********    1) ")
            print(incomplete[index])
            incomplete[index, :batch_num_nodes[index], batch_num_nodes[index] - 1] = 0
            incomplete[index, batch_num_nodes[index] - 1, :batch_num_nodes[index]] = 0
            print("********    2) ")
            print(incomplete[index])
            graph_show(nx.from_numpy_matrix(incomplete[index]), "incomplete graph", colors)
            incomplete[index, :batch_num_nodes[index] - 1, batch_num_nodes[index] - 1] = \
                sample.cpu()[index, :batch_num_nodes[index] - 1]
            incomplete[index, batch_num_nodes[index] - 1, :batch_num_nodes[index] - 1] = \
                sample.cpu()[index, :batch_num_nodes[index] - 1]
            graph_show(nx.from_numpy_matrix(incomplete[index]), "complete graph", colors)
            return
        mae = mean_absolute_error(gran_lable.numpy(), sample.cpu())
        mae_list.append(mae)
        tp_div_pos = 0
        label = gran_lable.numpy()
        result = sample.cpu().numpy()
        for i in range(gran_lable.size()[0]):
            part1 = label[i][result[i] == 1]
            part2 = part1[part1 == 1]
            tp_div_pos += part2.sum() / label[i].sum()
        tp_div_pos /= gran_lable.size()[0]
        tp_div_pos_list.append(tp_div_pos)
        # #######################################
        result = preds.cpu().detach().numpy()
        roc_score = 0
        ap_score = 0
        for i in range(gran_lable.size()[0]):
            positive = result[i][label[i] == 1]
            # print("****************")
            # print(len(positive))
            # print(len(list(result[i][label[i] == 0])))
            # print(positive)
            if len(positive) <= len(list(result[i][label[i] == 0])):
                negative = random.sample(list(result[i][label[i] == 0]), len(positive))
            else:
                negative = result[i][label[i] == 0]
                positive = random.sample(list(result[i][label[i] == 1]), len(negative))
            preds_all = np.hstack([positive, negative])
            labels_all = np.hstack([np.ones(len(positive)), np.zeros(len(positive))])
            if len(labels_all) > 0:
                roc_score += roc_auc_score(labels_all, preds_all)
                ap_score += average_precision_score(labels_all, preds_all)
        roc_score /= batch_size
        ap_score /= batch_size
        roc_score_list.append(roc_score)
        ap_score_list.append(ap_score)
        # #######################################
    print("*** MAE - tp_div_pos - roc_score - ap_score : " + str(mean(mae_list)) + " _ "
          + str(mean(tp_div_pos_list)) + " _ " + str(mean(roc_score_list)) + " _ "
          + str(mean(ap_score_list)) + " _ ")
    # print(mae)
    # print("*** tp_div_pos : ")
    # print(tp_div_pos)
    # print("*** roc_score : ")
    # print(roc_score)
    # print("*** ap_score : ")
    # print(ap_score)
    return mae, tp_div_pos, roc_score, ap_score
 |