You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

args.py 4.4KB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. ### program configuration
  2. class Args():
  3. def __init__(self):
  4. ### if clean tensorboard
  5. self.clean_tensorboard = False
  6. ### Which CUDA GPU device is used for training
  7. self.cuda = 1
  8. ### Which GraphRNN model variant is used.
  9. # The simple version of Graph RNN
  10. # self.note = 'GraphRNN_MLP'
  11. # The dependent Bernoulli sequence version of GraphRNN
  12. self.note = 'GraphRNN_RNN'
  13. ## for comparison, removing the BFS compoenent
  14. # self.note = 'GraphRNN_MLP_nobfs'
  15. # self.note = 'GraphRNN_RNN_nobfs'
  16. ### Which dataset is used to train the model
  17. # self.graph_type = 'DD'
  18. # self.graph_type = 'caveman'
  19. # self.graph_type = 'caveman_small'
  20. # self.graph_type = 'caveman_small_single'
  21. # self.graph_type = 'community4'
  22. self.graph_type = 'grid'
  23. # self.graph_type = 'grid_small'
  24. # self.graph_type = 'ladder_small'
  25. # self.graph_type = 'enzymes'
  26. # self.graph_type = 'enzymes_small'
  27. # self.graph_type = 'barabasi'
  28. # self.graph_type = 'barabasi_small'
  29. # self.graph_type = 'citeseer'
  30. # self.graph_type = 'citeseer_small'
  31. # self.graph_type = 'barabasi_noise'
  32. # self.noise = 10
  33. #
  34. # if self.graph_type == 'barabasi_noise':
  35. # self.graph_type = self.graph_type+str(self.noise)
  36. # if none, then auto calculate
  37. self.max_num_node = None # max number of nodes in a graph
  38. self.max_prev_node = None # max previous node that looks back
  39. ### network config
  40. ## GraphRNN
  41. if 'small' in self.graph_type:
  42. self.parameter_shrink = 2
  43. else:
  44. self.parameter_shrink = 1
  45. self.hidden_size_rnn = int(128/self.parameter_shrink) # hidden size for main RNN
  46. self.hidden_size_rnn_output = 16 # hidden size for output RNN
  47. self.embedding_size_rnn = int(64/self.parameter_shrink) # the size for LSTM input
  48. self.embedding_size_rnn_output = 8 # the embedding size for output rnn
  49. self.embedding_size_output = int(64/self.parameter_shrink) # the embedding size for output (VAE/MLP)
  50. self.batch_size = 32 # normal: 32, and the rest should be changed accordingly
  51. self.test_batch_size = 32
  52. self.test_total_size = 1000
  53. self.num_layers = 4
  54. ### training config
  55. self.num_workers = 4 # num workers to load data, default 4
  56. self.batch_ratio = 32 # how many batches of samples per epoch, default 32, e.g., 1 epoch = 32 batches
  57. self.epochs = 3000 # now one epoch means self.batch_ratio x batch_size
  58. self.epochs_test_start = 100
  59. self.epochs_test = 100
  60. self.epochs_log = 100
  61. self.epochs_save = 100
  62. self.lr = 0.003
  63. self.milestones = [400, 1000]
  64. self.lr_rate = 0.3
  65. self.sample_time = 2 # sample time in each time step, when validating
  66. ### output config
  67. # self.dir_input = "/dfs/scratch0/jiaxuany0/"
  68. self.dir_input = "./"
  69. self.model_save_path = self.dir_input+'model_save/' # only for nll evaluation
  70. self.graph_save_path = self.dir_input+'graphs/'
  71. self.figure_save_path = self.dir_input+'figures/'
  72. self.timing_save_path = self.dir_input+'timing/'
  73. self.figure_prediction_save_path = self.dir_input+'figures_prediction/'
  74. self.nll_save_path = self.dir_input+'nll/'
  75. self.load = False # if load model, default lr is very low
  76. self.load_epoch = 3000
  77. self.save = True
  78. ### baseline config
  79. # self.generator_baseline = 'Gnp'
  80. self.generator_baseline = 'BA'
  81. # self.metric_baseline = 'general'
  82. # self.metric_baseline = 'degree'
  83. self.metric_baseline = 'clustering'
  84. ### filenames to save intemediate and final outputs
  85. self.fname = self.note + '_' + self.graph_type + '_' + str(self.num_layers) + '_' + str(self.hidden_size_rnn) + '_'
  86. self.fname_pred = self.note+'_'+self.graph_type+'_'+str(self.num_layers)+'_'+ str(self.hidden_size_rnn)+'_pred_'
  87. self.fname_train = self.note+'_'+self.graph_type+'_'+str(self.num_layers)+'_'+ str(self.hidden_size_rnn)+'_train_'
  88. self.fname_test = self.note + '_' + self.graph_type + '_' + str(self.num_layers) + '_' + str(self.hidden_size_rnn) + '_test_'
  89. self.fname_baseline = self.graph_save_path + self.graph_type + self.generator_baseline+'_'+self.metric_baseline