You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

args.py 5.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. ### program configuration
  2. class GraphVAE_Args():
  3. def __init__(self):
  4. ### if clean tensorboard
  5. self.clean_tensorboard = False
  6. ### Which CUDA GPU device is used for training
  7. self.cuda = 1
  8. ### Which GraphRNN model variant is used.
  9. # The simple version of Graph RNN
  10. # self.note = 'GraphRNN_MLP'
  11. # The dependent Bernoulli sequence version of GraphRNN
  12. self.initial_note = 'GraphRNN_MLP'
  13. # self.note = 'GraphRNN_MLP_nobfs'
  14. self.note = 'GraphRNN_MLP'
  15. self.training_mode = 'TrainForCompletion'
  16. self.number_of_missing_nodes = 1
  17. self.graph_matching_mode = False
  18. self.permutation_mode = True
  19. self.bfs_mode = True
  20. self.completion_mode = False
  21. self.loss = 'Normal' # 'matrix_completion_loss' or 'graph_matching_loss' or 'GraphRNN_loss', 'Normal'
  22. ## for comparison, removing the BFS compoenent
  23. # self.note = 'GraphRNN_MLP_nobfs'
  24. # self.note = 'GraphRNN_RNN_nobfs'
  25. ### Which dataset is used to train the model
  26. # self.graph_type = 'DD'
  27. # self.graph_type = 'caveman'
  28. # self.graph_type = 'caveman_small'
  29. # self.graph_type = 'caveman_small_single'
  30. # self.graph_type = 'community4'
  31. # self.graph_type = 'grid'
  32. self.graph_type = 'grid_small'
  33. # self.graph_type = 'ladder_small'
  34. # self.graph_type = 'enzymes'
  35. # self.graph_type = 'enzymes_small'
  36. # self.graph_type = 'barabasi'
  37. # self.graph_type = 'barabasi_small'
  38. # self.graph_type = 'citeseer'
  39. # self.graph_type = 'citeseer_small'
  40. # self.graph_type = 'barabasi_noise'
  41. # self.noise = 10
  42. #
  43. # if self.graph_type == 'barabasi_noise':
  44. # self.graph_type = self.graph_type+str(self.noise)
  45. # if none, then auto calculate
  46. self.max_num_node = None # max number of nodes in a graph
  47. self.max_prev_node = None # max previous node that looks back
  48. ### network config
  49. ## GraphRNN
  50. if 'small' in self.graph_type:
  51. self.parameter_shrink = 2
  52. else:
  53. self.parameter_shrink = 1
  54. self.hidden_size_rnn = int(128/self.parameter_shrink) # hidden size for main RNN
  55. self.hidden_size_rnn_output = 16 # hidden size for output RNN
  56. self.embedding_size_rnn = int(64/self.parameter_shrink) # the size for LSTM input
  57. self.embedding_size_rnn_output = 8 # the embedding size for output rnn
  58. self.embedding_size_output = int(64/self.parameter_shrink) # the embedding size for output (VAE/MLP)
  59. self.batch_size = 32 # normal: 32, and the rest should be changed accordingly
  60. self.test_batch_size = 32
  61. self.test_total_size = 1000
  62. self.num_layers = 4
  63. ### training config
  64. self.num_workers = 4 # num workers to load data, default 4
  65. self.batch_ratio = 32 # how many batches of samples per epoch, default 32, e.g., 1 epoch = 32 batches
  66. self.epochs = 3000 # now one epoch means self.batch_ratio x batch_size
  67. self.epochs_test_start = 100
  68. self.epochs_test = 100
  69. self.epochs_log = 100
  70. self.epochs_save = 100
  71. self.lr = 0.003
  72. self.milestones = [400, 1000]
  73. self.lr_rate = 0.3
  74. self.sample_time = 2 # sample time in each time step, when validating
  75. ### output config
  76. # self.dir_input = "/dfs/scratch0/jiaxuany0/"
  77. self.dir_input = "./"
  78. self.model_save_path = self.dir_input+'model_save/' # only for nll evaluation
  79. self.tensorboard_path = self.dir_input + 'tensorboard/'
  80. self.graph_save_path = self.dir_input+'graphs/'
  81. self.figure_save_path = self.dir_input+'figures/'
  82. self.timing_save_path = self.dir_input+'timing/'
  83. self.figure_prediction_save_path = self.dir_input+'figures_prediction/'
  84. self.nll_save_path = self.dir_input+'nll/'
  85. self.load = False # if load model, default lr is very low
  86. self.load_epoch = 3000
  87. self.save = True
  88. ### baseline config
  89. # self.generator_baseline = 'Gnp'
  90. self.generator_baseline = 'BA'
  91. # self.metric_baseline = 'general'
  92. # self.metric_baseline = 'degree'
  93. self.metric_baseline = 'clustering'
  94. self.graph_completion_string = self.training_mode + "_number_of_missing_nodes_" + str(self.number_of_missing_nodes)
  95. ### filenames to save intemediate and final outputs
  96. self.initial_fname = self.initial_note + '_' + self.graph_type + '_' + str(self.num_layers) + '_' + str(self.hidden_size_rnn) + '_'
  97. self.fname = self.note + '_' + self.graph_type + '_' + str(self.num_layers) + '_' + str(self.hidden_size_rnn) + '_'
  98. self.fname_pred = self.note+'_'+self.graph_type+'_'+str(self.num_layers)+'_'+ str(self.hidden_size_rnn)+'_pred_'
  99. self.fname_train = self.note+'_'+self.graph_type+'_'+str(self.num_layers)+'_'+ str(self.hidden_size_rnn)+'_train_'
  100. self.fname_test = self.note + '_' + self.graph_type + '_' + str(self.num_layers) + '_' + str(self.hidden_size_rnn) + '_test_'
  101. self.fname_baseline = self.graph_save_path + self.graph_type + self.generator_baseline+'_'+self.metric_baseline
  102. self.fname_GraphRNN = "GraphRNN_MLP" + '_' + self.graph_type + '_' + str(self.num_layers) + '_' + str(self.hidden_size_rnn) + '_'