Browse Source

first commit

master
JiaxuanYou 6 years ago
commit
1ef475d957
66 changed files with 2798769 additions and 0 deletions
  1. 21
    0
      LICENSE
  2. 86
    0
      README.md
  3. 0
    0
      __init__.py
  4. 216
    0
      analysis.py
  5. 110
    0
      args.py
  6. 275
    0
      baselines/baseline_simple.py
  7. 58
    0
      baselines/graphvae/data.py
  8. 208
    0
      baselines/graphvae/model.py
  9. 132
    0
      baselines/graphvae/train.py
  10. 154
    0
      baselines/mmsb.py
  11. 155
    0
      create_graphs.py
  12. 1392
    0
      data.py
  13. 1686092
    0
      dataset/DD/DD_A.txt
  14. 334925
    0
      dataset/DD/DD_graph_indicator.txt
  15. 1178
    0
      dataset/DD/DD_graph_labels.txt
  16. 334925
    0
      dataset/DD/DD_node_labels.txt
  17. 75
    0
      dataset/DD/README.txt
  18. 74564
    0
      dataset/ENZYMES/ENZYMES_A.txt
  19. 19580
    0
      dataset/ENZYMES/ENZYMES_graph_indicator.txt
  20. 600
    0
      dataset/ENZYMES/ENZYMES_graph_labels.txt
  21. 19580
    0
      dataset/ENZYMES/ENZYMES_node_attributes.txt
  22. 19580
    0
      dataset/ENZYMES/ENZYMES_node_labels.txt
  23. 71
    0
      dataset/ENZYMES/README.txt
  24. 60
    0
      dataset/ENZYMES/load_data.py
  25. 162088
    0
      dataset/PROTEINS_full/PROTEINS_full_A.txt
  26. 43471
    0
      dataset/PROTEINS_full/PROTEINS_full_graph_indicator.txt
  27. 1113
    0
      dataset/PROTEINS_full/PROTEINS_full_graph_labels.txt
  28. 43471
    0
      dataset/PROTEINS_full/PROTEINS_full_node_attributes.txt
  29. 43471
    0
      dataset/PROTEINS_full/PROTEINS_full_node_labels.txt
  30. 61
    0
      dataset/PROTEINS_full/README.txt
  31. BIN
      dataset/ind.citeseer.allx
  32. BIN
      dataset/ind.citeseer.graph
  33. 1000
    0
      dataset/ind.citeseer.test.index
  34. BIN
      dataset/ind.citeseer.tx
  35. BIN
      dataset/ind.citeseer.x
  36. BIN
      dataset/ind.cora.allx
  37. BIN
      dataset/ind.cora.graph
  38. 1000
    0
      dataset/ind.cora.test.index
  39. BIN
      dataset/ind.cora.tx
  40. BIN
      dataset/ind.cora.x
  41. BIN
      dataset/ind.pubmed.allx
  42. BIN
      dataset/ind.pubmed.graph
  43. 1000
    0
      dataset/ind.pubmed.test.index
  44. BIN
      dataset/ind.pubmed.tx
  45. BIN
      dataset/ind.pubmed.x
  46. 0
    0
      dataset/test_load_data.py
  47. 267
    0
      environment.yml
  48. 2
    0
      eval/MANIFEST.in
  49. 0
    0
      eval/__init__.py
  50. 135
    0
      eval/mmd.py
  51. BIN
      eval/orca/orca
  52. 1532
    0
      eval/orca/orca.cpp
  53. 1488
    0
      eval/orca/orca.h
  54. 6
    0
      eval/orca/test.txt
  55. 69
    0
      eval/orcamodule.cpp
  56. 11
    0
      eval/setup.py
  57. 233
    0
      eval/stats.py
  58. 692
    0
      evaluate.py
  59. 141
    0
      main.py
  60. 594
    0
      main_DeepGMG.py
  61. 1500
    0
      model.py
  62. 50
    0
      plot.py
  63. 4
    0
      requirements.txt
  64. 55
    0
      test_MMD.py
  65. 760
    0
      train.py
  66. 518
    0
      utils.py

+ 21
- 0
LICENSE View File

@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2017 Jiaxuan You, Rex Ying

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

+ 86
- 0
README.md View File

@@ -0,0 +1,86 @@
# GraphRNN: Generating Realistic Graphs with Deep Auto-regressive Model
This repository is the official PyTorch implementation of GraphRNN, a graph generative model using auto-regressive model.

[Jiaxuan You](https://cs.stanford.edu/~jiaxuan/)\*, [Rex Ying](https://cs.stanford.edu/people/rexy/)\*, [Xiang Ren](http://www-bcf.usc.edu/~xiangren/), [William L. Hamilton](https://stanford.edu/~wleif/), [Jure Leskovec](https://cs.stanford.edu/people/jure/index.html), [GraphRNN: Generating Realistic Graphs with Deep Auto-regressive Model](https://arxiv.org/abs/1802.08773) (ICML 2018)

## Installation
Install PyTorch following the instuctions on the [official website](https://pytorch.org/). The code has been tested over PyTorch 0.2.0 and 0.4.0 versions.
```bash
conda install pytorch torchvision cuda90 -c pytorch
```
Then install the other dependencies.
```bash
pip install -r requirements.txt
```

## Test run
```bash
python main.py
```

## Code description
For the GraphRNN model:
`main.py` is the main executable file, and specific arguments are set in `args.py`.
`train.py` includes training iterations and calls `model.py` and `data.py`
`create_graphs.py` is where we prepare target graph datasets.

For baseline models:
* B-A and E-R models are implemented in `baselines/baseline_simple.py`.
* [Kronecker graph model](https://cs.stanford.edu/~jure/pubs/kronecker-jmlr10.pdf) is implemented in the SNAP software, which can be found in `https://github.com/snap-stanford/snap/tree/master/examples/krongen` (for generating Kronecker graphs), and `https://github.com/snap-stanford/snap/tree/master/examples/kronfit` (for learning parameters for the model).
* MMSB is implemented using the EDWARD library (http://edwardlib.org/), and is located in
`baselines`.
* We implemented the DeepGMG model based on the instructions of their [paper](https://arxiv.org/abs/1803.03324) in `main_DeepGMG.py`.
* We implemented the GraphVAE model based on the instructions of their [paper](https://arxiv.org/abs/1802.03480) in `baselines/graphvae`.

Parameter setting:
To adjust the hyper-parameter and input arguments to the model, modify the fields of `args.py`
accordingly.
For example, `args.cuda` controls which GPU is used to train the model, and `args.graph_type`
specifies which dataset is used to train the generative model. See the documentation in `args.py`
for more detailed descriptions of all fields.

## Outputs
There are several different types of outputs, each saved into a different directory under a path prefix. The path prefix is set at `args.dir_input`. Suppose that this field is set to `./`:
* `./graphs` contains the pickle files of training, test and generated graphs. Each contains a list
of networkx object.
* `./eval_results` contains the evaluation of MMD scores in txt format.
* `./model_save` stores the model checkpoints
* `./nll` saves the log-likelihood for generated graphs as sequences.
* `./figures` is used to save visualizations (see Visualization of graphs section).

## Evaluation
The evaluation is done in `evaluate.py`, where user can choose which settings to evaluate.
To evaluate how close the generated graphs are to the ground truth set, we use MMD (maximum mean discrepancy) to calculate the divergence between two _sets of distributions_ related to
the ground truth and generated graphs.
Three types of distributions are chosen: degree distribution, clustering coefficient distribution.
Both of which are implemented in `eval/stats.py`, using multiprocessing python
module. One can easily extend the evaluation to compute MMD for other distribution of graphs.

We also compute the orbit counts for each graph, represented as a high-dimensional data point. We then compute the MMD
between the two _sets of sampled points_ using ORCA (see http://www.biolab.si/supp/orca/orca.html) at `eval/orca`.
One first needs to compile ORCA by
```bash
g++ -O2 -std=c++11 -o orca orca.cpp`
```
in directory `eval/orca`.
(the binary file already in repo works in Ubuntu).

To evaluate, run
```bash
python evaluate.py
```
Arguments specific to evaluation is specified in class
`evaluate.Args_evaluate`. Note that the field `Args_evaluate.dataset_name_all` must only contain
datasets that are already trained, by setting args.graph_type to each of the datasets and running
`python main.py`.

## Visualization of graphs
The training, testing and generated graphs are saved at 'graphs/'.
One can visualize the generated graph using the function `utils.load_graph_list`, which loads the
list of graphs from the pickle file, and `util.draw_graph_list`, which plots the graph using
networkx.


## Misc
Jesse Bettencourt and Harris Chan have made a great [slide](https://duvenaud.github.io/learn-discrete/slides/graphrnn.pdf) introducing GraphRNN in Prof. David Duvenaudโ€™s seminar course [Learning Discrete Latent Structure](https://duvenaud.github.io/learn-discrete/).


+ 0
- 0
__init__.py View File


+ 216
- 0
analysis.py View File

@@ -0,0 +1,216 @@
# this file is used to plot images
from main import *

args = Args()
print(args.graph_type, args.note)
# epoch = 16000
epoch = 3000
sample_time = 3


def find_nearest_idx(array,value):
idx = (np.abs(array-value)).argmin()
return idx

# for baseline model
for num_layers in range(4,5):
# give file name and figure name
fname_real = args.graph_save_path + args.fname_real + str(0)
fname_pred = args.graph_save_path + args.fname_pred + str(epoch) +'_'+str(sample_time)
figname = args.figure_save_path + args.fname + str(epoch) +'_'+str(sample_time)

# fname_real = args.graph_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \
# str(epoch) + '_real_' + str(True) + '_' + str(num_layers)
# fname_pred = args.graph_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \
# str(epoch) + '_pred_' + str(True) + '_' + str(num_layers)
# figname = args.figure_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \
# str(epoch) + '_' + str(num_layers)
print(fname_real)
print(fname_pred)


# load data
graph_real_list = load_graph_list(fname_real + '.dat')
shuffle(graph_real_list)
graph_pred_list_raw = load_graph_list(fname_pred + '.dat')
graph_real_len_list = np.array([len(graph_real_list[i]) for i in range(len(graph_real_list))])
graph_pred_len_list_raw = np.array([len(graph_pred_list_raw[i]) for i in range(len(graph_pred_list_raw))])

graph_pred_list = graph_pred_list_raw
graph_pred_len_list = graph_pred_len_list_raw


# # select samples
# graph_pred_list = []
# graph_pred_len_list = []
# for value in graph_real_len_list:
# pred_idx = find_nearest_idx(graph_pred_len_list_raw, value)
# graph_pred_list.append(graph_pred_list_raw[pred_idx])
# graph_pred_len_list.append(graph_pred_len_list_raw[pred_idx])
# # delete
# graph_pred_len_list_raw=np.delete(graph_pred_len_list_raw, pred_idx)
# del graph_pred_list_raw[pred_idx]
# if len(graph_pred_list)==200:
# break
# graph_pred_len_list = np.array(graph_pred_len_list)



# # select pred data within certain range
# len_min = np.amin(graph_real_len_list)
# len_max = np.amax(graph_real_len_list)
# pred_index = np.where((graph_pred_len_list>=len_min)&(graph_pred_len_list<=len_max))
# # print(pred_index[0])
# graph_pred_list = [graph_pred_list[i] for i in pred_index[0]]
# graph_pred_len_list = graph_pred_len_list[pred_index[0]]



# real_order = np.argsort(graph_real_len_list)
# pred_order = np.argsort(graph_pred_len_list)
real_order = np.argsort(graph_real_len_list)[::-1]
pred_order = np.argsort(graph_pred_len_list)[::-1]
# print(real_order)
# print(pred_order)
graph_real_list = [graph_real_list[i] for i in real_order]
graph_pred_list = [graph_pred_list[i] for i in pred_order]

# shuffle(graph_real_list)
# shuffle(graph_pred_list)
print('real average nodes', sum([graph_real_list[i].number_of_nodes() for i in range(len(graph_real_list))])/len(graph_real_list))
print('pred average nodes', sum([graph_pred_list[i].number_of_nodes() for i in range(len(graph_pred_list))])/len(graph_pred_list))
print('num of real graphs', len(graph_real_list))
print('num of pred graphs', len(graph_pred_list))


# # draw all graphs
# for iter in range(8):
# print('iter', iter)
# graph_list = []
# for i in range(8):
# index = 8 * iter + i
# # graph_real_list[index].remove_nodes_from(list(nx.isolates(graph_real_list[index])))
# # graph_pred_list[index].remove_nodes_from(list(nx.isolates(graph_pred_list[index])))
# graph_list.append(graph_real_list[index])
# graph_list.append(graph_pred_list[index])
# print('real', graph_real_list[index].number_of_nodes())
# print('pred', graph_pred_list[index].number_of_nodes())
#
# draw_graph_list(graph_list, row=4, col=4, fname=figname + '_' + str(iter))

# draw all graphs
for iter in range(8):
print('iter', iter)
graph_list = []
for i in range(8):
index = 32 * iter + i
# graph_real_list[index].remove_nodes_from(list(nx.isolates(graph_real_list[index])))
# graph_pred_list[index].remove_nodes_from(list(nx.isolates(graph_pred_list[index])))
# graph_list.append(graph_real_list[index])
graph_list.append(graph_pred_list[index])
# print('real', graph_real_list[index].number_of_nodes())
print('pred', graph_pred_list[index].number_of_nodes())

draw_graph_list(graph_list, row=4, col=4, fname=figname + '_' + str(iter)+'_pred')

# draw all graphs
for iter in range(8):
print('iter', iter)
graph_list = []
for i in range(8):
index = 16 * iter + i
# graph_real_list[index].remove_nodes_from(list(nx.isolates(graph_real_list[index])))
# graph_pred_list[index].remove_nodes_from(list(nx.isolates(graph_pred_list[index])))
graph_list.append(graph_real_list[index])
# graph_list.append(graph_pred_list[index])
print('real', graph_real_list[index].number_of_nodes())
# print('pred', graph_pred_list[index].number_of_nodes())

draw_graph_list(graph_list, row=4, col=4, fname=figname + '_' + str(iter)+'_real')

#
# # for new model
# elif args.note == 'GraphRNN_structure' and args.is_flexible==False:
# for num_layers in range(4,5):
# # give file name and figure name
# # fname_real = args.graph_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \
# # str(epoch) + '_real_bptt_' + str(args.bptt)+'_'+str(num_layers)+'_dilation_'+str(args.is_dilation)+'_flexible_'+str(args.is_flexible)+'_bn_'+str(args.is_bn)+'_lr_'+str(args.lr)
# # fname_pred = args.graph_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \
# # str(epoch) + '_pred_bptt_' + str(args.bptt)+'_'+str(num_layers)+'_dilation_'+str(args.is_dilation)+'_flexible_'+str(args.is_flexible)+'_bn_'+str(args.is_bn)+'_lr_'+str(args.lr)
#
# fname_pred = args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
# str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt)+ '_' + str(args.bptt_len) + '_' + str(args.hidden_size)
# fname_real = args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
# str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt)+ '_' + str(args.bptt_len) + '_' + str(args.hidden_size)
# figname = args.figure_save_path + args.note + '_' + args.graph_type + '_' + \
# str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt)+ '_' + str(args.bptt_len) + '_' + str(args.hidden_size)
# print(fname_real)
# # load data
# graph_real_list = load_graph_list(fname_real+'.dat')
# graph_pred_list = load_graph_list(fname_pred+'.dat')
#
# graph_real_len_list = np.array([len(graph_real_list[i]) for i in range(len(graph_real_list))])
# graph_pred_len_list = np.array([len(graph_pred_list[i]) for i in range(len(graph_pred_list))])
# real_order = np.argsort(graph_real_len_list)[::-1]
# pred_order = np.argsort(graph_pred_len_list)[::-1]
# # print(real_order)
# # print(pred_order)
# graph_real_list = [graph_real_list[i] for i in real_order]
# graph_pred_list = [graph_pred_list[i] for i in pred_order]
#
# shuffle(graph_pred_list)
#
#
# print('real average nodes',
# sum([graph_real_list[i].number_of_nodes() for i in range(len(graph_real_list))]) / len(graph_real_list))
# print('pred average nodes',
# sum([graph_pred_list[i].number_of_nodes() for i in range(len(graph_pred_list))]) / len(graph_pred_list))
# print('num of graphs', len(graph_real_list))
#
# # draw all graphs
# for iter in range(2):
# print('iter', iter)
# graph_list = []
# for i in range(8):
# index = 8*iter + i
# graph_real_list[index].remove_nodes_from(nx.isolates(graph_real_list[index]))
# graph_pred_list[index].remove_nodes_from(nx.isolates(graph_pred_list[index]))
# graph_list.append(graph_real_list[index])
# graph_list.append(graph_pred_list[index])
# print('real', graph_real_list[index].number_of_nodes())
# print('pred', graph_pred_list[index].number_of_nodes())
# draw_graph_list(graph_list, row=4, col=4, fname=figname+'_'+str(iter))
#
#
# # for new model
# elif args.note == 'GraphRNN_structure' and args.is_flexible==True:
# for num_layers in range(4,5):
# graph_real_list = []
# graph_pred_list = []
# epoch_end = 30000
# for epoch in [epoch_end-500*(8-i) for i in range(8)]:
# # give file name and figure name
# fname_real = args.graph_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \
# str(epoch) + '_real_bptt_' + str(args.bptt)+'_'+str(num_layers)+'_dilation_'+str(args.is_dilation)+'_flexible_'+str(args.is_flexible)+'_bn_'+str(args.is_bn)+'_lr_'+str(args.lr)
# fname_pred = args.graph_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \
# str(epoch) + '_pred_bptt_' + str(args.bptt)+'_'+str(num_layers)+'_dilation_'+str(args.is_dilation)+'_flexible_'+str(args.is_flexible)+'_bn_'+str(args.is_bn)+'_lr_'+str(args.lr)
#
# # load data
# graph_real_list += load_graph_list(fname_real+'.dat')
# graph_pred_list += load_graph_list(fname_pred+'.dat')
# print('num of graphs', len(graph_real_list))
#
# figname = args.figure_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \
# str(epoch) + str(args.sample_when_validate) + '_' + str(num_layers) + '_dilation_' + str(args.is_dilation) + '_flexible_' + str(args.is_flexible) + '_bn_' + str(args.is_bn) + '_lr_' + str(args.lr)
#
# # draw all graphs
# for iter in range(1):
# print('iter', iter)
# graph_list = []
# for i in range(8):
# index = 8*iter + i
# graph_real_list[index].remove_nodes_from(nx.isolates(graph_real_list[index]))
# graph_pred_list[index].remove_nodes_from(nx.isolates(graph_pred_list[index]))
# graph_list.append(graph_real_list[index])
# graph_list.append(graph_pred_list[index])
# draw_graph_list(graph_list, row=4, col=4, fname=figname+'_'+str(iter))

+ 110
- 0
args.py View File

@@ -0,0 +1,110 @@

### program configuration
class Args():
def __init__(self):
### if clean tensorboard
self.clean_tensorboard = False
### Which CUDA GPU device is used for training
self.cuda = 1

### Which GraphRNN model variant is used.
# The simple version of Graph RNN
# self.note = 'GraphRNN_MLP'
# The dependent Bernoulli sequence version of GraphRNN
self.note = 'GraphRNN_RNN'

## for comparison, removing the BFS compoenent
# self.note = 'GraphRNN_MLP_nobfs'
# self.note = 'GraphRNN_RNN_nobfs'

### Which dataset is used to train the model
# self.graph_type = 'DD'
# self.graph_type = 'caveman'
# self.graph_type = 'caveman_small'
# self.graph_type = 'caveman_small_single'
# self.graph_type = 'community4'
self.graph_type = 'grid'
# self.graph_type = 'grid_small'
# self.graph_type = 'ladder_small'

# self.graph_type = 'enzymes'
# self.graph_type = 'enzymes_small'
# self.graph_type = 'barabasi'
# self.graph_type = 'barabasi_small'
# self.graph_type = 'citeseer'
# self.graph_type = 'citeseer_small'

# self.graph_type = 'barabasi_noise'
# self.noise = 10
#
# if self.graph_type == 'barabasi_noise':
# self.graph_type = self.graph_type+str(self.noise)

# if none, then auto calculate
self.max_num_node = None # max number of nodes in a graph
self.max_prev_node = None # max previous node that looks back

### network config
## GraphRNN
if 'small' in self.graph_type:
self.parameter_shrink = 2
else:
self.parameter_shrink = 1
self.hidden_size_rnn = int(128/self.parameter_shrink) # hidden size for main RNN
self.hidden_size_rnn_output = 16 # hidden size for output RNN
self.embedding_size_rnn = int(64/self.parameter_shrink) # the size for LSTM input
self.embedding_size_rnn_output = 8 # the embedding size for output rnn
self.embedding_size_output = int(64/self.parameter_shrink) # the embedding size for output (VAE/MLP)

self.batch_size = 32 # normal: 32, and the rest should be changed accordingly
self.test_batch_size = 32
self.test_total_size = 1000
self.num_layers = 4

### training config
self.num_workers = 4 # num workers to load data, default 4
self.batch_ratio = 32 # how many batches of samples per epoch, default 32, e.g., 1 epoch = 32 batches
self.epochs = 3000 # now one epoch means self.batch_ratio x batch_size
self.epochs_test_start = 100
self.epochs_test = 100
self.epochs_log = 100
self.epochs_save = 100

self.lr = 0.003
self.milestones = [400, 1000]
self.lr_rate = 0.3

self.sample_time = 2 # sample time in each time step, when validating

### output config
# self.dir_input = "/dfs/scratch0/jiaxuany0/"
self.dir_input = "./"
self.model_save_path = self.dir_input+'model_save/' # only for nll evaluation
self.graph_save_path = self.dir_input+'graphs/'
self.figure_save_path = self.dir_input+'figures/'
self.timing_save_path = self.dir_input+'timing/'
self.figure_prediction_save_path = self.dir_input+'figures_prediction/'
self.nll_save_path = self.dir_input+'nll/'


self.load = False # if load model, default lr is very low
self.load_epoch = 3000
self.save = True


### baseline config
# self.generator_baseline = 'Gnp'
self.generator_baseline = 'BA'

# self.metric_baseline = 'general'
# self.metric_baseline = 'degree'
self.metric_baseline = 'clustering'


### filenames to save intemediate and final outputs
self.fname = self.note + '_' + self.graph_type + '_' + str(self.num_layers) + '_' + str(self.hidden_size_rnn) + '_'
self.fname_pred = self.note+'_'+self.graph_type+'_'+str(self.num_layers)+'_'+ str(self.hidden_size_rnn)+'_pred_'
self.fname_train = self.note+'_'+self.graph_type+'_'+str(self.num_layers)+'_'+ str(self.hidden_size_rnn)+'_train_'
self.fname_test = self.note + '_' + self.graph_type + '_' + str(self.num_layers) + '_' + str(self.hidden_size_rnn) + '_test_'
self.fname_baseline = self.graph_save_path + self.graph_type + self.generator_baseline+'_'+self.metric_baseline


+ 275
- 0
baselines/baseline_simple.py View File

@@ -0,0 +1,275 @@
from main import *
from scipy.linalg import toeplitz
import pyemd
import scipy.optimize as opt

def Graph_generator_baseline_train_rulebased(graphs,generator='BA'):
graph_nodes = [graphs[i].number_of_nodes() for i in range(len(graphs))]
graph_edges = [graphs[i].number_of_edges() for i in range(len(graphs))]
parameter = {}
for i in range(len(graph_nodes)):
nodes = graph_nodes[i]
edges = graph_edges[i]
# based on rule, calculate optimal parameter
if generator=='BA':
# BA optimal: nodes = n; edges = (n-m)*m
n = nodes
m = (n - np.sqrt(n**2-4*edges))/2
parameter_temp = [n,m,1]
if generator=='Gnp':
# Gnp optimal: nodes = n; edges = ((n-1)*n/2)*p
n = nodes
p = float(edges)/((n-1)*n/2)
parameter_temp = [n,p,1]
# update parameter list
if nodes not in parameter.keys():
parameter[nodes] = parameter_temp
else:
count = parameter[nodes][-1]
parameter[nodes] = [(parameter[nodes][i]*count+parameter_temp[i])/(count+1) for i in range(len(parameter[nodes]))]
parameter[nodes][-1] = count+1
# print(parameter)
return parameter

def Graph_generator_baseline(graph_train, pred_num=1000, generator='BA'):
graph_nodes = [graph_train[i].number_of_nodes() for i in range(len(graph_train))]
graph_edges = [graph_train[i].number_of_edges() for i in range(len(graph_train))]
repeat = pred_num//len(graph_train)
graph_pred = []
for i in range(len(graph_nodes)):
nodes = graph_nodes[i]
edges = graph_edges[i]
# based on rule, calculate optimal parameter
if generator=='BA':
# BA optimal: nodes = n; edges = (n-m)*m
n = nodes
m = int((n - np.sqrt(n**2-4*edges))/2)
for j in range(repeat):
graph_pred.append(nx.barabasi_albert_graph(n,m))
if generator=='Gnp':
# Gnp optimal: nodes = n; edges = ((n-1)*n/2)*p
n = nodes
p = float(edges)/((n-1)*n/2)
for j in range(repeat):
graph_pred.append(nx.fast_gnp_random_graph(n, p))
return graph_pred

def emd_distance(x, y, distance_scaling=1.0):
support_size = max(len(x), len(y))
d_mat = toeplitz(range(support_size)).astype(np.float)
distance_mat = d_mat / distance_scaling

# convert histogram values x and y to float, and make them equal len
x = x.astype(np.float)
y = y.astype(np.float)
if len(x) < len(y):
x = np.hstack((x, [0.0] * (support_size - len(x))))
elif len(y) < len(x):
y = np.hstack((y, [0.0] * (support_size - len(y))))

emd = pyemd.emd(x, y, distance_mat)
return emd

# def Loss(x,args):
# '''
#
# :param x: 1-D array, parameters to be optimized
# :param args: tuple (n, G, generator, metric).
# n: n for pred graph;
# G: real graph in networkx format;
# generator: 'BA', 'Gnp', 'Powerlaw';
# metric: 'degree', 'clustering'
# :return: Loss: emd distance
# '''
# # get argument
# generator = args[2]
# metric = args[3]
#
# # get real and pred graphs
# G_real = args[1]
# if generator=='BA':
# G_pred = nx.barabasi_albert_graph(args[0],int(np.rint(x)))
# if generator=='Gnp':
# G_pred = nx.fast_gnp_random_graph(args[0],x)
#
# # define metric
# if metric == 'degree':
# G_real_hist = np.array(nx.degree_histogram(G_real))
# G_real_hist = G_real_hist / np.sum(G_real_hist)
# G_pred_hist = np.array(nx.degree_histogram(G_pred))
# G_pred_hist = G_pred_hist/np.sum(G_pred_hist)
# if metric == 'clustering':
# G_real_hist, _ = np.histogram(
# np.array(list(nx.clustering(G_real).values())), bins=50, range=(0.0, 1.0), density=False)
# G_real_hist = G_real_hist / np.sum(G_real_hist)
# G_pred_hist, _ = np.histogram(
# np.array(list(nx.clustering(G_pred).values())), bins=50, range=(0.0, 1.0), density=False)
# G_pred_hist = G_pred_hist / np.sum(G_pred_hist)
#
# loss = emd_distance(G_real_hist,G_pred_hist)
# return loss

def Loss(x,n,G_real,generator,metric):
'''

:param x: 1-D array, parameters to be optimized
:param
n: n for pred graph;
G: real graph in networkx format;
generator: 'BA', 'Gnp', 'Powerlaw';
metric: 'degree', 'clustering'
:return: Loss: emd distance
'''
# get argument

# get real and pred graphs
if generator=='BA':
G_pred = nx.barabasi_albert_graph(n,int(np.rint(x)))
if generator=='Gnp':
G_pred = nx.fast_gnp_random_graph(n,x)

# define metric
if metric == 'degree':
G_real_hist = np.array(nx.degree_histogram(G_real))
G_real_hist = G_real_hist / np.sum(G_real_hist)
G_pred_hist = np.array(nx.degree_histogram(G_pred))
G_pred_hist = G_pred_hist/np.sum(G_pred_hist)
if metric == 'clustering':
G_real_hist, _ = np.histogram(
np.array(list(nx.clustering(G_real).values())), bins=50, range=(0.0, 1.0), density=False)
G_real_hist = G_real_hist / np.sum(G_real_hist)
G_pred_hist, _ = np.histogram(
np.array(list(nx.clustering(G_pred).values())), bins=50, range=(0.0, 1.0), density=False)
G_pred_hist = G_pred_hist / np.sum(G_pred_hist)

loss = emd_distance(G_real_hist,G_pred_hist)
return loss

def optimizer_brute(x_min, x_max, x_step, n, G_real, generator, metric):
loss_all = []
x_list = np.arange(x_min,x_max,x_step)
for x_test in x_list:
loss_all.append(Loss(x_test,n,G_real,generator,metric))
x_optim = x_list[np.argmin(np.array(loss_all))]
return x_optim

def Graph_generator_baseline_train_optimizationbased(graphs,generator='BA',metric='degree'):
graph_nodes = [graphs[i].number_of_nodes() for i in range(len(graphs))]
parameter = {}
for i in range(len(graph_nodes)):
print('graph ',i)
nodes = graph_nodes[i]
if generator=='BA':
n = nodes
m = optimizer_brute(1,10,1, nodes, graphs[i], generator, metric)
parameter_temp = [n,m,1]
elif generator=='Gnp':
n = nodes
p = optimizer_brute(1e-6,1,0.01, nodes, graphs[i], generator, metric)
## if use evolution
# result = opt.differential_evolution(Loss,bounds=[(0,1)],args=(nodes, graphs[i], generator, metric),maxiter=1000)
# p = result.x
parameter_temp = [n, p, 1]

# update parameter list
if nodes not in parameter.keys():
parameter[nodes] = parameter_temp
else:
count = parameter[nodes][2]
parameter[nodes] = [(parameter[nodes][i]*count+parameter_temp[i])/(count+1) for i in range(len(parameter[nodes]))]
parameter[nodes][2] = count+1
print(parameter)
return parameter



def Graph_generator_baseline_test(graph_nodes, parameter, generator='BA'):
graphs = []
for i in range(len(graph_nodes)):
nodes = graph_nodes[i]
if not nodes in parameter.keys():
nodes = min(parameter.keys(), key=lambda k: abs(k - nodes))
if generator=='BA':
n = int(parameter[nodes][0])
m = int(np.rint(parameter[nodes][1]))
print(n,m)
graph = nx.barabasi_albert_graph(n,m)
if generator=='Gnp':
n = int(parameter[nodes][0])
p = parameter[nodes][1]
print(n,p)
graph = nx.fast_gnp_random_graph(n,p)
graphs.append(graph)
return graphs


if __name__ == '__main__':
args = Args()

print('File name prefix', args.fname)
### load datasets
graphs = []
# synthetic graphs
if args.graph_type=='ladder':
graphs = []
for i in range(100, 201):
graphs.append(nx.ladder_graph(i))
args.max_prev_node = 10
if args.graph_type=='tree':
graphs = []
for i in range(2,5):
for j in range(3,5):
graphs.append(nx.balanced_tree(i,j))
args.max_prev_node = 256
if args.graph_type=='caveman':
graphs = []
for i in range(5,10):
for j in range(5,25):
graphs.append(nx.connected_caveman_graph(i, j))
args.max_prev_node = 50
if args.graph_type=='grid':
graphs = []
for i in range(10,20):
for j in range(10,20):
graphs.append(nx.grid_2d_graph(i,j))
args.max_prev_node = 40
if args.graph_type=='barabasi':
graphs = []
for i in range(100,200):
graphs.append(nx.barabasi_albert_graph(i,2))
args.max_prev_node = 130
# real graphs
if args.graph_type == 'enzymes':
graphs= Graph_load_batch(min_num_nodes=10, name='ENZYMES')
args.max_prev_node = 25
if args.graph_type == 'protein':
graphs = Graph_load_batch(min_num_nodes=20, name='PROTEINS_full')
args.max_prev_node = 80
if args.graph_type == 'DD':
graphs = Graph_load_batch(min_num_nodes=100, max_num_nodes=500, name='DD',node_attributes=False,graph_labels=True)
args.max_prev_node = 230


graph_nodes = [graphs[i].number_of_nodes() for i in range(len(graphs))]
graph_edges = [graphs[i].number_of_edges() for i in range(len(graphs))]

args.max_num_node = max(graph_nodes)

# show graphs statistics
print('total graph num: {}'.format(len(graphs)))
print('max number node: {}'.format(args.max_num_node))
print('max previous node: {}'.format(args.max_prev_node))

# start baseline generation method

generator = args.generator_baseline
metric = args.metric_baseline
print(args.fname_baseline + '.dat')

if metric=='general':
parameter = Graph_generator_baseline_train_rulebased(graphs,generator=generator)
else:
parameter = Graph_generator_baseline_train_optimizationbased(graphs,generator=generator,metric=metric)
graphs_generated = Graph_generator_baseline_test(graph_nodes, parameter,generator)

save_graph_list(graphs_generated,args.fname_baseline + '.dat')

+ 58
- 0
baselines/graphvae/data.py View File

@@ -0,0 +1,58 @@
import networkx as nx
import numpy as np
import torch

class GraphAdjSampler(torch.utils.data.Dataset):
def __init__(self, G_list, max_num_nodes, features='id'):
self.max_num_nodes = max_num_nodes
self.adj_all = []
self.len_all = []
self.feature_all = []

for G in G_list:
adj = nx.to_numpy_matrix(G)
# the diagonal entries are 1 since they denote node probability
self.adj_all.append(
np.asarray(adj) + np.identity(G.number_of_nodes()))
self.len_all.append(G.number_of_nodes())
if features == 'id':
self.feature_all.append(np.identity(max_num_nodes))
elif features == 'deg':
degs = np.sum(np.array(adj), 1)
degs = np.expand_dims(np.pad(degs, [0, max_num_nodes - G.number_of_nodes()], 0),
axis=1)
self.feature_all.append(degs)
elif features == 'struct':
degs = np.sum(np.array(adj), 1)
degs = np.expand_dims(np.pad(degs, [0, max_num_nodes - G.number_of_nodes()],
'constant'),
axis=1)
clusterings = np.array(list(nx.clustering(G).values()))
clusterings = np.expand_dims(np.pad(clusterings,
[0, max_num_nodes - G.number_of_nodes()],
'constant'),
axis=1)
self.feature_all.append(np.hstack([degs, clusterings]))

def __len__(self):
return len(self.adj_all)

def __getitem__(self, idx):
adj = self.adj_all[idx]
num_nodes = adj.shape[0]
adj_padded = np.zeros((self.max_num_nodes, self.max_num_nodes))
adj_padded[:num_nodes, :num_nodes] = adj

adj_decoded = np.zeros(self.max_num_nodes * (self.max_num_nodes + 1) // 2)
node_idx = 0
adj_vectorized = adj_padded[np.triu(np.ones((self.max_num_nodes,self.max_num_nodes)) ) == 1]
# the following 2 lines recover the upper triangle of the adj matrix
#recovered = np.zeros((self.max_num_nodes, self.max_num_nodes))
#recovered[np.triu(np.ones((self.max_num_nodes, self.max_num_nodes)) ) == 1] = adj_vectorized
#print(recovered)
return {'adj':adj_padded,
'adj_decoded':adj_vectorized,
'features':self.feature_all[idx].copy()}


+ 208
- 0
baselines/graphvae/model.py View File

@@ -0,0 +1,208 @@

import numpy as np
import scipy.optimize

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
import torch.nn.init as init

import model


class GraphVAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim, max_num_nodes, pool='sum'):
'''
Args:
input_dim: input feature dimension for node.
hidden_dim: hidden dim for 2-layer gcn.
latent_dim: dimension of the latent representation of graph.
'''
super(GraphVAE, self).__init__()
self.conv1 = model.GraphConv(input_dim=input_dim, output_dim=hidden_dim)
self.bn1 = nn.BatchNorm1d(input_dim)
self.conv2 = model.GraphConv(input_dim=hidden_dim, output_dim=hidden_dim)
self.bn2 = nn.BatchNorm1d(input_dim)
self.act = nn.ReLU()

output_dim = max_num_nodes * (max_num_nodes + 1) // 2
#self.vae = model.MLP_VAE_plain(hidden_dim, latent_dim, output_dim)
self.vae = model.MLP_VAE_plain(input_dim * input_dim, latent_dim, output_dim)
#self.feature_mlp = model.MLP_plain(latent_dim, latent_dim, output_dim)

self.max_num_nodes = max_num_nodes
for m in self.modules():
if isinstance(m, model.GraphConv):
m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
elif isinstance(m, nn.BatchNorm1d):
m.weight.data.fill_(1)
m.bias.data.zero_()

self.pool = pool

def recover_adj_lower(self, l):
# NOTE: Assumes 1 per minibatch
adj = torch.zeros(self.max_num_nodes, self.max_num_nodes)
adj[torch.triu(torch.ones(self.max_num_nodes, self.max_num_nodes)) == 1] = l
return adj

def recover_full_adj_from_lower(self, lower):
diag = torch.diag(torch.diag(lower, 0))
return lower + torch.transpose(lower, 0, 1) - diag

def edge_similarity_matrix(self, adj, adj_recon, matching_features,
matching_features_recon, sim_func):
S = torch.zeros(self.max_num_nodes, self.max_num_nodes,
self.max_num_nodes, self.max_num_nodes)
for i in range(self.max_num_nodes):
for j in range(self.max_num_nodes):
if i == j:
for a in range(self.max_num_nodes):
S[i, i, a, a] = adj[i, i] * adj_recon[a, a] * \
sim_func(matching_features[i], matching_features_recon[a])
# with feature not implemented
# if input_features is not None:
else:
for a in range(self.max_num_nodes):
for b in range(self.max_num_nodes):
if b == a:
continue
S[i, j, a, b] = adj[i, j] * adj[i, i] * adj[j, j] * \
adj_recon[a, b] * adj_recon[a, a] * adj_recon[b, b]
return S

def mpm(self, x_init, S, max_iters=50):
x = x_init
for it in range(max_iters):
x_new = torch.zeros(self.max_num_nodes, self.max_num_nodes)
for i in range(self.max_num_nodes):
for a in range(self.max_num_nodes):
x_new[i, a] = x[i, a] * S[i, i, a, a]
pooled = [torch.max(x[j, :] * S[i, j, a, :])
for j in range(self.max_num_nodes) if j != i]
neigh_sim = sum(pooled)
x_new[i, a] += neigh_sim
norm = torch.norm(x_new)
x = x_new / norm
return x

def deg_feature_similarity(self, f1, f2):
return 1 / (abs(f1 - f2) + 1)

def permute_adj(self, adj, curr_ind, target_ind):
''' Permute adjacency matrix.
The target_ind (connectivity) should be permuted to the curr_ind position.
'''
# order curr_ind according to target ind
ind = np.zeros(self.max_num_nodes, dtype=np.int)
ind[target_ind] = curr_ind
adj_permuted = torch.zeros((self.max_num_nodes, self.max_num_nodes))
adj_permuted[:, :] = adj[ind, :]
adj_permuted[:, :] = adj_permuted[:, ind]
return adj_permuted

def pool_graph(self, x):
if self.pool == 'max':
out, _ = torch.max(x, dim=1, keepdim=False)
elif self.pool == 'sum':
out = torch.sum(x, dim=1, keepdim=False)
return out

def forward(self, input_features, adj):
#x = self.conv1(input_features, adj)
#x = self.bn1(x)
#x = self.act(x)
#x = self.conv2(x, adj)
#x = self.bn2(x)

# pool over all nodes
#graph_h = self.pool_graph(x)
graph_h = input_features.view(-1, self.max_num_nodes * self.max_num_nodes)
# vae
h_decode, z_mu, z_lsgms = self.vae(graph_h)
out = F.sigmoid(h_decode)
out_tensor = out.cpu().data
recon_adj_lower = self.recover_adj_lower(out_tensor)
recon_adj_tensor = self.recover_full_adj_from_lower(recon_adj_lower)

# set matching features be degree
out_features = torch.sum(recon_adj_tensor, 1)

adj_data = adj.cpu().data[0]
adj_features = torch.sum(adj_data, 1)

S = self.edge_similarity_matrix(adj_data, recon_adj_tensor, adj_features, out_features,
self.deg_feature_similarity)

# initialization strategies
init_corr = 1 / self.max_num_nodes
init_assignment = torch.ones(self.max_num_nodes, self.max_num_nodes) * init_corr
#init_assignment = torch.FloatTensor(4, 4)
#init.uniform(init_assignment)
assignment = self.mpm(init_assignment, S)
#print('Assignment: ', assignment)

# matching
# use negative of the assignment score since the alg finds min cost flow
row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy())
print('row: ', row_ind)
print('col: ', col_ind)
# order row index according to col index
#adj_permuted = self.permute_adj(adj_data, row_ind, col_ind)
adj_permuted = adj_data
adj_vectorized = adj_permuted[torch.triu(torch.ones(self.max_num_nodes,self.max_num_nodes) )== 1].squeeze_()
adj_vectorized_var = Variable(adj_vectorized).cuda()

#print(adj)
#print('permuted: ', adj_permuted)
#print('recon: ', recon_adj_tensor)
adj_recon_loss = self.adj_recon_loss(adj_vectorized_var, out[0])
print('recon: ', adj_recon_loss)
print(adj_vectorized_var)
print(out[0])

loss_kl = -0.5 * torch.sum(1 + z_lsgms - z_mu.pow(2) - z_lsgms.exp())
loss_kl /= self.max_num_nodes * self.max_num_nodes # normalize
print('kl: ', loss_kl)

loss = adj_recon_loss + loss_kl

return loss

def forward_test(self, input_features, adj):
self.max_num_nodes = 4
adj_data = torch.zeros(self.max_num_nodes, self.max_num_nodes)
adj_data[:4, :4] = torch.FloatTensor([[1,1,0,0], [1,1,1,0], [0,1,1,1], [0,0,1,1]])
adj_features = torch.Tensor([2,3,3,2])

adj_data1 = torch.zeros(self.max_num_nodes, self.max_num_nodes)
adj_data1 = torch.FloatTensor([[1,1,1,0], [1,1,0,1], [1,0,1,0], [0,1,0,1]])
adj_features1 = torch.Tensor([3,3,2,2])
S = self.edge_similarity_matrix(adj_data, adj_data1, adj_features, adj_features1,
self.deg_feature_similarity)

# initialization strategies
init_corr = 1 / self.max_num_nodes
init_assignment = torch.ones(self.max_num_nodes, self.max_num_nodes) * init_corr
#init_assignment = torch.FloatTensor(4, 4)
#init.uniform(init_assignment)
assignment = self.mpm(init_assignment, S)
#print('Assignment: ', assignment)

# matching
row_ind, col_ind = scipy.optimize.linear_sum_assignment(-assignment.numpy())
print('row: ', row_ind)
print('col: ', col_ind)

permuted_adj = self.permute_adj(adj_data, row_ind, col_ind)
print('permuted: ', permuted_adj)

adj_recon_loss = self.adj_recon_loss(permuted_adj, adj_data1)
print(adj_data1)
print('diff: ', adj_recon_loss)

def adj_recon_loss(self, adj_truth, adj_pred):
return F.binary_cross_entropy(adj_truth, adj_pred)


+ 132
- 0
baselines/graphvae/train.py View File

@@ -0,0 +1,132 @@

import argparse
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import os
from random import shuffle
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
import torch.nn.functional as F
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR

import data
from baselines.graphvae.model import GraphVAE
from baselines.graphvae.data import GraphAdjSampler

CUDA = 2

LR_milestones = [500, 1000]

def build_model(args, max_num_nodes):
out_dim = max_num_nodes * (max_num_nodes + 1) // 2
if args.feature_type == 'id':
input_dim = max_num_nodes
elif args.feature_type == 'deg':
input_dim = 1
elif args.feature_type == 'struct':
input_dim = 2
model = GraphVAE(input_dim, 64, 256, max_num_nodes)
return model

def train(args, dataloader, model):
epoch = 1
optimizer = optim.Adam(list(model.parameters()), lr=args.lr)
scheduler = MultiStepLR(optimizer, milestones=LR_milestones, gamma=args.lr)

model.train()
for epoch in range(5000):
for batch_idx, data in enumerate(dataloader):
model.zero_grad()
features = data['features'].float()
adj_input = data['adj'].float()

features = Variable(features).cuda()
adj_input = Variable(adj_input).cuda()
loss = model(features, adj_input)
print('Epoch: ', epoch, ', Iter: ', batch_idx, ', Loss: ', loss)
loss.backward()

optimizer.step()
scheduler.step()
break

def arg_parse():
parser = argparse.ArgumentParser(description='GraphVAE arguments.')
io_parser = parser.add_mutually_exclusive_group(required=False)
io_parser.add_argument('--dataset', dest='dataset',
help='Input dataset.')

parser.add_argument('--lr', dest='lr', type=float,
help='Learning rate.')
parser.add_argument('--batch_size', dest='batch_size', type=int,
help='Batch size.')
parser.add_argument('--num_workers', dest='num_workers', type=int,
help='Number of workers to load data.')
parser.add_argument('--max_num_nodes', dest='max_num_nodes', type=int,
help='Predefined maximum number of nodes in train/test graphs. -1 if determined by \
training data.')
parser.add_argument('--feature', dest='feature_type',
help='Feature used for encoder. Can be: id, deg')

parser.set_defaults(dataset='grid',
feature_type='id',
lr=0.001,
batch_size=1,
num_workers=1,
max_num_nodes=-1)
return parser.parse_args()

def main():
prog_args = arg_parse()

os.environ['CUDA_VISIBLE_DEVICES'] = str(CUDA)
print('CUDA', CUDA)
### running log

if prog_args.dataset == 'enzymes':
graphs= data.Graph_load_batch(min_num_nodes=10, name='ENZYMES')
num_graphs_raw = len(graphs)
elif prog_args.dataset == 'grid':
graphs = []
for i in range(2,3):
for j in range(2,3):
graphs.append(nx.grid_2d_graph(i,j))
num_graphs_raw = len(graphs)

if prog_args.max_num_nodes == -1:
max_num_nodes = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
else:
max_num_nodes = prog_args.max_num_nodes
# remove graphs with number of nodes greater than max_num_nodes
graphs = [g for g in graphs if g.number_of_nodes() <= max_num_nodes]

graphs_len = len(graphs)
print('Number of graphs removed due to upper-limit of number of nodes: ',
num_graphs_raw - graphs_len)
graphs_test = graphs[int(0.8 * graphs_len):]
#graphs_train = graphs[0:int(0.8*graphs_len)]
graphs_train = graphs

print('total graph num: {}, training set: {}'.format(len(graphs),len(graphs_train)))
print('max number node: {}'.format(max_num_nodes))

dataset = GraphAdjSampler(graphs_train, max_num_nodes, features=prog_args.feature_type)
#sample_strategy = torch.utils.data.sampler.WeightedRandomSampler(
# [1.0 / len(dataset) for i in range(len(dataset))],
# num_samples=prog_args.batch_size,
# replacement=False)
dataset_loader = torch.utils.data.DataLoader(
dataset,
batch_size=prog_args.batch_size,
num_workers=prog_args.num_workers)
model = build_model(prog_args, max_num_nodes).cuda()
train(prog_args, dataset_loader, model)


if __name__ == '__main__':
main()

+ 154
- 0
baselines/mmsb.py View File

@@ -0,0 +1,154 @@
"""Stochastic block model."""

import argparse
import os
from time import time

import edward as ed
import networkx as nx
import numpy as np
import tensorflow as tf

from edward.models import Bernoulli, Multinomial, Beta, Dirichlet, PointMass, Normal
from observations import karate
from sklearn.metrics.cluster import adjusted_rand_score

import utils

CUDA = 2
ed.set_seed(int(time()))
#ed.set_seed(42)

# DATA
#X_data, Z_true = karate("data")

def disjoint_cliques_test_graph(num_cliques, clique_size):
G = nx.disjoint_union_all([nx.complete_graph(clique_size) for _ in range(num_cliques)])
return nx.to_numpy_matrix(G)

def mmsb(N, K, data):
# sparsity
rho = 0.3
# MODEL
# probability of belonging to each of K blocks for each node
gamma = Dirichlet(concentration=tf.ones([K]))
# block connectivity
Pi = Beta(concentration0=tf.ones([K, K]), concentration1=tf.ones([K, K]))
# probability of belonging to each of K blocks for all nodes
Z = Multinomial(total_count=1.0, probs=gamma, sample_shape=N)
# adjacency
X = Bernoulli(probs=(1 - rho) * tf.matmul(Z, tf.matmul(Pi, tf.transpose(Z))))
# INFERENCE (EM algorithm)
qgamma = PointMass(params=tf.nn.softmax(tf.Variable(tf.random_normal([K]))))
qPi = PointMass(params=tf.nn.sigmoid(tf.Variable(tf.random_normal([K, K]))))
qZ = PointMass(params=tf.nn.softmax(tf.Variable(tf.random_normal([N, K]))))
#qgamma = Normal(loc=tf.get_variable("qgamma/loc", [K]),
# scale=tf.nn.softplus(
# tf.get_variable("qgamma/scale", [K])))
#qPi = Normal(loc=tf.get_variable("qPi/loc", [K, K]),
# scale=tf.nn.softplus(
# tf.get_variable("qPi/scale", [K, K])))
#qZ = Normal(loc=tf.get_variable("qZ/loc", [N, K]),
# scale=tf.nn.softplus(
# tf.get_variable("qZ/scale", [N, K])))
#inference = ed.KLqp({gamma: qgamma, Pi: qPi, Z: qZ}, data={X: data})
inference = ed.MAP({gamma: qgamma, Pi: qPi, Z: qZ}, data={X: data})
#inference.run()
n_iter = 6000
inference.initialize(optimizer=tf.train.AdamOptimizer(learning_rate=0.01), n_iter=n_iter)
tf.global_variables_initializer().run()
for _ in range(inference.n_iter):
info_dict = inference.update()
inference.print_progress(info_dict)
inference.finalize()
print('qgamma after: ', qgamma.mean().eval())
return qZ.mean().eval(), qPi.eval()

def arg_parse():
parser = argparse.ArgumentParser(description='MMSB arguments.')
parser.add_argument('--dataset', dest='dataset',
help='Input dataset.')
parser.add_argument('--K', dest='K', type=int,
help='Number of blocks.')
parser.add_argument('--samples-per-G', dest='samples', type=int,
help='Number of samples for every graph.')

parser.set_defaults(dataset='community',
K=4,
samples=1)
return parser.parse_args()

def graph_gen_from_blockmodel(B, Z):
n_blocks = len(B)
B = np.array(B)
Z = np.array(Z)
adj_prob = np.dot(Z, np.dot(B, np.transpose(Z)))
adj = np.random.binomial(1, adj_prob * 0.3)
return nx.from_numpy_matrix(adj)

if __name__ == '__main__':
prog_args = arg_parse()
os.environ['CUDA_VISIBLE_DEVICES'] = str(CUDA)
print('CUDA', CUDA)

X_dataset = []
#X_data = nx.to_numpy_matrix(nx.connected_caveman_graph(4, 7))
if prog_args.dataset == 'clique_test':
X_data = disjoint_cliques_test_graph(4, 7)
X_dataset.append(X_data)
elif prog_args.dataset == 'citeseer':
graphs = utils.citeseer_ego()
X_dataset = [nx.to_numpy_matrix(g) for g in graphs]
elif prog_args.dataset == 'community':
graphs = []
for i in range(2, 3):
for j in range(30, 81):
for k in range(10):
graphs.append(utils.caveman_special(i,j, p_edge=0.3))
X_dataset = [nx.to_numpy_matrix(g) for g in graphs]
elif prog_args.dataset == 'grid':
graphs = []
for i in range(10,20):
for j in range(10,20):
graphs.append(nx.grid_2d_graph(i,j))
X_dataset = [nx.to_numpy_matrix(g) for g in graphs]
elif prog_args.dataset.startswith('community'):
graphs = []
num_communities = int(prog_args.dataset[-1])
print('Creating dataset with ', num_communities, ' communities')
c_sizes = np.random.choice([12, 13, 14, 15, 16, 17], num_communities)
for k in range(3000):
graphs.append(utils.n_community(c_sizes, p_inter=0.01))
X_dataset = [nx.to_numpy_matrix(g) for g in graphs]

print('Number of graphs: ', len(X_dataset))
K = prog_args.K # number of clusters
gen_graphs = []
for i in range(len(X_dataset)):
if i % 5 == 0:
print(i)
X_data = X_dataset[i]
N = X_data.shape[0] # number of vertices

Zp, B = mmsb(N, K, X_data)
#print("Block: ", B)
Z_pred = Zp.argmax(axis=1)
print("Result (label flip can happen):")
#print("prob: ", Zp)
print("Predicted")
print(Z_pred)
#print(Z_true)
#print("Adjusted Rand Index =", adjusted_rand_score(Z_pred, Z_true))
for j in range(prog_args.samples):
gen_graphs.append(graph_gen_from_blockmodel(B, Zp))

save_path = '/lfs/local/0/rexy/graph-generation/eval_results/mmsb/'
utils.save_graph_list(gen_graphs, os.path.join(save_path, prog_args.dataset + '.dat'))


+ 155
- 0
create_graphs.py View File

@@ -0,0 +1,155 @@
import networkx as nx
import numpy as np

from utils import *
from data import *

def create(args):
### load datasets
graphs=[]
# synthetic graphs
if args.graph_type=='ladder':
graphs = []
for i in range(100, 201):
graphs.append(nx.ladder_graph(i))
args.max_prev_node = 10
elif args.graph_type=='ladder_small':
graphs = []
for i in range(2, 11):
graphs.append(nx.ladder_graph(i))
args.max_prev_node = 10
elif args.graph_type=='tree':
graphs = []
for i in range(2,5):
for j in range(3,5):
graphs.append(nx.balanced_tree(i,j))
args.max_prev_node = 256
elif args.graph_type=='caveman':
# graphs = []
# for i in range(5,10):
# for j in range(5,25):
# for k in range(5):
# graphs.append(nx.relaxed_caveman_graph(i, j, p=0.1))
graphs = []
for i in range(2, 3):
for j in range(30, 81):
for k in range(10):
graphs.append(caveman_special(i,j, p_edge=0.3))
args.max_prev_node = 100
elif args.graph_type=='caveman_small':
# graphs = []
# for i in range(2,5):
# for j in range(2,6):
# for k in range(10):
# graphs.append(nx.relaxed_caveman_graph(i, j, p=0.1))
graphs = []
for i in range(2, 3):
for j in range(6, 11):
for k in range(20):
graphs.append(caveman_special(i, j, p_edge=0.8)) # default 0.8
args.max_prev_node = 20
elif args.graph_type=='caveman_small_single':
# graphs = []
# for i in range(2,5):
# for j in range(2,6):
# for k in range(10):
# graphs.append(nx.relaxed_caveman_graph(i, j, p=0.1))
graphs = []
for i in range(2, 3):
for j in range(8, 9):
for k in range(100):
graphs.append(caveman_special(i, j, p_edge=0.5))
args.max_prev_node = 20
elif args.graph_type.startswith('community'):
num_communities = int(args.graph_type[-1])
print('Creating dataset with ', num_communities, ' communities')
c_sizes = np.random.choice([12, 13, 14, 15, 16, 17], num_communities)
#c_sizes = [15] * num_communities
for k in range(3000):
graphs.append(n_community(c_sizes, p_inter=0.01))
args.max_prev_node = 80
elif args.graph_type=='grid':
graphs = []
for i in range(10,20):
for j in range(10,20):
graphs.append(nx.grid_2d_graph(i,j))
args.max_prev_node = 40
elif args.graph_type=='grid_small':
graphs = []
for i in range(2,5):
for j in range(2,6):
graphs.append(nx.grid_2d_graph(i,j))
args.max_prev_node = 15
elif args.graph_type=='barabasi':
graphs = []
for i in range(100,200):
for j in range(4,5):
for k in range(5):
graphs.append(nx.barabasi_albert_graph(i,j))
args.max_prev_node = 130
elif args.graph_type=='barabasi_small':
graphs = []
for i in range(4,21):
for j in range(3,4):
for k in range(10):
graphs.append(nx.barabasi_albert_graph(i,j))
args.max_prev_node = 20
elif args.graph_type=='grid_big':
graphs = []
for i in range(36, 46):
for j in range(36, 46):
graphs.append(nx.grid_2d_graph(i, j))
args.max_prev_node = 90

elif 'barabasi_noise' in args.graph_type:
graphs = []
for i in range(100,101):
for j in range(4,5):
for k in range(500):
graphs.append(nx.barabasi_albert_graph(i,j))
graphs = perturb_new(graphs,p=args.noise/10.0)
args.max_prev_node = 99

# real graphs
elif args.graph_type == 'enzymes':
graphs= Graph_load_batch(min_num_nodes=10, name='ENZYMES')
args.max_prev_node = 25
elif args.graph_type == 'enzymes_small':
graphs_raw = Graph_load_batch(min_num_nodes=10, name='ENZYMES')
graphs = []
for G in graphs_raw:
if G.number_of_nodes()<=20:
graphs.append(G)
args.max_prev_node = 15
elif args.graph_type == 'protein':
graphs = Graph_load_batch(min_num_nodes=20, name='PROTEINS_full')
args.max_prev_node = 80
elif args.graph_type == 'DD':
graphs = Graph_load_batch(min_num_nodes=100, max_num_nodes=500, name='DD',node_attributes=False,graph_labels=True)
args.max_prev_node = 230
elif args.graph_type == 'citeseer':
_, _, G = Graph_load(dataset='citeseer')
G = max(nx.connected_component_subgraphs(G), key=len)
G = nx.convert_node_labels_to_integers(G)
graphs = []
for i in range(G.number_of_nodes()):
G_ego = nx.ego_graph(G, i, radius=3)
if G_ego.number_of_nodes() >= 50 and (G_ego.number_of_nodes() <= 400):
graphs.append(G_ego)
args.max_prev_node = 250
elif args.graph_type == 'citeseer_small':
_, _, G = Graph_load(dataset='citeseer')
G = max(nx.connected_component_subgraphs(G), key=len)
G = nx.convert_node_labels_to_integers(G)
graphs = []
for i in range(G.number_of_nodes()):
G_ego = nx.ego_graph(G, i, radius=1)
if (G_ego.number_of_nodes() >= 4) and (G_ego.number_of_nodes() <= 20):
graphs.append(G_ego)
shuffle(graphs)
graphs = graphs[0:200]
args.max_prev_node = 15

return graphs



+ 1392
- 0
data.py
File diff suppressed because it is too large
View File


+ 1686092
- 0
dataset/DD/DD_A.txt
File diff suppressed because it is too large
View File


+ 334925
- 0
dataset/DD/DD_graph_indicator.txt
File diff suppressed because it is too large
View File


+ 1178
- 0
dataset/DD/DD_graph_labels.txt
File diff suppressed because it is too large
View File


+ 334925
- 0
dataset/DD/DD_node_labels.txt
File diff suppressed because it is too large
View File


+ 75
- 0
dataset/DD/README.txt View File

@@ -0,0 +1,75 @@
README for dataset DD


=== Usage ===

This folder contains the following comma separated text files
(replace DS by the name of the dataset):

n = total number of nodes
m = total number of edges
N = number of graphs

(1) DS_A.txt (m lines)
sparse (block diagonal) adjacency matrix for all graphs,
each line corresponds to (row, col) resp. (node_id, node_id)

(2) DS_graph_indicator.txt (n lines)
column vector of graph identifiers for all nodes of all graphs,
the value in the i-th line is the graph_id of the node with node_id i

(3) DS_graph_labels.txt (N lines)
class labels for all graphs in the dataset,
the value in the i-th line is the class label of the graph with graph_id i

(4) DS_node_labels.txt (n lines)
column vector of node labels,
the value in the i-th line corresponds to the node with node_id i

There are OPTIONAL files if the respective information is available:

(5) DS_edge_labels.txt (m lines; same size as DS_A_sparse.txt)
labels for the edges in DS_A_sparse.txt

(6) DS_edge_attributes.txt (m lines; same size as DS_A.txt)
attributes for the edges in DS_A.txt

(7) DS_node_attributes.txt (n lines)
matrix of node attributes,
the comma seperated values in the i-th line is the attribute vector of the node with node_id i

(8) DS_graph_attributes.txt (N lines)
regression values for all graphs in the dataset,
the value in the i-th line is the attribute of the graph with graph_id i


=== Description ===

D&D is a dataset of 1178 protein structures (Dobson and Doig, 2003). Each protein is
represented by a graph, in which the nodes are amino acids and two nodes are connected
by an edge if they are less than 6 Angstroms apart. The prediction task is to classify
the protein structures into enzymes and non-enzymes.


=== Previous Use of the Dataset ===

Neumann, M., Garnett R., Bauckhage Ch., Kersting K.: Propagation Kernels: Efficient Graph
Kernels from Propagated Information. Under review at MLJ.

Neumann, M., Patricia, N., Garnett, R., Kersting, K.: Efficient Graph Kernels by
Randomization. In: P.A. Flach, T.D. Bie, N. Cristianini (eds.) ECML/PKDD, Notes in
Computer Science, vol. 7523, pp. 378-393. Springer (2012).

Shervashidze, N., Schweitzer, P., van Leeuwen, E., Mehlhorn, K., Borgwardt, K.:
Weisfeiler-Lehman Graph Kernels. Journal of Machine Learning Research 12, 2539-2561 (2011)


=== References ===

P. D. Dobson and A. J. Doig. Distinguishing enzyme structures from non-enzymes without
alignments. J. Mol. Biol., 330(4):771โ€“783, Jul 2003.






+ 74564
- 0
dataset/ENZYMES/ENZYMES_A.txt
File diff suppressed because it is too large
View File


+ 19580
- 0
dataset/ENZYMES/ENZYMES_graph_indicator.txt
File diff suppressed because it is too large
View File


+ 600
- 0
dataset/ENZYMES/ENZYMES_graph_labels.txt View File

@@ -0,0 +1,600 @@
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4