MIT License | |||||
Copyright (c) 2017 Jiaxuan You, Rex Ying | |||||
Permission is hereby granted, free of charge, to any person obtaining a copy | |||||
of this software and associated documentation files (the "Software"), to deal | |||||
in the Software without restriction, including without limitation the rights | |||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |||||
copies of the Software, and to permit persons to whom the Software is | |||||
furnished to do so, subject to the following conditions: | |||||
The above copyright notice and this permission notice shall be included in all | |||||
copies or substantial portions of the Software. | |||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |||||
SOFTWARE. |
# GraphRNN: Generating Realistic Graphs with Deep Auto-regressive Model | |||||
This repository is the official PyTorch implementation of GraphRNN, a graph generative model using auto-regressive model. | |||||
[Jiaxuan You](https://cs.stanford.edu/~jiaxuan/)\*, [Rex Ying](https://cs.stanford.edu/people/rexy/)\*, [Xiang Ren](http://www-bcf.usc.edu/~xiangren/), [William L. Hamilton](https://stanford.edu/~wleif/), [Jure Leskovec](https://cs.stanford.edu/people/jure/index.html), [GraphRNN: Generating Realistic Graphs with Deep Auto-regressive Model](https://arxiv.org/abs/1802.08773) (ICML 2018) | |||||
## Installation | |||||
Install PyTorch following the instuctions on the [official website](https://pytorch.org/). The code has been tested over PyTorch 0.2.0 and 0.4.0 versions. | |||||
```bash | |||||
conda install pytorch torchvision cuda90 -c pytorch | |||||
``` | |||||
Then install the other dependencies. | |||||
```bash | |||||
pip install -r requirements.txt | |||||
``` | |||||
## Test run | |||||
```bash | |||||
python main.py | |||||
``` | |||||
## Code description | |||||
For the GraphRNN model: | |||||
`main.py` is the main executable file, and specific arguments are set in `args.py`. | |||||
`train.py` includes training iterations and calls `model.py` and `data.py` | |||||
`create_graphs.py` is where we prepare target graph datasets. | |||||
For baseline models: | |||||
* B-A and E-R models are implemented in `baselines/baseline_simple.py`. | |||||
* [Kronecker graph model](https://cs.stanford.edu/~jure/pubs/kronecker-jmlr10.pdf) is implemented in the SNAP software, which can be found in `https://github.com/snap-stanford/snap/tree/master/examples/krongen` (for generating Kronecker graphs), and `https://github.com/snap-stanford/snap/tree/master/examples/kronfit` (for learning parameters for the model). | |||||
* MMSB is implemented using the EDWARD library (http://edwardlib.org/), and is located in | |||||
`baselines`. | |||||
* We implemented the DeepGMG model based on the instructions of their [paper](https://arxiv.org/abs/1803.03324) in `main_DeepGMG.py`. | |||||
* We implemented the GraphVAE model based on the instructions of their [paper](https://arxiv.org/abs/1802.03480) in `baselines/graphvae`. | |||||
Parameter setting: | |||||
To adjust the hyper-parameter and input arguments to the model, modify the fields of `args.py` | |||||
accordingly. | |||||
For example, `args.cuda` controls which GPU is used to train the model, and `args.graph_type` | |||||
specifies which dataset is used to train the generative model. See the documentation in `args.py` | |||||
for more detailed descriptions of all fields. | |||||
## Outputs | |||||
There are several different types of outputs, each saved into a different directory under a path prefix. The path prefix is set at `args.dir_input`. Suppose that this field is set to `./`: | |||||
* `./graphs` contains the pickle files of training, test and generated graphs. Each contains a list | |||||
of networkx object. | |||||
* `./eval_results` contains the evaluation of MMD scores in txt format. | |||||
* `./model_save` stores the model checkpoints | |||||
* `./nll` saves the log-likelihood for generated graphs as sequences. | |||||
* `./figures` is used to save visualizations (see Visualization of graphs section). | |||||
## Evaluation | |||||
The evaluation is done in `evaluate.py`, where user can choose which settings to evaluate. | |||||
To evaluate how close the generated graphs are to the ground truth set, we use MMD (maximum mean discrepancy) to calculate the divergence between two _sets of distributions_ related to | |||||
the ground truth and generated graphs. | |||||
Three types of distributions are chosen: degree distribution, clustering coefficient distribution. | |||||
Both of which are implemented in `eval/stats.py`, using multiprocessing python | |||||
module. One can easily extend the evaluation to compute MMD for other distribution of graphs. | |||||
We also compute the orbit counts for each graph, represented as a high-dimensional data point. We then compute the MMD | |||||
between the two _sets of sampled points_ using ORCA (see http://www.biolab.si/supp/orca/orca.html) at `eval/orca`. | |||||
One first needs to compile ORCA by | |||||
```bash | |||||
g++ -O2 -std=c++11 -o orca orca.cpp` | |||||
``` | |||||
in directory `eval/orca`. | |||||
(the binary file already in repo works in Ubuntu). | |||||
To evaluate, run | |||||
```bash | |||||
python evaluate.py | |||||
``` | |||||
Arguments specific to evaluation is specified in class | |||||
`evaluate.Args_evaluate`. Note that the field `Args_evaluate.dataset_name_all` must only contain | |||||
datasets that are already trained, by setting args.graph_type to each of the datasets and running | |||||
`python main.py`. | |||||
## Visualization of graphs | |||||
The training, testing and generated graphs are saved at 'graphs/'. | |||||
One can visualize the generated graph using the function `utils.load_graph_list`, which loads the | |||||
list of graphs from the pickle file, and `util.draw_graph_list`, which plots the graph using | |||||
networkx. | |||||
## Misc | |||||
Jesse Bettencourt and Harris Chan have made a great [slide](https://duvenaud.github.io/learn-discrete/slides/graphrnn.pdf) introducing GraphRNN in Prof. David Duvenaud’s seminar course [Learning Discrete Latent Structure](https://duvenaud.github.io/learn-discrete/). | |||||
# this file is used to plot images | |||||
from main import * | |||||
args = Args() | |||||
print(args.graph_type, args.note) | |||||
# epoch = 16000 | |||||
epoch = 3000 | |||||
sample_time = 3 | |||||
def find_nearest_idx(array,value): | |||||
idx = (np.abs(array-value)).argmin() | |||||
return idx | |||||
# for baseline model | |||||
for num_layers in range(4,5): | |||||
# give file name and figure name | |||||
fname_real = args.graph_save_path + args.fname_real + str(0) | |||||
fname_pred = args.graph_save_path + args.fname_pred + str(epoch) +'_'+str(sample_time) | |||||
figname = args.figure_save_path + args.fname + str(epoch) +'_'+str(sample_time) | |||||
# fname_real = args.graph_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \ | |||||
# str(epoch) + '_real_' + str(True) + '_' + str(num_layers) | |||||
# fname_pred = args.graph_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \ | |||||
# str(epoch) + '_pred_' + str(True) + '_' + str(num_layers) | |||||
# figname = args.figure_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \ | |||||
# str(epoch) + '_' + str(num_layers) | |||||
print(fname_real) | |||||
print(fname_pred) | |||||
# load data | |||||
graph_real_list = load_graph_list(fname_real + '.dat') | |||||
shuffle(graph_real_list) | |||||
graph_pred_list_raw = load_graph_list(fname_pred + '.dat') | |||||
graph_real_len_list = np.array([len(graph_real_list[i]) for i in range(len(graph_real_list))]) | |||||
graph_pred_len_list_raw = np.array([len(graph_pred_list_raw[i]) for i in range(len(graph_pred_list_raw))]) | |||||
graph_pred_list = graph_pred_list_raw | |||||
graph_pred_len_list = graph_pred_len_list_raw | |||||
# # select samples | |||||
# graph_pred_list = [] | |||||
# graph_pred_len_list = [] | |||||
# for value in graph_real_len_list: | |||||
# pred_idx = find_nearest_idx(graph_pred_len_list_raw, value) | |||||
# graph_pred_list.append(graph_pred_list_raw[pred_idx]) | |||||
# graph_pred_len_list.append(graph_pred_len_list_raw[pred_idx]) | |||||
# # delete | |||||
# graph_pred_len_list_raw=np.delete(graph_pred_len_list_raw, pred_idx) | |||||
# del graph_pred_list_raw[pred_idx] | |||||
# if len(graph_pred_list)==200: | |||||
# break | |||||
# graph_pred_len_list = np.array(graph_pred_len_list) | |||||
# # select pred data within certain range | |||||
# len_min = np.amin(graph_real_len_list) | |||||
# len_max = np.amax(graph_real_len_list) | |||||
# pred_index = np.where((graph_pred_len_list>=len_min)&(graph_pred_len_list<=len_max)) | |||||
# # print(pred_index[0]) | |||||
# graph_pred_list = [graph_pred_list[i] for i in pred_index[0]] | |||||
# graph_pred_len_list = graph_pred_len_list[pred_index[0]] | |||||
# real_order = np.argsort(graph_real_len_list) | |||||
# pred_order = np.argsort(graph_pred_len_list) | |||||
real_order = np.argsort(graph_real_len_list)[::-1] | |||||
pred_order = np.argsort(graph_pred_len_list)[::-1] | |||||
# print(real_order) | |||||
# print(pred_order) | |||||
graph_real_list = [graph_real_list[i] for i in real_order] | |||||
graph_pred_list = [graph_pred_list[i] for i in pred_order] | |||||
# shuffle(graph_real_list) | |||||
# shuffle(graph_pred_list) | |||||
print('real average nodes', sum([graph_real_list[i].number_of_nodes() for i in range(len(graph_real_list))])/len(graph_real_list)) | |||||
print('pred average nodes', sum([graph_pred_list[i].number_of_nodes() for i in range(len(graph_pred_list))])/len(graph_pred_list)) | |||||
print('num of real graphs', len(graph_real_list)) | |||||
print('num of pred graphs', len(graph_pred_list)) | |||||
# # draw all graphs | |||||
# for iter in range(8): | |||||
# print('iter', iter) | |||||
# graph_list = [] | |||||
# for i in range(8): | |||||
# index = 8 * iter + i | |||||
# # graph_real_list[index].remove_nodes_from(list(nx.isolates(graph_real_list[index]))) | |||||
# # graph_pred_list[index].remove_nodes_from(list(nx.isolates(graph_pred_list[index]))) | |||||
# graph_list.append(graph_real_list[index]) | |||||
# graph_list.append(graph_pred_list[index]) | |||||
# print('real', graph_real_list[index].number_of_nodes()) | |||||
# print('pred', graph_pred_list[index].number_of_nodes()) | |||||
# | |||||
# draw_graph_list(graph_list, row=4, col=4, fname=figname + '_' + str(iter)) | |||||
# draw all graphs | |||||
for iter in range(8): | |||||
print('iter', iter) | |||||
graph_list = [] | |||||
for i in range(8): | |||||
index = 32 * iter + i | |||||
# graph_real_list[index].remove_nodes_from(list(nx.isolates(graph_real_list[index]))) | |||||
# graph_pred_list[index].remove_nodes_from(list(nx.isolates(graph_pred_list[index]))) | |||||
# graph_list.append(graph_real_list[index]) | |||||
graph_list.append(graph_pred_list[index]) | |||||
# print('real', graph_real_list[index].number_of_nodes()) | |||||
print('pred', graph_pred_list[index].number_of_nodes()) | |||||
draw_graph_list(graph_list, row=4, col=4, fname=figname + '_' + str(iter)+'_pred') | |||||
# draw all graphs | |||||
for iter in range(8): | |||||
print('iter', iter) | |||||
graph_list = [] | |||||
for i in range(8): | |||||
index = 16 * iter + i | |||||
# graph_real_list[index].remove_nodes_from(list(nx.isolates(graph_real_list[index]))) | |||||
# graph_pred_list[index].remove_nodes_from(list(nx.isolates(graph_pred_list[index]))) | |||||
graph_list.append(graph_real_list[index]) | |||||
# graph_list.append(graph_pred_list[index]) | |||||
print('real', graph_real_list[index].number_of_nodes()) | |||||
# print('pred', graph_pred_list[index].number_of_nodes()) | |||||
draw_graph_list(graph_list, row=4, col=4, fname=figname + '_' + str(iter)+'_real') | |||||
# | |||||
# # for new model | |||||
# elif args.note == 'GraphRNN_structure' and args.is_flexible==False: | |||||
# for num_layers in range(4,5): | |||||
# # give file name and figure name | |||||
# # fname_real = args.graph_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \ | |||||
# # str(epoch) + '_real_bptt_' + str(args.bptt)+'_'+str(num_layers)+'_dilation_'+str(args.is_dilation)+'_flexible_'+str(args.is_flexible)+'_bn_'+str(args.is_bn)+'_lr_'+str(args.lr) | |||||
# # fname_pred = args.graph_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \ | |||||
# # str(epoch) + '_pred_bptt_' + str(args.bptt)+'_'+str(num_layers)+'_dilation_'+str(args.is_dilation)+'_flexible_'+str(args.is_flexible)+'_bn_'+str(args.is_bn)+'_lr_'+str(args.lr) | |||||
# | |||||
# fname_pred = args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | |||||
# str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt)+ '_' + str(args.bptt_len) + '_' + str(args.hidden_size) | |||||
# fname_real = args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | |||||
# str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt)+ '_' + str(args.bptt_len) + '_' + str(args.hidden_size) | |||||
# figname = args.figure_save_path + args.note + '_' + args.graph_type + '_' + \ | |||||
# str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt)+ '_' + str(args.bptt_len) + '_' + str(args.hidden_size) | |||||
# print(fname_real) | |||||
# # load data | |||||
# graph_real_list = load_graph_list(fname_real+'.dat') | |||||
# graph_pred_list = load_graph_list(fname_pred+'.dat') | |||||
# | |||||
# graph_real_len_list = np.array([len(graph_real_list[i]) for i in range(len(graph_real_list))]) | |||||
# graph_pred_len_list = np.array([len(graph_pred_list[i]) for i in range(len(graph_pred_list))]) | |||||
# real_order = np.argsort(graph_real_len_list)[::-1] | |||||
# pred_order = np.argsort(graph_pred_len_list)[::-1] | |||||
# # print(real_order) | |||||
# # print(pred_order) | |||||
# graph_real_list = [graph_real_list[i] for i in real_order] | |||||
# graph_pred_list = [graph_pred_list[i] for i in pred_order] | |||||
# | |||||
# shuffle(graph_pred_list) | |||||
# | |||||
# | |||||
# print('real average nodes', | |||||
# sum([graph_real_list[i].number_of_nodes() for i in range(len(graph_real_list))]) / len(graph_real_list)) | |||||
# print('pred average nodes', | |||||
# sum([graph_pred_list[i].number_of_nodes() for i in range(len(graph_pred_list))]) / len(graph_pred_list)) | |||||
# print('num of graphs', len(graph_real_list)) | |||||
# | |||||
# # draw all graphs | |||||
# for iter in range(2): | |||||
# print('iter', iter) | |||||
# graph_list = [] | |||||
# for i in range(8): | |||||
# index = 8*iter + i | |||||
# graph_real_list[index].remove_nodes_from(nx.isolates(graph_real_list[index])) | |||||
# graph_pred_list[index].remove_nodes_from(nx.isolates(graph_pred_list[index])) | |||||
# graph_list.append(graph_real_list[index]) | |||||
# graph_list.append(graph_pred_list[index]) | |||||
# print('real', graph_real_list[index].number_of_nodes()) | |||||
# print('pred', graph_pred_list[index].number_of_nodes()) | |||||
# draw_graph_list(graph_list, row=4, col=4, fname=figname+'_'+str(iter)) | |||||
# | |||||
# | |||||
# # for new model | |||||
# elif args.note == 'GraphRNN_structure' and args.is_flexible==True: | |||||
# for num_layers in range(4,5): | |||||
# graph_real_list = [] | |||||
# graph_pred_list = [] | |||||
# epoch_end = 30000 | |||||
# for epoch in [epoch_end-500*(8-i) for i in range(8)]: | |||||
# # give file name and figure name | |||||
# fname_real = args.graph_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \ | |||||
# str(epoch) + '_real_bptt_' + str(args.bptt)+'_'+str(num_layers)+'_dilation_'+str(args.is_dilation)+'_flexible_'+str(args.is_flexible)+'_bn_'+str(args.is_bn)+'_lr_'+str(args.lr) | |||||
# fname_pred = args.graph_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \ | |||||
# str(epoch) + '_pred_bptt_' + str(args.bptt)+'_'+str(num_layers)+'_dilation_'+str(args.is_dilation)+'_flexible_'+str(args.is_flexible)+'_bn_'+str(args.is_bn)+'_lr_'+str(args.lr) | |||||
# | |||||
# # load data | |||||
# graph_real_list += load_graph_list(fname_real+'.dat') | |||||
# graph_pred_list += load_graph_list(fname_pred+'.dat') | |||||
# print('num of graphs', len(graph_real_list)) | |||||
# | |||||
# figname = args.figure_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \ | |||||
# str(epoch) + str(args.sample_when_validate) + '_' + str(num_layers) + '_dilation_' + str(args.is_dilation) + '_flexible_' + str(args.is_flexible) + '_bn_' + str(args.is_bn) + '_lr_' + str(args.lr) | |||||
# | |||||
# # draw all graphs | |||||
# for iter in range(1): | |||||
# print('iter', iter) | |||||
# graph_list = [] | |||||
# for i in range(8): | |||||
# index = 8*iter + i | |||||
# graph_real_list[index].remove_nodes_from(nx.isolates(graph_real_list[index])) | |||||
# graph_pred_list[index].remove_nodes_from(nx.isolates(graph_pred_list[index])) | |||||
# graph_list.append(graph_real_list[index]) | |||||
# graph_list.append(graph_pred_list[index]) | |||||
# draw_graph_list(graph_list, row=4, col=4, fname=figname+'_'+str(iter)) |
### program configuration | |||||
class Args(): | |||||
def __init__(self): | |||||
### if clean tensorboard | |||||
self.clean_tensorboard = False | |||||
### Which CUDA GPU device is used for training | |||||
self.cuda = 1 | |||||
### Which GraphRNN model variant is used. | |||||
# The simple version of Graph RNN | |||||
# self.note = 'GraphRNN_MLP' | |||||
# The dependent Bernoulli sequence version of GraphRNN | |||||
self.note = 'GraphRNN_RNN' | |||||
## for comparison, removing the BFS compoenent | |||||
# self.note = 'GraphRNN_MLP_nobfs' | |||||
# self.note = 'GraphRNN_RNN_nobfs' | |||||
### Which dataset is used to train the model | |||||
# self.graph_type = 'DD' | |||||
# self.graph_type = 'caveman' | |||||
# self.graph_type = 'caveman_small' | |||||
# self.graph_type = 'caveman_small_single' | |||||
# self.graph_type = 'community4' | |||||
self.graph_type = 'grid' | |||||
# self.graph_type = 'grid_small' | |||||
# self.graph_type = 'ladder_small' | |||||
# self.graph_type = 'enzymes' | |||||
# self.graph_type = 'enzymes_small' | |||||
# self.graph_type = 'barabasi' | |||||
# self.graph_type = 'barabasi_small' | |||||
# self.graph_type = 'citeseer' | |||||
# self.graph_type = 'citeseer_small' | |||||
# self.graph_type = 'barabasi_noise' | |||||
# self.noise = 10 | |||||
# | |||||
# if self.graph_type == 'barabasi_noise': | |||||
# self.graph_type = self.graph_type+str(self.noise) | |||||
# if none, then auto calculate | |||||
self.max_num_node = None # max number of nodes in a graph | |||||
self.max_prev_node = None # max previous node that looks back | |||||
### network config | |||||
## GraphRNN | |||||
if 'small' in self.graph_type: | |||||
self.parameter_shrink = 2 | |||||
else: | |||||
self.parameter_shrink = 1 | |||||
self.hidden_size_rnn = int(128/self.parameter_shrink) # hidden size for main RNN | |||||
self.hidden_size_rnn_output = 16 # hidden size for output RNN | |||||
self.embedding_size_rnn = int(64/self.parameter_shrink) # the size for LSTM input | |||||
self.embedding_size_rnn_output = 8 # the embedding size for output rnn | |||||
self.embedding_size_output = int(64/self.parameter_shrink) # the embedding size for output (VAE/MLP) | |||||
self.batch_size = 32 # normal: 32, and the rest should be changed accordingly | |||||
self.test_batch_size = 32 | |||||
self.test_total_size = 1000 | |||||
self.num_layers = 4 | |||||
### training config | |||||
self.num_workers = 4 # num workers to load data, default 4 | |||||
self.batch_ratio = 32 # how many batches of samples per epoch, default 32, e.g., 1 epoch = 32 batches | |||||
self.epochs = 3000 # now one epoch means self.batch_ratio x batch_size | |||||
self.epochs_test_start = 100 | |||||
self.epochs_test = 100 | |||||
self.epochs_log = 100 | |||||
self.epochs_save = 100 | |||||
self.lr = 0.003 | |||||
self.milestones = [400, 1000] | |||||
self.lr_rate = 0.3 | |||||
self.sample_time = 2 # sample time in each time step, when validating | |||||
### output config | |||||
# self.dir_input = "/dfs/scratch0/jiaxuany0/" | |||||
self.dir_input = "./" | |||||
self.model_save_path = self.dir_input+'model_save/' # only for nll evaluation | |||||
self.graph_save_path = self.dir_input+'graphs/' | |||||
self.figure_save_path = self.dir_input+'figures/' | |||||
self.timing_save_path = self.dir_input+'timing/' | |||||
self.figure_prediction_save_path = self.dir_input+'figures_prediction/' | |||||
self.nll_save_path = self.dir_input+'nll/' | |||||
self.load = False # if load model, default lr is very low | |||||
self.load_epoch = 3000 | |||||
self.save = True | |||||
### baseline config | |||||
# self.generator_baseline = 'Gnp' | |||||
self.generator_baseline = 'BA' | |||||
# self.metric_baseline = 'general' | |||||
# self.metric_baseline = 'degree' | |||||
self.metric_baseline = 'clustering' | |||||
### filenames to save intemediate and final outputs | |||||
self.fname = self.note + '_' + self.graph_type + '_' + str(self.num_layers) + '_' + str(self.hidden_size_rnn) + '_' | |||||
self.fname_pred = self.note+'_'+self.graph_type+'_'+str(self.num_layers)+'_'+ str(self.hidden_size_rnn)+'_pred_' | |||||
self.fname_train = self.note+'_'+self.graph_type+'_'+str(self.num_layers)+'_'+ str(self.hidden_size_rnn)+'_train_' | |||||
self.fname_test = self.note + '_' + self.graph_type + '_' + str(self.num_layers) + '_' + str(self.hidden_size_rnn) + '_test_' | |||||
self.fname_baseline = self.graph_save_path + self.graph_type + self.generator_baseline+'_'+self.metric_baseline | |||||
from main import * | |||||
from scipy.linalg import toeplitz | |||||
import pyemd | |||||
import scipy.optimize as opt | |||||
def Graph_generator_baseline_train_rulebased(graphs,generator='BA'): | |||||
graph_nodes = [graphs[i].number_of_nodes() for i in range(len(graphs))] | |||||
graph_edges = [graphs[i].number_of_edges() for i in range(len(graphs))] | |||||
parameter = {} | |||||
for i in range(len(graph_nodes)): | |||||
nodes = graph_nodes[i] | |||||
edges = graph_edges[i] | |||||
# based on rule, calculate optimal parameter | |||||
if generator=='BA': | |||||
# BA optimal: nodes = n; edges = (n-m)*m | |||||
n = nodes | |||||
m = (n - np.sqrt(n**2-4*edges))/2 | |||||
parameter_temp = [n,m,1] | |||||
if generator=='Gnp': | |||||
# Gnp optimal: nodes = n; edges = ((n-1)*n/2)*p | |||||
n = nodes | |||||
p = float(edges)/((n-1)*n/2) | |||||
parameter_temp = [n,p,1] | |||||
# update parameter list | |||||
if nodes not in parameter.keys(): | |||||
parameter[nodes] = parameter_temp | |||||
else: | |||||
count = parameter[nodes][-1] | |||||
parameter[nodes] = [(parameter[nodes][i]*count+parameter_temp[i])/(count+1) for i in range(len(parameter[nodes]))] | |||||
parameter[nodes][-1] = count+1 | |||||
# print(parameter) | |||||
return parameter | |||||
def Graph_generator_baseline(graph_train, pred_num=1000, generator='BA'): | |||||
graph_nodes = [graph_train[i].number_of_nodes() for i in range(len(graph_train))] | |||||
graph_edges = [graph_train[i].number_of_edges() for i in range(len(graph_train))] | |||||
repeat = pred_num//len(graph_train) | |||||
graph_pred = [] | |||||
for i in range(len(graph_nodes)): | |||||
nodes = graph_nodes[i] | |||||
edges = graph_edges[i] | |||||
# based on rule, calculate optimal parameter | |||||
if generator=='BA': | |||||
# BA optimal: nodes = n; edges = (n-m)*m | |||||
n = nodes | |||||
m = int((n - np.sqrt(n**2-4*edges))/2) | |||||
for j in range(repeat): | |||||
graph_pred.append(nx.barabasi_albert_graph(n,m)) | |||||
if generator=='Gnp': | |||||
# Gnp optimal: nodes = n; edges = ((n-1)*n/2)*p | |||||
n = nodes | |||||
p = float(edges)/((n-1)*n/2) | |||||
for j in range(repeat): | |||||
graph_pred.append(nx.fast_gnp_random_graph(n, p)) | |||||
return graph_pred | |||||
def emd_distance(x, y, distance_scaling=1.0): | |||||
support_size = max(len(x), len(y)) | |||||
d_mat = toeplitz(range(support_size)).astype(np.float) | |||||
distance_mat = d_mat / distance_scaling | |||||
# convert histogram values x and y to float, and make them equal len | |||||
x = x.astype(np.float) | |||||
y = y.astype(np.float) | |||||
if len(x) < len(y): | |||||
x = np.hstack((x, [0.0] * (support_size - len(x)))) | |||||
elif len(y) < len(x): | |||||
y = np.hstack((y, [0.0] * (support_size - len(y)))) | |||||
emd = pyemd.emd(x, y, distance_mat) | |||||
return emd | |||||
# def Loss(x,args): | |||||
# ''' | |||||
# | |||||
# :param x: 1-D array, parameters to be optimized | |||||
# :param args: tuple (n, G, generator, metric). | |||||
# n: n for pred graph; | |||||
# G: real graph in networkx format; | |||||
# generator: 'BA', 'Gnp', 'Powerlaw'; | |||||
# metric: 'degree', 'clustering' | |||||
# :return: Loss: emd distance | |||||
# ''' | |||||
# # get argument | |||||
# generator = args[2] | |||||
# metric = args[3] | |||||
# | |||||
# # get real and pred graphs | |||||
# G_real = args[1] | |||||
# if generator=='BA': | |||||
# G_pred = nx.barabasi_albert_graph(args[0],int(np.rint(x))) | |||||
# if generator=='Gnp': | |||||
# G_pred = nx.fast_gnp_random_graph(args[0],x) | |||||
# | |||||
# # define metric | |||||
# if metric == 'degree': | |||||
# G_real_hist = np.array(nx.degree_histogram(G_real)) | |||||
# G_real_hist = G_real_hist / np.sum(G_real_hist) | |||||
# G_pred_hist = np.array(nx.degree_histogram(G_pred)) | |||||
# G_pred_hist = G_pred_hist/np.sum(G_pred_hist) | |||||
# if metric == 'clustering': | |||||
# G_real_hist, _ = np.histogram( | |||||
# np.array(list(nx.clustering(G_real).values())), bins=50, range=(0.0, 1.0), density=False) | |||||
# G_real_hist = G_real_hist / np.sum(G_real_hist) | |||||
# G_pred_hist, _ = np.histogram( | |||||
# np.array(list(nx.clustering(G_pred).values())), bins=50, range=(0.0, 1.0), density=False) | |||||
# G_pred_hist = G_pred_hist / np.sum(G_pred_hist) | |||||
# | |||||
# loss = emd_distance(G_real_hist,G_pred_hist) | |||||
# return loss | |||||
def Loss(x,n,G_real,generator,metric): | |||||
''' | |||||
:param x: 1-D array, parameters to be optimized | |||||
:param | |||||
n: n for pred graph; | |||||
G: real graph in networkx format; | |||||
generator: 'BA', 'Gnp', 'Powerlaw'; | |||||
metric: 'degree', 'clustering' | |||||
:return: Loss: emd distance | |||||
''' | |||||
# get argument | |||||
# get real and pred graphs | |||||
if generator=='BA': | |||||
G_pred = nx.barabasi_albert_graph(n,int(np.rint(x))) | |||||
if generator=='Gnp': | |||||
G_pred = nx.fast_gnp_random_graph(n,x) | |||||
# define metric | |||||
if metric == 'degree': | |||||
G_real_hist = np.array(nx.degree_histogram(G_real)) | |||||
G_real_hist = G_real_hist / np.sum(G_real_hist) | |||||
G_pred_hist = np.array(nx.degree_histogram(G_pred)) | |||||
G_pred_hist = G_pred_hist/np.sum(G_pred_hist) | |||||
if metric == 'clustering': | |||||
G_real_hist, _ = np.histogram( | |||||
np.array(list(nx.clustering(G_real).values())), bins=50, range=(0.0, 1.0), density=False) | |||||
G_real_hist = G_real_hist / np.sum(G_real_hist) | |||||
G_pred_hist, _ = np.histogram( | |||||
np.array(list(nx.clustering(G_pred).values())), bins=50, range=(0.0, 1.0), density=False) | |||||
G_pred_hist = G_pred_hist / np.sum(G_pred_hist) | |||||
loss = emd_distance(G_real_hist,G_pred_hist) | |||||
return loss | |||||
def optimizer_brute(x_min, x_max, x_step, n, G_real, generator, metric): | |||||
loss_all = [] | |||||
x_list = np.arange(x_min,x_max,x_step) | |||||
for x_test in x_list: | |||||
loss_all.append(Loss(x_test,n,G_real,generator,metric)) | |||||
x_optim = x_list[np.argmin(np.array(loss_all))] | |||||
return x_optim | |||||
def Graph_generator_baseline_train_optimizationbased(graphs,generator='BA',metric='degree'): | |||||
graph_nodes = [graphs[i].number_of_nodes() for i in range(len(graphs))] | |||||
parameter = {} | |||||
for i in range(len(graph_nodes)): | |||||
print('graph ',i) | |||||
nodes = graph_nodes[i] | |||||
if generator=='BA': | |||||
n = nodes | |||||
m = optimizer_brute(1,10,1, nodes, graphs[i], generator, metric) | |||||
parameter_temp = [n,m,1] | |||||
elif generator=='Gnp': | |||||
n = nodes | |||||
p = optimizer_brute(1e-6,1,0.01, nodes, graphs[i], generator, metric) | |||||
## if use evolution | |||||
# result = opt.differential_evolution(Loss,bounds=[(0,1)],args=(nodes, graphs[i], generator, metric),maxiter=1000) | |||||
# p = result.x | |||||
parameter_temp = [n, p, 1] | |||||
# update parameter list | |||||
if nodes not in parameter.keys(): | |||||
parameter[nodes] = parameter_temp | |||||
else: | |||||
count = parameter[nodes][2] | |||||
parameter[nodes] = [(parameter[nodes][i]*count+parameter_temp[i])/(count+1) for i in range(len(parameter[nodes]))] | |||||
parameter[nodes][2] = count+1 | |||||
print(parameter) | |||||
return parameter | |||||
def Graph_generator_baseline_test(graph_nodes, parameter, generator='BA'): | |||||
graphs = [] | |||||
for i in range(len(graph_nodes)): | |||||
nodes = graph_nodes[i] | |||||
if not nodes in parameter.keys(): | |||||
nodes = min(parameter.keys(), key=lambda k: abs(k - nodes)) | |||||
if generator=='BA': | |||||
n = int(parameter[nodes][0]) | |||||
m = int(np.rint(parameter[nodes][1])) | |||||
print(n,m) | |||||
graph = nx.barabasi_albert_graph(n,m) | |||||
if generator=='Gnp': | |||||
n = int(parameter[nodes][0]) | |||||
p = parameter[nodes][1] | |||||
print(n,p) | |||||
graph = nx.fast_gnp_random_graph(n,p) | |||||
graphs.append(graph) | |||||
return graphs | |||||
if __name__ == '__main__': | |||||
args = Args() | |||||
print('File name prefix', args.fname) | |||||
### load datasets | |||||
graphs = [] | |||||
# synthetic graphs | |||||
if args.graph_type=='ladder': | |||||
graphs = [] | |||||
for i in range(100, 201): | |||||
graphs.append(nx.ladder_graph(i)) | |||||
args.max_prev_node = 10 | |||||
if args.graph_type=='tree': | |||||
graphs = [] | |||||
for i in range(2,5): | |||||
for j in range(3,5): | |||||
graphs.append(nx.balanced_tree(i,j)) | |||||
args.max_prev_node = 256 | |||||
if args.graph_type=='caveman': | |||||
graphs = [] | |||||
for i in range(5,10): | |||||
for j in range(5,25): | |||||
graphs.append(nx.connected_caveman_graph(i, j)) | |||||
args.max_prev_node = 50 | |||||
if args.graph_type=='grid': | |||||
graphs = [] | |||||
for i in range(10,20): | |||||
for j in range(10,20): | |||||
graphs.append(nx.grid_2d_graph(i,j)) | |||||
args.max_prev_node = 40 | |||||
if args.graph_type=='barabasi': | |||||
graphs = [] | |||||
for i in range(100,200): | |||||
graphs.append(nx.barabasi_albert_graph(i,2)) | |||||
args.max_prev_node = 130 | |||||
# real graphs | |||||
if args.graph_type == 'enzymes': | |||||
graphs= Graph_load_batch(min_num_nodes=10, name='ENZYMES') | |||||
args.max_prev_node = 25 | |||||
if args.graph_type == 'protein': | |||||
graphs = Graph_load_batch(min_num_nodes=20, name='PROTEINS_full') | |||||
args.max_prev_node = 80 | |||||
if args.graph_type == 'DD': | |||||
graphs = Graph_load_batch(min_num_nodes=100, max_num_nodes=500, name='DD',node_attributes=False,graph_labels=True) | |||||
args.max_prev_node = 230 | |||||
graph_nodes = [graphs[i].number_of_nodes() for i in range(len(graphs))] | |||||
graph_edges = [graphs[i].number_of_edges() for i in range(len(graphs))] | |||||
args.max_num_node = max(graph_nodes) | |||||
# show graphs statistics | |||||
print('total graph num: {}'.format(len(graphs))) | |||||
print('max number node: {}'.format(args.max_num_node)) | |||||
print('max previous node: {}'.format(args.max_prev_node)) | |||||
# start baseline generation method | |||||
generator = args.generator_baseline | |||||
metric = args.metric_baseline | |||||
print(args.fname_baseline + '.dat') | |||||
if metric=='general': | |||||
parameter = Graph_generator_baseline_train_rulebased(graphs,generator=generator) | |||||
else: | |||||
parameter = Graph_generator_baseline_train_optimizationbased(graphs,generator=generator,metric=metric) | |||||
graphs_generated = Graph_generator_baseline_test(graph_nodes, parameter,generator) | |||||
save_graph_list(graphs_generated,args.fname_baseline + '.dat') |
import networkx as nx | |||||
import numpy as np | |||||
import torch | |||||
class GraphAdjSampler(torch.utils.data.Dataset): | |||||
def __init__(self, G_list, max_num_nodes, features='id'): | |||||
self.max_num_nodes = max_num_nodes | |||||
self.adj_all = [] | |||||
self.len_all = [] | |||||
self.feature_all = [] | |||||
for G in G_list: | |||||
adj = nx.to_numpy_matrix(G) | |||||
# the diagonal entries are 1 since they denote node probability | |||||
self.adj_all.append( | |||||
np.asarray(adj) + np.identity(G.number_of_nodes())) | |||||
self.len_all.append(G.number_of_nodes()) | |||||
if features == 'id': | |||||
self.feature_all.append(np.identity(max_num_nodes)) | |||||
elif features == 'deg': | |||||
degs = np.sum(np.array(adj), 1) | |||||
degs = np.expand_dims(np.pad(degs, [0, max_num_nodes - G.number_of_nodes()], 0), | |||||
axis=1) | |||||
self.feature_all.append(degs) | |||||
elif features == 'struct': | |||||
degs = np.sum(np.array(adj), 1) | |||||
degs = np.expand_dims(np.pad(degs, [0, max_num_nodes - G.number_of_nodes()], | |||||
'constant'), | |||||
axis=1) | |||||
clusterings = np.array(list(nx.clustering(G).values())) | |||||
clusterings = np.expand_dims(np.pad(clusterings, | |||||
[0, max_num_nodes - G.number_of_nodes()], | |||||
'constant'), | |||||
axis=1) | |||||
self.feature_all.append(np.hstack([degs, clusterings])) | |||||
def __len__(self): | |||||
return len(self.adj_all) | |||||
def __getitem__(self, idx): | |||||
adj = self.adj_all[idx] | |||||
num_nodes = adj.shape[0] | |||||
adj_padded = np.zeros((self.max_num_nodes, self.max_num_nodes)) | |||||
adj_padded[:num_nodes, :num_nodes] = adj | |||||
adj_decoded = np.zeros(self.max_num_nodes * (self.max_num_nodes + 1) // 2) | |||||
node_idx = 0 | |||||
adj_vectorized = adj_padded[np.triu(np.ones((self.max_num_nodes,self.max_num_nodes)) ) == 1] | |||||
# the following 2 lines recover the upper triangle of the adj matrix | |||||
#recovered = np.zeros((self.max_num_nodes, self.max_num_nodes)) | |||||
#recovered[np.triu(np.ones((self.max_num_nodes, self.max_num_nodes)) ) == 1] = adj_vectorized | |||||
#print(recovered) | |||||
return {'adj':adj_padded, | |||||
'adj_decoded':adj_vectorized, | |||||
'features':self.feature_all[idx].copy()} | |||||
import numpy as np | |||||
import scipy.optimize | |||||
import torch | |||||
import torch.nn as nn | |||||
from torch.autograd import Variable | |||||
from torch import optim | |||||
import torch.nn.functional as F | |||||
import torch.nn.init as init | |||||
import model | |||||
class GraphVAE(nn.Module): | |||||
def __init__(self, input_dim, hidden_dim, latent_dim, max_num_nodes, pool='sum'): | |||||
''' | |||||
Args: | |||||
input_dim: input feature dimension for node. | |||||
hidden_dim: hidden dim for 2-layer gcn. | |||||
latent_dim: dimension of the latent representation of graph. | |||||
''' | |||||
super(GraphVAE, self).__init__() | |||||
self.conv1 = model.GraphConv(input_dim=input_dim, output_dim=hidden_dim) | |||||
self.bn1 = nn.BatchNorm1d(input_dim) | |||||
self.conv2 = model.GraphConv(input_dim=hidden_dim, output_dim=hidden_dim) | |||||
self.bn2 = nn.BatchNorm1d(input_dim) | |||||
self.act = nn.ReLU() | |||||
output_dim = max_num_nodes * (max_num_nodes + 1) // 2 | |||||
#self.vae = model.MLP_VAE_plain(hidden_dim, latent_dim, output_dim) | |||||
self.vae = model.MLP_VAE_plain(input_dim * input_dim, latent_dim, output_dim) | |||||
#self.feature_mlp = model.MLP_plain(latent_dim, latent_dim, output_dim) | |||||
self.max_num_nodes = max_num_nodes | |||||
for m in self.modules(): | |||||
if isinstance(m, model.GraphConv): | |||||
m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu')) | |||||
elif isinstance(m, nn.BatchNorm1d): | |||||
m.weight.data.fill_(1) | |||||
m.bias.data.zero_() | |||||
self.pool = pool | |||||
def recover_adj_lower(self, l): | |||||
# NOTE: Assumes 1 per minibatch | |||||
adj = torch.zeros(self.max_num_nodes, self.max_num_nodes) | |||||
adj[torch.triu(torch.ones(self.max_num_nodes, self.max_num_nodes)) == 1] = l | |||||
return adj | |||||
def recover_full_adj_from_lower(self, lower): | |||||
diag = torch.diag(torch.diag(lower, 0)) | |||||
return lower + torch.transpose(lower, 0, 1) - diag | |||||
def edge_similarity_matrix(self, adj, adj_recon, matching_features, | |||||
matching_features_recon, sim_func): | |||||
S = torch.zeros(self.max_num_nodes, self.max_num_nodes, | |||||
self.max_num_nodes, self.max_num_nodes) | |||||
for i in range(self.max_num_nodes): | |||||
for j in range(self.max_num_nodes): | |||||
if i == j: | |||||
for a in range(self.max_num_nodes): | |||||
S[i, i, a, a] = adj[i, i] * adj_recon[a, a] * \ | |||||
sim_func(matching_features[i], matching_features_recon[a]) | |||||
# with feature not implemented | |||||
# if input_features is not None: | |||||
else: | |||||
for a in range(self.max_num_nodes): | |||||
for b in range(self.max_num_nodes): | |||||
if b == a: | |||||
continue | |||||
S[i, j, a, b] = adj[i, j] * adj[i, i] * adj[j, j] * \ | |||||
adj_recon[a, b] * adj_recon[a, a] * adj_recon[b, b] | |||||
return S | |||||
def mpm(self, x_init, S, max_iters=50): | |||||
x = x_init | |||||
for it in range(max_iters): | |||||
x_new = torch.zeros(self.max_num_nodes, self.max_num_nodes) | |||||
for i in range(self.max_num_nodes): | |||||
for a in range(self.max_num_nodes): | |||||
x_new[i, a] = x[i, a] * S[i, i, a, a] | |||||
pooled = [torch.max(x[j, :] * S[i, j, a, :]) | |||||
for j in range(self.max_num_nodes) if j != i] | |||||
neigh_sim = sum(pooled) | |||||
x_new[i, a] += neigh_sim | |||||
norm = torch.norm(x_new) | |||||
x = x_new / norm | |||||
return x | |||||
def deg_feature_similarity(self, f1, f2): | |||||
return 1 / (abs(f1 - f2) + 1) | |||||
def permute_adj(self, adj, curr_ind, target_ind): | |||||
''' Permute adjacency matrix. | |||||
The target_ind (connectivity) should be permuted to the curr_ind position. | |||||
''' | |||||
# order curr_ind according to target ind | |||||
ind = np.zeros(self.max_num_nodes, dtype=np.int) | |||||
ind[target_ind] = curr_ind | |||||
adj_permuted = torch.zeros((self.max_num_nodes, self.max_num_nodes)) | |||||
adj_permuted[:, :] = adj[ind, :] | |||||
adj_permuted[:, :] = adj_permuted[:, ind] | |||||
return adj_permuted | |||||
def pool_graph(self, x): | |||||
if self.pool == 'max': | |||||
out, _ = torch.max(x, dim=1, keepdim=False) | |||||
elif self.pool == 'sum': | |||||
out = torch.sum(x, dim=1, keepdim=False) | |||||
return out | |||||
def forward(self, input_features, adj): | |||||
#x = self.conv1(input_features, adj) | |||||
#x = self.bn1(x) | |||||
#x = self.act(x) | |||||
#x = self.conv2(x, adj) | |||||
#x = self.bn2(x) | |||||
# pool over all nodes | |||||
#graph_h = self.pool_graph(x) | |||||
graph_h = input_features.view(-1, self.max_num_nodes * self.max_num_nodes) | |||||
# vae | |||||
h_decode, z_mu, z_lsgms = self.vae(graph_h) | |||||
out = F.sigmoid(h_decode) | |||||
out_tensor = out.cpu().data | |||||
recon_adj_lower = self.recover_adj_lower(out_tensor) | |||||
recon_adj_tensor = self.recover_full_adj_from_lower(recon_adj_lower) | |||||
# set matching features be degree | |||||
out_features = torch.sum(recon_adj_tensor, 1) | |||||
adj_data = adj.cpu().data[0] | |||||
adj_features = torch.sum(adj_data, 1) | |||||
S = self.edge_similarity_matrix(adj_data, recon_adj_tensor, adj_features, out_features, | |||||
self.deg_feature_similarity) | |||||
# initialization strategies | |||||
init_corr = 1 / self.max_num_nodes | |||||
init_assignment = torch.ones(self.max_num_nodes, self.max_num_nodes) * init_corr | |||||
#init_assignment = torch.FloatTensor(4, 4) | |||||
#init.uniform(init_assignment) | |||||
assignment = self.mpm(init_assignment, S) | |||||
#print('Assignment: ', assignment) | |||||
# matching | |||||
# use negative of the assignment score since the alg finds min cost flow | |||||
row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy()) | |||||
print('row: ', row_ind) | |||||
print('col: ', col_ind) | |||||
# order row index according to col index | |||||
#adj_permuted = self.permute_adj(adj_data, row_ind, col_ind) | |||||
adj_permuted = adj_data | |||||
adj_vectorized = adj_permuted[torch.triu(torch.ones(self.max_num_nodes,self.max_num_nodes) )== 1].squeeze_() | |||||
adj_vectorized_var = Variable(adj_vectorized).cuda() | |||||
#print(adj) | |||||
#print('permuted: ', adj_permuted) | |||||
#print('recon: ', recon_adj_tensor) | |||||
adj_recon_loss = self.adj_recon_loss(adj_vectorized_var, out[0]) | |||||
print('recon: ', adj_recon_loss) | |||||
print(adj_vectorized_var) | |||||
print(out[0]) | |||||
loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp()) | |||||
loss_kl /= self.max_num_nodes * self.max_num_nodes # normalize | |||||
print('kl: ', loss_kl) | |||||
loss = adj_recon_loss + loss_kl | |||||
return loss | |||||
def forward_test(self, input_features, adj): | |||||
self.max_num_nodes = 4 | |||||
adj_data = torch.zeros(self.max_num_nodes, self.max_num_nodes) | |||||
adj_data[:4, :4] = torch.FloatTensor([[1,1,0,0], [1,1,1,0], [0,1,1,1], [0,0,1,1]]) | |||||
adj_features = torch.Tensor([2,3,3,2]) | |||||
adj_data1 = torch.zeros(self.max_num_nodes, self.max_num_nodes) | |||||
adj_data1 = torch.FloatTensor([[1,1,1,0], [1,1,0,1], [1,0,1,0], [0,1,0,1]]) | |||||
adj_features1 = torch.Tensor([3,3,2,2]) | |||||
S = self.edge_similarity_matrix(adj_data, adj_data1, adj_features, adj_features1, | |||||
self.deg_feature_similarity) | |||||
# initialization strategies | |||||
init_corr = 1 / self.max_num_nodes | |||||
init_assignment = torch.ones(self.max_num_nodes, self.max_num_nodes) * init_corr | |||||
#init_assignment = torch.FloatTensor(4, 4) | |||||
#init.uniform(init_assignment) | |||||
assignment = self.mpm(init_assignment, S) | |||||
#print('Assignment: ', assignment) | |||||
# matching | |||||
row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy()) | |||||
print('row: ', row_ind) | |||||
print('col: ', col_ind) | |||||
permuted_adj = self.permute_adj(adj_data, row_ind, col_ind) | |||||
print('permuted: ', permuted_adj) | |||||
adj_recon_loss = self.adj_recon_loss(permuted_adj, adj_data1) | |||||
print(adj_data1) | |||||
print('diff: ', adj_recon_loss) | |||||
def adj_recon_loss(self, adj_truth, adj_pred): | |||||
return F.binary_cross_entropy(adj_truth, adj_pred) | |||||
import argparse | |||||
import matplotlib.pyplot as plt | |||||
import networkx as nx | |||||
import numpy as np | |||||
import os | |||||
from random import shuffle | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.init as init | |||||
from torch.autograd import Variable | |||||
import torch.nn.functional as F | |||||
from torch import optim | |||||
from torch.optim.lr_scheduler import MultiStepLR | |||||
import data | |||||
from baselines.graphvae.model import GraphVAE | |||||
from baselines.graphvae.data import GraphAdjSampler | |||||
CUDA = 2 | |||||
LR_milestones = [500, 1000] | |||||
def build_model(args, max_num_nodes): | |||||
out_dim = max_num_nodes * (max_num_nodes + 1) // 2 | |||||
if args.feature_type == 'id': | |||||
input_dim = max_num_nodes | |||||
elif args.feature_type == 'deg': | |||||
input_dim = 1 | |||||
elif args.feature_type == 'struct': | |||||
input_dim = 2 | |||||
model = GraphVAE(input_dim, 64, 256, max_num_nodes) | |||||
return model | |||||
def train(args, dataloader, model): | |||||
epoch = 1 | |||||
optimizer = optim.Adam(list(model.parameters()), lr=args.lr) | |||||
scheduler = MultiStepLR(optimizer, milestones=LR_milestones, gamma=args.lr) | |||||
model.train() | |||||
for epoch in range(5000): | |||||
for batch_idx, data in enumerate(dataloader): | |||||
model.zero_grad() | |||||
features = data['features'].float() | |||||
adj_input = data['adj'].float() | |||||
features = Variable(features).cuda() | |||||
adj_input = Variable(adj_input).cuda() | |||||
loss = model(features, adj_input) | |||||
print('Epoch: ', epoch, ', Iter: ', batch_idx, ', Loss: ', loss) | |||||
loss.backward() | |||||
optimizer.step() | |||||
scheduler.step() | |||||
break | |||||
def arg_parse(): | |||||
parser = argparse.ArgumentParser(description='GraphVAE arguments.') | |||||
io_parser = parser.add_mutually_exclusive_group(required=False) | |||||
io_parser.add_argument('--dataset', dest='dataset', | |||||
help='Input dataset.') | |||||
parser.add_argument('--lr', dest='lr', type=float, | |||||
help='Learning rate.') | |||||
parser.add_argument('--batch_size', dest='batch_size', type=int, | |||||
help='Batch size.') | |||||
parser.add_argument('--num_workers', dest='num_workers', type=int, | |||||
help='Number of workers to load data.') | |||||
parser.add_argument('--max_num_nodes', dest='max_num_nodes', type=int, | |||||
help='Predefined maximum number of nodes in train/test graphs. -1 if determined by \ | |||||
training data.') | |||||
parser.add_argument('--feature', dest='feature_type', | |||||
help='Feature used for encoder. Can be: id, deg') | |||||
parser.set_defaults(dataset='grid', | |||||
feature_type='id', | |||||
lr=0.001, | |||||
batch_size=1, | |||||
num_workers=1, | |||||
max_num_nodes=-1) | |||||
return parser.parse_args() | |||||
def main(): | |||||
prog_args = arg_parse() | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(CUDA) | |||||
print('CUDA', CUDA) | |||||
### running log | |||||
if prog_args.dataset == 'enzymes': | |||||
graphs= data.Graph_load_batch(min_num_nodes=10, name='ENZYMES') | |||||
num_graphs_raw = len(graphs) | |||||
elif prog_args.dataset == 'grid': | |||||
graphs = [] | |||||
for i in range(2,3): | |||||
for j in range(2,3): | |||||
graphs.append(nx.grid_2d_graph(i,j)) | |||||
num_graphs_raw = len(graphs) | |||||
if prog_args.max_num_nodes == -1: | |||||
max_num_nodes = max([graphs[i].number_of_nodes() for i in range(len(graphs))]) | |||||
else: | |||||
max_num_nodes = prog_args.max_num_nodes | |||||
# remove graphs with number of nodes greater than max_num_nodes | |||||
graphs = [g for g in graphs if g.number_of_nodes() <= max_num_nodes] | |||||
graphs_len = len(graphs) | |||||
print('Number of graphs removed due to upper-limit of number of nodes: ', | |||||
num_graphs_raw - graphs_len) | |||||
graphs_test = graphs[int(0.8 * graphs_len):] | |||||
#graphs_train = graphs[0:int(0.8*graphs_len)] | |||||
graphs_train = graphs | |||||
print('total graph num: {}, training set: {}'.format(len(graphs),len(graphs_train))) | |||||
print('max number node: {}'.format(max_num_nodes)) | |||||
dataset = GraphAdjSampler(graphs_train, max_num_nodes, features=prog_args.feature_type) | |||||
#sample_strategy = torch.utils.data.sampler.WeightedRandomSampler( | |||||
# [1.0 / len(dataset) for i in range(len(dataset))], | |||||
# num_samples=prog_args.batch_size, | |||||
# replacement=False) | |||||
dataset_loader = torch.utils.data.DataLoader( | |||||
dataset, | |||||
batch_size=prog_args.batch_size, | |||||
num_workers=prog_args.num_workers) | |||||
model = build_model(prog_args, max_num_nodes).cuda() | |||||
train(prog_args, dataset_loader, model) | |||||
if __name__ == '__main__': | |||||
main() |
"""Stochastic block model.""" | |||||
import argparse | |||||
import os | |||||
from time import time | |||||
import edward as ed | |||||
import networkx as nx | |||||
import numpy as np | |||||
import tensorflow as tf | |||||
from edward.models import Bernoulli, Multinomial, Beta, Dirichlet, PointMass, Normal | |||||
from observations import karate | |||||
from sklearn.metrics.cluster import adjusted_rand_score | |||||
import utils | |||||
CUDA = 2 | |||||
ed.set_seed(int(time())) | |||||
#ed.set_seed(42) | |||||
# DATA | |||||
#X_data, Z_true = karate("data") | |||||
def disjoint_cliques_test_graph(num_cliques, clique_size): | |||||
G = nx.disjoint_union_all([nx.complete_graph(clique_size) for _ in range(num_cliques)]) | |||||
return nx.to_numpy_matrix(G) | |||||
def mmsb(N, K, data): | |||||
# sparsity | |||||
rho = 0.3 | |||||
# MODEL | |||||
# probability of belonging to each of K blocks for each node | |||||
gamma = Dirichlet(concentration=tf.ones([K])) | |||||
# block connectivity | |||||
Pi = Beta(concentration0=tf.ones([K, K]), concentration1=tf.ones([K, K])) | |||||
# probability of belonging to each of K blocks for all nodes | |||||
Z = Multinomial(total_count=1.0, probs=gamma, sample_shape=N) | |||||
# adjacency | |||||
X = Bernoulli(probs=(1 - rho) * tf.matmul(Z, tf.matmul(Pi, tf.transpose(Z)))) | |||||
# INFERENCE (EM algorithm) | |||||
qgamma = PointMass(params=tf.nn.softmax(tf.Variable(tf.random_normal([K])))) | |||||
qPi = PointMass(params=tf.nn.sigmoid(tf.Variable(tf.random_normal([K, K])))) | |||||
qZ = PointMass(params=tf.nn.softmax(tf.Variable(tf.random_normal([N, K])))) | |||||
#qgamma = Normal(loc=tf.get_variable("qgamma/loc", [K]), | |||||
# scale=tf.nn.softplus( | |||||
# tf.get_variable("qgamma/scale", [K]))) | |||||
#qPi = Normal(loc=tf.get_variable("qPi/loc", [K, K]), | |||||
# scale=tf.nn.softplus( | |||||
# tf.get_variable("qPi/scale", [K, K]))) | |||||
#qZ = Normal(loc=tf.get_variable("qZ/loc", [N, K]), | |||||
# scale=tf.nn.softplus( | |||||
# tf.get_variable("qZ/scale", [N, K]))) | |||||
#inference = ed.KLqp({gamma: qgamma, Pi: qPi, Z: qZ}, data={X: data}) | |||||
inference = ed.MAP({gamma: qgamma, Pi: qPi, Z: qZ}, data={X: data}) | |||||
#inference.run() | |||||
n_iter = 6000 | |||||
inference.initialize(optimizer=tf.train.AdamOptimizer(learning_rate=0.01), n_iter=n_iter) | |||||
tf.global_variables_initializer().run() | |||||
for _ in range(inference.n_iter): | |||||
info_dict = inference.update() | |||||
inference.print_progress(info_dict) | |||||
inference.finalize() | |||||
print('qgamma after: ', qgamma.mean().eval()) | |||||
return qZ.mean().eval(), qPi.eval() | |||||
def arg_parse(): | |||||
parser = argparse.ArgumentParser(description='MMSB arguments.') | |||||
parser.add_argument('--dataset', dest='dataset', | |||||
help='Input dataset.') | |||||
parser.add_argument('--K', dest='K', type=int, | |||||
help='Number of blocks.') | |||||
parser.add_argument('--samples-per-G', dest='samples', type=int, | |||||
help='Number of samples for every graph.') | |||||
parser.set_defaults(dataset='community', | |||||
K=4, | |||||
samples=1) | |||||
return parser.parse_args() | |||||
def graph_gen_from_blockmodel(B, Z): | |||||
n_blocks = len(B) | |||||
B = np.array(B) | |||||
Z = np.array(Z) | |||||
adj_prob = np.dot(Z, np.dot(B, np.transpose(Z))) | |||||
adj = np.random.binomial(1, adj_prob * 0.3) | |||||
return nx.from_numpy_matrix(adj) | |||||
if __name__ == '__main__': | |||||
prog_args = arg_parse() | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(CUDA) | |||||
print('CUDA', CUDA) | |||||
X_dataset = [] | |||||
#X_data = nx.to_numpy_matrix(nx.connected_caveman_graph(4, 7)) | |||||
if prog_args.dataset == 'clique_test': | |||||
X_data = disjoint_cliques_test_graph(4, 7) | |||||
X_dataset.append(X_data) | |||||
elif prog_args.dataset == 'citeseer': | |||||
graphs = utils.citeseer_ego() | |||||
X_dataset = [nx.to_numpy_matrix(g) for g in graphs] | |||||
elif prog_args.dataset == 'community': | |||||
graphs = [] | |||||
for i in range(2, 3): | |||||
for j in range(30, 81): | |||||
for k in range(10): | |||||
graphs.append(utils.caveman_special(i,j, p_edge=0.3)) | |||||
X_dataset = [nx.to_numpy_matrix(g) for g in graphs] | |||||
elif prog_args.dataset == 'grid': | |||||
graphs = [] | |||||
for i in range(10,20): | |||||
for j in range(10,20): | |||||
graphs.append(nx.grid_2d_graph(i,j)) | |||||
X_dataset = [nx.to_numpy_matrix(g) for g in graphs] | |||||
elif prog_args.dataset.startswith('community'): | |||||
graphs = [] | |||||
num_communities = int(prog_args.dataset[-1]) | |||||
print('Creating dataset with ', num_communities, ' communities') | |||||
c_sizes = np.random.choice([12, 13, 14, 15, 16, 17], num_communities) | |||||
for k in range(3000): | |||||
graphs.append(utils.n_community(c_sizes, p_inter=0.01)) | |||||
X_dataset = [nx.to_numpy_matrix(g) for g in graphs] | |||||
print('Number of graphs: ', len(X_dataset)) | |||||
K = prog_args.K # number of clusters | |||||
gen_graphs = [] | |||||
for i in range(len(X_dataset)): | |||||
if i % 5 == 0: | |||||
print(i) | |||||
X_data = X_dataset[i] | |||||
N = X_data.shape[0] # number of vertices | |||||
Zp, B = mmsb(N, K, X_data) | |||||
#print("Block: ", B) | |||||
Z_pred = Zp.argmax(axis=1) | |||||
print("Result (label flip can happen):") | |||||
#print("prob: ", Zp) | |||||
print("Predicted") | |||||
print(Z_pred) | |||||
#print(Z_true) | |||||
#print("Adjusted Rand Index =", adjusted_rand_score(Z_pred, Z_true)) | |||||
for j in range(prog_args.samples): | |||||
gen_graphs.append(graph_gen_from_blockmodel(B, Zp)) | |||||
save_path = '/lfs/local/0/rexy/graph-generation/eval_results/mmsb/' | |||||
utils.save_graph_list(gen_graphs, os.path.join(save_path, prog_args.dataset + '.dat')) | |||||
import networkx as nx | |||||
import numpy as np | |||||
from utils import * | |||||
from data import * | |||||
def create(args): | |||||
### load datasets | |||||
graphs=[] | |||||
# synthetic graphs | |||||
if args.graph_type=='ladder': | |||||
graphs = [] | |||||
for i in range(100, 201): | |||||
graphs.append(nx.ladder_graph(i)) | |||||
args.max_prev_node = 10 | |||||
elif args.graph_type=='ladder_small': | |||||
graphs = [] | |||||
for i in range(2, 11): | |||||
graphs.append(nx.ladder_graph(i)) | |||||
args.max_prev_node = 10 | |||||
elif args.graph_type=='tree': | |||||
graphs = [] | |||||
for i in range(2,5): | |||||
for j in range(3,5): | |||||
graphs.append(nx.balanced_tree(i,j)) | |||||
args.max_prev_node = 256 | |||||
elif args.graph_type=='caveman': | |||||
# graphs = [] | |||||
# for i in range(5,10): | |||||
# for j in range(5,25): | |||||
# for k in range(5): | |||||
# graphs.append(nx.relaxed_caveman_graph(i, j, p=0.1)) | |||||
graphs = [] | |||||
for i in range(2, 3): | |||||
for j in range(30, 81): | |||||
for k in range(10): | |||||
graphs.append(caveman_special(i,j, p_edge=0.3)) | |||||
args.max_prev_node = 100 | |||||
elif args.graph_type=='caveman_small': | |||||
# graphs = [] | |||||
# for i in range(2,5): | |||||
# for j in range(2,6): | |||||
# for k in range(10): | |||||
# graphs.append(nx.relaxed_caveman_graph(i, j, p=0.1)) | |||||
graphs = [] | |||||
for i in range(2, 3): | |||||
for j in range(6, 11): | |||||
for k in range(20): | |||||
graphs.append(caveman_special(i, j, p_edge=0.8)) # default 0.8 | |||||
args.max_prev_node = 20 | |||||
elif args.graph_type=='caveman_small_single': | |||||
# graphs = [] | |||||
# for i in range(2,5): | |||||
# for j in range(2,6): | |||||
# for k in range(10): | |||||
# graphs.append(nx.relaxed_caveman_graph(i, j, p=0.1)) | |||||
graphs = [] | |||||
for i in range(2, 3): | |||||
for j in range(8, 9): | |||||
for k in range(100): | |||||
graphs.append(caveman_special(i, j, p_edge=0.5)) | |||||
args.max_prev_node = 20 | |||||
elif args.graph_type.startswith('community'): | |||||
num_communities = int(args.graph_type[-1]) | |||||
print('Creating dataset with ', num_communities, ' communities') | |||||
c_sizes = np.random.choice([12, 13, 14, 15, 16, 17], num_communities) | |||||
#c_sizes = [15] * num_communities | |||||
for k in range(3000): | |||||
graphs.append(n_community(c_sizes, p_inter=0.01)) | |||||
args.max_prev_node = 80 | |||||
elif args.graph_type=='grid': | |||||
graphs = [] | |||||
for i in range(10,20): | |||||
for j in range(10,20): | |||||
graphs.append(nx.grid_2d_graph(i,j)) | |||||
args.max_prev_node = 40 | |||||
elif args.graph_type=='grid_small': | |||||
graphs = [] | |||||
for i in range(2,5): | |||||
for j in range(2,6): | |||||
graphs.append(nx.grid_2d_graph(i,j)) | |||||
args.max_prev_node = 15 | |||||
elif args.graph_type=='barabasi': | |||||
graphs = [] | |||||
for i in range(100,200): | |||||
for j in range(4,5): | |||||
for k in range(5): | |||||
graphs.append(nx.barabasi_albert_graph(i,j)) | |||||
args.max_prev_node = 130 | |||||
elif args.graph_type=='barabasi_small': | |||||
graphs = [] | |||||
for i in range(4,21): | |||||
for j in range(3,4): | |||||
for k in range(10): | |||||
graphs.append(nx.barabasi_albert_graph(i,j)) | |||||
args.max_prev_node = 20 | |||||
elif args.graph_type=='grid_big': | |||||
graphs = [] | |||||
for i in range(36, 46): | |||||
for j in range(36, 46): | |||||
graphs.append(nx.grid_2d_graph(i, j)) | |||||
args.max_prev_node = 90 | |||||
elif 'barabasi_noise' in args.graph_type: | |||||
graphs = [] | |||||
for i in range(100,101): | |||||
for j in range(4,5): | |||||
for k in range(500): | |||||
graphs.append(nx.barabasi_albert_graph(i,j)) | |||||
graphs = perturb_new(graphs,p=args.noise/10.0) | |||||
args.max_prev_node = 99 | |||||
# real graphs | |||||
elif args.graph_type == 'enzymes': | |||||
graphs= Graph_load_batch(min_num_nodes=10, name='ENZYMES') | |||||
args.max_prev_node = 25 | |||||
elif args.graph_type == 'enzymes_small': | |||||
graphs_raw = Graph_load_batch(min_num_nodes=10, name='ENZYMES') | |||||
graphs = [] | |||||
for G in graphs_raw: | |||||
if G.number_of_nodes()<=20: | |||||
graphs.append(G) | |||||
args.max_prev_node = 15 | |||||
elif args.graph_type == 'protein': | |||||
graphs = Graph_load_batch(min_num_nodes=20, name='PROTEINS_full') | |||||
args.max_prev_node = 80 | |||||
elif args.graph_type == 'DD': | |||||
graphs = Graph_load_batch(min_num_nodes=100, max_num_nodes=500, name='DD',node_attributes=False,graph_labels=True) | |||||
args.max_prev_node = 230 | |||||
elif args.graph_type == 'citeseer': | |||||
_, _, G = Graph_load(dataset='citeseer') | |||||
G = max(nx.connected_component_subgraphs(G), key=len) | |||||
G = nx.convert_node_labels_to_integers(G) | |||||
graphs = [] | |||||
for i in range(G.number_of_nodes()): | |||||
G_ego = nx.ego_graph(G, i, radius=3) | |||||
if G_ego.number_of_nodes() >= 50 and (G_ego.number_of_nodes() <= 400): | |||||
graphs.append(G_ego) | |||||
args.max_prev_node = 250 | |||||
elif args.graph_type == 'citeseer_small': | |||||
_, _, G = Graph_load(dataset='citeseer') | |||||
G = max(nx.connected_component_subgraphs(G), key=len) | |||||
G = nx.convert_node_labels_to_integers(G) | |||||
graphs = [] | |||||
for i in range(G.number_of_nodes()): | |||||
G_ego = nx.ego_graph(G, i, radius=1) | |||||
if (G_ego.number_of_nodes() >= 4) and (G_ego.number_of_nodes() <= 20): | |||||
graphs.append(G_ego) | |||||
shuffle(graphs) | |||||
graphs = graphs[0:200] | |||||
args.max_prev_node = 15 | |||||
return graphs | |||||
README for dataset DD | |||||
=== Usage === | |||||
This folder contains the following comma separated text files | |||||
(replace DS by the name of the dataset): | |||||
n = total number of nodes | |||||
m = total number of edges | |||||
N = number of graphs | |||||
(1) DS_A.txt (m lines) | |||||
sparse (block diagonal) adjacency matrix for all graphs, | |||||
each line corresponds to (row, col) resp. (node_id, node_id) | |||||
(2) DS_graph_indicator.txt (n lines) | |||||
column vector of graph identifiers for all nodes of all graphs, | |||||
the value in the i-th line is the graph_id of the node with node_id i | |||||
(3) DS_graph_labels.txt (N lines) | |||||
class labels for all graphs in the dataset, | |||||
the value in the i-th line is the class label of the graph with graph_id i | |||||
(4) DS_node_labels.txt (n lines) | |||||
column vector of node labels, | |||||
the value in the i-th line corresponds to the node with node_id i | |||||
There are OPTIONAL files if the respective information is available: | |||||
(5) DS_edge_labels.txt (m lines; same size as DS_A_sparse.txt) | |||||
labels for the edges in DS_A_sparse.txt | |||||
(6) DS_edge_attributes.txt (m lines; same size as DS_A.txt) | |||||
attributes for the edges in DS_A.txt | |||||
(7) DS_node_attributes.txt (n lines) | |||||
matrix of node attributes, | |||||
the comma seperated values in the i-th line is the attribute vector of the node with node_id i | |||||
(8) DS_graph_attributes.txt (N lines) | |||||
regression values for all graphs in the dataset, | |||||
the value in the i-th line is the attribute of the graph with graph_id i | |||||
=== Description === | |||||
D&D is a dataset of 1178 protein structures (Dobson and Doig, 2003). Each protein is | |||||
represented by a graph, in which the nodes are amino acids and two nodes are connected | |||||
by an edge if they are less than 6 Angstroms apart. The prediction task is to classify | |||||
the protein structures into enzymes and non-enzymes. | |||||
=== Previous Use of the Dataset === | |||||
Neumann, M., Garnett R., Bauckhage Ch., Kersting K.: Propagation Kernels: Efficient Graph | |||||
Kernels from Propagated Information. Under review at MLJ. | |||||
Neumann, M., Patricia, N., Garnett, R., Kersting, K.: Efficient Graph Kernels by | |||||
Randomization. In: P.A. Flach, T.D. Bie, N. Cristianini (eds.) ECML/PKDD, Notes in | |||||
Computer Science, vol. 7523, pp. 378-393. Springer (2012). | |||||
Shervashidze, N., Schweitzer, P., van Leeuwen, E., Mehlhorn, K., Borgwardt, K.: | |||||
Weisfeiler-Lehman Graph Kernels. Journal of Machine Learning Research 12, 2539-2561 (2011) | |||||
=== References === | |||||
P. D. Dobson and A. J. Doig. Distinguishing enzyme structures from non-enzymes without | |||||
alignments. J. Mol. Biol., 330(4):771–783, Jul 2003. | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
6 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
5 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
1 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
2 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
3 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 | |||||
4 |
README for dataset ENZYMES | |||||
=== Usage === | |||||
This folder contains the following comma separated text files | |||||
(replace DS by the name of the dataset): | |||||
n = total number of nodes | |||||
m = total number of edges | |||||
N = number of graphs | |||||
(1) DS_A.txt (m lines) | |||||
sparse (block diagonal) adjacency matrix for all graphs, | |||||
each line corresponds to (row, col) resp. (node_id, node_id) | |||||
(2) DS_graph_indicator.txt (n lines) | |||||
column vector of graph identifiers for all nodes of all graphs, | |||||
the value in the i-th line is the graph_id of the node with node_id i | |||||
(3) DS_graph_labels.txt (N lines) | |||||
class labels for all graphs in the dataset, | |||||
the value in the i-th line is the class label of the graph with graph_id i | |||||
(4) DS_node_labels.txt (n lines) | |||||
column vector of node labels, | |||||
the value in the i-th line corresponds to the node with node_id i | |||||
There are OPTIONAL files if the respective information is available: | |||||
(5) DS_edge_labels.txt (m lines; same size as DS_A_sparse.txt) | |||||
labels for the edges in DS_A_sparse.txt | |||||
(6) DS_edge_attributes.txt (m lines; same size as DS_A.txt) | |||||
attributes for the edges in DS_A.txt | |||||
(7) DS_node_attributes.txt (n lines) | |||||
matrix of node attributes, | |||||
the comma seperated values in the i-th line is the attribute vector of the node with node_id i | |||||
(8) DS_graph_attributes.txt (N lines) | |||||
regression values for all graphs in the dataset, | |||||
the value in the i-th line is the attribute of the graph with graph_id i | |||||
=== Description === | |||||
ENZYMES is a dataset of protein tertiary structures obtained from (Borgwardt et al., 2005) | |||||
consisting of 600 enzymes from the BRENDA enzyme database (Schomburg et al., 2004). | |||||
In this case the task is to correctly assign each enzyme to one of the 6 EC top-level | |||||
classes. | |||||
=== Previous Use of the Dataset === | |||||
Feragen, A., Kasenburg, N., Petersen, J., de Bruijne, M., Borgwardt, K.M.: Scalable | |||||
kernels for graphs with continuous attributes. In: C.J.C. Burges, L. Bottou, Z. Ghahra- | |||||
mani, K.Q. Weinberger (eds.) NIPS, pp. 216-224 (2013) | |||||
Neumann, M., Garnett R., Bauckhage Ch., Kersting K.: Propagation Kernels: Efficient Graph | |||||
Kernels from Propagated Information. Under review at MLJ. | |||||
=== References === | |||||
K. M. Borgwardt, C. S. Ong, S. Schoenauer, S. V. N. Vishwanathan, A. J. Smola, and H. P. | |||||
Kriegel. Protein function prediction via graph kernels. Bioinformatics, 21(Suppl 1):i47–i56, | |||||
Jun 2005. | |||||
I. Schomburg, A. Chang, C. Ebeling, M. Gremse, C. Heldt, G. Huhn, and D. Schomburg. Brenda, | |||||
the enzyme database: updates and major new developments. Nucleic Acids Research, 32D:431–433, 2004. |
import numpy as np | |||||
import networkx as nx | |||||
G = nx.Graph() | |||||
# load data | |||||
data_adj = np.loadtxt('ENZYMES_A.txt', delimiter=',').astype(int) | |||||
data_node_att = np.loadtxt('ENZYMES_node_attributes.txt', delimiter=',') | |||||
data_node_label = np.loadtxt('ENZYMES_node_labels.txt', delimiter=',').astype(int) | |||||
data_graph_indicator = np.loadtxt('ENZYMES_graph_indicator.txt', delimiter=',').astype(int) | |||||
data_graph_labels = np.loadtxt('ENZYMES_graph_labels.txt', delimiter=',').astype(int) | |||||
data_tuple = list(map(tuple, data_adj)) | |||||
print(len(data_tuple)) | |||||
print(data_tuple[0]) | |||||
# add edges | |||||
G.add_edges_from(data_tuple) | |||||
# add node attributes | |||||
for i in range(data_node_att.shape[0]): | |||||
G.add_node(i+1, feature = data_node_att[i]) | |||||
G.add_node(i+1, label = data_node_label[i]) | |||||
G.remove_nodes_from(nx.isolates(G)) | |||||
print(G.number_of_nodes()) | |||||
print(G.number_of_edges()) | |||||
# split into graphs | |||||
graph_num = 600 | |||||
node_list = np.arange(data_graph_indicator.shape[0])+1 | |||||
graphs = [] | |||||
node_num_list = [] | |||||
for i in range(graph_num): | |||||
# find the nodes for each graph | |||||
nodes = node_list[data_graph_indicator==i+1] | |||||
G_sub = G.subgraph(nodes) | |||||
graphs.append(G_sub) | |||||
G_sub.graph['label'] = data_graph_labels[i] | |||||
# print('nodes', G_sub.number_of_nodes()) | |||||
# print('edges', G_sub.number_of_edges()) | |||||
# print('label', G_sub.graph) | |||||
node_num_list.append(G_sub.number_of_nodes()) | |||||
print('average', sum(node_num_list)/len(node_num_list)) | |||||
print('all', len(node_num_list)) | |||||
node_num_list = np.array(node_num_list) | |||||
print('selected', len(node_num_list[node_num_list>10])) | |||||
# print(graphs[0].nodes(data=True)[0][1]['feature']) | |||||
# print(graphs[0].nodes()) | |||||
keys = tuple(graphs[0].nodes()) | |||||
# print(nx.get_node_attributes(graphs[0], 'feature')) | |||||
dictionary = nx.get_node_attributes(graphs[0], 'feature') | |||||
# print('keys', keys) | |||||
# print('keys from dict', list(dictionary.keys())) | |||||
# print('valuse from dict', list(dictionary.values())) | |||||
features = np.zeros((len(dictionary), list(dictionary.values())[0].shape[0])) | |||||
for i in range(len(dictionary)): | |||||
features[i,:] = list(dictionary.values())[i] | |||||
# print(features) | |||||
# print(features.shape) |
README for dataset PROTEINS_full | |||||
=== Usage === | |||||
This folder contains the following comma separated text files | |||||
(replace DS by the name of the dataset): | |||||
n = total number of nodes | |||||
m = total number of edges | |||||
N = number of graphs | |||||
(1) DS_A.txt (m lines) | |||||
sparse (block diagonal) adjacency matrix for all graphs, | |||||
each line corresponds to (row, col) resp. (node_id, node_id) | |||||
(2) DS_graph_indicator.txt (n lines) | |||||
column vector of graph identifiers for all nodes of all graphs, | |||||
the value in the i-th line is the graph_id of the node with node_id i | |||||
(3) DS_graph_labels.txt (N lines) | |||||
class labels for all graphs in the dataset, | |||||
the value in the i-th line is the class label of the graph with graph_id i | |||||
(4) DS_node_labels.txt (n lines) | |||||
column vector of node labels, | |||||
the value in the i-th line corresponds to the node with node_id i | |||||
There are OPTIONAL files if the respective information is available: | |||||
(5) DS_edge_labels.txt (m lines; same size as DS_A_sparse.txt) | |||||
labels for the edges in DS_A_sparse.txt | |||||
(6) DS_edge_attributes.txt (m lines; same size as DS_A.txt) | |||||
attributes for the edges in DS_A.txt | |||||
(7) DS_node_attributes.txt (n lines) | |||||
matrix of node attributes, | |||||
the comma seperated values in the i-th line is the attribute vector of the node with node_id i | |||||
(8) DS_graph_attributes.txt (N lines) | |||||
regression values for all graphs in the dataset, | |||||
the value in the i-th line is the attribute of the graph with graph_id i | |||||
=== Previous Use of the Dataset === | |||||
Neumann, M., Garnett R., Bauckhage Ch., Kersting K.: Propagation Kernels: Efficient Graph | |||||
Kernels from Propagated Information. Under review at MLJ. | |||||
=== References === | |||||
K. M. Borgwardt, C. S. Ong, S. Schoenauer, S. V. N. Vishwanathan, A. J. Smola, and H. P. | |||||
Kriegel. Protein function prediction via graph kernels. Bioinformatics, 21(Suppl 1):i47–i56, | |||||
Jun 2005. | |||||
P. D. Dobson and A. J. Doig. Distinguishing enzyme structures from non-enzymes without | |||||
alignments. J. Mol. Biol., 330(4):771–783, Jul 2003. | |||||
name: root | |||||
channels: | |||||
- soumith | |||||
- conda-forge | |||||
- defaults | |||||
dependencies: | |||||
- conda=4.3.29=py36_0 | |||||
- conda-env=2.6.0=0 | |||||
- gensim=3.0.0=py36_0 | |||||
- smart_open=1.5.3=py36_0 | |||||
- _ipyw_jlab_nb_ext_conf=0.1.0=py36he11e457_0 | |||||
- alabaster=0.7.10=py36h306e16b_0 | |||||
- anaconda=5.0.0=py36h06de3c5_0 | |||||
- anaconda-client=1.6.5=py36h19c0dcd_0 | |||||
- anaconda-navigator=1.6.8=py36h672ccc7_0 | |||||
- anaconda-project=0.8.0=py36h29abdf5_0 | |||||
- asn1crypto=0.22.0=py36h265ca7c_1 | |||||
- astroid=1.5.3=py36hbdb9df2_0 | |||||
- astropy=2.0.2=py36ha51211e_4 | |||||
- babel=2.5.0=py36h7d14adf_0 | |||||
- backports=1.0=py36hfa02d7e_1 | |||||
- backports.shutil_get_terminal_size=1.0.0=py36hfea85ff_2 | |||||
- beautifulsoup4=4.6.0=py36h49b8c8c_1 | |||||
- bitarray=0.8.1=py36h5834eb8_0 | |||||
- bkcharts=0.2=py36h735825a_0 | |||||
- blaze=0.11.3=py36h4e06776_0 | |||||
- bleach=2.0.0=py36h688b259_0 | |||||
- bokeh=0.12.7=py36h169c5fd_1 | |||||
- boto=2.48.0=py36h6e4cd66_1 | |||||
- bottleneck=1.2.1=py36haac1ea0_0 | |||||
- bz2file=0.98=py36_0 | |||||
- ca-certificates=2017.08.26=h1d4fec5_0 | |||||
- cairo=1.14.10=h58b644b_4 | |||||
- certifi=2017.7.27.1=py36h8b7b77e_0 | |||||
- cffi=1.10.0=py36had8d393_1 | |||||
- chardet=3.0.4=py36h0f667ec_1 | |||||
- click=6.7=py36h5253387_0 | |||||
- cloudpickle=0.4.0=py36h30f8c20_0 | |||||
- clyent=1.2.2=py36h7e57e65_1 | |||||
- colorama=0.3.9=py36h489cec4_0 | |||||
- conda-build=3.0.22=py36ha23cd1e_0 | |||||
- conda-verify=2.0.0=py36h98955d8_0 | |||||
- contextlib2=0.5.5=py36h6c84a62_0 | |||||
- cryptography=2.0.3=py36ha225213_1 | |||||
- curl=7.55.1=hcb0b314_2 | |||||
- cycler=0.10.0=py36h93f1223_0 | |||||
- cython=0.26.1=py36h21c49d0_0 | |||||
- cytoolz=0.8.2=py36h708bfd4_0 | |||||
- dask=0.15.2=py36h9b48dc4_0 | |||||
- dask-core=0.15.2=py36h0f988a8_0 | |||||
- datashape=0.5.4=py36h3ad6b5c_0 | |||||
- dbus=1.10.22=h3b5a359_0 | |||||
- decorator=4.1.2=py36hd076ac8_0 | |||||
- distributed=1.18.3=py36h73cd4ae_0 | |||||
- docutils=0.14=py36hb0f60f5_0 | |||||
- entrypoints=0.2.3=py36h1aec115_2 | |||||
- et_xmlfile=1.0.1=py36hd6bccc3_0 | |||||
- expat=2.2.4=hc00ebd1_1 | |||||
- fastcache=1.0.2=py36h5b0c431_0 | |||||
- filelock=2.0.12=py36hacfa1f5_0 | |||||
- flask=0.12.2=py36hb24657c_0 | |||||
- flask-cors=3.0.3=py36h2d857d3_0 | |||||
- fontconfig=2.12.4=h88586e7_1 | |||||
- freetype=2.8=h52ed37b_0 | |||||
- get_terminal_size=1.0.0=haa9412d_0 | |||||
- gevent=1.2.2=py36h2fe25dc_0 | |||||
- glib=2.53.6=hc861d11_1 | |||||
- glob2=0.5=py36h2c1b292_1 | |||||
- gmp=6.1.2=hb3b607b_0 | |||||
- gmpy2=2.0.8=py36h55090d7_1 | |||||
- graphite2=1.3.10=hc526e54_0 | |||||
- greenlet=0.4.12=py36h2d503a6_0 | |||||
- gst-plugins-base=1.12.2=he3457e5_0 | |||||
- gstreamer=1.12.2=h4f93127_0 | |||||
- h5py=2.7.0=py36he81ebca_1 | |||||
- harfbuzz=1.5.0=h2545bd6_0 | |||||
- hdf5=1.10.1=hb0523eb_0 | |||||
- heapdict=1.0.0=py36h79797d7_0 | |||||
- html5lib=0.999999999=py36h2cfc398_0 | |||||
- icu=58.2=h211956c_0 | |||||
- idna=2.6=py36h82fb2a8_1 | |||||
- imageio=2.2.0=py36he555465_0 | |||||
- imagesize=0.7.1=py36h52d8127_0 | |||||
- intel-openmp=2018.0.0=h15fc484_7 | |||||
- ipykernel=4.6.1=py36hbf841aa_0 | |||||
- ipython=6.1.0=py36hc72a948_1 | |||||
- ipython_genutils=0.2.0=py36hb52b0d5_0 | |||||
- ipywidgets=7.0.0=py36h7b55c3a_0 | |||||
- isort=4.2.15=py36had401c0_0 | |||||
- itsdangerous=0.24=py36h93cc618_1 | |||||
- jbig=2.1=hdba287a_0 | |||||
- jdcal=1.3=py36h4c697fb_0 | |||||
- jedi=0.10.2=py36h552def0_0 | |||||
- jinja2=2.9.6=py36h489bce4_1 | |||||
- jpeg=9b=habf39ab_1 | |||||
- jsonschema=2.6.0=py36h006f8b5_0 | |||||
- jupyter=1.0.0=py36h9896ce5_0 | |||||
- jupyter_client=5.1.0=py36h614e9ea_0 | |||||
- jupyter_console=5.2.0=py36he59e554_1 | |||||
- jupyter_core=4.3.0=py36h357a921_0 | |||||
- jupyterlab=0.27.0=py36h86377d0_2 | |||||
- jupyterlab_launcher=0.4.0=py36h4d8058d_0 | |||||
- lazy-object-proxy=1.3.1=py36h10fcdad_0 | |||||
- libedit=3.1=heed3624_0 | |||||
- libffi=3.2.1=h4deb6c0_3 | |||||
- libgcc=7.2.0=h69d50b8_2 | |||||
- libgcc-ng=7.2.0=hcbc56d2_1 | |||||
- libgfortran-ng=7.2.0=h6fcbd8e_1 | |||||
- libpng=1.6.32=hda9c8bc_2 | |||||
- libsodium=1.0.13=h31c71d8_2 | |||||
- libssh2=1.8.0=h8c220ad_2 | |||||
- libstdcxx-ng=7.2.0=h24385c6_1 | |||||
- libtiff=4.0.8=h90200ff_9 | |||||
- libtool=2.4.6=hd50d1a6_0 | |||||
- libxcb=1.12=he6ee5dd_2 | |||||
- libxml2=2.9.4=h6b072ca_5 | |||||
- libxslt=1.1.29=hcf9102b_5 | |||||
- llvmlite=0.20.0=py36_0 | |||||
- locket=0.2.0=py36h787c0ad_1 | |||||
- lxml=3.8.0=py36h6c6e760_0 | |||||
- lzo=2.10=hc0eb8fc_0 | |||||
- markupsafe=1.0=py36hd9260cd_1 | |||||
- matplotlib=2.0.2=py36h2acb4ad_1 | |||||
- mccabe=0.6.1=py36h5ad9710_1 | |||||
- mistune=0.7.4=py36hbab8784_0 | |||||
- mkl=2018.0.0=hb491cac_4 | |||||
- mkl-service=1.1.2=py36h17a0993_4 | |||||
- mpc=1.0.3=hf803216_4 | |||||
- mpfr=3.1.5=h12ff648_1 | |||||
- mpmath=0.19=py36h8cc018b_2 | |||||
- msgpack-python=0.4.8=py36hec4c5d1_0 | |||||
- multipledispatch=0.4.9=py36h41da3fb_0 | |||||
- navigator-updater=0.1.0=py36h14770f7_0 | |||||
- nbconvert=5.3.1=py36hb41ffb7_0 | |||||
- nbformat=4.4.0=py36h31c9010_0 | |||||
- ncurses=6.0=h06874d7_1 | |||||
- networkx=1.11=py36hfb3574a_0 | |||||
- nltk=3.2.4=py36h1a0979f_0 | |||||
- nose=1.3.7=py36hcdf7029_2 | |||||
- notebook=5.0.0=py36h0b20546_2 | |||||
- numba=0.35.0=np113py36_10 | |||||
- numexpr=2.6.2=py36hdd3393f_1 | |||||
- numpy=1.13.1=py36h5bc529a_2 | |||||
- numpydoc=0.7.0=py36h18f165f_0 | |||||
- odo=0.5.1=py36h90ed295_0 | |||||
- olefile=0.44=py36h79f9f78_0 | |||||
- openpyxl=2.4.8=py36h41dd2a8_1 | |||||
- openssl=1.0.2l=h9d1a558_3 | |||||
- packaging=16.8=py36ha668100_1 | |||||
- pandas=0.20.3=py36h842e28d_2 | |||||
- pandoc=1.19.2.1=hea2e7c5_1 | |||||
- pandocfilters=1.4.2=py36ha6701b7_1 | |||||
- pango=1.40.11=hedb6d6b_0 | |||||
- partd=0.3.8=py36h36fd896_0 | |||||
- patchelf=0.9=hf79760b_2 | |||||
- path.py=10.3.1=py36he0c6f6d_0 | |||||
- pathlib2=2.3.0=py36h49efa8e_0 | |||||
- patsy=0.4.1=py36ha3be15e_0 | |||||
- pcre=8.41=hc71a17e_0 | |||||
- pep8=1.7.0=py36h26ade29_0 | |||||
- pexpect=4.2.1=py36h3b9d41b_0 | |||||
- pickleshare=0.7.4=py36h63277f8_0 | |||||
- pillow=4.2.1=py36h9119f52_0 | |||||
- pip=9.0.1=py36h30f8307_2 | |||||
- pixman=0.34.0=ha72d70b_1 | |||||
- pkginfo=1.4.1=py36h215d178_1 | |||||
- ply=3.10=py36hed35086_0 | |||||
- prompt_toolkit=1.0.15=py36h17d85b1_0 | |||||
- psutil=5.2.2=py36h74c8701_0 | |||||
- ptyprocess=0.5.2=py36h69acd42_0 | |||||
- py=1.4.34=py36h0712aa3_1 | |||||
- pycodestyle=2.3.1=py36hf609f19_0 | |||||
- pycosat=0.6.2=py36h1a0ea17_1 | |||||
- pycparser=2.18=py36hf9f622e_1 | |||||
- pycrypto=2.6.1=py36h6998063_1 | |||||
- pycurl=7.43.0=py36h5e72054_3 | |||||
- pyflakes=1.5.0=py36h5510808_1 | |||||
- pygments=2.2.0=py36h0d3125c_0 | |||||
- pylint=1.7.2=py36h484ab97_0 | |||||
- pyodbc=4.0.17=py36h999153c_0 | |||||
- pyopenssl=17.2.0=py36h5cc804b_0 | |||||
- pyparsing=2.2.0=py36hee85983_1 | |||||
- pyqt=5.6.0=py36h0386399_5 | |||||
- pysocks=1.6.7=py36hd97a5b1_1 | |||||
- pytables=3.4.2=py36hdce54c9_1 | |||||
- pytest=3.2.1=py36h11ad3bb_1 | |||||
- python=3.6.2=h02fb82a_12 | |||||
- python-dateutil=2.6.1=py36h88d3b88_1 | |||||
- pytz=2017.2=py36hc2ccc2a_1 | |||||
- pywavelets=0.5.2=py36he602eb0_0 | |||||
- pyyaml=3.12=py36hafb9ca4_1 | |||||
- pyzmq=16.0.2=py36h3b0cf96_2 | |||||
- qt=5.6.2=h974d657_12 | |||||
- qtawesome=0.4.4=py36h609ed8c_0 | |||||
- qtconsole=4.3.1=py36h8f73b5b_0 | |||||
- qtpy=1.3.1=py36h3691cc8_0 | |||||
- readline=7.0=hac23ff0_3 | |||||
- requests=2.18.4=py36he2e5f8d_1 | |||||
- rope=0.10.5=py36h1f8c17e_0 | |||||
- ruamel_yaml=0.11.14=py36ha2fb22d_2 | |||||
- scikit-image=0.13.0=py36had3c07a_1 | |||||
- scikit-learn=0.19.0=py36h97ac459_2 | |||||
- scipy=0.19.1=py36h9976243_3 | |||||
- seaborn=0.8.0=py36h197244f_0 | |||||
- setuptools=36.5.0=py36he42e2e1_0 | |||||
- simplegeneric=0.8.1=py36h2cb9092_0 | |||||
- singledispatch=3.4.0.3=py36h7a266c3_0 | |||||
- sip=4.18.1=py36h51ed4ed_2 | |||||
- six=1.10.0=py36hcac75e4_1 | |||||
- snowballstemmer=1.2.1=py36h6febd40_0 | |||||
- sortedcollections=0.5.3=py36h3c761f9_0 | |||||
- sortedcontainers=1.5.7=py36hdf89491_0 | |||||
- sphinx=1.6.3=py36he5f0bdb_0 | |||||
- sphinxcontrib=1.0=py36h6d0f590_1 | |||||
- sphinxcontrib-websupport=1.0.1=py36hb5cb234_1 | |||||
- spyder=3.2.3=py36he38cbf7_1 | |||||
- sqlalchemy=1.1.13=py36hfb5efd7_0 | |||||
- sqlite=3.20.1=h6d8b0f3_1 | |||||
- statsmodels=0.8.0=py36h8533d0b_0 | |||||
- sympy=1.1.1=py36hc6d1c1c_0 | |||||
- tblib=1.3.2=py36h34cf8b6_0 | |||||
- terminado=0.6=py36ha25a19f_0 | |||||
- testpath=0.3.1=py36h8cadb63_0 | |||||
- tk=8.6.7=h5979e9b_1 | |||||
- toolz=0.8.2=py36h81f2dff_0 | |||||
- tornado=4.5.2=py36h1283b2a_0 | |||||
- traitlets=4.3.2=py36h674d592_0 | |||||
- typing=3.6.2=py36h7da032a_0 | |||||
- unicodecsv=0.14.1=py36ha668878_0 | |||||
- unixodbc=2.3.4=hc36303a_1 | |||||
- urllib3=1.22=py36hbe7ace6_0 | |||||
- wcwidth=0.1.7=py36hdf4376a_0 | |||||
- webencodings=0.5.1=py36h800622e_1 | |||||
- werkzeug=0.12.2=py36hc703753_0 | |||||
- wheel=0.29.0=py36he7f4e38_1 | |||||
- widgetsnbextension=3.0.2=py36hd01bb71_1 | |||||
- wrapt=1.10.11=py36h28b7045_0 | |||||
- xlrd=1.1.0=py36h1db9f0c_1 | |||||
- xlsxwriter=0.9.8=py36hf41c223_0 | |||||
- xlwt=1.3.0=py36h7b00a1f_0 | |||||
- xz=5.2.3=h2bcbf08_1 | |||||
- yaml=0.1.7=h96e3832_1 | |||||
- zeromq=4.2.2=hb0b69da_1 | |||||
- zict=0.1.2=py36ha0d441b_0 | |||||
- zlib=1.2.11=hfbfcf68_1 | |||||
- cuda80=1.0=0 | |||||
- pytorch=0.2.0=py36h53baedd_4cu80 | |||||
- torchvision=0.1.9=py36h7584368_1 | |||||
- pip: | |||||
- backports.shutil-get-terminal-size==1.0.0 | |||||
- et-xmlfile==1.0.1 | |||||
- gae==0.0.1 | |||||
- ipython-genutils==0.2.0 | |||||
- jupyter-client==5.1.0 | |||||
- jupyter-console==5.2.0 | |||||
- jupyter-core==4.3.0 | |||||
- jupyterlab-launcher==0.4.0 | |||||
- prompt-toolkit==1.0.15 | |||||
- protobuf==3.4.0 | |||||
- python-louvain==0.9 | |||||
- ruamel-yaml==0.11.14 | |||||
- smart-open==1.5.3 | |||||
- tables==3.4.2 | |||||
- tensorboard-logger==0.0.4 | |||||
- torch==0.2.0.post4 | |||||
prefix: /lfs/hyperion/0/jiaxuany/anaconda3 | |||||
include orca/orca.h | |||||
import concurrent.futures | |||||
from functools import partial | |||||
import networkx as nx | |||||
import numpy as np | |||||
from scipy.linalg import toeplitz | |||||
import pyemd | |||||
def emd(x, y, distance_scaling=1.0): | |||||
support_size = max(len(x), len(y)) | |||||
d_mat = toeplitz(range(support_size)).astype(np.float) | |||||
distance_mat = d_mat / distance_scaling | |||||
# convert histogram values x and y to float, and make them equal len | |||||
x = x.astype(np.float) | |||||
y = y.astype(np.float) | |||||
if len(x) < len(y): | |||||
x = np.hstack((x, [0.0] * (support_size - len(x)))) | |||||
elif len(y) < len(x): | |||||
y = np.hstack((y, [0.0] * (support_size - len(y)))) | |||||
emd = pyemd.emd(x, y, distance_mat) | |||||
return emd | |||||
def l2(x, y): | |||||
dist = np.linalg.norm(x - y, 2) | |||||
return dist | |||||
def gaussian_emd(x, y, sigma=1.0, distance_scaling=1.0): | |||||
''' Gaussian kernel with squared distance in exponential term replaced by EMD | |||||
Args: | |||||
x, y: 1D pmf of two distributions with the same support | |||||
sigma: standard deviation | |||||
''' | |||||
support_size = max(len(x), len(y)) | |||||
d_mat = toeplitz(range(support_size)).astype(np.float) | |||||
distance_mat = d_mat / distance_scaling | |||||
# convert histogram values x and y to float, and make them equal len | |||||
x = x.astype(np.float) | |||||
y = y.astype(np.float) | |||||
if len(x) < len(y): | |||||
x = np.hstack((x, [0.0] * (support_size - len(x)))) | |||||
elif len(y) < len(x): | |||||
y = np.hstack((y, [0.0] * (support_size - len(y)))) | |||||
emd = pyemd.emd(x, y, distance_mat) | |||||
return np.exp(-emd * emd / (2 * sigma * sigma)) | |||||
def gaussian(x, y, sigma=1.0): | |||||
dist = np.linalg.norm(x - y, 2) | |||||
return np.exp(-dist * dist / (2 * sigma * sigma)) | |||||
def kernel_parallel_unpacked(x, samples2, kernel): | |||||
d = 0 | |||||
for s2 in samples2: | |||||
d += kernel(x, s2) | |||||
return d | |||||
def kernel_parallel_worker(t): | |||||
return kernel_parallel_unpacked(*t) | |||||
def disc(samples1, samples2, kernel, is_parallel=True, *args, **kwargs): | |||||
''' Discrepancy between 2 samples | |||||
''' | |||||
d = 0 | |||||
if not is_parallel: | |||||
for s1 in samples1: | |||||
for s2 in samples2: | |||||
d += kernel(s1, s2, *args, **kwargs) | |||||
else: | |||||
with concurrent.futures.ProcessPoolExecutor() as executor: | |||||
for dist in executor.map(kernel_parallel_worker, | |||||
[(s1, samples2, partial(kernel, *args, **kwargs)) for s1 in samples1]): | |||||
d += dist | |||||
d /= len(samples1) * len(samples2) | |||||
return d | |||||
def compute_mmd(samples1, samples2, kernel, is_hist=True, *args, **kwargs): | |||||
''' MMD between two samples | |||||
''' | |||||
# normalize histograms into pmf | |||||
if is_hist: | |||||
samples1 = [s1 / np.sum(s1) for s1 in samples1] | |||||
samples2 = [s2 / np.sum(s2) for s2 in samples2] | |||||
# print('===============================') | |||||
# print('s1: ', disc(samples1, samples1, kernel, *args, **kwargs)) | |||||
# print('--------------------------') | |||||
# print('s2: ', disc(samples2, samples2, kernel, *args, **kwargs)) | |||||
# print('--------------------------') | |||||
# print('cross: ', disc(samples1, samples2, kernel, *args, **kwargs)) | |||||
# print('===============================') | |||||
return disc(samples1, samples1, kernel, *args, **kwargs) + \ | |||||
disc(samples2, samples2, kernel, *args, **kwargs) - \ | |||||
2 * disc(samples1, samples2, kernel, *args, **kwargs) | |||||
def compute_emd(samples1, samples2, kernel, is_hist=True, *args, **kwargs): | |||||
''' EMD between average of two samples | |||||
''' | |||||
# normalize histograms into pmf | |||||
if is_hist: | |||||
samples1 = [np.mean(samples1)] | |||||
samples2 = [np.mean(samples2)] | |||||
# print('===============================') | |||||
# print('s1: ', disc(samples1, samples1, kernel, *args, **kwargs)) | |||||
# print('--------------------------') | |||||
# print('s2: ', disc(samples2, samples2, kernel, *args, **kwargs)) | |||||
# print('--------------------------') | |||||
# print('cross: ', disc(samples1, samples2, kernel, *args, **kwargs)) | |||||
# print('===============================') | |||||
return disc(samples1, samples2, kernel, *args, **kwargs),[samples1[0],samples2[0]] | |||||
def test(): | |||||
s1 = np.array([0.2, 0.8]) | |||||
s2 = np.array([0.3, 0.7]) | |||||
samples1 = [s1, s2] | |||||
s3 = np.array([0.25, 0.75]) | |||||
s4 = np.array([0.35, 0.65]) | |||||
samples2 = [s3, s4] | |||||
s5 = np.array([0.8, 0.2]) | |||||
s6 = np.array([0.7, 0.3]) | |||||
samples3 = [s5, s6] | |||||
print('between samples1 and samples2: ', compute_mmd(samples1, samples2, kernel=gaussian_emd, | |||||
is_parallel=False, sigma=1.0)) | |||||
print('between samples1 and samples3: ', compute_mmd(samples1, samples3, kernel=gaussian_emd, | |||||
is_parallel=False, sigma=1.0)) | |||||
if __name__ == '__main__': | |||||
test() | |||||
4 4 | |||||
0 1 | |||||
1 2 | |||||
2 3 | |||||
3 0 | |||||
#include <cstdio> | |||||
#include <cstdlib> | |||||
#include <cstring> | |||||
#include <Python.h> | |||||
#include "orca/orca.h" | |||||
static PyObject * | |||||
orca_motifs(PyObject *self, PyObject *args) | |||||
{ | |||||
const char *orbit_type; | |||||
int graphlet_size; | |||||
const char *input_filename; | |||||
const char *output_filename; | |||||
int sts; | |||||
if (!PyArg_ParseTuple(args, "siss", &orbit_type, &graphlet_size, &input_filename, &output_filename)) | |||||
return NULL; | |||||
sts = system(orbit_type); | |||||
motif_counts(orbit_type, graphlet_size, input_filename, output_filename); | |||||
return PyLong_FromLong(sts); | |||||
} | |||||
static PyMethodDef OrcaMethods[] = { | |||||
{"motifs", orca_motifs, METH_VARARGS, | |||||
"Compute motif counts."}, | |||||
}; | |||||
static struct PyModuleDef orcamodule = { | |||||
PyModuleDef_HEAD_INIT, | |||||
"orca", /* name of module */ | |||||
NULL, /* module documentation, may be NULL */ | |||||
-1, /* size of per-interpreter state of the module, | |||||
or -1 if the module keeps state in global variables. */ | |||||
OrcaMethods | |||||
}; | |||||
PyMODINIT_FUNC | |||||
PyInit_orca(void) | |||||
{ | |||||
return PyModule_Create(&orcamodule); | |||||
} | |||||
int main(int argc, char *argv[]) { | |||||
wchar_t *program = Py_DecodeLocale(argv[0], NULL); | |||||
if (program == NULL) { | |||||
fprintf(stderr, "Fatal error: cannot decode argv[0]\n"); | |||||
exit(1); | |||||
} | |||||
/* Add a built-in module, before Py_Initialize */ | |||||
PyImport_AppendInittab("orca", PyInit_orca); | |||||
/* Pass argv[0] to the Python interpreter */ | |||||
Py_SetProgramName(program); | |||||
/* Initialize the Python interpreter. Required. */ | |||||
Py_Initialize(); | |||||
/* Optionally import the module; alternatively, | |||||
import can be deferred until the embedded script | |||||
imports it. */ | |||||
PyImport_ImportModule("orca"); | |||||
PyMem_RawFree(program); | |||||
} | |||||
from distutils.core import setup, Extension | |||||
orca_module = Extension('orca', | |||||
sources = ['orcamodule.cpp'], | |||||
extra_compile_args=['-std=c++11'],) | |||||
setup (name = 'orca', | |||||
version = '1.0', | |||||
description = 'ORCA motif counting package', | |||||
ext_modules = [orca_module]) | |||||
import concurrent.futures | |||||
from datetime import datetime | |||||
from functools import partial | |||||
import numpy as np | |||||
import networkx as nx | |||||
import os | |||||
import pickle as pkl | |||||
import subprocess as sp | |||||
import time | |||||
import eval.mmd as mmd | |||||
PRINT_TIME = False | |||||
def degree_worker(G): | |||||
return np.array(nx.degree_histogram(G)) | |||||
def add_tensor(x,y): | |||||
support_size = max(len(x), len(y)) | |||||
if len(x) < len(y): | |||||
x = np.hstack((x, [0.0] * (support_size - len(x)))) | |||||
elif len(y) < len(x): | |||||
y = np.hstack((y, [0.0] * (support_size - len(y)))) | |||||
return x+y | |||||
def degree_stats(graph_ref_list, graph_pred_list, is_parallel=False): | |||||
''' Compute the distance between the degree distributions of two unordered sets of graphs. | |||||
Args: | |||||
graph_ref_list, graph_target_list: two lists of networkx graphs to be evaluated | |||||
''' | |||||
sample_ref = [] | |||||
sample_pred = [] | |||||
# in case an empty graph is generated | |||||
graph_pred_list_remove_empty = [G for G in graph_pred_list if not G.number_of_nodes() == 0] | |||||
prev = datetime.now() | |||||
if is_parallel: | |||||
with concurrent.futures.ProcessPoolExecutor() as executor: | |||||
for deg_hist in executor.map(degree_worker, graph_ref_list): | |||||
sample_ref.append(deg_hist) | |||||
with concurrent.futures.ProcessPoolExecutor() as executor: | |||||
for deg_hist in executor.map(degree_worker, graph_pred_list_remove_empty): | |||||
sample_pred.append(deg_hist) | |||||
else: | |||||
for i in range(len(graph_ref_list)): | |||||
degree_temp = np.array(nx.degree_histogram(graph_ref_list[i])) | |||||
sample_ref.append(degree_temp) | |||||
for i in range(len(graph_pred_list_remove_empty)): | |||||
degree_temp = np.array(nx.degree_histogram(graph_pred_list_remove_empty[i])) | |||||
sample_pred.append(degree_temp) | |||||
print(len(sample_ref),len(sample_pred)) | |||||
mmd_dist = mmd.compute_mmd(sample_ref, sample_pred, kernel=mmd.gaussian_emd) | |||||
elapsed = datetime.now() - prev | |||||
if PRINT_TIME: | |||||
print('Time computing degree mmd: ', elapsed) | |||||
return mmd_dist | |||||
def clustering_worker(param): | |||||
G, bins = param | |||||
clustering_coeffs_list = list(nx.clustering(G).values()) | |||||
hist, _ = np.histogram( | |||||
clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False) | |||||
return hist | |||||
def clustering_stats(graph_ref_list, graph_pred_list, bins=100, is_parallel=True): | |||||
sample_ref = [] | |||||
sample_pred = [] | |||||
graph_pred_list_remove_empty = [G for G in graph_pred_list if not G.number_of_nodes() == 0] | |||||
prev = datetime.now() | |||||
if is_parallel: | |||||
with concurrent.futures.ProcessPoolExecutor() as executor: | |||||
for clustering_hist in executor.map(clustering_worker, | |||||
[(G, bins) for G in graph_ref_list]): | |||||
sample_ref.append(clustering_hist) | |||||
with concurrent.futures.ProcessPoolExecutor() as executor: | |||||
for clustering_hist in executor.map(clustering_worker, | |||||
[(G, bins) for G in graph_pred_list_remove_empty]): | |||||
sample_pred.append(clustering_hist) | |||||
# check non-zero elements in hist | |||||
#total = 0 | |||||
#for i in range(len(sample_pred)): | |||||
# nz = np.nonzero(sample_pred[i])[0].shape[0] | |||||
# total += nz | |||||
#print(total) | |||||
else: | |||||
for i in range(len(graph_ref_list)): | |||||
clustering_coeffs_list = list(nx.clustering(graph_ref_list[i]).values()) | |||||
hist, _ = np.histogram( | |||||
clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False) | |||||
sample_ref.append(hist) | |||||
for i in range(len(graph_pred_list_remove_empty)): | |||||
clustering_coeffs_list = list(nx.clustering(graph_pred_list_remove_empty[i]).values()) | |||||
hist, _ = np.histogram( | |||||
clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False) | |||||
sample_pred.append(hist) | |||||
mmd_dist = mmd.compute_mmd(sample_ref, sample_pred, kernel=mmd.gaussian_emd, | |||||
sigma=1.0/10, distance_scaling=bins) | |||||
elapsed = datetime.now() - prev | |||||
if PRINT_TIME: | |||||
print('Time computing clustering mmd: ', elapsed) | |||||
return mmd_dist | |||||
# maps motif/orbit name string to its corresponding list of indices from orca output | |||||
motif_to_indices = { | |||||
'3path' : [1, 2], | |||||
'4cycle' : [8], | |||||
} | |||||
COUNT_START_STR = 'orbit counts: \n' | |||||
def edge_list_reindexed(G): | |||||
idx = 0 | |||||
id2idx = dict() | |||||
for u in G.nodes(): | |||||
id2idx[str(u)] = idx | |||||
idx += 1 | |||||
edges = [] | |||||
for (u, v) in G.edges(): | |||||
edges.append((id2idx[str(u)], id2idx[str(v)])) | |||||
return edges | |||||
def orca(graph): | |||||
tmp_fname = 'eval/orca/tmp.txt' | |||||
f = open(tmp_fname, 'w') | |||||
f.write(str(graph.number_of_nodes()) + ' ' + str(graph.number_of_edges()) + '\n') | |||||
for (u, v) in edge_list_reindexed(graph): | |||||
f.write(str(u) + ' ' + str(v) + '\n') | |||||
f.close() | |||||
output = sp.check_output(['./eval/orca/orca', 'node', '4', 'eval/orca/tmp.txt', 'std']) | |||||
output = output.decode('utf8').strip() | |||||
idx = output.find(COUNT_START_STR) + len(COUNT_START_STR) | |||||
output = output[idx:] | |||||
node_orbit_counts = np.array([list(map(int, node_cnts.strip().split(' ') )) | |||||
for node_cnts in output.strip('\n').split('\n')]) | |||||
try: | |||||
os.remove(tmp_fname) | |||||
except OSError: | |||||
pass | |||||
return node_orbit_counts | |||||
def motif_stats(graph_ref_list, graph_pred_list, motif_type='4cycle', ground_truth_match=None, bins=100): | |||||
# graph motif counts (int for each graph) | |||||
# normalized by graph size | |||||
total_counts_ref = [] | |||||
total_counts_pred = [] | |||||
num_matches_ref = [] | |||||
num_matches_pred = [] | |||||
graph_pred_list_remove_empty = [G for G in graph_pred_list if not G.number_of_nodes() == 0] | |||||
indices = motif_to_indices[motif_type] | |||||
for G in graph_ref_list: | |||||
orbit_counts = orca(G) | |||||
motif_counts = np.sum(orbit_counts[:, indices], axis=1) | |||||
if ground_truth_match is not None: | |||||
match_cnt = 0 | |||||
for elem in motif_counts: | |||||
if elem == ground_truth_match: | |||||
match_cnt += 1 | |||||
num_matches_ref.append(match_cnt / G.number_of_nodes()) | |||||
#hist, _ = np.histogram( | |||||
# motif_counts, bins=bins, density=False) | |||||
motif_temp = np.sum(motif_counts) / G.number_of_nodes() | |||||
total_counts_ref.append(motif_temp) | |||||
for G in graph_pred_list_remove_empty: | |||||
orbit_counts = orca(G) | |||||
motif_counts = np.sum(orbit_counts[:, indices], axis=1) | |||||
if ground_truth_match is not None: | |||||
match_cnt = 0 | |||||
for elem in motif_counts: | |||||
if elem == ground_truth_match: | |||||
match_cnt += 1 | |||||
num_matches_pred.append(match_cnt / G.number_of_nodes()) | |||||
motif_temp = np.sum(motif_counts) / G.number_of_nodes() | |||||
total_counts_pred.append(motif_temp) | |||||
mmd_dist = mmd.compute_mmd(total_counts_ref, total_counts_pred, kernel=mmd.gaussian, | |||||
is_hist=False) | |||||
#print('-------------------------') | |||||
#print(np.sum(total_counts_ref) / len(total_counts_ref)) | |||||
#print('...') | |||||
#print(np.sum(total_counts_pred) / len(total_counts_pred)) | |||||
#print('-------------------------') | |||||
return mmd_dist | |||||
def orbit_stats_all(graph_ref_list, graph_pred_list): | |||||
total_counts_ref = [] | |||||
total_counts_pred = [] | |||||
graph_pred_list_remove_empty = [G for G in graph_pred_list if not G.number_of_nodes() == 0] | |||||
for G in graph_ref_list: | |||||
try: | |||||
orbit_counts = orca(G) | |||||
except: | |||||
continue | |||||
orbit_counts_graph = np.sum(orbit_counts, axis=0) / G.number_of_nodes() | |||||
total_counts_ref.append(orbit_counts_graph) | |||||
for G in graph_pred_list: | |||||
try: | |||||
orbit_counts = orca(G) | |||||
except: | |||||
continue | |||||
orbit_counts_graph = np.sum(orbit_counts, axis=0) / G.number_of_nodes() | |||||
total_counts_pred.append(orbit_counts_graph) | |||||
total_counts_ref = np.array(total_counts_ref) | |||||
total_counts_pred = np.array(total_counts_pred) | |||||
mmd_dist = mmd.compute_mmd(total_counts_ref, total_counts_pred, kernel=mmd.gaussian, | |||||
is_hist=False, sigma=30.0) | |||||
print('-------------------------') | |||||
print(np.sum(total_counts_ref, axis=0) / len(total_counts_ref)) | |||||
print('...') | |||||
print(np.sum(total_counts_pred, axis=0) / len(total_counts_pred)) | |||||
print('-------------------------') | |||||
return mmd_dist | |||||
import argparse | |||||
import numpy as np | |||||
import os | |||||
import re | |||||
from random import shuffle | |||||
import eval.stats | |||||
import utils | |||||
# import main.Args | |||||
from baselines.baseline_simple import * | |||||
class Args_evaluate(): | |||||
def __init__(self): | |||||
# loop over the settings | |||||
# self.model_name_all = ['GraphRNN_MLP','GraphRNN_RNN','Internal','Noise'] | |||||
# self.model_name_all = ['E-R', 'B-A'] | |||||
self.model_name_all = ['GraphRNN_RNN'] | |||||
# self.model_name_all = ['Baseline_DGMG'] | |||||
# list of dataset to evaluate | |||||
# use a list of 1 element to evaluate a single dataset | |||||
self.dataset_name_all = ['caveman', 'grid', 'barabasi', 'citeseer', 'DD'] | |||||
# self.dataset_name_all = ['citeseer_small','caveman_small'] | |||||
# self.dataset_name_all = ['barabasi_noise0','barabasi_noise2','barabasi_noise4','barabasi_noise6','barabasi_noise8','barabasi_noise10'] | |||||
# self.dataset_name_all = ['caveman_small', 'ladder_small', 'grid_small', 'ladder_small', 'enzymes_small', 'barabasi_small','citeseer_small'] | |||||
self.epoch_start=100 | |||||
self.epoch_end=3001 | |||||
self.epoch_step=100 | |||||
def find_nearest_idx(array,value): | |||||
idx = (np.abs(array-value)).argmin() | |||||
return idx | |||||
def extract_result_id_and_epoch(name, prefix, suffix): | |||||
''' | |||||
Args: | |||||
eval_every: the number of epochs between consecutive evaluations | |||||
suffix: real_ or pred_ | |||||
Returns: | |||||
A tuple of (id, epoch number) extracted from the filename string | |||||
''' | |||||
pos = name.find(suffix) + len(suffix) | |||||
end_pos = name.find('.dat') | |||||
result_id = name[pos:end_pos] | |||||
pos = name.find(prefix) + len(prefix) | |||||
end_pos = name.find('_', pos) | |||||
epochs = int(name[pos:end_pos]) | |||||
return result_id, epochs | |||||
def eval_list(real_graphs_filename, pred_graphs_filename, prefix, eval_every): | |||||
real_graphs_dict = {} | |||||
pred_graphs_dict = {} | |||||
for fname in real_graphs_filename: | |||||
result_id, epochs = extract_result_id_and_epoch(fname, prefix, 'real_') | |||||
if not epochs % eval_every == 0: | |||||
continue | |||||
if result_id not in real_graphs_dict: | |||||
real_graphs_dict[result_id] = {} | |||||
real_graphs_dict[result_id][epochs] = fname | |||||
for fname in pred_graphs_filename: | |||||
result_id, epochs = extract_result_id_and_epoch(fname, prefix, 'pred_') | |||||
if not epochs % eval_every == 0: | |||||
continue | |||||
if result_id not in pred_graphs_dict: | |||||
pred_graphs_dict[result_id] = {} | |||||
pred_graphs_dict[result_id][epochs] = fname | |||||
for result_id in real_graphs_dict.keys(): | |||||
for epochs in sorted(real_graphs_dict[result_id]): | |||||
real_g_list = utils.load_graph_list(real_graphs_dict[result_id][epochs]) | |||||
pred_g_list = utils.load_graph_list(pred_graphs_dict[result_id][epochs]) | |||||
shuffle(real_g_list) | |||||
shuffle(pred_g_list) | |||||
perturbed_g_list = perturb(real_g_list, 0.05) | |||||
#dist = eval.stats.degree_stats(real_g_list, pred_g_list) | |||||
dist = eval.stats.clustering_stats(real_g_list, pred_g_list) | |||||
print('dist between real and pred (', result_id, ') at epoch ', epochs, ': ', dist) | |||||
#dist = eval.stats.degree_stats(real_g_list, perturbed_g_list) | |||||
dist = eval.stats.clustering_stats(real_g_list, perturbed_g_list) | |||||
print('dist between real and perturbed: ', dist) | |||||
mid = len(real_g_list) // 2 | |||||
#dist = eval.stats.degree_stats(real_g_list[:mid], real_g_list[mid:]) | |||||
dist = eval.stats.clustering_stats(real_g_list[:mid], real_g_list[mid:]) | |||||
print('dist among real: ', dist) | |||||
def compute_basic_stats(real_g_list, target_g_list): | |||||
dist_degree = eval.stats.degree_stats(real_g_list, target_g_list) | |||||
dist_clustering = eval.stats.clustering_stats(real_g_list, target_g_list) | |||||
return dist_degree, dist_clustering | |||||
def clean_graphs(graph_real, graph_pred): | |||||
''' Selecting graphs generated that have the similar sizes. | |||||
It is usually necessary for GraphRNN-S version, but not the full GraphRNN model. | |||||
''' | |||||
shuffle(graph_real) | |||||
shuffle(graph_pred) | |||||
# get length | |||||
real_graph_len = np.array([len(graph_real[i]) for i in range(len(graph_real))]) | |||||
pred_graph_len = np.array([len(graph_pred[i]) for i in range(len(graph_pred))]) | |||||
# select pred samples | |||||
# The number of nodes are sampled from the similar distribution as the training set | |||||
pred_graph_new = [] | |||||
pred_graph_len_new = [] | |||||
for value in real_graph_len: | |||||
pred_idx = find_nearest_idx(pred_graph_len, value) | |||||
pred_graph_new.append(graph_pred[pred_idx]) | |||||
pred_graph_len_new.append(pred_graph_len[pred_idx]) | |||||
return graph_real, pred_graph_new | |||||
def load_ground_truth(dir_input, dataset_name, model_name='GraphRNN_RNN'): | |||||
''' Read ground truth graphs. | |||||
''' | |||||
if not 'small' in dataset_name: | |||||
hidden = 128 | |||||
else: | |||||
hidden = 64 | |||||
if model_name=='Internal' or model_name=='Noise' or model_name=='B-A' or model_name=='E-R': | |||||
fname_test = dir_input + 'GraphRNN_MLP' + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | |||||
hidden) + '_test_' + str(0) + '.dat' | |||||
else: | |||||
fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | |||||
hidden) + '_test_' + str(0) + '.dat' | |||||
try: | |||||
graph_test = utils.load_graph_list(fname_test,is_real=True) | |||||
except: | |||||
print('Not found: ' + fname_test) | |||||
logging.warning('Not found: ' + fname_test) | |||||
return None | |||||
return graph_test | |||||
def eval_single_list(graphs, dir_input, dataset_name): | |||||
''' Evaluate a list of graphs by comparing with graphs in directory dir_input. | |||||
Args: | |||||
dir_input: directory where ground truth graph list is stored | |||||
dataset_name: name of the dataset (ground truth) | |||||
''' | |||||
graph_test = load_ground_truth(dir_input, dataset_name) | |||||
graph_test_len = len(graph_test) | |||||
graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set | |||||
mmd_degree = eval.stats.degree_stats(graph_test, graphs) | |||||
mmd_clustering = eval.stats.clustering_stats(graph_test, graphs) | |||||
try: | |||||
mmd_4orbits = eval.stats.orbit_stats_all(graph_test, graphs) | |||||
except: | |||||
mmd_4orbits = -1 | |||||
print('deg: ', mmd_degree) | |||||
print('clustering: ', mmd_clustering) | |||||
print('orbits: ', mmd_4orbits) | |||||
def evaluation_epoch(dir_input, fname_output, model_name, dataset_name, args, is_clean=True, epoch_start=1000,epoch_end=3001,epoch_step=100): | |||||
with open(fname_output, 'w+') as f: | |||||
f.write('sample_time,epoch,degree_validate,clustering_validate,orbits4_validate,degree_test,clustering_test,orbits4_test\n') | |||||
# TODO: Maybe refactor into a separate file/function that specifies THE naming convention | |||||
# across main and evaluate | |||||
if not 'small' in dataset_name: | |||||
hidden = 128 | |||||
else: | |||||
hidden = 64 | |||||
# read real graph | |||||
if model_name=='Internal' or model_name=='Noise' or model_name=='B-A' or model_name=='E-R': | |||||
fname_test = dir_input + 'GraphRNN_MLP' + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | |||||
hidden) + '_test_' + str(0) + '.dat' | |||||
elif 'Baseline' in model_name: | |||||
fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(64) + '_test_' + str(0) + '.dat' | |||||
else: | |||||
fname_test = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str( | |||||
hidden) + '_test_' + str(0) + '.dat' | |||||
try: | |||||
graph_test = utils.load_graph_list(fname_test,is_real=True) | |||||
except: | |||||
print('Not found: ' + fname_test) | |||||
logging.warning('Not found: ' + fname_test) | |||||
return None | |||||
graph_test_len = len(graph_test) | |||||
graph_train = graph_test[0:int(0.8 * graph_test_len)] # train | |||||
graph_validate = graph_test[0:int(0.2 * graph_test_len)] # validate | |||||
graph_test = graph_test[int(0.8 * graph_test_len):] # test on a hold out test set | |||||
graph_test_aver = 0 | |||||
for graph in graph_test: | |||||
graph_test_aver+=graph.number_of_nodes() | |||||
graph_test_aver /= len(graph_test) | |||||
print('test average len',graph_test_aver) | |||||
# get performance for proposed approaches | |||||
if 'GraphRNN' in model_name: | |||||
# read test graph | |||||
for epoch in range(epoch_start,epoch_end,epoch_step): | |||||
for sample_time in range(1,4): | |||||
# get filename | |||||
fname_pred = dir_input + model_name + '_' + dataset_name + '_' + str(args.num_layers) + '_' + str(hidden) + '_pred_' + str(epoch) + '_' + str(sample_time) + '.dat' | |||||
# load graphs | |||||
try: | |||||
graph_pred = utils.load_graph_list(fname_pred,is_real=False) # default False | |||||
except: | |||||
print('Not found: '+ fname_pred) | |||||
logging.warning('Not found: '+ fname_pred) | |||||
continue | |||||
# clean graphs | |||||
if is_clean: | |||||
graph_test, graph_pred = clean_graphs(graph_test, graph_pred) | |||||
else: | |||||
shuffle(graph_pred) | |||||
graph_pred = graph_pred[0:len(graph_test)] | |||||
print('len graph_test', len(graph_test)) | |||||
print('len graph_validate', len(graph_validate)) | |||||
print('len graph_pred', len(graph_pred)) | |||||
graph_pred_aver = 0 | |||||
for graph in graph_pred: | |||||
graph_pred_aver += graph.number_of_nodes() | |||||
graph_pred_aver /= len(graph_pred) | |||||
print('pred average len', graph_pred_aver) | |||||
# evaluate MMD test | |||||
mmd_degree = eval.stats.degree_stats(graph_test, graph_pred) | |||||
mmd_clustering = eval.stats.clustering_stats(graph_test, graph_pred) | |||||
try: | |||||
mmd_4orbits = eval.stats.orbit_stats_all(graph_test, graph_pred) | |||||
except: | |||||
mmd_4orbits = -1 | |||||
# evaluate MMD validate | |||||
mmd_degree_validate = eval.stats.degree_stats(graph_validate, graph_pred) | |||||
mmd_clustering_validate = eval.stats.clustering_stats(graph_validate, graph_pred) | |||||
try: | |||||
mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_validate, graph_pred) | |||||
except: | |||||
mmd_4orbits_validate = -1 | |||||
# write results | |||||
f.write(str(sample_time)+','+ | |||||
str(epoch)+','+ | |||||
str(mmd_degree_validate)+','+ | |||||
str(mmd_clustering_validate)+','+ | |||||
str(mmd_4orbits_validate)+','+ | |||||
str(mmd_degree)+','+ | |||||
str(mmd_clustering)+','+ | |||||
str(mmd_4orbits)+'\n') | |||||
print('degree',mmd_degree,'clustering',mmd_clustering,'orbits',mmd_4orbits) | |||||
# get internal MMD (MMD between ground truth validation and test sets) | |||||
if model_name == 'Internal': | |||||
mmd_degree_validate = eval.stats.degree_stats(graph_test, graph_validate) | |||||
mmd_clustering_validate = eval.stats.clustering_stats(graph_test, graph_validate) | |||||
try: | |||||
mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_test, graph_validate) | |||||
except: | |||||
mmd_4orbits_validate = -1 | |||||
f.write(str(-1) + ',' + str(-1) + ',' + str(mmd_degree_validate) + ',' + str( | |||||
mmd_clustering_validate) + ',' + str(mmd_4orbits_validate) | |||||
+ ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + '\n') | |||||
# get MMD between ground truth and its perturbed graphs | |||||
if model_name == 'Noise': | |||||
graph_validate_perturbed = perturb(graph_validate, 0.05) | |||||
mmd_degree_validate = eval.stats.degree_stats(graph_test, graph_validate_perturbed) | |||||
mmd_clustering_validate = eval.stats.clustering_stats(graph_test, graph_validate_perturbed) | |||||
try: | |||||
mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_test, graph_validate_perturbed) | |||||
except: | |||||
mmd_4orbits_validate = -1 | |||||
f.write(str(-1) + ',' + str(-1) + ',' + str(mmd_degree_validate) + ',' + str( | |||||
mmd_clustering_validate) + ',' + str(mmd_4orbits_validate) | |||||
+ ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + '\n') | |||||
# get E-R MMD | |||||
if model_name == 'E-R': | |||||
graph_pred = Graph_generator_baseline(graph_train,generator='Gnp') | |||||
# clean graphs | |||||
if is_clean: | |||||
graph_test, graph_pred = clean_graphs(graph_test, graph_pred) | |||||
print('len graph_test', len(graph_test)) | |||||
print('len graph_pred', len(graph_pred)) | |||||
mmd_degree = eval.stats.degree_stats(graph_test, graph_pred) | |||||
mmd_clustering = eval.stats.clustering_stats(graph_test, graph_pred) | |||||
try: | |||||
mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_test, graph_pred) | |||||
except: | |||||
mmd_4orbits_validate = -1 | |||||
f.write(str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) | |||||
+ ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits_validate) + '\n') | |||||
# get B-A MMD | |||||
if model_name == 'B-A': | |||||
graph_pred = Graph_generator_baseline(graph_train, generator='BA') | |||||
# clean graphs | |||||
if is_clean: | |||||
graph_test, graph_pred = clean_graphs(graph_test, graph_pred) | |||||
print('len graph_test', len(graph_test)) | |||||
print('len graph_pred', len(graph_pred)) | |||||
mmd_degree = eval.stats.degree_stats(graph_test, graph_pred) | |||||
mmd_clustering = eval.stats.clustering_stats(graph_test, graph_pred) | |||||
try: | |||||
mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_test, graph_pred) | |||||
except: | |||||
mmd_4orbits_validate = -1 | |||||
f.write(str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) + ',' + str(-1) | |||||
+ ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits_validate) + '\n') | |||||
# get performance for baseline approaches | |||||
if 'Baseline' in model_name: | |||||
# read test graph | |||||
for epoch in range(epoch_start, epoch_end, epoch_step): | |||||
# get filename | |||||
fname_pred = dir_input + model_name + '_' + dataset_name + '_' + str( | |||||
64) + '_pred_' + str(epoch) + '.dat' | |||||
# load graphs | |||||
try: | |||||
graph_pred = utils.load_graph_list(fname_pred, is_real=True) # default False | |||||
except: | |||||
print('Not found: ' + fname_pred) | |||||
logging.warning('Not found: ' + fname_pred) | |||||
continue | |||||
# clean graphs | |||||
if is_clean: | |||||
graph_test, graph_pred = clean_graphs(graph_test, graph_pred) | |||||
else: | |||||
shuffle(graph_pred) | |||||
graph_pred = graph_pred[0:len(graph_test)] | |||||
print('len graph_test', len(graph_test)) | |||||
print('len graph_validate', len(graph_validate)) | |||||
print('len graph_pred', len(graph_pred)) | |||||
graph_pred_aver = 0 | |||||
for graph in graph_pred: | |||||
graph_pred_aver += graph.number_of_nodes() | |||||
graph_pred_aver /= len(graph_pred) | |||||
print('pred average len', graph_pred_aver) | |||||
# evaluate MMD test | |||||
mmd_degree = eval.stats.degree_stats(graph_test, graph_pred) | |||||
mmd_clustering = eval.stats.clustering_stats(graph_test, graph_pred) | |||||
try: | |||||
mmd_4orbits = eval.stats.orbit_stats_all(graph_test, graph_pred) | |||||
except: | |||||
mmd_4orbits = -1 | |||||
# evaluate MMD validate | |||||
mmd_degree_validate = eval.stats.degree_stats(graph_validate, graph_pred) | |||||
mmd_clustering_validate = eval.stats.clustering_stats(graph_validate, graph_pred) | |||||
try: | |||||
mmd_4orbits_validate = eval.stats.orbit_stats_all(graph_validate, graph_pred) | |||||
except: | |||||
mmd_4orbits_validate = -1 | |||||
# write results | |||||
f.write(str(-1) + ',' + str(epoch) + ',' + str(mmd_degree_validate) + ',' + str( | |||||
mmd_clustering_validate) + ',' + str(mmd_4orbits_validate) | |||||
+ ',' + str(mmd_degree) + ',' + str(mmd_clustering) + ',' + str(mmd_4orbits) + '\n') | |||||
print('degree', mmd_degree, 'clustering', mmd_clustering, 'orbits', mmd_4orbits) | |||||
return True | |||||
def evaluation(args_evaluate,dir_input, dir_output, model_name_all, dataset_name_all, args, overwrite = True): | |||||
''' Evaluate the performance of a set of models on a set of datasets. | |||||
''' | |||||
for model_name in model_name_all: | |||||
for dataset_name in dataset_name_all: | |||||
# check output exist | |||||
fname_output = dir_output+model_name+'_'+dataset_name+'.csv' | |||||
print('processing: '+dir_output + model_name + '_' + dataset_name + '.csv') | |||||
logging.info('processing: '+dir_output + model_name + '_' + dataset_name + '.csv') | |||||
if overwrite==False and os.path.isfile(fname_output): | |||||
print(dir_output+model_name+'_'+dataset_name+'.csv exists!') | |||||
logging.info(dir_output+model_name+'_'+dataset_name+'.csv exists!') | |||||
continue | |||||
evaluation_epoch(dir_input,fname_output,model_name,dataset_name,args,is_clean=True, epoch_start=args_evaluate.epoch_start,epoch_end=args_evaluate.epoch_end,epoch_step=args_evaluate.epoch_step) | |||||
def eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||||
eval_every, epoch_range=None, out_file_prefix=None): | |||||
''' Evaluate list of predicted graphs compared to ground truth, stored in files. | |||||
Args: | |||||
baselines: dict mapping name of the baseline to list of generated graphs. | |||||
''' | |||||
if out_file_prefix is not None: | |||||
out_files = { | |||||
'train': open(out_file_prefix + '_train.txt', 'w+'), | |||||
'compare': open(out_file_prefix + '_compare.txt', 'w+') | |||||
} | |||||
out_files['train'].write('degree,clustering,orbits4\n') | |||||
line = 'metric,real,ours,perturbed' | |||||
for bl in baselines: | |||||
line += ',' + bl | |||||
line += '\n' | |||||
out_files['compare'].write(line) | |||||
results = { | |||||
'deg': { | |||||
'real': 0, | |||||
'ours': 100, # take min over all training epochs | |||||
'perturbed': 0, | |||||
'kron': 0}, | |||||
'clustering': { | |||||
'real': 0, | |||||
'ours': 100, | |||||
'perturbed': 0, | |||||
'kron': 0}, | |||||
'orbits4': { | |||||
'real': 0, | |||||
'ours': 100, | |||||
'perturbed': 0, | |||||
'kron': 0} | |||||
} | |||||
num_evals = len(pred_graphs_filename) | |||||
if epoch_range is None: | |||||
epoch_range = [i * eval_every for i in range(num_evals)] | |||||
for i in range(num_evals): | |||||
real_g_list = utils.load_graph_list(real_graph_filename) | |||||
#pred_g_list = utils.load_graph_list(pred_graphs_filename[i]) | |||||
# contains all predicted G | |||||
pred_g_list_raw = utils.load_graph_list(pred_graphs_filename[i]) | |||||
if len(real_g_list)>200: | |||||
real_g_list = real_g_list[0:200] | |||||
shuffle(real_g_list) | |||||
shuffle(pred_g_list_raw) | |||||
# get length | |||||
real_g_len_list = np.array([len(real_g_list[i]) for i in range(len(real_g_list))]) | |||||
pred_g_len_list_raw = np.array([len(pred_g_list_raw[i]) for i in range(len(pred_g_list_raw))]) | |||||
# get perturb real | |||||
#perturbed_g_list_001 = perturb(real_g_list, 0.01) | |||||
perturbed_g_list_005 = perturb(real_g_list, 0.05) | |||||
#perturbed_g_list_010 = perturb(real_g_list, 0.10) | |||||
# select pred samples | |||||
# The number of nodes are sampled from the similar distribution as the training set | |||||
pred_g_list = [] | |||||
pred_g_len_list = [] | |||||
for value in real_g_len_list: | |||||
pred_idx = find_nearest_idx(pred_g_len_list_raw, value) | |||||
pred_g_list.append(pred_g_list_raw[pred_idx]) | |||||
pred_g_len_list.append(pred_g_len_list_raw[pred_idx]) | |||||
# delete | |||||
pred_g_len_list_raw = np.delete(pred_g_len_list_raw, pred_idx) | |||||
del pred_g_list_raw[pred_idx] | |||||
if len(pred_g_list) == len(real_g_list): | |||||
break | |||||
# pred_g_len_list = np.array(pred_g_len_list) | |||||
print('################## epoch {} ##################'.format(epoch_range[i])) | |||||
# info about graph size | |||||
print('real average nodes', | |||||
sum([real_g_list[i].number_of_nodes() for i in range(len(real_g_list))]) / len(real_g_list)) | |||||
print('pred average nodes', | |||||
sum([pred_g_list[i].number_of_nodes() for i in range(len(pred_g_list))]) / len(pred_g_list)) | |||||
print('num of real graphs', len(real_g_list)) | |||||
print('num of pred graphs', len(pred_g_list)) | |||||
# ======================================== | |||||
# Evaluation | |||||
# ======================================== | |||||
mid = len(real_g_list) // 2 | |||||
dist_degree, dist_clustering = compute_basic_stats(real_g_list[:mid], real_g_list[mid:]) | |||||
#dist_4cycle = eval.stats.motif_stats(real_g_list[:mid], real_g_list[mid:]) | |||||
dist_4orbits = eval.stats.orbit_stats_all(real_g_list[:mid], real_g_list[mid:]) | |||||
print('degree dist among real: ', dist_degree) | |||||
print('clustering dist among real: ', dist_clustering) | |||||
#print('4 cycle dist among real: ', dist_4cycle) | |||||
print('orbits dist among real: ', dist_4orbits) | |||||
results['deg']['real'] += dist_degree | |||||
results['clustering']['real'] += dist_clustering | |||||
results['orbits4']['real'] += dist_4orbits | |||||
dist_degree, dist_clustering = compute_basic_stats(real_g_list, pred_g_list) | |||||
#dist_4cycle = eval.stats.motif_stats(real_g_list, pred_g_list) | |||||
dist_4orbits = eval.stats.orbit_stats_all(real_g_list, pred_g_list) | |||||
print('degree dist between real and pred at epoch ', epoch_range[i], ': ', dist_degree) | |||||
print('clustering dist between real and pred at epoch ', epoch_range[i], ': ', dist_clustering) | |||||
#print('4 cycle dist between real and pred at epoch: ', epoch_range[i], dist_4cycle) | |||||
print('orbits dist between real and pred at epoch ', epoch_range[i], ': ', dist_4orbits) | |||||
results['deg']['ours'] = min(dist_degree, results['deg']['ours']) | |||||
results['clustering']['ours'] = min(dist_clustering, results['clustering']['ours']) | |||||
results['orbits4']['ours'] = min(dist_4orbits, results['orbits4']['ours']) | |||||
# performance at training time | |||||
out_files['train'].write(str(dist_degree) + ',') | |||||
out_files['train'].write(str(dist_clustering) + ',') | |||||
out_files['train'].write(str(dist_4orbits) + ',') | |||||
dist_degree, dist_clustering = compute_basic_stats(real_g_list, perturbed_g_list_005) | |||||
#dist_4cycle = eval.stats.motif_stats(real_g_list, perturbed_g_list_005) | |||||
dist_4orbits = eval.stats.orbit_stats_all(real_g_list, perturbed_g_list_005) | |||||
print('degree dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_degree) | |||||
print('clustering dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_clustering) | |||||
#print('4 cycle dist between real and perturbed at epoch: ', epoch_range[i], dist_4cycle) | |||||
print('orbits dist between real and perturbed at epoch ', epoch_range[i], ': ', dist_4orbits) | |||||
results['deg']['perturbed'] += dist_degree | |||||
results['clustering']['perturbed'] += dist_clustering | |||||
results['orbits4']['perturbed'] += dist_4orbits | |||||
if i == 0: | |||||
# Baselines | |||||
for baseline in baselines: | |||||
dist_degree, dist_clustering = compute_basic_stats(real_g_list, baselines[baseline]) | |||||
dist_4orbits = eval.stats.orbit_stats_all(real_g_list, baselines[baseline]) | |||||
results['deg'][baseline] = dist_degree | |||||
results['clustering'][baseline] = dist_clustering | |||||
results['orbits4'][baseline] = dist_4orbits | |||||
print('Kron: deg=', dist_degree, ', clustering=', dist_clustering, | |||||
', orbits4=', dist_4orbits) | |||||
out_files['train'].write('\n') | |||||
for metric, methods in results.items(): | |||||
methods['real'] /= num_evals | |||||
methods['perturbed'] /= num_evals | |||||
# Write results | |||||
for metric, methods in results.items(): | |||||
line = metric+','+ \ | |||||
str(methods['real'])+','+ \ | |||||
str(methods['ours'])+','+ \ | |||||
str(methods['perturbed']) | |||||
for baseline in baselines: | |||||
line += ',' + str(methods[baseline]) | |||||
line += '\n' | |||||
out_files['compare'].write(line) | |||||
for _, out_f in out_files.items(): | |||||
out_f.close() | |||||
def eval_performance(datadir, prefix=None, args=None, eval_every=200, out_file_prefix=None, | |||||
sample_time = 2, baselines={}): | |||||
if args is None: | |||||
real_graphs_filename = [datadir + f for f in os.listdir(datadir) | |||||
if re.match(prefix + '.*real.*\.dat', f)] | |||||
pred_graphs_filename = [datadir + f for f in os.listdir(datadir) | |||||
if re.match(prefix + '.*pred.*\.dat', f)] | |||||
eval_list(real_graphs_filename, pred_graphs_filename, prefix, 200) | |||||
else: | |||||
# # for vanilla graphrnn | |||||
# real_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | |||||
# str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(args.bptt_len) + '.dat' for epoch in range(0,50001,eval_every)] | |||||
# pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | |||||
# str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str(args.bptt_len) + '.dat' for epoch in range(0,50001,eval_every)] | |||||
real_graph_filename = datadir+args.graph_save_path + args.fname_test + '0.dat' | |||||
# for proposed model | |||||
end_epoch = 3001 | |||||
epoch_range = range(eval_every, end_epoch, eval_every) | |||||
pred_graphs_filename = [datadir+args.graph_save_path + args.fname_pred+str(epoch)+'_'+str(sample_time)+'.dat' | |||||
for epoch in epoch_range] | |||||
# for baseline model | |||||
#pred_graphs_filename = [datadir+args.fname_baseline+'.dat'] | |||||
#real_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | |||||
# str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str( | |||||
# args.bptt_len) + '_' + str(args.gumbel) + '.dat' for epoch in range(10000, 50001, eval_every)] | |||||
#pred_graphs_filename = [datadir + args.graph_save_path + args.note + '_' + args.graph_type + '_' + \ | |||||
# str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt) + '_' + str( | |||||
# args.bptt_len) + '_' + str(args.gumbel) + '.dat' for epoch in range(10000, 50001, eval_every)] | |||||
eval_list_fname(real_graph_filename, pred_graphs_filename, baselines, | |||||
epoch_range=epoch_range, | |||||
eval_every=eval_every, | |||||
out_file_prefix=out_file_prefix) | |||||
def process_kron(kron_dir): | |||||
txt_files = [] | |||||
for f in os.listdir(kron_dir): | |||||
filename = os.fsdecode(f) | |||||
if filename.endswith('.txt'): | |||||
txt_files.append(filename) | |||||
elif filename.endswith('.dat'): | |||||
return utils.load_graph_list(os.path.join(kron_dir, filename)) | |||||
G_list = [] | |||||
for filename in txt_files: | |||||
G_list.append(utils.snap_txt_output_to_nx(os.path.join(kron_dir, filename))) | |||||
return G_list | |||||
if __name__ == '__main__': | |||||
args = Args() | |||||
args_evaluate = Args_evaluate() | |||||
parser = argparse.ArgumentParser(description='Evaluation arguments.') | |||||
feature_parser = parser.add_mutually_exclusive_group(required=False) | |||||
feature_parser.add_argument('--export-real', dest='export', action='store_true') | |||||
feature_parser.add_argument('--no-export-real', dest='export', action='store_false') | |||||
feature_parser.add_argument('--kron-dir', dest='kron_dir', | |||||
help='Directory where graphs generated by kronecker method is stored.') | |||||
parser.add_argument('--testfile', dest='test_file', | |||||
help='The file that stores list of graphs to be evaluated. Only used when 1 list of ' | |||||
'graphs is to be evaluated.') | |||||
parser.add_argument('--dir-prefix', dest='dir_prefix', | |||||
help='The file that stores list of graphs to be evaluated. Can be used when evaluating multiple' | |||||
'models on multiple datasets.') | |||||
parser.add_argument('--graph-type', dest='graph_type', | |||||
help='Type of graphs / dataset.') | |||||
parser.set_defaults(export=False, kron_dir='', test_file='', | |||||
dir_prefix='', | |||||
graph_type=args.graph_type) | |||||
prog_args = parser.parse_args() | |||||
# dir_prefix = prog_args.dir_prefix | |||||
# dir_prefix = "/dfs/scratch0/jiaxuany0/" | |||||
dir_prefix = args.dir_input | |||||
time_now = strftime("%Y-%m-%d %H:%M:%S", gmtime()) | |||||
if not os.path.isdir('logs/'): | |||||
os.makedirs('logs/') | |||||
logging.basicConfig(filename='logs/evaluate' + time_now + '.log', level=logging.INFO) | |||||
if prog_args.export: | |||||
if not os.path.isdir('eval_results'): | |||||
os.makedirs('eval_results') | |||||
if not os.path.isdir('eval_results/ground_truth'): | |||||
os.makedirs('eval_results/ground_truth') | |||||
out_dir = os.path.join('eval_results/ground_truth', prog_args.graph_type) | |||||
if not os.path.isdir(out_dir): | |||||
os.makedirs(out_dir) | |||||
output_prefix = os.path.join(out_dir, prog_args.graph_type) | |||||
print('Export ground truth to prefix: ', output_prefix) | |||||
if prog_args.graph_type == 'grid': | |||||
graphs = [] | |||||
for i in range(10,20): | |||||
for j in range(10,20): | |||||
graphs.append(nx.grid_2d_graph(i,j)) | |||||
utils.export_graphs_to_txt(graphs, output_prefix) | |||||
elif prog_args.graph_type == 'caveman': | |||||
graphs = [] | |||||
for i in range(2, 3): | |||||
for j in range(30, 81): | |||||
for k in range(10): | |||||
graphs.append(caveman_special(i,j, p_edge=0.3)) | |||||
utils.export_graphs_to_txt(graphs, output_prefix) | |||||
elif prog_args.graph_type == 'citeseer': | |||||
graphs = utils.citeseer_ego() | |||||
utils.export_graphs_to_txt(graphs, output_prefix) | |||||
else: | |||||
# load from directory | |||||
input_path = dir_prefix + real_graph_filename | |||||
g_list = utils.load_graph_list(input_path) | |||||
utils.export_graphs_to_txt(g_list, output_prefix) | |||||
elif not prog_args.kron_dir == '': | |||||
kron_g_list = process_kron(prog_args.kron_dir) | |||||
fname = os.path.join(prog_args.kron_dir, prog_args.graph_type + '.dat') | |||||
print([g.number_of_nodes() for g in kron_g_list]) | |||||
utils.save_graph_list(kron_g_list, fname) | |||||
elif not prog_args.test_file == '': | |||||
# evaluate single .dat file containing list of test graphs (networkx format) | |||||
graphs = utils.load_graph_list(prog_args.test_file) | |||||
eval_single_list(graphs, dir_input=dir_prefix+'graphs/', dataset_name='grid') | |||||
## if you don't try kronecker, only the following part is needed | |||||
else: | |||||
if not os.path.isdir(dir_prefix+'eval_results'): | |||||
os.makedirs(dir_prefix+'eval_results') | |||||
evaluation(args_evaluate,dir_input=dir_prefix+"graphs/", dir_output=dir_prefix+"eval_results/", | |||||
model_name_all=args_evaluate.model_name_all,dataset_name_all=args_evaluate.dataset_name_all,args=args,overwrite=True) | |||||
from train import * | |||||
if __name__ == '__main__': | |||||
# All necessary arguments are defined in args.py | |||||
args = Args() | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda) | |||||
print('CUDA', args.cuda) | |||||
print('File name prefix',args.fname) | |||||
# check if necessary directories exist | |||||
if not os.path.isdir(args.model_save_path): | |||||
os.makedirs(args.model_save_path) | |||||
if not os.path.isdir(args.graph_save_path): | |||||
os.makedirs(args.graph_save_path) | |||||
if not os.path.isdir(args.figure_save_path): | |||||
os.makedirs(args.figure_save_path) | |||||
if not os.path.isdir(args.timing_save_path): | |||||
os.makedirs(args.timing_save_path) | |||||
if not os.path.isdir(args.figure_prediction_save_path): | |||||
os.makedirs(args.figure_prediction_save_path) | |||||
if not os.path.isdir(args.nll_save_path): | |||||
os.makedirs(args.nll_save_path) | |||||
time = strftime("%Y-%m-%d %H:%M:%S", gmtime()) | |||||
# logging.basicConfig(filename='logs/train' + time + '.log', level=logging.DEBUG) | |||||
if args.clean_tensorboard: | |||||
if os.path.isdir("tensorboard"): | |||||
shutil.rmtree("tensorboard") | |||||
configure("tensorboard/run"+time, flush_secs=5) | |||||
graphs = create_graphs.create(args) | |||||
# split datasets | |||||
random.seed(123) | |||||
shuffle(graphs) | |||||
graphs_len = len(graphs) | |||||
graphs_test = graphs[int(0.8 * graphs_len):] | |||||
graphs_train = graphs[0:int(0.8*graphs_len)] | |||||
graphs_validate = graphs[0:int(0.2*graphs_len)] | |||||
# if use pre-saved graphs | |||||
# dir_input = "/dfs/scratch0/jiaxuany0/graphs/" | |||||
# fname_test = dir_input + args.note + '_' + args.graph_type + '_' + str(args.num_layers) + '_' + str( | |||||
# args.hidden_size_rnn) + '_test_' + str(0) + '.dat' | |||||
# graphs = load_graph_list(fname_test, is_real=True) | |||||
# graphs_test = graphs[int(0.8 * graphs_len):] | |||||
# graphs_train = graphs[0:int(0.8 * graphs_len)] | |||||
# graphs_validate = graphs[int(0.2 * graphs_len):int(0.4 * graphs_len)] | |||||
graph_validate_len = 0 | |||||
for graph in graphs_validate: | |||||
graph_validate_len += graph.number_of_nodes() | |||||
graph_validate_len /= len(graphs_validate) | |||||
print('graph_validate_len', graph_validate_len) | |||||
graph_test_len = 0 | |||||
for graph in graphs_test: | |||||
graph_test_len += graph.number_of_nodes() | |||||
graph_test_len /= len(graphs_test) | |||||
print('graph_test_len', graph_test_len) | |||||
args.max_num_node = max([graphs[i].number_of_nodes() for i in range(len(graphs))]) | |||||
max_num_edge = max([graphs[i].number_of_edges() for i in range(len(graphs))]) | |||||
min_num_edge = min([graphs[i].number_of_edges() for i in range(len(graphs))]) | |||||
# args.max_num_node = 2000 | |||||
# show graphs statistics | |||||
print('total graph num: {}, training set: {}'.format(len(graphs),len(graphs_train))) | |||||
print('max number node: {}'.format(args.max_num_node)) | |||||
print('max/min number edge: {}; {}'.format(max_num_edge,min_num_edge)) | |||||
print('max previous node: {}'.format(args.max_prev_node)) | |||||
# save ground truth graphs | |||||
## To get train and test set, after loading you need to manually slice | |||||
save_graph_list(graphs, args.graph_save_path + args.fname_train + '0.dat') | |||||
save_graph_list(graphs, args.graph_save_path + args.fname_test + '0.dat') | |||||
print('train and test graphs saved at: ', args.graph_save_path + args.fname_test + '0.dat') | |||||
### comment when normal training, for graph completion only | |||||
# p = 0.5 | |||||
# for graph in graphs_train: | |||||
# for node in list(graph.nodes()): | |||||
# # print('node',node) | |||||
# if np.random.rand()>p: | |||||
# graph.remove_node(node) | |||||
# for edge in list(graph.edges()): | |||||
# # print('edge',edge) | |||||
# if np.random.rand()>p: | |||||
# graph.remove_edge(edge[0],edge[1]) | |||||
### dataset initialization | |||||
if 'nobfs' in args.note: | |||||
print('nobfs') | |||||
dataset = Graph_sequence_sampler_pytorch_nobfs(graphs_train, max_num_node=args.max_num_node) | |||||
args.max_prev_node = args.max_num_node-1 | |||||
if 'barabasi_noise' in args.graph_type: | |||||
print('barabasi_noise') | |||||
dataset = Graph_sequence_sampler_pytorch_canonical(graphs_train,max_prev_node=args.max_prev_node) | |||||
args.max_prev_node = args.max_num_node - 1 | |||||
else: | |||||
dataset = Graph_sequence_sampler_pytorch(graphs_train,max_prev_node=args.max_prev_node,max_num_node=args.max_num_node) | |||||
sample_strategy = torch.utils.data.sampler.WeightedRandomSampler([1.0 / len(dataset) for i in range(len(dataset))], | |||||
num_samples=args.batch_size*args.batch_ratio, replacement=True) | |||||
dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, | |||||
sampler=sample_strategy) | |||||
### model initialization | |||||
## Graph RNN VAE model | |||||
# lstm = LSTM_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_lstm, | |||||
# hidden_size=args.hidden_size, num_layers=args.num_layers).cuda() | |||||
if 'GraphRNN_VAE_conditional' in args.note: | |||||
rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn, | |||||
hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True, | |||||
has_output=False).cuda() | |||||
output = MLP_VAE_conditional_plain(h_size=args.hidden_size_rnn, embedding_size=args.embedding_size_output, y_size=args.max_prev_node).cuda() | |||||
elif 'GraphRNN_MLP' in args.note: | |||||
rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn, | |||||
hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True, | |||||
has_output=False).cuda() | |||||
output = MLP_plain(h_size=args.hidden_size_rnn, embedding_size=args.embedding_size_output, y_size=args.max_prev_node).cuda() | |||||
elif 'GraphRNN_RNN' in args.note: | |||||
rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn, | |||||
hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True, | |||||
has_output=True, output_size=args.hidden_size_rnn_output).cuda() | |||||
output = GRU_plain(input_size=1, embedding_size=args.embedding_size_rnn_output, | |||||
hidden_size=args.hidden_size_rnn_output, num_layers=args.num_layers, has_input=True, | |||||
has_output=True, output_size=1).cuda() | |||||
### start training | |||||
train(args, dataset_loader, rnn, output) | |||||
### graph completion | |||||
# train_graph_completion(args,dataset_loader,rnn,output) | |||||
### nll evaluation | |||||
# train_nll(args, dataset_loader, dataset_loader, rnn, output, max_iter = 200, graph_validate_len=graph_validate_len,graph_test_len=graph_test_len) | |||||
# an implementation for "Learning Deep Generative Models of Graphs" | |||||
from main import * | |||||
class Args_DGMG(): | |||||
def __init__(self): | |||||
### CUDA | |||||
self.cuda = 2 | |||||
### model type | |||||
self.note = 'Baseline_DGMG' # do GCN after adding each edge | |||||
# self.note = 'Baseline_DGMG_fast' # do GCN only after adding each node | |||||
### data config | |||||
self.graph_type = 'caveman_small' | |||||
# self.graph_type = 'grid_small' | |||||
# self.graph_type = 'ladder_small' | |||||
# self.graph_type = 'enzymes_small' | |||||
# self.graph_type = 'barabasi_small' | |||||
# self.graph_type = 'citeseer_small' | |||||
self.max_num_node = 20 | |||||
### network config | |||||
self.node_embedding_size = 64 | |||||
self.test_graph_num = 200 | |||||
### training config | |||||
self.epochs = 2000 # now one epoch means self.batch_ratio x batch_size | |||||
self.load_epoch = 2000 | |||||
self.epochs_test_start = 100 | |||||
self.epochs_test = 100 | |||||
self.epochs_log = 100 | |||||
self.epochs_save = 100 | |||||
if 'fast' in self.note: | |||||
self.is_fast = True | |||||
else: | |||||
self.is_fast = False | |||||
self.lr = 0.001 | |||||
self.milestones = [300, 600, 1000] | |||||
self.lr_rate = 0.3 | |||||
### output config | |||||
self.model_save_path = 'model_save/' | |||||
self.graph_save_path = 'graphs/' | |||||
self.figure_save_path = 'figures/' | |||||
self.timing_save_path = 'timing/' | |||||
self.figure_prediction_save_path = 'figures_prediction/' | |||||
self.nll_save_path = 'nll/' | |||||
self.fname = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size) | |||||
self.fname_pred = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size) + '_pred_' | |||||
self.fname_train = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size) + '_train_' | |||||
self.fname_test = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size) + '_test_' | |||||
self.load = False | |||||
self.save = True | |||||
def train_DGMG_epoch(epoch, args, model, dataset, optimizer, scheduler, is_fast = False): | |||||
model.train() | |||||
graph_num = len(dataset) | |||||
order = list(range(graph_num)) | |||||
shuffle(order) | |||||
loss_addnode = 0 | |||||
loss_addedge = 0 | |||||
loss_node = 0 | |||||
for i in order: | |||||
model.zero_grad() | |||||
graph = dataset[i] | |||||
# do random ordering: relabel nodes | |||||
node_order = list(range(graph.number_of_nodes())) | |||||
shuffle(node_order) | |||||
order_mapping = dict(zip(graph.nodes(), node_order)) | |||||
graph = nx.relabel_nodes(graph, order_mapping, copy=True) | |||||
# NOTE: when starting loop, we assume a node has already been generated | |||||
node_count = 1 | |||||
node_embedding = [Variable(torch.ones(1,args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden | |||||
loss = 0 | |||||
while node_count<=graph.number_of_nodes(): | |||||
node_neighbor = graph.subgraph(list(range(node_count))).adjacency_list() # list of lists (first node is zero) | |||||
node_neighbor_new = graph.subgraph(list(range(node_count+1))).adjacency_list()[-1] # list of new node's neighbors | |||||
# 1 message passing | |||||
# do 2 times message passing | |||||
node_embedding = message_passing(node_neighbor, node_embedding, model) | |||||
# 2 graph embedding and new node embedding | |||||
node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||||
init_embedding = calc_init_embedding(node_embedding_cat, model) | |||||
# 3 f_addnode | |||||
p_addnode = model.f_an(graph_embedding) | |||||
if node_count < graph.number_of_nodes(): | |||||
# add node | |||||
node_neighbor.append([]) | |||||
node_embedding.append(init_embedding) | |||||
if is_fast: | |||||
node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
# calc loss | |||||
loss_addnode_step = F.binary_cross_entropy(p_addnode,Variable(torch.ones((1,1))).cuda()) | |||||
# loss_addnode_step.backward(retain_graph=True) | |||||
loss += loss_addnode_step | |||||
loss_addnode += loss_addnode_step.data | |||||
else: | |||||
# calc loss | |||||
loss_addnode_step = F.binary_cross_entropy(p_addnode, Variable(torch.zeros((1, 1))).cuda()) | |||||
# loss_addnode_step.backward(retain_graph=True) | |||||
loss += loss_addnode_step | |||||
loss_addnode += loss_addnode_step.data | |||||
break | |||||
edge_count = 0 | |||||
while edge_count<=len(node_neighbor_new): | |||||
if not is_fast: | |||||
node_embedding = message_passing(node_neighbor, node_embedding, model) | |||||
node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||||
# 4 f_addedge | |||||
p_addedge = model.f_ae(graph_embedding) | |||||
if edge_count < len(node_neighbor_new): | |||||
# calc loss | |||||
loss_addedge_step = F.binary_cross_entropy(p_addedge, Variable(torch.ones((1, 1))).cuda()) | |||||
# loss_addedge_step.backward(retain_graph=True) | |||||
loss += loss_addedge_step | |||||
loss_addedge += loss_addedge_step.data | |||||
# 5 f_nodes | |||||
# excluding the last node (which is the new node) | |||||
node_new_embedding_cat = node_embedding_cat[-1,:].expand(node_embedding_cat.size(0)-1,node_embedding_cat.size(1)) | |||||
s_node = model.f_s(torch.cat((node_embedding_cat[0:-1,:],node_new_embedding_cat),dim=1)) | |||||
p_node = F.softmax(s_node.permute(1,0)) | |||||
# get ground truth | |||||
a_node = torch.zeros((1,p_node.size(1))) | |||||
# print('node_neighbor_new',node_neighbor_new, edge_count) | |||||
a_node[0,node_neighbor_new[edge_count]] = 1 | |||||
a_node = Variable(a_node).cuda() | |||||
# add edge | |||||
node_neighbor[-1].append(node_neighbor_new[edge_count]) | |||||
node_neighbor[node_neighbor_new[edge_count]].append(len(node_neighbor)-1) | |||||
# calc loss | |||||
loss_node_step = F.binary_cross_entropy(p_node,a_node) | |||||
# loss_node_step.backward(retain_graph=True) | |||||
loss += loss_node_step | |||||
loss_node += loss_node_step.data | |||||
else: | |||||
# calc loss | |||||
loss_addedge_step = F.binary_cross_entropy(p_addedge, Variable(torch.zeros((1, 1))).cuda()) | |||||
# loss_addedge_step.backward(retain_graph=True) | |||||
loss += loss_addedge_step | |||||
loss_addedge += loss_addedge_step.data | |||||
break | |||||
edge_count += 1 | |||||
node_count += 1 | |||||
# update deterministic and lstm | |||||
loss.backward() | |||||
optimizer.step() | |||||
scheduler.step() | |||||
loss_all = loss_addnode + loss_addedge + loss_node | |||||
if epoch % args.epochs_log==0: | |||||
print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, hidden: {}'.format( | |||||
epoch, args.epochs,loss_all[0], args.graph_type, args.node_embedding_size)) | |||||
# loss_sum += loss.data[0]*x.size(0) | |||||
# return loss_sum | |||||
def train_DGMG_forward_epoch(args, model, dataset, is_fast = False): | |||||
model.train() | |||||
graph_num = len(dataset) | |||||
order = list(range(graph_num)) | |||||
shuffle(order) | |||||
loss_addnode = 0 | |||||
loss_addedge = 0 | |||||
loss_node = 0 | |||||
for i in order: | |||||
model.zero_grad() | |||||
graph = dataset[i] | |||||
# do random ordering: relabel nodes | |||||
node_order = list(range(graph.number_of_nodes())) | |||||
shuffle(node_order) | |||||
order_mapping = dict(zip(graph.nodes(), node_order)) | |||||
graph = nx.relabel_nodes(graph, order_mapping, copy=True) | |||||
# NOTE: when starting loop, we assume a node has already been generated | |||||
node_count = 1 | |||||
node_embedding = [Variable(torch.ones(1,args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden | |||||
loss = 0 | |||||
while node_count<=graph.number_of_nodes(): | |||||
node_neighbor = graph.subgraph(list(range(node_count))).adjacency_list() # list of lists (first node is zero) | |||||
node_neighbor_new = graph.subgraph(list(range(node_count+1))).adjacency_list()[-1] # list of new node's neighbors | |||||
# 1 message passing | |||||
# do 2 times message passing | |||||
node_embedding = message_passing(node_neighbor, node_embedding, model) | |||||
# 2 graph embedding and new node embedding | |||||
node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||||
init_embedding = calc_init_embedding(node_embedding_cat, model) | |||||
# 3 f_addnode | |||||
p_addnode = model.f_an(graph_embedding) | |||||
if node_count < graph.number_of_nodes(): | |||||
# add node | |||||
node_neighbor.append([]) | |||||
node_embedding.append(init_embedding) | |||||
if is_fast: | |||||
node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
# calc loss | |||||
loss_addnode_step = F.binary_cross_entropy(p_addnode,Variable(torch.ones((1,1))).cuda()) | |||||
# loss_addnode_step.backward(retain_graph=True) | |||||
loss += loss_addnode_step | |||||
loss_addnode += loss_addnode_step.data | |||||
else: | |||||
# calc loss | |||||
loss_addnode_step = F.binary_cross_entropy(p_addnode, Variable(torch.zeros((1, 1))).cuda()) | |||||
# loss_addnode_step.backward(retain_graph=True) | |||||
loss += loss_addnode_step | |||||
loss_addnode += loss_addnode_step.data | |||||
break | |||||
edge_count = 0 | |||||
while edge_count<=len(node_neighbor_new): | |||||
if not is_fast: | |||||
node_embedding = message_passing(node_neighbor, node_embedding, model) | |||||
node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||||
# 4 f_addedge | |||||
p_addedge = model.f_ae(graph_embedding) | |||||
if edge_count < len(node_neighbor_new): | |||||
# calc loss | |||||
loss_addedge_step = F.binary_cross_entropy(p_addedge, Variable(torch.ones((1, 1))).cuda()) | |||||
# loss_addedge_step.backward(retain_graph=True) | |||||
loss += loss_addedge_step | |||||
loss_addedge += loss_addedge_step.data | |||||
# 5 f_nodes | |||||
# excluding the last node (which is the new node) | |||||
node_new_embedding_cat = node_embedding_cat[-1,:].expand(node_embedding_cat.size(0)-1,node_embedding_cat.size(1)) | |||||
s_node = model.f_s(torch.cat((node_embedding_cat[0:-1,:],node_new_embedding_cat),dim=1)) | |||||
p_node = F.softmax(s_node.permute(1,0)) | |||||
# get ground truth | |||||
a_node = torch.zeros((1,p_node.size(1))) | |||||
# print('node_neighbor_new',node_neighbor_new, edge_count) | |||||
a_node[0,node_neighbor_new[edge_count]] = 1 | |||||
a_node = Variable(a_node).cuda() | |||||
# add edge | |||||
node_neighbor[-1].append(node_neighbor_new[edge_count]) | |||||
node_neighbor[node_neighbor_new[edge_count]].append(len(node_neighbor)-1) | |||||
# calc loss | |||||
loss_node_step = F.binary_cross_entropy(p_node,a_node) | |||||
# loss_node_step.backward(retain_graph=True) | |||||
loss += loss_node_step | |||||
loss_node += loss_node_step.data*p_node.size(1) | |||||
else: | |||||
# calc loss | |||||
loss_addedge_step = F.binary_cross_entropy(p_addedge, Variable(torch.zeros((1, 1))).cuda()) | |||||
# loss_addedge_step.backward(retain_graph=True) | |||||
loss += loss_addedge_step | |||||
loss_addedge += loss_addedge_step.data | |||||
break | |||||
edge_count += 1 | |||||
node_count += 1 | |||||
loss_all = loss_addnode + loss_addedge + loss_node | |||||
# if epoch % args.epochs_log==0: | |||||
# print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, hidden: {}'.format( | |||||
# epoch, args.epochs,loss_all[0], args.graph_type, args.node_embedding_size)) | |||||
return loss_all[0]/len(dataset) | |||||
def test_DGMG_epoch(args, model, is_fast=False): | |||||
model.eval() | |||||
graph_num = args.test_graph_num | |||||
graphs_generated = [] | |||||
for i in range(graph_num): | |||||
# NOTE: when starting loop, we assume a node has already been generated | |||||
node_neighbor = [[]] # list of lists (first node is zero) | |||||
node_embedding = [Variable(torch.ones(1,args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden | |||||
node_count = 1 | |||||
while node_count<=args.max_num_node: | |||||
# 1 message passing | |||||
# do 2 times message passing | |||||
node_embedding = message_passing(node_neighbor, node_embedding, model) | |||||
# 2 graph embedding and new node embedding | |||||
node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||||
init_embedding = calc_init_embedding(node_embedding_cat, model) | |||||
# 3 f_addnode | |||||
p_addnode = model.f_an(graph_embedding) | |||||
a_addnode = sample_tensor(p_addnode) | |||||
# print(a_addnode.data[0][0]) | |||||
if a_addnode.data[0][0]==1: | |||||
# print('add node') | |||||
# add node | |||||
node_neighbor.append([]) | |||||
node_embedding.append(init_embedding) | |||||
if is_fast: | |||||
node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
else: | |||||
break | |||||
edge_count = 0 | |||||
while edge_count<args.max_num_node: | |||||
if not is_fast: | |||||
node_embedding = message_passing(node_neighbor, node_embedding, model) | |||||
node_embedding_cat = torch.cat(node_embedding, dim=0) | |||||
graph_embedding = calc_graph_embedding(node_embedding_cat, model) | |||||
# 4 f_addedge | |||||
p_addedge = model.f_ae(graph_embedding) | |||||
a_addedge = sample_tensor(p_addedge) | |||||
# print(a_addedge.data[0][0]) | |||||
if a_addedge.data[0][0]==1: | |||||
# print('add edge') | |||||
# 5 f_nodes | |||||
# excluding the last node (which is the new node) | |||||
node_new_embedding_cat = node_embedding_cat[-1,:].expand(node_embedding_cat.size(0)-1,node_embedding_cat.size(1)) | |||||
s_node = model.f_s(torch.cat((node_embedding_cat[0:-1,:],node_new_embedding_cat),dim=1)) | |||||
p_node = F.softmax(s_node.permute(1,0)) | |||||
a_node = gumbel_softmax(p_node, temperature=0.01) | |||||
_, a_node_id = a_node.topk(1) | |||||
a_node_id = int(a_node_id.data[0][0]) | |||||
# add edge | |||||
node_neighbor[-1].append(a_node_id) | |||||
node_neighbor[a_node_id].append(len(node_neighbor)-1) | |||||
else: | |||||
break | |||||
edge_count += 1 | |||||
node_count += 1 | |||||
# save graph | |||||
node_neighbor_dict = dict(zip(list(range(len(node_neighbor))), node_neighbor)) | |||||
graph = nx.from_dict_of_lists(node_neighbor_dict) | |||||
graphs_generated.append(graph) | |||||
return graphs_generated | |||||
########### train function for LSTM + VAE | |||||
def train_DGMG(args, dataset_train, model): | |||||
# check if load existing model | |||||
if args.load: | |||||
fname = args.model_save_path + args.fname + 'model_' + str(args.load_epoch) + '.dat' | |||||
model.load_state_dict(torch.load(fname)) | |||||
args.lr = 0.00001 | |||||
epoch = args.load_epoch | |||||
print('model loaded!, lr: {}'.format(args.lr)) | |||||
else: | |||||
epoch = 1 | |||||
# initialize optimizer | |||||
optimizer = optim.Adam(list(model.parameters()), lr=args.lr) | |||||
scheduler = MultiStepLR(optimizer, milestones=args.milestones, gamma=args.lr_rate) | |||||
# start main loop | |||||
time_all = np.zeros(args.epochs) | |||||
while epoch <= args.epochs: | |||||
time_start = tm.time() | |||||
# train | |||||
train_DGMG_epoch(epoch, args, model, dataset_train, optimizer, scheduler, is_fast=args.is_fast) | |||||
time_end = tm.time() | |||||
time_all[epoch - 1] = time_end - time_start | |||||
# print('time used',time_all[epoch - 1]) | |||||
# test | |||||
if epoch % args.epochs_test == 0 and epoch >= args.epochs_test_start: | |||||
graphs = test_DGMG_epoch(args,model, is_fast=args.is_fast) | |||||
fname = args.graph_save_path + args.fname_pred + str(epoch) + '.dat' | |||||
save_graph_list(graphs, fname) | |||||
# print('test done, graphs saved') | |||||
# save model checkpoint | |||||
if args.save: | |||||
if epoch % args.epochs_save == 0: | |||||
fname = args.model_save_path + args.fname + 'model_' + str(epoch) + '.dat' | |||||
torch.save(model.state_dict(), fname) | |||||
epoch += 1 | |||||
np.save(args.timing_save_path + args.fname, time_all) | |||||
########### train function for LSTM + VAE | |||||
def train_DGMG_nll(args, dataset_train,dataset_test, model,max_iter=1000): | |||||
# check if load existing model | |||||
fname = args.model_save_path + args.fname + 'model_' + str(args.load_epoch) + '.dat' | |||||
model.load_state_dict(torch.load(fname)) | |||||
fname_output = args.nll_save_path + args.note + '_' + args.graph_type + '.csv' | |||||
with open(fname_output, 'w+') as f: | |||||
f.write('train,test\n') | |||||
# start main loop | |||||
for iter in range(max_iter): | |||||
nll_train = train_DGMG_forward_epoch(args, model, dataset_train, is_fast=args.is_fast) | |||||
nll_test = train_DGMG_forward_epoch(args, model, dataset_test, is_fast=args.is_fast) | |||||
print('train', nll_train, 'test', nll_test) | |||||
f.write(str(nll_train) + ',' + str(nll_test) + '\n') | |||||
if __name__ == '__main__': | |||||
args = Args_DGMG() | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda) | |||||
print('CUDA', args.cuda) | |||||
print('File name prefix',args.fname) | |||||
graphs = [] | |||||
for i in range(4, 10): | |||||
graphs.append(nx.ladder_graph(i)) | |||||
model = DGM_graphs(h_size = args.node_embedding_size).cuda() | |||||
if args.graph_type == 'ladder_small': | |||||
graphs = [] | |||||
for i in range(2, 11): | |||||
graphs.append(nx.ladder_graph(i)) | |||||
args.max_prev_node = 10 | |||||
# if args.graph_type == 'caveman_small': | |||||
# graphs = [] | |||||
# for i in range(2, 5): | |||||
# for j in range(2, 6): | |||||
# for k in range(10): | |||||
# graphs.append(nx.relaxed_caveman_graph(i, j, p=0.1)) | |||||
# args.max_prev_node = 20 | |||||
if args.graph_type=='caveman_small': | |||||
graphs = [] | |||||
for i in range(2, 3): | |||||
for j in range(6, 11): | |||||
for k in range(20): | |||||
graphs.append(caveman_special(i, j, p_edge=0.8)) | |||||
args.max_prev_node = 20 | |||||
if args.graph_type == 'grid_small': | |||||
graphs = [] | |||||
for i in range(2, 5): | |||||
for j in range(2, 6): | |||||
graphs.append(nx.grid_2d_graph(i, j)) | |||||
args.max_prev_node = 15 | |||||
if args.graph_type == 'barabasi_small': | |||||
graphs = [] | |||||
for i in range(4, 21): | |||||
for j in range(3, 4): | |||||
for k in range(10): | |||||
graphs.append(nx.barabasi_albert_graph(i, j)) | |||||
args.max_prev_node = 20 | |||||
if args.graph_type == 'enzymes_small': | |||||
graphs_raw = Graph_load_batch(min_num_nodes=10, name='ENZYMES') | |||||
graphs = [] | |||||
for G in graphs_raw: | |||||
if G.number_of_nodes()<=20: | |||||
graphs.append(G) | |||||
args.max_prev_node = 15 | |||||
if args.graph_type == 'citeseer_small': | |||||
_, _, G = Graph_load(dataset='citeseer') | |||||
G = max(nx.connected_component_subgraphs(G), key=len) | |||||
G = nx.convert_node_labels_to_integers(G) | |||||
graphs = [] | |||||
for i in range(G.number_of_nodes()): | |||||
G_ego = nx.ego_graph(G, i, radius=1) | |||||
if (G_ego.number_of_nodes() >= 4) and (G_ego.number_of_nodes() <= 20): | |||||
graphs.append(G_ego) | |||||
shuffle(graphs) | |||||
graphs = graphs[0:200] | |||||
args.max_prev_node = 15 | |||||
# remove self loops | |||||
for graph in graphs: | |||||
edges_with_selfloops = graph.selfloop_edges() | |||||
if len(edges_with_selfloops) > 0: | |||||
graph.remove_edges_from(edges_with_selfloops) | |||||
# split datasets | |||||
random.seed(123) | |||||
shuffle(graphs) | |||||
graphs_len = len(graphs) | |||||
graphs_test = graphs[int(0.8 * graphs_len):] | |||||
graphs_train = graphs[0:int(0.8 * graphs_len)] | |||||
args.max_num_node = max([graphs[i].number_of_nodes() for i in range(len(graphs))]) | |||||
# args.max_num_node = 2000 | |||||
# show graphs statistics | |||||
print('total graph num: {}, training set: {}'.format(len(graphs), len(graphs_train))) | |||||
print('max number node: {}'.format(args.max_num_node)) | |||||
print('max previous node: {}'.format(args.max_prev_node)) | |||||
# save ground truth graphs | |||||
# save_graph_list(graphs, args.graph_save_path + args.fname_train + '0.dat') | |||||
# save_graph_list(graphs, args.graph_save_path + args.fname_test + '0.dat') | |||||
# print('train and test graphs saved') | |||||
## if use pre-saved graphs | |||||
# dir_input = "graphs/" | |||||
# fname_test = args.graph_save_path + args.fname_test + '0.dat' | |||||
# graphs = load_graph_list(fname_test, is_real=True) | |||||
# graphs_test = graphs[int(0.8 * graphs_len):] | |||||
# graphs_train = graphs[0:int(0.8 * graphs_len)] | |||||
# graphs_validate = graphs[0:int(0.2 * graphs_len)] | |||||
# print('train') | |||||
# for graph in graphs_validate: | |||||
# print(graph.number_of_nodes()) | |||||
# print('test') | |||||
# for graph in graphs_test: | |||||
# print(graph.number_of_nodes()) | |||||
### train | |||||
train_DGMG(args,graphs,model) | |||||
### calc nll | |||||
# train_DGMG_nll(args, graphs_validate,graphs_test, model,max_iter=1000) | |||||
# for j in range(1000): | |||||
# graph = graphs[0] | |||||
# # do random ordering: relabel nodes | |||||
# node_order = list(range(graph.number_of_nodes())) | |||||
# shuffle(node_order) | |||||
# order_mapping = dict(zip(graph.nodes(), node_order)) | |||||
# graph = nx.relabel_nodes(graph, order_mapping, copy=True) | |||||
# print(graph.nodes()) |
import numpy as np | |||||
import matplotlib as mpl | |||||
import matplotlib.pyplot as plt | |||||
import seaborn as sns | |||||
sns.set() | |||||
sns.set_style("ticks") | |||||
sns.set_context("poster",font_scale=1.28,rc={"lines.linewidth": 3}) | |||||
### plot robustness result | |||||
noise = np.array([0,0.2,0.4,0.6,0.8,1.0]) | |||||
MLP_degree = np.array([0.3440, 0.1365, 0.0663, 0.0430, 0.0214, 0.0201]) | |||||
RNN_degree = np.array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]) | |||||
BA_degree = np.array([0.0892,0.3558,1.1754,1.5914,1.7037,1.7502]) | |||||
Gnp_degree = np.array([1.7115,1.5536,0.5529,0.1433,0.0725,0.0503]) | |||||
MLP_clustering = np.array([0.0096, 0.0056, 0.0027, 0.0020, 0.0012, 0.0028]) | |||||
RNN_clustering = np.array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]) | |||||
BA_clustering = np.array([0.0255,0.0881,0.3433,0.4237,0.6041,0.7851]) | |||||
Gnp_clustering = np.array([0.7683,0.1849,0.1081,0.0146,0.0210,0.0329]) | |||||
plt.plot(noise,Gnp_degree) | |||||
plt.plot(noise,BA_degree) | |||||
plt.plot(noise, MLP_degree) | |||||
# plt.plot(noise, RNN_degree) | |||||
# plt.rc('text', usetex=True) | |||||
plt.legend(['E-R','B-A','GraphRNN']) | |||||
plt.xlabel('Noise level') | |||||
plt.ylabel('MMD degree') | |||||
plt.tight_layout() | |||||
plt.savefig('figures_paper/robustness_degree.png',dpi=300) | |||||
plt.close() | |||||
plt.plot(noise,Gnp_clustering) | |||||
plt.plot(noise,BA_clustering) | |||||
plt.plot(noise, MLP_clustering) | |||||
# plt.plot(noise, RNN_clustering) | |||||
plt.legend(['E-R','B-A','GraphRNN']) | |||||
plt.xlabel('Noise level') | |||||
plt.ylabel('MMD clustering') | |||||
plt.tight_layout() | |||||
plt.savefig('figures_paper/robustness_clustering.png',dpi=300) | |||||
plt.close() | |||||
tensorboard-logger | |||||
tensorflow | |||||
networkx==1.11 | |||||
pyemd |
import torch | |||||
import numpy as np | |||||
import time | |||||
def compute_kernel(x,y): | |||||
x_size = x.size(0) | |||||
y_size = y.size(0) | |||||
dim = x.size(1) | |||||
x_tile = x.view(x_size,1,dim) | |||||
x_tile = x_tile.repeat(1,y_size,1) | |||||
y_tile = y.view(1,y_size,dim) | |||||
y_tile = y_tile.repeat(x_size,1,1) | |||||
return torch.exp(-torch.mean((x_tile-y_tile)**2,dim = 2)/float(dim)) | |||||
def compute_mmd(x,y): | |||||
x_kernel = compute_kernel(x,x) | |||||
# print(x_kernel) | |||||
y_kernel = compute_kernel(y,y) | |||||
# print(y_kernel) | |||||
xy_kernel = compute_kernel(x,y) | |||||
# print(xy_kernel) | |||||
return torch.mean(x_kernel)+torch.mean(y_kernel)-2*torch.mean(xy_kernel) | |||||
# start = time.time() | |||||
# x = torch.randn(4000,1).cuda() | |||||
# y = torch.randn(4000,1).cuda() | |||||
# print(compute_mmd(x,y)) | |||||
# end = time.time() | |||||
# print('GPU time:', end-start) | |||||
start = time.time() | |||||
torch.manual_seed(123) | |||||
batch = 1000 | |||||
x = torch.randn(batch,1) | |||||
y_baseline = torch.randn(batch,1) | |||||
y_pred = torch.zeros(batch,1) | |||||
print('MMD baseline', compute_mmd(x,y_baseline)) | |||||
print('MMD prediction', compute_mmd(x,y_pred)) | |||||
# | |||||
# print('before',x) | |||||
# print('MMD', compute_mmd(x,y)) | |||||
# x_idx = np.random.permutation(x.size(0)) | |||||
# x = x[x_idx,:] | |||||
# print('after permutation',x) | |||||
# print('MMD', compute_mmd(x,y)) | |||||
# | |||||
# | |||||
# end = time.time() | |||||
# print('CPU time:', end-start) |
import networkx as nx | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.init as init | |||||
from torch.autograd import Variable | |||||
import matplotlib.pyplot as plt | |||||
import torch.nn.functional as F | |||||
from torch import optim | |||||
from torch.optim.lr_scheduler import MultiStepLR | |||||
from sklearn.decomposition import PCA | |||||
import logging | |||||
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence | |||||
from time import gmtime, strftime | |||||
from sklearn.metrics import roc_curve | |||||
from sklearn.metrics import roc_auc_score | |||||
from sklearn.metrics import average_precision_score | |||||
from random import shuffle | |||||
import pickle | |||||
from tensorboard_logger import configure, log_value | |||||
import scipy.misc | |||||
import time as tm | |||||
from utils import * | |||||
from model import * | |||||
from data import * | |||||
from args import Args | |||||
import create_graphs | |||||
def train_vae_epoch(epoch, args, rnn, output, data_loader, | |||||
optimizer_rnn, optimizer_output, | |||||
scheduler_rnn, scheduler_output): | |||||
rnn.train() | |||||
output.train() | |||||
loss_sum = 0 | |||||
for batch_idx, data in enumerate(data_loader): | |||||
rnn.zero_grad() | |||||
output.zero_grad() | |||||
x_unsorted = data['x'].float() | |||||
y_unsorted = data['y'].float() | |||||
y_len_unsorted = data['len'] | |||||
y_len_max = max(y_len_unsorted) | |||||
x_unsorted = x_unsorted[:, 0:y_len_max, :] | |||||
y_unsorted = y_unsorted[:, 0:y_len_max, :] | |||||
# initialize lstm hidden state according to batch size | |||||
rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) | |||||
# sort input | |||||
y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) | |||||
y_len = y_len.numpy().tolist() | |||||
x = torch.index_select(x_unsorted,0,sort_index) | |||||
y = torch.index_select(y_unsorted,0,sort_index) | |||||
x = Variable(x).cuda() | |||||
y = Variable(y).cuda() | |||||
# if using ground truth to train | |||||
h = rnn(x, pack=True, input_len=y_len) | |||||
y_pred,z_mu,z_lsgms = output(h) | |||||
y_pred = F.sigmoid(y_pred) | |||||
# clean | |||||
y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True) | |||||
y_pred = pad_packed_sequence(y_pred, batch_first=True)[0] | |||||
z_mu = pack_padded_sequence(z_mu, y_len, batch_first=True) | |||||
z_mu = pad_packed_sequence(z_mu, batch_first=True)[0] | |||||
z_lsgms = pack_padded_sequence(z_lsgms, y_len, batch_first=True) | |||||
z_lsgms = pad_packed_sequence(z_lsgms, batch_first=True)[0] | |||||
# use cross entropy loss | |||||
loss_bce = binary_cross_entropy_weight(y_pred, y) | |||||
loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp()) | |||||
loss_kl /= y.size(0)*y.size(1)*sum(y_len) # normalize | |||||
loss = loss_bce + loss_kl | |||||
loss.backward() | |||||
# update deterministic and lstm | |||||
optimizer_output.step() | |||||
optimizer_rnn.step() | |||||
scheduler_output.step() | |||||
scheduler_rnn.step() | |||||
z_mu_mean = torch.mean(z_mu.data) | |||||
z_sgm_mean = torch.mean(z_lsgms.mul(0.5).exp_().data) | |||||
z_mu_min = torch.min(z_mu.data) | |||||
z_sgm_min = torch.min(z_lsgms.mul(0.5).exp_().data) | |||||
z_mu_max = torch.max(z_mu.data) | |||||
z_sgm_max = torch.max(z_lsgms.mul(0.5).exp_().data) | |||||
if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics | |||||
print('Epoch: {}/{}, train bce loss: {:.6f}, train kl loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | |||||
epoch, args.epochs,loss_bce.data[0], loss_kl.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||||
print('z_mu_mean', z_mu_mean, 'z_mu_min', z_mu_min, 'z_mu_max', z_mu_max, 'z_sgm_mean', z_sgm_mean, 'z_sgm_min', z_sgm_min, 'z_sgm_max', z_sgm_max) | |||||
# logging | |||||
log_value('bce_loss_'+args.fname, loss_bce.data[0], epoch*args.batch_ratio+batch_idx) | |||||
log_value('kl_loss_' +args.fname, loss_kl.data[0], epoch*args.batch_ratio + batch_idx) | |||||
log_value('z_mu_mean_'+args.fname, z_mu_mean, epoch*args.batch_ratio + batch_idx) | |||||
log_value('z_mu_min_'+args.fname, z_mu_min, epoch*args.batch_ratio + batch_idx) | |||||
log_value('z_mu_max_'+args.fname, z_mu_max, epoch*args.batch_ratio + batch_idx) | |||||
log_value('z_sgm_mean_'+args.fname, z_sgm_mean, epoch*args.batch_ratio + batch_idx) | |||||
log_value('z_sgm_min_'+args.fname, z_sgm_min, epoch*args.batch_ratio + batch_idx) | |||||
log_value('z_sgm_max_'+args.fname, z_sgm_max, epoch*args.batch_ratio + batch_idx) | |||||
loss_sum += loss.data[0] | |||||
return loss_sum/(batch_idx+1) | |||||
def test_vae_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False, sample_time = 1): | |||||
rnn.hidden = rnn.init_hidden(test_batch_size) | |||||
rnn.eval() | |||||
output.eval() | |||||
# generate graphs | |||||
max_num_node = int(args.max_num_node) | |||||
y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||||
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||||
for i in range(max_num_node): | |||||
h = rnn(x_step) | |||||
y_pred_step, _, _ = output(h) | |||||
y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) | |||||
x_step = sample_sigmoid(y_pred_step, sample=True, sample_time=sample_time) | |||||
y_pred_long[:, i:i + 1, :] = x_step | |||||
rnn.hidden = Variable(rnn.hidden.data).cuda() | |||||
y_pred_data = y_pred.data | |||||
y_pred_long_data = y_pred_long.data.long() | |||||
# save graphs as pickle | |||||
G_pred_list = [] | |||||
for i in range(test_batch_size): | |||||
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | |||||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||||
G_pred_list.append(G_pred) | |||||
# save prediction histograms, plot histogram over each time step | |||||
# if save_histogram: | |||||
# save_prediction_histogram(y_pred_data.cpu().numpy(), | |||||
# fname_pred=args.figure_prediction_save_path+args.fname_pred+str(epoch)+'.jpg', | |||||
# max_num_node=max_num_node) | |||||
return G_pred_list | |||||
def test_vae_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1): | |||||
rnn.eval() | |||||
output.eval() | |||||
G_pred_list = [] | |||||
for batch_idx, data in enumerate(data_loader): | |||||
x = data['x'].float() | |||||
y = data['y'].float() | |||||
y_len = data['len'] | |||||
test_batch_size = x.size(0) | |||||
rnn.hidden = rnn.init_hidden(test_batch_size) | |||||
# generate graphs | |||||
max_num_node = int(args.max_num_node) | |||||
y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||||
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||||
for i in range(max_num_node): | |||||
print('finish node',i) | |||||
h = rnn(x_step) | |||||
y_pred_step, _, _ = output(h) | |||||
y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) | |||||
x_step = sample_sigmoid_supervised(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time) | |||||
y_pred_long[:, i:i + 1, :] = x_step | |||||
rnn.hidden = Variable(rnn.hidden.data).cuda() | |||||
y_pred_data = y_pred.data | |||||
y_pred_long_data = y_pred_long.data.long() | |||||
# save graphs as pickle | |||||
for i in range(test_batch_size): | |||||
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | |||||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||||
G_pred_list.append(G_pred) | |||||
return G_pred_list | |||||
def train_mlp_epoch(epoch, args, rnn, output, data_loader, | |||||
optimizer_rnn, optimizer_output, | |||||
scheduler_rnn, scheduler_output): | |||||
rnn.train() | |||||
output.train() | |||||
loss_sum = 0 | |||||
for batch_idx, data in enumerate(data_loader): | |||||
rnn.zero_grad() | |||||
output.zero_grad() | |||||
x_unsorted = data['x'].float() | |||||
y_unsorted = data['y'].float() | |||||
y_len_unsorted = data['len'] | |||||
y_len_max = max(y_len_unsorted) | |||||
x_unsorted = x_unsorted[:, 0:y_len_max, :] | |||||
y_unsorted = y_unsorted[:, 0:y_len_max, :] | |||||
# initialize lstm hidden state according to batch size | |||||
rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) | |||||
# sort input | |||||
y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) | |||||
y_len = y_len.numpy().tolist() | |||||
x = torch.index_select(x_unsorted,0,sort_index) | |||||
y = torch.index_select(y_unsorted,0,sort_index) | |||||
x = Variable(x).cuda() | |||||
y = Variable(y).cuda() | |||||
h = rnn(x, pack=True, input_len=y_len) | |||||
y_pred = output(h) | |||||
y_pred = F.sigmoid(y_pred) | |||||
# clean | |||||
y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True) | |||||
y_pred = pad_packed_sequence(y_pred, batch_first=True)[0] | |||||
# use cross entropy loss | |||||
loss = binary_cross_entropy_weight(y_pred, y) | |||||
loss.backward() | |||||
# update deterministic and lstm | |||||
optimizer_output.step() | |||||
optimizer_rnn.step() | |||||
scheduler_output.step() | |||||
scheduler_rnn.step() | |||||
if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics | |||||
print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | |||||
epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||||
# logging | |||||
log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx) | |||||
loss_sum += loss.data[0] | |||||
return loss_sum/(batch_idx+1) | |||||
def test_mlp_epoch(epoch, args, rnn, output, test_batch_size=16, save_histogram=False,sample_time=1): | |||||
rnn.hidden = rnn.init_hidden(test_batch_size) | |||||
rnn.eval() | |||||
output.eval() | |||||
# generate graphs | |||||
max_num_node = int(args.max_num_node) | |||||
y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||||
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||||
for i in range(max_num_node): | |||||
h = rnn(x_step) | |||||
y_pred_step = output(h) | |||||
y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) | |||||
x_step = sample_sigmoid(y_pred_step, sample=True, sample_time=sample_time) | |||||
y_pred_long[:, i:i + 1, :] = x_step | |||||
rnn.hidden = Variable(rnn.hidden.data).cuda() | |||||
y_pred_data = y_pred.data | |||||
y_pred_long_data = y_pred_long.data.long() | |||||
# save graphs as pickle | |||||
G_pred_list = [] | |||||
for i in range(test_batch_size): | |||||
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | |||||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||||
G_pred_list.append(G_pred) | |||||
# # save prediction histograms, plot histogram over each time step | |||||
# if save_histogram: | |||||
# save_prediction_histogram(y_pred_data.cpu().numpy(), | |||||
# fname_pred=args.figure_prediction_save_path+args.fname_pred+str(epoch)+'.jpg', | |||||
# max_num_node=max_num_node) | |||||
return G_pred_list | |||||
def test_mlp_partial_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1): | |||||
rnn.eval() | |||||
output.eval() | |||||
G_pred_list = [] | |||||
for batch_idx, data in enumerate(data_loader): | |||||
x = data['x'].float() | |||||
y = data['y'].float() | |||||
y_len = data['len'] | |||||
test_batch_size = x.size(0) | |||||
rnn.hidden = rnn.init_hidden(test_batch_size) | |||||
# generate graphs | |||||
max_num_node = int(args.max_num_node) | |||||
y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||||
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||||
for i in range(max_num_node): | |||||
print('finish node',i) | |||||
h = rnn(x_step) | |||||
y_pred_step = output(h) | |||||
y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) | |||||
x_step = sample_sigmoid_supervised(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time) | |||||
y_pred_long[:, i:i + 1, :] = x_step | |||||
rnn.hidden = Variable(rnn.hidden.data).cuda() | |||||
y_pred_data = y_pred.data | |||||
y_pred_long_data = y_pred_long.data.long() | |||||
# save graphs as pickle | |||||
for i in range(test_batch_size): | |||||
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | |||||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||||
G_pred_list.append(G_pred) | |||||
return G_pred_list | |||||
def test_mlp_partial_simple_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1): | |||||
rnn.eval() | |||||
output.eval() | |||||
G_pred_list = [] | |||||
for batch_idx, data in enumerate(data_loader): | |||||
x = data['x'].float() | |||||
y = data['y'].float() | |||||
y_len = data['len'] | |||||
test_batch_size = x.size(0) | |||||
rnn.hidden = rnn.init_hidden(test_batch_size) | |||||
# generate graphs | |||||
max_num_node = int(args.max_num_node) | |||||
y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||||
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||||
for i in range(max_num_node): | |||||
print('finish node',i) | |||||
h = rnn(x_step) | |||||
y_pred_step = output(h) | |||||
y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) | |||||
x_step = sample_sigmoid_supervised_simple(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time) | |||||
y_pred_long[:, i:i + 1, :] = x_step | |||||
rnn.hidden = Variable(rnn.hidden.data).cuda() | |||||
y_pred_data = y_pred.data | |||||
y_pred_long_data = y_pred_long.data.long() | |||||
# save graphs as pickle | |||||
for i in range(test_batch_size): | |||||
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | |||||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||||
G_pred_list.append(G_pred) | |||||
return G_pred_list | |||||
def train_mlp_forward_epoch(epoch, args, rnn, output, data_loader): | |||||
rnn.train() | |||||
output.train() | |||||
loss_sum = 0 | |||||
for batch_idx, data in enumerate(data_loader): | |||||
rnn.zero_grad() | |||||
output.zero_grad() | |||||
x_unsorted = data['x'].float() | |||||
y_unsorted = data['y'].float() | |||||
y_len_unsorted = data['len'] | |||||
y_len_max = max(y_len_unsorted) | |||||
x_unsorted = x_unsorted[:, 0:y_len_max, :] | |||||
y_unsorted = y_unsorted[:, 0:y_len_max, :] | |||||
# initialize lstm hidden state according to batch size | |||||
rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) | |||||
# sort input | |||||
y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) | |||||
y_len = y_len.numpy().tolist() | |||||
x = torch.index_select(x_unsorted,0,sort_index) | |||||
y = torch.index_select(y_unsorted,0,sort_index) | |||||
x = Variable(x).cuda() | |||||
y = Variable(y).cuda() | |||||
h = rnn(x, pack=True, input_len=y_len) | |||||
y_pred = output(h) | |||||
y_pred = F.sigmoid(y_pred) | |||||
# clean | |||||
y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True) | |||||
y_pred = pad_packed_sequence(y_pred, batch_first=True)[0] | |||||
# use cross entropy loss | |||||
loss = 0 | |||||
for j in range(y.size(1)): | |||||
# print('y_pred',y_pred[0,j,:],'y',y[0,j,:]) | |||||
end_idx = min(j+1,y.size(2)) | |||||
loss += binary_cross_entropy_weight(y_pred[:,j,0:end_idx], y[:,j,0:end_idx])*end_idx | |||||
if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics | |||||
print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | |||||
epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||||
# logging | |||||
log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx) | |||||
loss_sum += loss.data[0] | |||||
return loss_sum/(batch_idx+1) | |||||
## too complicated, deprecated | |||||
# def test_mlp_partial_bfs_epoch(epoch, args, rnn, output, data_loader, save_histogram=False,sample_time=1): | |||||
# rnn.eval() | |||||
# output.eval() | |||||
# G_pred_list = [] | |||||
# for batch_idx, data in enumerate(data_loader): | |||||
# x = data['x'].float() | |||||
# y = data['y'].float() | |||||
# y_len = data['len'] | |||||
# test_batch_size = x.size(0) | |||||
# rnn.hidden = rnn.init_hidden(test_batch_size) | |||||
# # generate graphs | |||||
# max_num_node = int(args.max_num_node) | |||||
# y_pred = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # normalized prediction score | |||||
# y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||||
# x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||||
# for i in range(max_num_node): | |||||
# # 1 back up hidden state | |||||
# hidden_prev = Variable(rnn.hidden.data).cuda() | |||||
# h = rnn(x_step) | |||||
# y_pred_step = output(h) | |||||
# y_pred[:, i:i + 1, :] = F.sigmoid(y_pred_step) | |||||
# x_step = sample_sigmoid_supervised(y_pred_step, y[:,i:i+1,:].cuda(), current=i, y_len=y_len, sample_time=sample_time) | |||||
# y_pred_long[:, i:i + 1, :] = x_step | |||||
# | |||||
# rnn.hidden = Variable(rnn.hidden.data).cuda() | |||||
# | |||||
# print('finish node', i) | |||||
# y_pred_data = y_pred.data | |||||
# y_pred_long_data = y_pred_long.data.long() | |||||
# | |||||
# # save graphs as pickle | |||||
# for i in range(test_batch_size): | |||||
# adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | |||||
# G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||||
# G_pred_list.append(G_pred) | |||||
# return G_pred_list | |||||
def train_rnn_epoch(epoch, args, rnn, output, data_loader, | |||||
optimizer_rnn, optimizer_output, | |||||
scheduler_rnn, scheduler_output): | |||||
rnn.train() | |||||
output.train() | |||||
loss_sum = 0 | |||||
for batch_idx, data in enumerate(data_loader): | |||||
rnn.zero_grad() | |||||
output.zero_grad() | |||||
x_unsorted = data['x'].float() | |||||
y_unsorted = data['y'].float() | |||||
y_len_unsorted = data['len'] | |||||
y_len_max = max(y_len_unsorted) | |||||
x_unsorted = x_unsorted[:, 0:y_len_max, :] | |||||
y_unsorted = y_unsorted[:, 0:y_len_max, :] | |||||
# initialize lstm hidden state according to batch size | |||||
rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) | |||||
# output.hidden = output.init_hidden(batch_size=x_unsorted.size(0)*x_unsorted.size(1)) | |||||
# sort input | |||||
y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) | |||||
y_len = y_len.numpy().tolist() | |||||
x = torch.index_select(x_unsorted,0,sort_index) | |||||
y = torch.index_select(y_unsorted,0,sort_index) | |||||
# input, output for output rnn module | |||||
# a smart use of pytorch builtin function: pack variable--b1_l1,b2_l1,...,b1_l2,b2_l2,... | |||||
y_reshape = pack_padded_sequence(y,y_len,batch_first=True).data | |||||
# reverse y_reshape, so that their lengths are sorted, add dimension | |||||
idx = [i for i in range(y_reshape.size(0)-1, -1, -1)] | |||||
idx = torch.LongTensor(idx) | |||||
y_reshape = y_reshape.index_select(0, idx) | |||||
y_reshape = y_reshape.view(y_reshape.size(0),y_reshape.size(1),1) | |||||
output_x = torch.cat((torch.ones(y_reshape.size(0),1,1),y_reshape[:,0:-1,0:1]),dim=1) | |||||
output_y = y_reshape | |||||
# batch size for output module: sum(y_len) | |||||
output_y_len = [] | |||||
output_y_len_bin = np.bincount(np.array(y_len)) | |||||
for i in range(len(output_y_len_bin)-1,0,-1): | |||||
count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i | |||||
output_y_len.extend([min(i,y.size(2))]*count_temp) # put them in output_y_len; max value should not exceed y.size(2) | |||||
# pack into variable | |||||
x = Variable(x).cuda() | |||||
y = Variable(y).cuda() | |||||
output_x = Variable(output_x).cuda() | |||||
output_y = Variable(output_y).cuda() | |||||
# print(output_y_len) | |||||
# print('len',len(output_y_len)) | |||||
# print('y',y.size()) | |||||
# print('output_y',output_y.size()) | |||||
# if using ground truth to train | |||||
h = rnn(x, pack=True, input_len=y_len) | |||||
h = pack_padded_sequence(h,y_len,batch_first=True).data # get packed hidden vector | |||||
# reverse h | |||||
idx = [i for i in range(h.size(0) - 1, -1, -1)] | |||||
idx = Variable(torch.LongTensor(idx)).cuda() | |||||
h = h.index_select(0, idx) | |||||
hidden_null = Variable(torch.zeros(args.num_layers-1, h.size(0), h.size(1))).cuda() | |||||
output.hidden = torch.cat((h.view(1,h.size(0),h.size(1)),hidden_null),dim=0) # num_layers, batch_size, hidden_size | |||||
y_pred = output(output_x, pack=True, input_len=output_y_len) | |||||
y_pred = F.sigmoid(y_pred) | |||||
# clean | |||||
y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True) | |||||
y_pred = pad_packed_sequence(y_pred, batch_first=True)[0] | |||||
output_y = pack_padded_sequence(output_y,output_y_len,batch_first=True) | |||||
output_y = pad_packed_sequence(output_y,batch_first=True)[0] | |||||
# use cross entropy loss | |||||
loss = binary_cross_entropy_weight(y_pred, output_y) | |||||
loss.backward() | |||||
# update deterministic and lstm | |||||
optimizer_output.step() | |||||
optimizer_rnn.step() | |||||
scheduler_output.step() | |||||
scheduler_rnn.step() | |||||
if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics | |||||
print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | |||||
epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||||
# logging | |||||
log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx) | |||||
feature_dim = y.size(1)*y.size(2) | |||||
loss_sum += loss.data[0]*feature_dim | |||||
return loss_sum/(batch_idx+1) | |||||
def test_rnn_epoch(epoch, args, rnn, output, test_batch_size=16): | |||||
rnn.hidden = rnn.init_hidden(test_batch_size) | |||||
rnn.eval() | |||||
output.eval() | |||||
# generate graphs | |||||
max_num_node = int(args.max_num_node) | |||||
y_pred_long = Variable(torch.zeros(test_batch_size, max_num_node, args.max_prev_node)).cuda() # discrete prediction | |||||
x_step = Variable(torch.ones(test_batch_size,1,args.max_prev_node)).cuda() | |||||
for i in range(max_num_node): | |||||
h = rnn(x_step) | |||||
# output.hidden = h.permute(1,0,2) | |||||
hidden_null = Variable(torch.zeros(args.num_layers - 1, h.size(0), h.size(2))).cuda() | |||||
output.hidden = torch.cat((h.permute(1,0,2), hidden_null), | |||||
dim=0) # num_layers, batch_size, hidden_size | |||||
x_step = Variable(torch.zeros(test_batch_size,1,args.max_prev_node)).cuda() | |||||
output_x_step = Variable(torch.ones(test_batch_size,1,1)).cuda() | |||||
for j in range(min(args.max_prev_node,i+1)): | |||||
output_y_pred_step = output(output_x_step) | |||||
output_x_step = sample_sigmoid(output_y_pred_step, sample=True, sample_time=1) | |||||
x_step[:,:,j:j+1] = output_x_step | |||||
output.hidden = Variable(output.hidden.data).cuda() | |||||
y_pred_long[:, i:i + 1, :] = x_step | |||||
rnn.hidden = Variable(rnn.hidden.data).cuda() | |||||
y_pred_long_data = y_pred_long.data.long() | |||||
# save graphs as pickle | |||||
G_pred_list = [] | |||||
for i in range(test_batch_size): | |||||
adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy()) | |||||
G_pred = get_graph(adj_pred) # get a graph from zero-padded adj | |||||
G_pred_list.append(G_pred) | |||||
return G_pred_list | |||||
def train_rnn_forward_epoch(epoch, args, rnn, output, data_loader): | |||||
rnn.train() | |||||
output.train() | |||||
loss_sum = 0 | |||||
for batch_idx, data in enumerate(data_loader): | |||||
rnn.zero_grad() | |||||
output.zero_grad() | |||||
x_unsorted = data['x'].float() | |||||
y_unsorted = data['y'].float() | |||||
y_len_unsorted = data['len'] | |||||
y_len_max = max(y_len_unsorted) | |||||
x_unsorted = x_unsorted[:, 0:y_len_max, :] | |||||
y_unsorted = y_unsorted[:, 0:y_len_max, :] | |||||
# initialize lstm hidden state according to batch size | |||||
rnn.hidden = rnn.init_hidden(batch_size=x_unsorted.size(0)) | |||||
# output.hidden = output.init_hidden(batch_size=x_unsorted.size(0)*x_unsorted.size(1)) | |||||
# sort input | |||||
y_len,sort_index = torch.sort(y_len_unsorted,0,descending=True) | |||||
y_len = y_len.numpy().tolist() | |||||
x = torch.index_select(x_unsorted,0,sort_index) | |||||
y = torch.index_select(y_unsorted,0,sort_index) | |||||
# input, output for output rnn module | |||||
# a smart use of pytorch builtin function: pack variable--b1_l1,b2_l1,...,b1_l2,b2_l2,... | |||||
y_reshape = pack_padded_sequence(y,y_len,batch_first=True).data | |||||
# reverse y_reshape, so that their lengths are sorted, add dimension | |||||
idx = [i for i in range(y_reshape.size(0)-1, -1, -1)] | |||||
idx = torch.LongTensor(idx) | |||||
y_reshape = y_reshape.index_select(0, idx) | |||||
y_reshape = y_reshape.view(y_reshape.size(0),y_reshape.size(1),1) | |||||
output_x = torch.cat((torch.ones(y_reshape.size(0),1,1),y_reshape[:,0:-1,0:1]),dim=1) | |||||
output_y = y_reshape | |||||
# batch size for output module: sum(y_len) | |||||
output_y_len = [] | |||||
output_y_len_bin = np.bincount(np.array(y_len)) | |||||
for i in range(len(output_y_len_bin)-1,0,-1): | |||||
count_temp = np.sum(output_y_len_bin[i:]) # count how many y_len is above i | |||||
output_y_len.extend([min(i,y.size(2))]*count_temp) # put them in output_y_len; max value should not exceed y.size(2) | |||||
# pack into variable | |||||
x = Variable(x).cuda() | |||||
y = Variable(y).cuda() | |||||
output_x = Variable(output_x).cuda() | |||||
output_y = Variable(output_y).cuda() | |||||
# print(output_y_len) | |||||
# print('len',len(output_y_len)) | |||||
# print('y',y.size()) | |||||
# print('output_y',output_y.size()) | |||||
# if using ground truth to train | |||||
h = rnn(x, pack=True, input_len=y_len) | |||||
h = pack_padded_sequence(h,y_len,batch_first=True).data # get packed hidden vector | |||||
# reverse h | |||||
idx = [i for i in range(h.size(0) - 1, -1, -1)] | |||||
idx = Variable(torch.LongTensor(idx)).cuda() | |||||
h = h.index_select(0, idx) | |||||
hidden_null = Variable(torch.zeros(args.num_layers-1, h.size(0), h.size(1))).cuda() | |||||
output.hidden = torch.cat((h.view(1,h.size(0),h.size(1)),hidden_null),dim=0) # num_layers, batch_size, hidden_size | |||||
y_pred = output(output_x, pack=True, input_len=output_y_len) | |||||
y_pred = F.sigmoid(y_pred) | |||||
# clean | |||||
y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True) | |||||
y_pred = pad_packed_sequence(y_pred, batch_first=True)[0] | |||||
output_y = pack_padded_sequence(output_y,output_y_len,batch_first=True) | |||||
output_y = pad_packed_sequence(output_y,batch_first=True)[0] | |||||
# use cross entropy loss | |||||
loss = binary_cross_entropy_weight(y_pred, output_y) | |||||
if epoch % args.epochs_log==0 and batch_idx==0: # only output first batch's statistics | |||||
print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format( | |||||
epoch, args.epochs,loss.data[0], args.graph_type, args.num_layers, args.hidden_size_rnn)) | |||||
# logging | |||||
log_value('loss_'+args.fname, loss.data[0], epoch*args.batch_ratio+batch_idx) | |||||
# print(y_pred.size()) | |||||
feature_dim = y_pred.size(0)*y_pred.size(1) | |||||
loss_sum += loss.data[0]*feature_dim/y.size(0) | |||||
return loss_sum/(batch_idx+1) | |||||
########### train function for LSTM + VAE | |||||
def train(args, dataset_train, rnn, output): | |||||
# check if load existing model | |||||
if args.load: | |||||
fname = args.model_save_path + args.fname + 'lstm_' + str(args.load_epoch) + '.dat' | |||||
rnn.load_state_dict(torch.load(fname)) | |||||
fname = args.model_save_path + args.fname + 'output_' + str(args.load_epoch) + '.dat' | |||||
output.load_state_dict(torch.load(fname)) | |||||
args.lr = 0.00001 | |||||
epoch = args.load_epoch | |||||
print('model loaded!, lr: {}'.format(args.lr)) | |||||
else: | |||||
epoch = 1 | |||||
# initialize optimizer | |||||
optimizer_rnn = optim.Adam(list(rnn.parameters()), lr=args.lr) | |||||
optimizer_output = optim.Adam(list(output.parameters()), lr=args.lr) | |||||
scheduler_rnn = MultiStepLR(optimizer_rnn, milestones=args.milestones, gamma=args.lr_rate) | |||||
scheduler_output = MultiStepLR(optimizer_output, milestones=args.milestones, gamma=args.lr_rate) | |||||
# start main loop | |||||
time_all = np.zeros(args.epochs) | |||||
while epoch<=args.epochs: | |||||
time_start = tm.time() | |||||
# train | |||||
if 'GraphRNN_VAE' in args.note: | |||||
train_vae_epoch(epoch, args, rnn, output, dataset_train, | |||||
optimizer_rnn, optimizer_output, | |||||
scheduler_rnn, scheduler_output) | |||||
elif 'GraphRNN_MLP' in args.note: | |||||
train_mlp_epoch(epoch, args, rnn, output, dataset_train, | |||||
optimizer_rnn, optimizer_output, | |||||
scheduler_rnn, scheduler_output) | |||||
elif 'GraphRNN_RNN' in args.note: | |||||
train_rnn_epoch(epoch, args, rnn, output, dataset_train, | |||||
optimizer_rnn, optimizer_output, | |||||
scheduler_rnn, scheduler_output) | |||||
time_end = tm.time() | |||||
time_all[epoch - 1] = time_end - time_start | |||||
# test | |||||
if epoch % args.epochs_test == 0 and epoch>=args.epochs_test_start: | |||||
for sample_time in range(1,4): | |||||
G_pred = [] | |||||
while len(G_pred)<args.test_total_size: | |||||
if 'GraphRNN_VAE' in args.note: | |||||
G_pred_step = test_vae_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size,sample_time=sample_time) | |||||
elif 'GraphRNN_MLP' in args.note: | |||||
G_pred_step = test_mlp_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size,sample_time=sample_time) | |||||
elif 'GraphRNN_RNN' in args.note: | |||||
G_pred_step = test_rnn_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size) | |||||
G_pred.extend(G_pred_step) | |||||
# save graphs | |||||
fname = args.graph_save_path + args.fname_pred + str(epoch) +'_'+str(sample_time) + '.dat' | |||||
save_graph_list(G_pred, fname) | |||||
if 'GraphRNN_RNN' in args.note: | |||||
break | |||||
print('test done, graphs saved') | |||||
# save model checkpoint | |||||
if args.save: | |||||
if epoch % args.epochs_save == 0: | |||||
fname = args.model_save_path + args.fname + 'lstm_' + str(epoch) + '.dat' | |||||
torch.save(rnn.state_dict(), fname) | |||||
fname = args.model_save_path + args.fname + 'output_' + str(epoch) + '.dat' | |||||
torch.save(output.state_dict(), fname) | |||||
epoch += 1 | |||||
np.save(args.timing_save_path+args.fname,time_all) | |||||
########### for graph completion task | |||||
def train_graph_completion(args, dataset_test, rnn, output): | |||||
fname = args.model_save_path + args.fname + 'lstm_' + str(args.load_epoch) + '.dat' | |||||
rnn.load_state_dict(torch.load(fname)) | |||||
fname = args.model_save_path + args.fname + 'output_' + str(args.load_epoch) + '.dat' | |||||
output.load_state_dict(torch.load(fname)) | |||||
epoch = args.load_epoch | |||||
print('model loaded!, epoch: {}'.format(args.load_epoch)) | |||||
for sample_time in range(1,4): | |||||
if 'GraphRNN_MLP' in args.note: | |||||
G_pred = test_mlp_partial_simple_epoch(epoch, args, rnn, output, dataset_test,sample_time=sample_time) | |||||
if 'GraphRNN_VAE' in args.note: | |||||
G_pred = test_vae_partial_epoch(epoch, args, rnn, output, dataset_test,sample_time=sample_time) | |||||
# save graphs | |||||
fname = args.graph_save_path + args.fname_pred + str(epoch) +'_'+str(sample_time) + 'graph_completion.dat' | |||||
save_graph_list(G_pred, fname) | |||||
print('graph completion done, graphs saved') | |||||
########### for NLL evaluation | |||||
def train_nll(args, dataset_train, dataset_test, rnn, output,graph_validate_len,graph_test_len, max_iter = 1000): | |||||
fname = args.model_save_path + args.fname + 'lstm_' + str(args.load_epoch) + '.dat' | |||||
rnn.load_state_dict(torch.load(fname)) | |||||
fname = args.model_save_path + args.fname + 'output_' + str(args.load_epoch) + '.dat' | |||||
output.load_state_dict(torch.load(fname)) | |||||
epoch = args.load_epoch | |||||
print('model loaded!, epoch: {}'.format(args.load_epoch)) | |||||
fname_output = args.nll_save_path + args.note + '_' + args.graph_type + '.csv' | |||||
with open(fname_output, 'w+') as f: | |||||
f.write(str(graph_validate_len)+','+str(graph_test_len)+'\n') | |||||
f.write('train,test\n') | |||||
for iter in range(max_iter): | |||||
if 'GraphRNN_MLP' in args.note: | |||||
nll_train = train_mlp_forward_epoch(epoch, args, rnn, output, dataset_train) | |||||
nll_test = train_mlp_forward_epoch(epoch, args, rnn, output, dataset_test) | |||||
if 'GraphRNN_RNN' in args.note: | |||||
nll_train = train_rnn_forward_epoch(epoch, args, rnn, output, dataset_train) | |||||
nll_test = train_rnn_forward_epoch(epoch, args, rnn, output, dataset_test) | |||||
print('train',nll_train,'test',nll_test) | |||||
f.write(str(nll_train)+','+str(nll_test)+'\n') | |||||
print('NLL evaluation done') |
import networkx as nx | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.init as init | |||||
from torch.autograd import Variable | |||||
import matplotlib.pyplot as plt | |||||
import torch.nn.functional as F | |||||
from torch import optim | |||||
from torch.optim.lr_scheduler import MultiStepLR | |||||
# import node2vec.src.main as nv | |||||
from sklearn.decomposition import PCA | |||||
import community | |||||
import pickle | |||||
import re | |||||
import data | |||||
def citeseer_ego(): | |||||
_, _, G = data.Graph_load(dataset='citeseer') | |||||
G = max(nx.connected_component_subgraphs(G), key=len) | |||||
G = nx.convert_node_labels_to_integers(G) | |||||
graphs = [] | |||||
for i in range(G.number_of_nodes()): | |||||
G_ego = nx.ego_graph(G, i, radius=3) | |||||
if G_ego.number_of_nodes() >= 50 and (G_ego.number_of_nodes() <= 400): | |||||
graphs.append(G_ego) | |||||
return graphs | |||||
def caveman_special(c=2,k=20,p_path=0.1,p_edge=0.3): | |||||
p = p_path | |||||
path_count = max(int(np.ceil(p * k)),1) | |||||
G = nx.caveman_graph(c, k) | |||||
# remove 50% edges | |||||
p = 1-p_edge | |||||
for (u, v) in list(G.edges()): | |||||
if np.random.rand() < p and ((u < k and v < k) or (u >= k and v >= k)): | |||||
G.remove_edge(u, v) | |||||
# add path_count links | |||||
for i in range(path_count): | |||||
u = np.random.randint(0, k) | |||||
v = np.random.randint(k, k * 2) | |||||
G.add_edge(u, v) | |||||
G = max(nx.connected_component_subgraphs(G), key=len) | |||||
return G | |||||
def n_community(c_sizes, p_inter=0.01): | |||||
graphs = [nx.gnp_random_graph(c_sizes[i], 0.7, seed=i) for i in range(len(c_sizes))] | |||||
G = nx.disjoint_union_all(graphs) | |||||
communities = list(nx.connected_component_subgraphs(G)) | |||||
for i in range(len(communities)): | |||||
subG1 = communities[i] | |||||
nodes1 = list(subG1.nodes()) | |||||
for j in range(i+1, len(communities)): | |||||
subG2 = communities[j] | |||||
nodes2 = list(subG2.nodes()) | |||||
has_inter_edge = False | |||||
for n1 in nodes1: | |||||
for n2 in nodes2: | |||||
if np.random.rand() < p_inter: | |||||
G.add_edge(n1, n2) | |||||
has_inter_edge = True | |||||
if not has_inter_edge: | |||||
G.add_edge(nodes1[0], nodes2[0]) | |||||
#print('connected comp: ', len(list(nx.connected_component_subgraphs(G)))) | |||||
return G | |||||
def perturb(graph_list, p_del, p_add=None): | |||||
''' Perturb the list of graphs by adding/removing edges. | |||||
Args: | |||||
p_add: probability of adding edges. If None, estimate it according to graph density, | |||||
such that the expected number of added edges is equal to that of deleted edges. | |||||
p_del: probability of removing edges | |||||
Returns: | |||||
A list of graphs that are perturbed from the original graphs | |||||
''' | |||||
perturbed_graph_list = [] | |||||
for G_original in graph_list: | |||||
G = G_original.copy() | |||||
trials = np.random.binomial(1, p_del, size=G.number_of_edges()) | |||||
edges = list(G.edges()) | |||||
i = 0 | |||||
for (u, v) in edges: | |||||
if trials[i] == 1: | |||||
G.remove_edge(u, v) | |||||
i += 1 | |||||
if p_add is None: | |||||
num_nodes = G.number_of_nodes() | |||||
p_add_est = np.sum(trials) / (num_nodes * (num_nodes - 1) / 2 - | |||||
G.number_of_edges()) | |||||
else: | |||||
p_add_est = p_add | |||||
nodes = list(G.nodes()) | |||||
tmp = 0 | |||||
for i in range(len(nodes)): | |||||
u = nodes[i] | |||||
trials = np.random.binomial(1, p_add_est, size=G.number_of_nodes()) | |||||
j = 0 | |||||
for j in range(i+1, len(nodes)): | |||||
v = nodes[j] | |||||
if trials[j] == 1: | |||||
tmp += 1 | |||||
G.add_edge(u, v) | |||||
j += 1 | |||||
perturbed_graph_list.append(G) | |||||
return perturbed_graph_list | |||||
def perturb_new(graph_list, p): | |||||
''' Perturb the list of graphs by adding/removing edges. | |||||
Args: | |||||
p_add: probability of adding edges. If None, estimate it according to graph density, | |||||
such that the expected number of added edges is equal to that of deleted edges. | |||||
p_del: probability of removing edges | |||||
Returns: | |||||
A list of graphs that are perturbed from the original graphs | |||||
''' | |||||
perturbed_graph_list = [] | |||||
for G_original in graph_list: | |||||
G = G_original.copy() | |||||
edge_remove_count = 0 | |||||
for (u, v) in list(G.edges()): | |||||
if np.random.rand()<p: | |||||
G.remove_edge(u, v) | |||||
edge_remove_count += 1 | |||||
# randomly add the edges back | |||||
for i in range(edge_remove_count): | |||||
while True: | |||||
u = np.random.randint(0, G.number_of_nodes()) | |||||
v = np.random.randint(0, G.number_of_nodes()) | |||||
if (not G.has_edge(u,v)) and (u!=v): | |||||
break | |||||
G.add_edge(u, v) | |||||
perturbed_graph_list.append(G) | |||||
return perturbed_graph_list | |||||
def imsave(fname, arr, vmin=None, vmax=None, cmap=None, format=None, origin=None): | |||||
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas | |||||
from matplotlib.figure import Figure | |||||
fig = Figure(figsize=arr.shape[::-1], dpi=1, frameon=False) | |||||
canvas = FigureCanvas(fig) | |||||
fig.figimage(arr, cmap=cmap, vmin=vmin, vmax=vmax, origin=origin) | |||||
fig.savefig(fname, dpi=1, format=format) | |||||
def save_prediction_histogram(y_pred_data, fname_pred, max_num_node, bin_n=20): | |||||
bin_edge = np.linspace(1e-6, 1, bin_n + 1) | |||||
output_pred = np.zeros((bin_n, max_num_node)) | |||||
for i in range(max_num_node): | |||||
output_pred[:, i], _ = np.histogram(y_pred_data[:, i, :], bins=bin_edge, density=False) | |||||
# normalize | |||||
output_pred[:, i] /= np.sum(output_pred[:, i]) | |||||
imsave(fname=fname_pred, arr=output_pred, origin='upper', cmap='Greys_r', vmin=0.0, vmax=3.0 / bin_n) | |||||
# draw a single graph G | |||||
def draw_graph(G, prefix = 'test'): | |||||
parts = community.best_partition(G) | |||||
values = [parts.get(node) for node in G.nodes()] | |||||
colors = [] | |||||
for i in range(len(values)): | |||||
if values[i] == 0: | |||||
colors.append('red') | |||||
if values[i] == 1: | |||||
colors.append('green') | |||||
if values[i] == 2: | |||||
colors.append('blue') | |||||
if values[i] == 3: | |||||
colors.append('yellow') | |||||
if values[i] == 4: | |||||
colors.append('orange') | |||||
if values[i] == 5: | |||||
colors.append('pink') | |||||
if values[i] == 6: | |||||
colors.append('black') | |||||
# spring_pos = nx.spring_layout(G) | |||||
plt.switch_backend('agg') | |||||
plt.axis("off") | |||||
pos = nx.spring_layout(G) | |||||
nx.draw_networkx(G, with_labels=True, node_size=35, node_color=colors,pos=pos) | |||||
# plt.switch_backend('agg') | |||||
# options = { | |||||
# 'node_color': 'black', | |||||
# 'node_size': 10, | |||||
# 'width': 1 | |||||
# } | |||||
# plt.figure() | |||||
# plt.subplot() | |||||
# nx.draw_networkx(G, **options) | |||||
plt.savefig('figures/graph_view_'+prefix+'.png', dpi=200) | |||||
plt.close() | |||||
plt.switch_backend('agg') | |||||
G_deg = nx.degree_histogram(G) | |||||
G_deg = np.array(G_deg) | |||||
# plt.plot(range(len(G_deg)), G_deg, 'r', linewidth = 2) | |||||
plt.loglog(np.arange(len(G_deg))[G_deg>0], G_deg[G_deg>0], 'r', linewidth=2) | |||||
plt.savefig('figures/degree_view_' + prefix + '.png', dpi=200) | |||||
plt.close() | |||||
# degree_sequence = sorted(nx.degree(G).values(), reverse=True) # degree sequence | |||||
# plt.loglog(degree_sequence, 'b-', marker='o') | |||||
# plt.title("Degree rank plot") | |||||
# plt.ylabel("degree") | |||||
# plt.xlabel("rank") | |||||
# plt.savefig('figures/degree_view_' + prefix + '.png', dpi=200) | |||||
# plt.close() | |||||
# G = nx.grid_2d_graph(8,8) | |||||
# G = nx.karate_club_graph() | |||||
# draw_graph(G) | |||||
# draw a list of graphs [G] | |||||
def draw_graph_list(G_list, row, col, fname = 'figures/test', layout='spring', is_single=False,k=1,node_size=55,alpha=1,width=1.3): | |||||
# # draw graph view | |||||
# from pylab import rcParams | |||||
# rcParams['figure.figsize'] = 12,3 | |||||
plt.switch_backend('agg') | |||||
for i,G in enumerate(G_list): | |||||
plt.subplot(row,col,i+1) | |||||
plt.subplots_adjust(left=0, bottom=0, right=1, top=1, | |||||
wspace=0, hspace=0) | |||||
# if i%2==0: | |||||
# plt.title('real nodes: '+str(G.number_of_nodes()), fontsize = 4) | |||||
# else: | |||||
# plt.title('pred nodes: '+str(G.number_of_nodes()), fontsize = 4) | |||||
# plt.title('num of nodes: '+str(G.number_of_nodes()), fontsize = 4) | |||||
# parts = community.best_partition(G) | |||||
# values = [parts.get(node) for node in G.nodes()] | |||||
# colors = [] | |||||
# for i in range(len(values)): | |||||
# if values[i] == 0: | |||||
# colors.append('red') | |||||
# if values[i] == 1: | |||||
# colors.append('green') | |||||
# if values[i] == 2: | |||||
# colors.append('blue') | |||||
# if values[i] == 3: | |||||
# colors.append('yellow') | |||||
# if values[i] == 4: | |||||
# colors.append('orange') | |||||
# if values[i] == 5: | |||||
# colors.append('pink') | |||||
# if values[i] == 6: | |||||
# colors.append('black') | |||||
plt.axis("off") | |||||
if layout=='spring': | |||||
pos = nx.spring_layout(G,k=k/np.sqrt(G.number_of_nodes()),iterations=100) | |||||
# pos = nx.spring_layout(G) | |||||
elif layout=='spectral': | |||||
pos = nx.spectral_layout(G) | |||||
# # nx.draw_networkx(G, with_labels=True, node_size=2, width=0.15, font_size = 1.5, node_color=colors,pos=pos) | |||||
# nx.draw_networkx(G, with_labels=False, node_size=1.5, width=0.2, font_size = 1.5, linewidths=0.2, node_color = 'k',pos=pos,alpha=0.2) | |||||
if is_single: | |||||
# node_size default 60, edge_width default 1.5 | |||||
nx.draw_networkx_nodes(G, pos, node_size=node_size, node_color='#336699', alpha=1, linewidths=0, font_size=0) | |||||
nx.draw_networkx_edges(G, pos, alpha=alpha, width=width) | |||||
else: | |||||
nx.draw_networkx_nodes(G, pos, node_size=1.5, node_color='#336699',alpha=1, linewidths=0.2, font_size = 1.5) | |||||
nx.draw_networkx_edges(G, pos, alpha=0.3,width=0.2) | |||||
# plt.axis('off') | |||||
# plt.title('Complete Graph of Odd-degree Nodes') | |||||
# plt.show() | |||||
plt.tight_layout() | |||||
plt.savefig(fname+'.png', dpi=600) | |||||
plt.close() | |||||
# # draw degree distribution | |||||
# plt.switch_backend('agg') | |||||
# for i, G in enumerate(G_list): | |||||
# plt.subplot(row, col, i + 1) | |||||
# G_deg = np.array(list(G.degree(G.nodes()).values())) | |||||
# bins = np.arange(20) | |||||
# plt.hist(np.array(G_deg), bins=bins, align='left') | |||||
# plt.xlabel('degree', fontsize = 3) | |||||
# plt.ylabel('count', fontsize = 3) | |||||
# G_deg_mean = 2*G.number_of_edges()/float(G.number_of_nodes()) | |||||
# # if i % 2 == 0: | |||||
# # plt.title('real average degree: {:.2f}'.format(G_deg_mean), fontsize=4) | |||||
# # else: | |||||
# # plt.title('pred average degree: {:.2f}'.format(G_deg_mean), fontsize=4) | |||||
# plt.title('average degree: {:.2f}'.format(G_deg_mean), fontsize=4) | |||||
# plt.tick_params(axis='both', which='major', labelsize=3) | |||||
# plt.tick_params(axis='both', which='minor', labelsize=3) | |||||
# plt.tight_layout() | |||||
# plt.savefig(fname+'_degree.png', dpi=600) | |||||
# plt.close() | |||||
# | |||||
# # draw clustering distribution | |||||
# plt.switch_backend('agg') | |||||
# for i, G in enumerate(G_list): | |||||
# plt.subplot(row, col, i + 1) | |||||
# G_cluster = list(nx.clustering(G).values()) | |||||
# bins = np.linspace(0,1,20) | |||||
# plt.hist(np.array(G_cluster), bins=bins, align='left') | |||||
# plt.xlabel('clustering coefficient', fontsize=3) | |||||
# plt.ylabel('count', fontsize=3) | |||||
# G_cluster_mean = sum(G_cluster) / len(G_cluster) | |||||
# # if i % 2 == 0: | |||||
# # plt.title('real average clustering: {:.4f}'.format(G_cluster_mean), fontsize=4) | |||||
# # else: | |||||
# # plt.title('pred average clustering: {:.4f}'.format(G_cluster_mean), fontsize=4) | |||||
# plt.title('average clustering: {:.4f}'.format(G_cluster_mean), fontsize=4) | |||||
# plt.tick_params(axis='both', which='major', labelsize=3) | |||||
# plt.tick_params(axis='both', which='minor', labelsize=3) | |||||
# plt.tight_layout() | |||||
# plt.savefig(fname+'_clustering.png', dpi=600) | |||||
# plt.close() | |||||
# | |||||
# # draw circle distribution | |||||
# plt.switch_backend('agg') | |||||
# for i, G in enumerate(G_list): | |||||
# plt.subplot(row, col, i + 1) | |||||
# cycle_len = [] | |||||
# cycle_all = nx.cycle_basis(G) | |||||
# for item in cycle_all: | |||||
# cycle_len.append(len(item)) | |||||
# | |||||
# bins = np.arange(20) | |||||
# plt.hist(np.array(cycle_len), bins=bins, align='left') | |||||
# plt.xlabel('cycle length', fontsize=3) | |||||
# plt.ylabel('count', fontsize=3) | |||||
# G_cycle_mean = 0 | |||||
# if len(cycle_len)>0: | |||||
# G_cycle_mean = sum(cycle_len) / len(cycle_len) | |||||
# # if i % 2 == 0: | |||||
# # plt.title('real average cycle: {:.4f}'.format(G_cycle_mean), fontsize=4) | |||||
# # else: | |||||
# # plt.title('pred average cycle: {:.4f}'.format(G_cycle_mean), fontsize=4) | |||||
# plt.title('average cycle: {:.4f}'.format(G_cycle_mean), fontsize=4) | |||||
# plt.tick_params(axis='both', which='major', labelsize=3) | |||||
# plt.tick_params(axis='both', which='minor', labelsize=3) | |||||
# plt.tight_layout() | |||||
# plt.savefig(fname+'_cycle.png', dpi=600) | |||||
# plt.close() | |||||
# | |||||
# # draw community distribution | |||||
# plt.switch_backend('agg') | |||||
# for i, G in enumerate(G_list): | |||||
# plt.subplot(row, col, i + 1) | |||||
# parts = community.best_partition(G) | |||||
# values = np.array([parts.get(node) for node in G.nodes()]) | |||||
# counts = np.sort(np.bincount(values)[::-1]) | |||||
# pos = np.arange(len(counts)) | |||||
# plt.bar(pos,counts,align = 'edge') | |||||
# plt.xlabel('community ID', fontsize=3) | |||||
# plt.ylabel('count', fontsize=3) | |||||
# G_community_count = len(counts) | |||||
# # if i % 2 == 0: | |||||
# # plt.title('real average clustering: {}'.format(G_community_count), fontsize=4) | |||||
# # else: | |||||
# # plt.title('pred average clustering: {}'.format(G_community_count), fontsize=4) | |||||
# plt.title('average clustering: {}'.format(G_community_count), fontsize=4) | |||||
# plt.tick_params(axis='both', which='major', labelsize=3) | |||||
# plt.tick_params(axis='both', which='minor', labelsize=3) | |||||
# plt.tight_layout() | |||||
# plt.savefig(fname+'_community.png', dpi=600) | |||||
# plt.close() | |||||
# plt.switch_backend('agg') | |||||
# G_deg = nx.degree_histogram(G) | |||||
# G_deg = np.array(G_deg) | |||||
# # plt.plot(range(len(G_deg)), G_deg, 'r', linewidth = 2) | |||||
# plt.loglog(np.arange(len(G_deg))[G_deg>0], G_deg[G_deg>0], 'r', linewidth=2) | |||||
# plt.savefig('figures/degree_view_' + prefix + '.png', dpi=200) | |||||
# plt.close() | |||||
# degree_sequence = sorted(nx.degree(G).values(), reverse=True) # degree sequence | |||||
# plt.loglog(degree_sequence, 'b-', marker='o') | |||||
# plt.title("Degree rank plot") | |||||
# plt.ylabel("degree") | |||||
# plt.xlabel("rank") | |||||
# plt.savefig('figures/degree_view_' + prefix + '.png', dpi=200) | |||||
# plt.close() | |||||
# directly get graph statistics from adj, obsoleted | |||||
def decode_graph(adj, prefix): | |||||
adj = np.asmatrix(adj) | |||||
G = nx.from_numpy_matrix(adj) | |||||
# G.remove_nodes_from(nx.isolates(G)) | |||||
print('num of nodes: {}'.format(G.number_of_nodes())) | |||||
print('num of edges: {}'.format(G.number_of_edges())) | |||||
G_deg = nx.degree_histogram(G) | |||||
G_deg_sum = [a * b for a, b in zip(G_deg, range(0, len(G_deg)))] | |||||
print('average degree: {}'.format(sum(G_deg_sum) / G.number_of_nodes())) | |||||
if nx.is_connected(G): | |||||
print('average path length: {}'.format(nx.average_shortest_path_length(G))) | |||||
print('average diameter: {}'.format(nx.diameter(G))) | |||||
G_cluster = sorted(list(nx.clustering(G).values())) | |||||
print('average clustering coefficient: {}'.format(sum(G_cluster) / len(G_cluster))) | |||||
cycle_len = [] | |||||
cycle_all = nx.cycle_basis(G, 0) | |||||
for item in cycle_all: | |||||
cycle_len.append(len(item)) | |||||
print('cycles', cycle_len) | |||||
print('cycle count', len(cycle_len)) | |||||
draw_graph(G, prefix=prefix) | |||||
def get_graph(adj): | |||||
''' | |||||
get a graph from zero-padded adj | |||||
:param adj: | |||||
:return: | |||||
''' | |||||
# remove all zeros rows and columns | |||||
adj = adj[~np.all(adj == 0, axis=1)] | |||||
adj = adj[:, ~np.all(adj == 0, axis=0)] | |||||
adj = np.asmatrix(adj) | |||||
G = nx.from_numpy_matrix(adj) | |||||
return G | |||||
# save a list of graphs | |||||
def save_graph_list(G_list, fname): | |||||
with open(fname, "wb") as f: | |||||
pickle.dump(G_list, f) | |||||
# pick the first connected component | |||||
def pick_connected_component(G): | |||||
node_list = nx.node_connected_component(G,0) | |||||
return G.subgraph(node_list) | |||||
def pick_connected_component_new(G): | |||||
adj_list = G.adjacency_list() | |||||
for id,adj in enumerate(adj_list): | |||||
id_min = min(adj) | |||||
if id<id_min and id>=1: | |||||
# if id<id_min and id>=4: | |||||
break | |||||
node_list = list(range(id)) # only include node prior than node "id" | |||||
G = G.subgraph(node_list) | |||||
G = max(nx.connected_component_subgraphs(G), key=len) | |||||
return G | |||||
# load a list of graphs | |||||
def load_graph_list(fname,is_real=True): | |||||
with open(fname, "rb") as f: | |||||
graph_list = pickle.load(f) | |||||
for i in range(len(graph_list)): | |||||
edges_with_selfloops = graph_list[i].selfloop_edges() | |||||
if len(edges_with_selfloops)>0: | |||||
graph_list[i].remove_edges_from(edges_with_selfloops) | |||||
if is_real: | |||||
graph_list[i] = max(nx.connected_component_subgraphs(graph_list[i]), key=len) | |||||
graph_list[i] = nx.convert_node_labels_to_integers(graph_list[i]) | |||||
else: | |||||
graph_list[i] = pick_connected_component_new(graph_list[i]) | |||||
return graph_list | |||||
def export_graphs_to_txt(g_list, output_filename_prefix): | |||||
i = 0 | |||||
for G in g_list: | |||||
f = open(output_filename_prefix + '_' + str(i) + '.txt', 'w+') | |||||
for (u, v) in G.edges(): | |||||
idx_u = G.nodes().index(u) | |||||
idx_v = G.nodes().index(v) | |||||
f.write(str(idx_u) + '\t' + str(idx_v) + '\n') | |||||
i += 1 | |||||
def snap_txt_output_to_nx(in_fname): | |||||
G = nx.Graph() | |||||
with open(in_fname, 'r') as f: | |||||
for line in f: | |||||
if not line[0] == '#': | |||||
splitted = re.split('[ \t]', line) | |||||
# self loop might be generated, but should be removed | |||||
u = int(splitted[0]) | |||||
v = int(splitted[1]) | |||||
if not u == v: | |||||
G.add_edge(int(u), int(v)) | |||||
return G | |||||
def test_perturbed(): | |||||
graphs = [] | |||||
for i in range(100,101): | |||||
for j in range(4,5): | |||||
for k in range(500): | |||||
graphs.append(nx.barabasi_albert_graph(i,j)) | |||||
g_perturbed = perturb(graphs, 0.9) | |||||
print([g.number_of_edges() for g in graphs]) | |||||
print([g.number_of_edges() for g in g_perturbed]) | |||||
if __name__ == '__main__': | |||||
#test_perturbed() | |||||
#graphs = load_graph_list('graphs/' + 'GraphRNN_RNN_community4_4_128_train_0.dat') | |||||
#graphs = load_graph_list('graphs/' + 'GraphRNN_RNN_community4_4_128_pred_2500_1.dat') | |||||
graphs = load_graph_list('eval_results/mmsb/' + 'community41.dat') | |||||
for i in range(0, 160, 16): | |||||
draw_graph_list(graphs[i:i+16], 4, 4, fname='figures/community4_' + str(i)) | |||||