123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325 |
- from statistics import mean
-
- from main_baselines.graphvae.Main_VAE_Args import Main_VAE_Args
- from train import *
- from sklearn.metrics import mean_absolute_error, roc_auc_score, average_precision_score
- from baselines.graphvae.graphvae_train import graph_statistics
- from baselines.graphvae.util import *
- from main_baselines.graphvae.graphvae_train import build_model
- from main_baselines.graphvae.graphvae_train import arg_parse
-
-
- def upper_triangular_to_symmetric_converter(A):
- for i in range(A.shape[0]):
- for j in range(A.shape[1]):
- A[j, i] = A[i, j]
- return A
-
-
- def vector_to_matrix_converter(input_vector, max_num_nodes):
- # first it converts input vector to an upper triangular matrix, then the result will be casted to a symmetric matrix
- tri = np.zeros((max_num_nodes, max_num_nodes))
- tri[np.triu_indices(max_num_nodes)] = input_vector
- return upper_triangular_to_symmetric_converter(tri)
-
-
- def getitem(graph, max_num_nodes):
- adj = nx.to_numpy_matrix(graph)
- num_nodes = adj.shape[0]
- adj_padded = np.zeros((max_num_nodes, max_num_nodes))
- adj_padded[:num_nodes, :num_nodes] = adj
-
- return {'adj': adj_padded,
- 'num_nodes': num_nodes,
- 'features': np.identity(max_num_nodes)}
-
-
- def test(model, input_features, adj):
- # cpu_numpy = adj.squeeze(0).cpu().numpy()
- # print("********** 0)")
- # print(cpu_numpy)
- #
- # graph_show(nx.from_numpy_matrix(cpu_numpy), "1")
- x = model.conv1(input_features, adj)
- x = model.act(x)
- x = model.bn1(x)
- x = model.conv2(x, adj)
- # x = self.bn2(x)
-
- # pool over all nodes
- # graph_h = self.pool_graph(x)
- graph_h = x.view(-1, model.max_num_nodes * model.hidden_dim)
- # vaemax_num_nodes
- h_decode, z_mu, z_lsgms = model.vae(graph_h)
- out = F.sigmoid(h_decode)
- out_tensor = out.cpu().data
- print("*** h_decode")
- print(h_decode)
- print("*** out_tensor")
- print(out_tensor)
- recon_adj_lower = model.recover_adj_lower(out_tensor)
- recon_adj_tensor = model.recover_full_adj_from_lower(recon_adj_lower)
-
- # set matching features be degree
- out_features = torch.sum(recon_adj_tensor, 1)
-
- adj_data = adj.cpu().data
- adj_features = torch.sum(adj_data, 1)
-
- S = model.edge_similarity_matrix(adj_data, recon_adj_tensor, adj_features, out_features,
- model.deg_feature_similarity)
-
- # initialization strategies
- init_corr = 1 / model.max_num_nodes
- init_assignment = torch.ones(model.max_num_nodes, model.max_num_nodes) * init_corr
- assignment = model.mpm(init_assignment, S)
- # matching
- # use negative of the assignment score since the alg finds min cost flow
- row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy())
- # order row index according to col index
- adj_permuted = model.permute_adj(adj_data, row_ind, col_ind)
- adj_vectorized = adj_permuted[torch.triu(torch.ones(model.max_num_nodes, model.max_num_nodes)) == 1].squeeze_()
- adj_vectorized_var = Variable(adj_vectorized).cuda()
-
- adj_recon_loss = model.adj_recon_loss(adj_vectorized_var, out[0])
-
- loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp())
- loss_kl /= model.max_num_nodes * model.max_num_nodes # normalize
- # print('kl: ', loss_kl)
-
- loss = adj_recon_loss + loss_kl
-
- return loss
-
-
- def VAE_test(model, input_features, adj_with_missing_node, adj_input):
- # 2 layer GCN
- model.eval()
- x = model.conv1(input_features, adj_with_missing_node)
- x = model.act(x)
- x = model.bn1(x)
- x = model.conv2(x, adj_with_missing_node)
-
- graph_h = x.view(-1, model.max_num_nodes * model.hidden_dim)
-
- # vae
- h_decode, z_mu, z_lsgms = model.vae(graph_h)
-
- predicted_vector = sample_sigmoid(h_decode.unsqueeze(0).unsqueeze(0), sample=False, thresh=0.5, sample_time=10)
- predicted_matrix = vector_to_matrix_converter(predicted_vector.cpu(), model.max_num_nodes)
-
- out = torch.sigmoid(h_decode)
- out_tensor = out.cpu().data
- recon_adj_lower = model.recover_adj_lower(out_tensor)
- recon_adj_tensor = model.recover_full_adj_from_lower(recon_adj_lower)
- # print("*** h_decode")
- # print(h_decode)
- # print("*** out")
- # print(out)
- # print("*** out_tensor")
- # print(out_tensor)
- # print("*** recon_adj_tensor")
- # print(recon_adj_tensor)
- # print("*******************************************")
- # set matching features be degree
- out_features = torch.sum(recon_adj_tensor, 1)
-
- # adj_data = adj_with_missing_node.cpu().data
- adj_data = torch.from_numpy(adj_input)
-
- # print(adj_with_missing_node.size())
- # print(adj_data.size())
- adj_features = torch.sum(adj_data, 1)
- # print("*** adj_data")
- # print(adj_data)
- # print("*** recon_adj_tensor")
- # print(recon_adj_tensor)
- # print("*** adj_features")
- # print(adj_features)
- # print("*** out_features")
- # print(out_features)
- # print("*******************************************")
- S = model.edge_similarity_matrix(adj_data, recon_adj_tensor, adj_features, out_features,
- model.deg_feature_similarity)
- # print(S)
-
- # initialization strategies
- init_corr = 1 / model.max_num_nodes
- init_assignment = torch.ones(model.max_num_nodes, model.max_num_nodes) * init_corr
- assignment = model.mpm(init_assignment, S)
-
- # matching
- # use negative of the assignment score since the alg finds min cost flow
- # print(assignment)
- try:
- row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy())
- except ValueError:
- print("Oops! That was no valid number. Try again...")
- return predicted_matrix[2, :], predicted_matrix[2, :]
-
- # order row index according to col index
- adj_permuted = model.permute_adj(adj_data, row_ind, col_ind)
- main_adj_permuted = model.permute_adj(torch.from_numpy(adj_input).cuda().float(), row_ind, col_ind)
- # print("%%%%%%%%%%%%%%%")
- # print(adj_permuted)
- missing_node_index = np.where(~adj_permuted.numpy().any(axis=1))[0][0]
- # print(missing_node_index)
- # print(main_adj_permuted)
- return main_adj_permuted[missing_node_index, :].numpy(), predicted_matrix[missing_node_index, :]
-
-
- def evaluate(model, test_graph_list):
- model.eval()
-
- mae_list = []
- roc_score_list = []
- ap_score_list = []
- precision_list = []
- recall_list = []
- for graph in test_graph_list:
- data = getitem(graph, model.max_num_nodes)
- input_features = data['features']
- input_features = torch.tensor(input_features, device='cuda:0').unsqueeze(0).float()
- adj_input = data['adj']
- num_nodes = data['num_nodes']
- # print("#######################")
- # print(adj_input)
- # remove random node from adj_input
- adj_with_missing_node = adj_input.copy()
- random_index = np.random.randint(num_nodes)
- adj_with_missing_node[:, random_index] = 0
- adj_with_missing_node[random_index, :] = 0
-
- test(model, input_features, torch.from_numpy(adj_input).cuda().float())
- # label, result = VAE_test(model, input_features, torch.from_numpy(adj_with_missing_node).cuda().float(),
- # adj_input)
-
- # sample = sample_sigmoid(preds.unsqueeze(0).unsqueeze(0), sample=False, thresh=0.5, sample_time=10)
- # sample = sample.squeeze()
- mae = mean_absolute_error(label, result)
- mae_list.append(mae)
- print("**** label")
- print(label)
- print("**** result")
- print(result)
- # label = gran_lable.numpy()
- # result = sample.cpu().numpy()
- # print("*** label")
- # print(label)
- # print("*** result")
- # print(result)
- # print(sum(result))
- # print(num_nodes)
- part1 = label[result == 1]
- part2 = part1[part1 == 1]
- part3 = part1[part1 == 0]
- part4 = label[result == 0]
- part5 = part4[part4 == 1]
- tp = len(part2)
- fp = len(part3)
- fn = part5.sum()
- if tp + fp > 0:
- precision = tp / (tp + fp)
- else:
- precision = 0
- recall = tp / (tp + fn)
- # F_Measure = 2 * precision * recall / (precision + recall)
- precision_list.append(precision)
- recall_list.append(recall)
- # F_Measure_list.append(F_Measure)
- # tp_div_pos = part2.sum() / label.sum()
- # tp_div_pos_list.append(tp_div_pos)
-
- # result = preds.cpu().detach().numpy() ??????????????????
-
- positive = result[label == 1]
- if len(positive) <= len(list(result[label == 0])):
- negative = random.sample(list(result[label == 0]), len(positive))
- else:
- negative = result[label == 0]
- positive = random.sample(list(result[label == 1]), len(negative))
- preds_all = np.hstack([positive, negative])
- labels_all = np.hstack([np.ones(len(positive)), np.zeros(len(positive))])
- if len(labels_all) > 0:
- roc_score = roc_auc_score(labels_all, preds_all)
- ap_score = average_precision_score(labels_all, preds_all)
- roc_score_list.append(roc_score)
- ap_score_list.append(ap_score)
- # print("*********************************")
- # print(precision_list)
- # print(
- # "*** MAE - roc_score - ap_score - precision - recall - F_Measure : " + str(mean(mae_list)) + " _ "
- # + str(mean(roc_score_list)) + " _ " + str(mean(ap_score_list)) + " _ "
- # + str(mean(precision_list)) + " _ " + str(mean(recall_list)) + " _ "
- # + str(mean(F_Measure_list)) + " _ ")
- return mean(mae_list), mean(roc_score_list), mean(precision_list), mean(recall_list)
-
-
- if __name__ == '__main__':
- random.seed(123)
- torch.manual_seed(1234)
- prog_args = arg_parse()
- args = Main_VAE_Args()
-
- graphs = create_graphs.create(args)
-
- training_on_KronEM_data = False
-
- if training_on_KronEM_data:
- small_graphs = []
- for i in range(len(graphs)):
- if graphs[i].number_of_nodes() == 8 or graphs[i].number_of_nodes() == 16 or graphs[
- i].number_of_nodes() == 32 or \
- graphs[i].number_of_nodes() == 64 or graphs[i].number_of_nodes() == 128 or graphs[
- i].number_of_nodes() == 256:
- small_graphs.append(graphs[i])
- graphs = small_graphs
- else:
- if args.graph_type == 'IMDBBINARY' or args.graph_type == 'IMDBMULTI':
- small_graphs = []
- for i in range(len(graphs)):
- if graphs[i].number_of_nodes() < 13:
- small_graphs.append(graphs[i])
- graphs = small_graphs
- elif args.graph_type == 'COLLAB':
- small_graphs = []
- for i in range(len(graphs)):
- if graphs[i].number_of_nodes() < 52 and graphs[i].number_of_nodes() > 41:
- small_graphs.append(graphs[i])
- graphs = small_graphs
- max_num_nodes = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
- fname = args.model_save_path + "GraphVAE" + str(args.load_epoch) + '.dat'
- model = build_model(prog_args, max_num_nodes).cuda()
- # model.load_state_dict(torch.load(fname))
- graph_statistics(graphs)
- # split datasets
- random.seed(123)
- shuffle(graphs)
- graphs_len = len(graphs)
- graphs_test = graphs[int(0.8 * graphs_len):]
- graphs_train = graphs[0:int(0.8 * graphs_len)]
- print("**** test graph size : " + str(len(graphs_test)))
- graphs_train = graphs[0:int(0.8 * graphs_len)]
- graph_statistics(graphs_test)
- iteration = 1
- mae_list = []
- roc_score_list = []
- precision_list = []
- recall_list = []
- F_Measure_list = []
- arbitrary_node_deletion_flag = True
- for i in range(iteration):
- print("########################################################################## " + str(i))
- mae, roc_score, precision, recall = evaluate(model, graphs_train)
- F_Measure = 2 * precision * recall / (precision + recall)
- mae_list.append(mae)
- roc_score_list.append(roc_score)
- precision_list.append(precision)
- recall_list.append(recall)
- F_Measure_list.append(F_Measure)
- print(
- "In Test: *** MAE - roc_score - precision - recall - F_Measure : " + str(
- mean(mae_list)) + " _ "
- + str(mean(roc_score_list)) + " _ "
- + str(mean(precision_list)) + " _ " + str(mean(recall_list)) + " _ "
- + str(mean(F_Measure_list)))
|