|
|
- import argparse
- from time import strftime, gmtime
- import statistics
-
- import matplotlib.pyplot as plt
- import networkx as nx
- import numpy as np
- from numpy import count_nonzero
- import os
- import shutil
- from random import shuffle
- import torch
- import torch.nn as nn
- import torch.nn.init as init
- from tensorboard_logger import configure, log_value
- from torch.autograd import Variable
- import torch.nn.functional as F
- from torch import optim
- from torch.optim.lr_scheduler import MultiStepLR
- from baselines.graphvae.evaluate import evaluate
- from baselines.graphvae.evaluate2 import evaluate as evaluate2
- from statistics import mean
-
- import data
- # from baselines.graphvae.graphvae_args import Graph_VAE_Args
- from baselines.graphvae.graphvae_model import GraphVAE
- from baselines.graphvae.graphvae_data import GraphAdjSampler
- from baselines.graphvae.args import GraphVAE_Args
- from baselines.graphvae.util import *
-
- CUDA = 0
- vae_args = GraphVAE_Args()
- LR_milestones = [100, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000]
-
-
- def graph_statistics(graphs):
- sparsity = 0
- matrix_size_list = []
- counter = 0
- for i in range(len(graphs)):
- numpy_matrix = nx.to_numpy_matrix(graphs[i])
- non_zero = np.count_nonzero(numpy_matrix)
- # if numpy_matrix.shape[0] == 12:
- # print("^^^^^^^^^^^^^^^^^^^^^^^^^^^")
- # print(non_zero)
- # print(numpy_matrix.shape)
- if non_zero == numpy_matrix.shape[0] * numpy_matrix.shape[1] - numpy_matrix.shape[1]:
- counter += 1
- # print("salam")
- sparsity += 1.0 - (count_nonzero(numpy_matrix) / float(numpy_matrix.size))
- matrix_size_list.append(numpy_matrix.shape[0])
- smallest_graph_size = min(matrix_size_list)
- largest_graph_size = max(matrix_size_list)
- graph_size_std = statistics.stdev(matrix_size_list)
- sparsity /= len(graphs)
- print("*** smallest_graph_size = " + str(smallest_graph_size) +
- " *** largest_graph_size = " + str(largest_graph_size) +
- " *** mean_graph_size = " + str(statistics.mean(matrix_size_list)) +
- " *** graph_size_std = " + str(graph_size_std) +
- " *** average_graph_sparsity = " + str(sparsity))
- print("*** counter")
- print(counter)
- return
-
-
- def build_model(args, max_num_nodes):
- out_dim = max_num_nodes * (max_num_nodes + 1) // 2
- if args.feature_type == 'id':
- if vae_args.completion_mode_small_parameter_size:
- input_dim = max_num_nodes - vae_args.number_of_missing_nodes
- else:
- input_dim = max_num_nodes
- elif args.feature_type == 'deg':
- input_dim = 1
- elif args.feature_type == 'struct':
- input_dim = 2
- model = GraphVAE(input_dim, 16, 256, max_num_nodes, vae_args)
- return model
-
-
- def train(args, dataloader, test_dataset_loader, graphs_test, model):
- epoch = 1
- optimizer = optim.Adam(list(model.parameters()), lr=args.lr)
- scheduler = MultiStepLR(optimizer, milestones=LR_milestones, gamma=0.1)
-
- model.train()
- # A = Graph_VAE_Args()
- for epoch in range(211):
- test_mae_list = []
- test_roc_score_list = []
- test_ap_score_list = []
- test_precision_list = []
- test_recall_list = []
- for batch_idx, data in enumerate(dataloader):
- model.zero_grad()
- features = data['features'].float()
- batch_num_nodes = data['num_nodes'].int().numpy()
- # print("features")
- # print(features)
- # print("************************************************")
- adj_input = data['adj'].float()
- # print("adj_input")
- # print(adj_input)
- # print("************************************************")
-
- features = Variable(features).cuda()
- adj_input = Variable(adj_input).cuda()
-
- loss = model(features, adj_input, batch_num_nodes)
- print('Epoch: ', epoch, ', Iter: ', batch_idx, ', Loss: ', loss)
- loss.backward()
-
- optimizer.step()
- scheduler.step()
- if epoch % 10 == 0:
- fname = vae_args.model_save_path + "GraphVAE" + str(epoch) + '.dat'
- log_value("Training Loss, dataset: " + arg_parse().dataset, loss.data, epoch * 32 + batch_idx)
- # test_mae, test_tp_div_pos, test_roc_score, test_ap_score = evaluate(test_dataset_loader, model)
- test_mae, test_roc_score, test_ap_score, test_precision, test_recall = evaluate2(
- graphs_test, model)
- test_mae_list.append(test_mae)
- test_roc_score_list.append(test_roc_score)
- test_ap_score_list.append(test_ap_score)
- test_precision_list.append(test_precision)
- test_recall_list.append(test_recall)
- log_value("Test MAE, dataset: " + arg_parse().dataset, test_mae, epoch * 32 + batch_idx)
- # log_value("Test test_tp_div_pos, dataset: " + arg_parse().dataset, test_tp_div_pos,
- # epoch * 32 + batch_idx)
- log_value("Test test_roc_score, dataset: " + arg_parse().dataset, test_roc_score,
- epoch * 32 + batch_idx)
- log_value("Test test_ap_score, dataset: " + arg_parse().dataset, test_ap_score,
- epoch * 32 + batch_idx)
- if epoch % 50 == 0 and epoch != 0:
- torch.save(model.state_dict(), fname)
- # test_mae = evaluate(test_dataset_loader, model, True)
- # break
- if len(test_mae_list) > 0:
- precision = mean(test_precision_list)
- recall = mean(test_recall_list)
- test_F_Measure = 2 * precision * recall / (precision + recall)
- print(
- "In Train: *** MAE - roc_score - ap_score - precision - recall - F_Measure : " + str(
- mean(test_mae_list)) + " _ "
- + str(mean(test_roc_score_list)) + " _ " + str(mean(test_ap_score_list)) + " _ "
- + str(precision) + " _ " + str(recall) + " _ "
- + str(test_F_Measure))
-
-
- def arg_parse():
- parser = argparse.ArgumentParser(description='GraphVAE arguments.')
- io_parser = parser.add_mutually_exclusive_group(required=False)
- io_parser.add_argument('--dataset', dest='dataset',
- help='Input dataset.')
-
- parser.add_argument('--lr', dest='lr', type=float,
- help='Learning rate.')
- parser.add_argument('--batch_size', dest='batch_size', type=int,
- help='Batch size.')
- parser.add_argument('--batch_ratio', dest='batch_ratio', type=int,
- help='Batch ratio.')
- parser.add_argument('--num_workers', dest='num_workers', type=int,
- help='Number of workers to load data.')
- parser.add_argument('--max_num_nodes', dest='max_num_nodes', type=int,
- help='Predefined maximum number of nodes in train/test graphs. -1 if determined by \
- training data.')
- parser.add_argument('--feature', dest='feature_type',
- help='Feature used for encoder. Can be: id, deg')
-
- parser.set_defaults(dataset='REDDITMULTI5K',
- feature_type='id',
- lr=0.01,
- batch_size=32,
- batch_ratio=10,
- num_workers=4,
- max_num_nodes=-1)
- return parser.parse_args()
-
-
- def main():
- prog_args = arg_parse()
-
- os.environ['CUDA_VISIBLE_DEVICES'] = str(CUDA)
- print('CUDA', CUDA)
- torch.manual_seed(1234)
- ### running log
-
- if prog_args.dataset == 'enzymes':
- print("SALAAAAAAAAAAAAAAAAAAAAAAAAAAAMMMMMM")
- graphs = data.Graph_load_batch(min_num_nodes=10, name='ENZYMES')
- num_graphs_raw = len(graphs)
- # print(num_graphs_raw)
- # print(type(graphs))
- # matrix = nx.to_numpy_matrix(graphs[1])
- # print(matrix.shape)
- elif prog_args.dataset == 'dd':
- graphs = data.Graph_load_batch(min_num_nodes=10, name='DD')
- num_graphs_raw = len(graphs)
- elif prog_args.dataset == 'ladder':
- graphs = []
- for i in range(100, 201):
- graphs.append(nx.ladder_graph(i))
- elif prog_args.dataset == 'barabasi':
- graphs = []
- for i in range(100, 200):
- for j in range(4, 5):
- for k in range(5):
- graphs.append(nx.barabasi_albert_graph(i, j))
- elif prog_args.dataset == 'citeseer':
- _, _, G = data.Graph_load(dataset='citeseer')
- G = max(nx.connected_component_subgraphs(G), key=len)
- G = nx.convert_node_labels_to_integers(G)
- graphs = []
- for i in range(G.number_of_nodes()):
- G_ego = nx.ego_graph(G, i, radius=3)
- if G_ego.number_of_nodes() >= 50 and (G_ego.number_of_nodes() <= 400):
- graphs.append(G_ego)
- elif prog_args.dataset == 'grid':
- graphs = []
- # for i in range(10, 20):
- # for j in range(10, 20):
- # graphs.append(nx.grid_2d_graph(i, j))
- # for i in range(5,10):
- # for j in range(5,10):
- # graphs.append(nx.grid_2d_graph(i,j))
- # *********************************
-
- graphs.append(nx.grid_2d_graph(2, 3))
- # graphs.append(nx.grid_2d_graph(2, 2))
- # graphs.append(nx.grid_2d_graph(2, 2))
- # graphs.append(nx.grid_2d_graph(2, 3))
- # graphs.append(nx.grid_2d_graph(2, 2))
- # graphs.append(nx.grid_2d_graph(2, 3))
- # graphs.append(nx.grid_2d_graph(2, 2))
- # graphs.append(nx.grid_2d_graph(2, 3))
- # graphs.append(nx.grid_2d_graph(4, 2))
- # graphs.append(nx.grid_2d_graph(3, 2))
- # graphs.append(nx.grid_2d_graph(3, 2))
- # graphs.append(nx.grid_2d_graph(1, 4))
- # graphs.append(nx.grid_2d_graph(1, 4))
- # graphs.append(nx.grid_2d_graph(1, 4))
- # graphs.append(nx.grid_2d_graph(4, 1))
- # graphs.append(nx.grid_2d_graph(1, 6))
- # graphs.append(nx.grid_2d_graph(6, 1))
- ###############################################################
- # graphs.append(nx.grid_2d_graph(3, 4))
- # graphs.append(nx.grid_2d_graph(1, 12))
- # graphs.append(nx.grid_2d_graph(2, 6))
- graphs.append(nx.grid_2d_graph(3, 4))
- # graphs.append(nx.grid_2d_graph(4, 3))
- graphs.append(nx.grid_2d_graph(6, 2))
- # graphs.append(nx.grid_2d_graph(12, 1))
- # # *********************************
- # graphs.append(nx.grid_2d_graph(1, 24))
- # graphs.append(nx.grid_2d_graph(2, 12))
- # graphs.append(nx.grid_2d_graph(3, 8))
- graphs.append(nx.grid_2d_graph(4, 6))
- graphs.append(nx.grid_2d_graph(6, 4))
- graphs.append(nx.grid_2d_graph(8, 3))
- graphs.append(nx.grid_2d_graph(12, 2))
- # graphs.append(nx.grid_2d_graph(24, 1))
- num_graphs_raw = len(graphs)
- elif prog_args.dataset == 'grid_big':
- graphs = []
- for i in range(36, 46):
- for j in range(36, 46):
- graphs.append(nx.grid_2d_graph(i, j))
- num_graphs_raw = len(graphs)
- elif prog_args.dataset == 'grid_small':
- graphs = []
- for i in range(2, 5):
- for j in range(2, 5):
- graphs.append(nx.grid_2d_graph(i, j))
- num_graphs_raw = len(graphs)
- else:
- graphs, num_classes = load_data(prog_args.dataset, True)
- # graphs = data.Graph_load_batch(min_num_nodes=10, name='DD')
- num_graphs_raw = len(graphs)
-
- if prog_args.max_num_nodes == -1:
- # max_num_nodes = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
- # print("@@@ max")
- # print(max_num_nodes)
- min_num_nodes = min([graphs[i].number_of_nodes() for i in range(len(graphs))])
- # print("@@@ min")
- # print(min_num_nodes)
- small_graphs_size = 0
- if prog_args.dataset != 'grid_small' and prog_args.dataset != 'grid':
- small_graphs = []
- for i in range(len(graphs)):
- if graphs[i].number_of_nodes() < 41:
- # if graphs[i].number_of_nodes() == 8 or graphs[i].number_of_nodes() == 16 or graphs[
- # i].number_of_nodes() == 32 or \
- # graphs[i].number_of_nodes() == 64 or graphs[i].number_of_nodes() == 128 or graphs[
- # i].number_of_nodes() == 256:
- small_graphs_size += 1
- small_graphs.append(graphs[i])
- graphs = small_graphs
- print(len(graphs))
- max_num_nodes = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
- graph_statistics(graphs)
- else:
- max_num_nodes = prog_args.max_num_nodes
- # remove graphs with number of nodes greater than max_num_nodes
- graphs = [g for g in graphs if g.number_of_nodes() <= max_num_nodes]
-
- graphs_len = len(graphs)
- # print('Number of graphs removed due to upper-limit of number of nodes: ',
- # num_graphs_raw - graphs_len)
- graphs_test = graphs[int(0.8 * graphs_len):]
- # graphs_train = graphs[0:int(0.8*graphs_len)]
- # prepare train and test data
- random.seed(123)
- shuffle(graphs)
- graphs_len = len(graphs)
- graphs_test = graphs[int(0.8 * graphs_len):]
- # print("**** Test graphs statistics:")
- # print(len(graphs_test))
- # graph_statistics(graphs_test)
- # #################################################################
- kronEM_graphs = []
- for i in range(len(graphs_test)):
- if graphs_test[i].number_of_nodes() == 8 or graphs_test[i].number_of_nodes() == 16 or \
- graphs_test[i].number_of_nodes() == 32 or graphs_test[i].number_of_nodes() == 64 or graphs_test[
- i].number_of_nodes() == 128:
- kronEM_graphs.append(graphs_test[i])
- prepare_kronEM_data(kronEM_graphs, prog_args.dataset, True)
- # #################################################################
- graphs_train = graphs[0:int(0.8 * graphs_len)]
- # print("**** Train graphs statistics:")
- # print(len(graphs_train))
- # graphs_train = graphs
- save_graphs_as_mat(graphs_test)
- print('total graph num: {}, training set: {}'.format(len(graphs), len(graphs_train)))
- # print('max number node: {}'.format(max_num_nodes))
- # print('min number node: {}'.format(min_num_nodes))
- # print('small graphs size: {}'.format(small_graphs_size))
- dataset = GraphAdjSampler(graphs_train, max_num_nodes, vae_args.permutation_mode, vae_args.bfs_mode,
- vae_args.bfs_mode_with_arbitrary_node_deleted,
- features=prog_args.feature_type)
- test_dataset = GraphAdjSampler(graphs_test, max_num_nodes, vae_args.permutation_mode, vae_args.bfs_mode,
- vae_args.bfs_mode_with_arbitrary_node_deleted,
- features=prog_args.feature_type)
- # sample_strategy = torch.utils.data.sampler.WeightedRandomSampler(
- # [1.0 / len(dataset) for i in range(len(dataset))],
- # num_samples=prog_args.batch_size,
- # replacement=False)
- sample_strategy = torch.utils.data.sampler.WeightedRandomSampler([1.0 / len(dataset) for i in range(len(dataset))],
- num_samples=prog_args.batch_size * prog_args.batch_ratio,
- replacement=True)
- test_sample_strategy = torch.utils.data.sampler.WeightedRandomSampler(
- [1.0 / len(test_dataset) for i in range(len(test_dataset))],
- num_samples=prog_args.batch_size * prog_args.batch_ratio,
- replacement=True)
-
- dataset_loader = torch.utils.data.DataLoader(
- dataset,
- batch_size=prog_args.batch_size,
- num_workers=prog_args.num_workers,
- sampler=sample_strategy)
- test_dataset_loader = torch.utils.data.DataLoader(
- test_dataset,
- batch_size=prog_args.batch_size,
- num_workers=prog_args.num_workers,
- sampler=test_sample_strategy)
- model = build_model(prog_args, max_num_nodes).cuda()
- train(prog_args, dataset_loader, test_dataset_loader, graphs_test, model)
-
-
- if __name__ == '__main__':
- if not os.path.isdir(vae_args.model_save_path):
- os.makedirs(vae_args.model_save_path)
- # configure(my_args.tensorboard_path, flush_secs=5)
- time = strftime("%Y-%m-%d %H:%M:%S", gmtime())
- if vae_args.clean_tensorboard:
- if os.path.isdir("tensorboard"):
- shutil.rmtree("tensorboard")
- configure("tensorboard/run" + time, flush_secs=5)
- main()
|