You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

args.py 5.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. ### program configuration
  2. class GraphVAE_Args():
  3. def __init__(self):
  4. ### if clean tensorboard
  5. self.clean_tensorboard = False
  6. ### Which CUDA GPU device is used for training
  7. self.cuda = 1
  8. ### Which GraphRNN model variant is used.
  9. # The simple version of Graph RNN
  10. # self.note = 'GraphRNN_MLP'
  11. # The dependent Bernoulli sequence version of GraphRNN
  12. self.initial_note = 'GraphRNN_MLP'
  13. # self.note = 'GraphRNN_MLP_nobfs'
  14. self.note = 'GraphRNN_MLP'
  15. self.training_mode = 'TrainForCompletion'
  16. self.number_of_missing_nodes = 1
  17. self.graph_matching_mode = False
  18. self.permutation_mode = True
  19. self.bfs_mode = False
  20. self.graph_pooling_mode = False
  21. self.bfs_mode_with_arbitrary_node_deleted = False
  22. self.completion_mode_small_parameter_size = False
  23. self.reconstruction=False
  24. self.diffpoolGCN = False
  25. self.GRAN = True
  26. self.GRAN_arbitrary_node = True
  27. self.GRAN_linear = False
  28. self.different_graph_sizes = True
  29. self.loss = 'Normal' # 'matrix_completion_loss' or 'graph_matching_loss' or 'GraphRNN_loss', 'Normal'
  30. ## for comparison, removing the BFS compoenent
  31. # self.note = 'GraphRNN_MLP_nobfs'
  32. # self.note = 'GraphRNN_RNN_nobfs'
  33. ### Which dataset is used to train the model
  34. # self.graph_type = 'DD'
  35. # self.graph_type = 'caveman'
  36. # self.graph_type = 'caveman_small'
  37. # self.graph_type = 'caveman_small_single'
  38. # self.graph_type = 'community4'
  39. # self.graph_type = 'grid'
  40. self.graph_type = 'grid_small'
  41. # self.graph_type = 'ladder_small'
  42. # self.graph_type = 'enzymes'
  43. # self.graph_type = 'enzymes_small'
  44. # self.graph_type = 'barabasi'
  45. # self.graph_type = 'barabasi_small'
  46. # self.graph_type = 'citeseer'
  47. # self.graph_type = 'citeseer_small'
  48. # self.graph_type = 'barabasi_noise'
  49. # self.noise = 10
  50. #
  51. # if self.graph_type == 'barabasi_noise':
  52. # self.graph_type = self.graph_type+str(self.noise)
  53. # if none, then auto calculate
  54. self.max_num_node = None # max number of nodes in a graph
  55. self.max_prev_node = None # max previous node that looks back
  56. ### network config
  57. ## GraphRNN
  58. if 'small' in self.graph_type:
  59. self.parameter_shrink = 2
  60. else:
  61. self.parameter_shrink = 1
  62. self.hidden_size_rnn = int(128/self.parameter_shrink) # hidden size for main RNN
  63. self.hidden_size_rnn_output = 16 # hidden size for output RNN
  64. self.embedding_size_rnn = int(64/self.parameter_shrink) # the size for LSTM input
  65. self.embedding_size_rnn_output = 8 # the embedding size for output rnn
  66. self.embedding_size_output = int(64/self.parameter_shrink) # the embedding size for output (VAE/MLP)
  67. self.batch_size = 32 # normal: 32, and the rest should be changed accordingly
  68. self.test_batch_size = 32
  69. self.test_total_size = 1000
  70. self.num_layers = 4
  71. ### training config
  72. self.num_workers = 4 # num workers to load data, default 4
  73. self.batch_ratio = 32 # how many batches of samples per epoch, default 32, e.g., 1 epoch = 32 batches
  74. self.epochs = 3000 # now one epoch means self.batch_ratio x batch_size
  75. self.epochs_test_start = 100
  76. self.epochs_test = 100
  77. self.epochs_log = 100
  78. self.epochs_save = 100
  79. self.lr = 0.003
  80. self.milestones = [400, 1000]
  81. self.lr_rate = 0.3
  82. self.sample_time = 2 # sample time in each time step, when validating
  83. ### output config
  84. # self.dir_input = "/dfs/scratch0/jiaxuany0/"
  85. self.dir_input = "./"
  86. self.model_save_path = self.dir_input+'model_save/' # only for nll evaluation
  87. self.tensorboard_path = self.dir_input + 'tensorboard/'
  88. self.graph_save_path = self.dir_input+'graphs/'
  89. self.figure_save_path = self.dir_input+'figures/'
  90. self.timing_save_path = self.dir_input+'timing/'
  91. self.figure_prediction_save_path = self.dir_input+'figures_prediction/'
  92. self.nll_save_path = self.dir_input+'nll/'
  93. self.load = False # if load model, default lr is very low
  94. self.load_epoch = 3000
  95. self.save = True
  96. ### baseline config
  97. # self.generator_baseline = 'Gnp'
  98. self.generator_baseline = 'BA'
  99. # self.metric_baseline = 'general'
  100. # self.metric_baseline = 'degree'
  101. self.metric_baseline = 'clustering'
  102. self.graph_completion_string = self.training_mode + "_number_of_missing_nodes_" + str(self.number_of_missing_nodes)
  103. ### filenames to save intemediate and final outputs
  104. self.initial_fname = self.initial_note + '_' + self.graph_type + '_' + str(self.num_layers) + '_' + str(self.hidden_size_rnn) + '_'
  105. self.fname = self.note + '_' + self.graph_type + '_' + str(self.num_layers) + '_' + str(self.hidden_size_rnn) + '_'
  106. self.fname_pred = self.note+'_'+self.graph_type+'_'+str(self.num_layers)+'_'+ str(self.hidden_size_rnn)+'_pred_'
  107. self.fname_train = self.note+'_'+self.graph_type+'_'+str(self.num_layers)+'_'+ str(self.hidden_size_rnn)+'_train_'
  108. self.fname_test = self.note + '_' + self.graph_type + '_' + str(self.num_layers) + '_' + str(self.hidden_size_rnn) + '_test_'
  109. self.fname_baseline = self.graph_save_path + self.graph_type + self.generator_baseline+'_'+self.metric_baseline
  110. self.fname_GraphRNN = "GraphRNN_MLP" + '_' + self.graph_type + '_' + str(self.num_layers) + '_' + str(self.hidden_size_rnn) + '_'