123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256 |
- from statistics import mean
- from train import *
- from sklearn.metrics import mean_absolute_error, roc_auc_score, average_precision_score
- from baselines.graphvae.graphvae_train import graph_statistics
- from baselines.graphvae.util import *
-
-
- def getitem(graph, max_prev_node, arbitrary_node_deletion_flag):
- # print("*** input graph size = " + str(graph.number_of_nodes()))
- input_graph_size = graph.number_of_nodes()
- adj = nx.to_numpy_matrix(graph)
- if arbitrary_node_deletion_flag:
- adj = move_random_node_to_the_last_index(adj)
- adj_copy = adj.copy()
-
- len = adj_copy.shape[0]
- x = np.zeros((graph.number_of_nodes() - 1, max_prev_node)) # here zeros are padded for small graph
- x[0, :] = 1 # the first input token is all ones
- y = np.zeros((graph.number_of_nodes() - 1, max_prev_node)) # here zeros are padded for small graph
- # print("%%%%%%%%% in getitem")
- # print(adj_copy)
- column_vector = adj_copy[:, adj_copy.shape[0] - 1]
- # print(column_vector)
- # print("%%% end")
- incomplete_adj = adj_copy.copy()
- incomplete_adj = incomplete_adj[:, :incomplete_adj.shape[0] - 1]
- incomplete_adj = incomplete_adj[:incomplete_adj.shape[0] - 1, :]
- x_idx = np.random.permutation(incomplete_adj.shape[0])
-
- x_idx_prime = np.concatenate((x_idx, [adj.shape[0] - 1]), axis=0)
- column_vector = column_vector[np.ix_(x_idx_prime)]
-
- incomplete_adj = incomplete_adj[np.ix_(x_idx, x_idx)]
- #
- incomplete_matrix = np.asmatrix(incomplete_adj)
- G = nx.from_numpy_matrix(incomplete_matrix)
- # then do bfs in the permuted G
- start_idx = np.random.randint(incomplete_adj.shape[0])
- x_idx = np.array(bfs_seq(G, start_idx))
- incomplete_adj = incomplete_adj[np.ix_(x_idx, x_idx)]
- # print("*** graph size after BFS = " + str(incomplete_adj.shape[0]))
- # print("##########################################################################")
- graph_size_after_BFS = incomplete_adj.shape[0]
- if graph_size_after_BFS != input_graph_size - 1:
- return
- adj_encoded = encode_adj(incomplete_adj.copy(), max_prev_node=max_prev_node)
- #
- x_idx_prime = np.concatenate((x_idx, [adj.shape[0] - 1]), axis=0)
- column_vector = column_vector[np.ix_(x_idx_prime)]
- y[0:adj_encoded.shape[0], :] = adj_encoded
- x[1:adj_encoded.shape[0] + 1, :] = adj_encoded
- # print("*** incomplete")
- # print(incomplete_adj)
- return {'x': x, 'y': y, 'len': len - 1, 'label': column_vector}
-
-
- def evaluate(test_graph_list, args, arbitrary_node_deletion_flag):
- if 'GraphRNN_RNN' in args.note:
- rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
- hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
- has_output=True, output_size=args.hidden_size_rnn_output).cuda()
- output = GRU_plain(input_size=1, embedding_size=args.embedding_size_rnn_output,
- hidden_size=args.hidden_size_rnn_output, num_layers=args.num_layers, has_input=True,
- has_output=True, output_size=1).cuda()
- elif 'GraphRNN_MLP' in args.note:
- rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
- hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
- has_output=False).cuda()
- output = MLP_plain(h_size=args.hidden_size_rnn, embedding_size=args.embedding_size_output,
- y_size=args.max_prev_node).cuda()
- fname = args.model_save_path + args.fname_GraphRNN + 'lstm_' + str(args.load_epoch) + '.dat'
- print("*** fname : " + fname)
- rnn.load_state_dict(torch.load(fname))
- fname = args.model_save_path + args.fname_GraphRNN + 'output_' + str(args.load_epoch) + '.dat'
- output.load_state_dict(torch.load(fname))
-
- rnn.hidden = rnn.init_hidden(1)
- rnn.eval()
- output.eval()
-
- mae_list = []
- roc_score_list = []
- ap_score_list = []
- precision_list = []
- recall_list = []
- number_of_removed_data = 0
- for graph in test_graph_list:
- data = getitem(graph, args.max_prev_node, arbitrary_node_deletion_flag)
- if data == None:
- number_of_removed_data += 1
- continue
- x_unsorted = data['x']
- y_unsorted = data['y']
- y_len_unsorted = data['len']
- x_step = Variable(torch.ones(1, 1, args.max_prev_node)).cuda()
- for i in range(y_len_unsorted):
- h = rnn(x_step)
- if (i < y_len_unsorted - 1):
- x_step = (torch.Tensor(x_unsorted[i + 1:i + 2, :]).unsqueeze(0)).cuda()
- else: # the last step prediction
- if 'GraphRNN_RNN' in args.note:
- hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(2))).cuda()
- output.hidden = torch.cat((h.permute(1, 0, 2), hidden_null),
- dim=0) # num_layers, batch_size, hidden_size
- x_step = Variable(torch.zeros(1, 1, args.max_prev_node)).cuda()
- output_x_step = Variable(torch.ones(1, 1, 1)).cuda()
- y_pred = x_step
- for j in range(min(args.max_prev_node, i + 1)):
- output_y_pred_step = output(output_x_step)
- output_x_step = sample_sigmoid(output_y_pred_step, sample=True, sample_time=1)
- y_pred[:, :, j:j + 1] = F.sigmoid(output_y_pred_step)
- x_step[:, :, j:j + 1] = output_x_step
- output.hidden = Variable(output.hidden.data).cuda()
- result = x_step.detach()
- elif 'GraphRNN_MLP' in args.note:
- y_pred_step = output(h)
- y_pred = F.sigmoid(y_pred_step)
- result = sample_sigmoid(y_pred_step, sample=True, sample_time=1)
- # result = sample_sigmoid(y_pred_step, sample=False, thresh=0.5, sample_time=10)
- rnn.hidden = Variable(rnn.hidden.data).cuda()
- # print("*** before")
- # print(y_unsorted)
- result = result.squeeze(0).cpu().numpy()
- y_pred = y_pred.squeeze(0).cpu().detach().numpy()
- y_unsorted = np.concatenate((y_unsorted[:y_unsorted.shape[0] - 1, :], result), axis=0)
- probabilistic_y_unsorted = np.concatenate((y_unsorted[:y_unsorted.shape[0] - 1, :], y_pred), axis=0)
- # print("*** result")
- # print(result)
- # print("*** after")
- # print(y_unsorted)
- # print(probabilistic_y_unsorted)
- adj_decoded = decode_adj(y_unsorted)
- adj_decoded_prob = decode_adj(probabilistic_y_unsorted)
- # print(adj_decoded.shape)
- # print("*** decoded")
- # print(adj_decoded)
- # print(adj_decoded_prob)
- result = adj_decoded[:, adj_decoded.shape[0] - 1]
- y_pred = adj_decoded_prob[:, adj_decoded_prob.shape[0] - 1]
- label = np.transpose(data['label'])
- label = np.asarray(label)[0, :]
- # print("*** result")
- #
- # print(np.shape(result))
- # print(result)
- # print("*** label")
- # print(label)
- mae = mean_absolute_error(label, result)
- mae_list.append(mae)
-
- # label = gran_lable.numpy()
- # result = sample.cpu().numpy()
- part1 = label[result == 1]
- part2 = part1[part1 == 1]
- part3 = part1[part1 == 0]
- part4 = label[result == 0]
- part5 = part4[part4 == 1]
- tp = len(part2)
- fp = len(part3)
- fn = part5.sum()
- if tp + fp > 0:
- precision = tp / (tp + fp)
- else:
- precision = 0
- recall = tp / (tp + fn)
- precision_list.append(precision)
- recall_list.append(recall)
-
- # result = y_pred.cpu().detach().numpy()
- # print(y_pred)
- positive = result[label == 1]
- if len(positive) <= len(list(result[label == 0])):
- negative = random.sample(list(result[label == 0]), len(positive))
- else:
- negative = result[label == 0]
- positive = random.sample(list(result[label == 1]), len(negative))
- preds_all = np.hstack([positive, negative])
- labels_all = np.hstack([np.ones(len(positive)), np.zeros(len(positive))])
- if len(labels_all) > 0:
- roc_score = roc_auc_score(labels_all, preds_all)
- ap_score = average_precision_score(labels_all, preds_all)
- roc_score_list.append(roc_score)
- print("************ list length : " + str(len(precision_list)))
- return mean(mae_list), mean(roc_score_list), mean(precision_list), mean(recall_list), number_of_removed_data
-
-
- if __name__ == '__main__':
- # graphs_test = []
- # graph = nx.grid_2d_graph(2, 3)
- # graphs_test.append(graph)
- # # graph = nx.grid_2d_graph(2, 4)
- # graphs_test.append(graph)
- print("salam")
- args = Args()
- args.max_prev_node = 40
- graphs = create_graphs.create(args)
-
- training_on_KronEM_data = True
-
- if training_on_KronEM_data:
- small_graphs = []
- for i in range(len(graphs)):
- if graphs[i].number_of_nodes() == 8 or graphs[i].number_of_nodes() == 16 or graphs[
- i].number_of_nodes() == 32 or \
- graphs[i].number_of_nodes() == 64 or graphs[i].number_of_nodes() == 128 or graphs[
- i].number_of_nodes() == 256:
- small_graphs.append(graphs[i])
- graphs = small_graphs
- else:
- if args.graph_type == 'IMDBBINARY' or args.graph_type == 'IMDBMULTI':
- small_graphs = []
- for i in range(len(graphs)):
- if graphs[i].number_of_nodes() < 41:
- small_graphs.append(graphs[i])
- graphs = small_graphs
- elif args.graph_type == 'COLLAB':
- small_graphs = []
- for i in range(len(graphs)):
- if graphs[i].number_of_nodes() < 52 and graphs[i].number_of_nodes() > 41:
- small_graphs.append(graphs[i])
- graphs = small_graphs
- graph_statistics(graphs)
- # split datasets
- random.seed(123)
- shuffle(graphs)
- graphs_len = len(graphs)
- graphs_test = graphs[int(0.8 * graphs_len):]
- print("**** test graph size : " + str(len(graphs_test)))
- graphs_train = graphs[0:int(0.8 * graphs_len)]
- graph_statistics(graphs_test)
- iteration = 20
- mae_list = []
- roc_score_list = []
- precision_list = []
- recall_list = []
- F_Measure_list = []
- number_of_removed_data_list = []
- arbitrary_node_deletion_flag = True
- for i in range(iteration):
- print("########################################################################## " + str(i))
- mae, roc_score, precision, recall, number_of_removed_data = evaluate(graphs_test, args,
- arbitrary_node_deletion_flag)
- F_Measure = 2 * precision * recall / (precision + recall)
- mae_list.append(mae)
- roc_score_list.append(roc_score)
- precision_list.append(precision)
- recall_list.append(recall)
- F_Measure_list.append(F_Measure)
- number_of_removed_data_list.append(number_of_removed_data)
- print("Mean number of removed data : " + str(mean(number_of_removed_data_list)))
- print(
- "In Test: *** MAE - roc_score - precision - recall - F_Measure : " + str(
- mean(mae_list)) + " _ "
- + str(mean(roc_score_list)) + " _ "
- + str(mean(precision_list)) + " _ " + str(mean(recall_list)) + " _ "
- + str(mean(F_Measure_list)))
|