You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

args.py 5.1KB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. ### program configuration
  2. class Args():
  3. def __init__(self):
  4. ### if clean tensorboard
  5. self.clean_tensorboard = False
  6. ### Which CUDA GPU device is used for training
  7. self.cuda = 1
  8. ### Which GraphRNN model variant is used.
  9. # The simple version of Graph RNN
  10. # self.note = 'GraphRNN_MLP'
  11. # The dependent Bernoulli sequence version of GraphRNN
  12. self.initial_note = 'GraphRNN_RNN'
  13. # self.note = 'GraphRNN_MLP_nobfs'
  14. self.note = 'GraphRNN_RNN'
  15. self.training_mode = 'TrainForCompletion'
  16. self.number_of_missing_nodes = 3
  17. self.loss = 'GraphRNN_loss' # 'matrix_completion_loss' or 'graph_matching_loss' or 'GraphRNN_loss', 'Normal'
  18. ## for comparison, removing the BFS compoenent
  19. # self.note = 'GraphRNN_MLP_nobfs'
  20. # self.note = 'GraphRNN_RNN_nobfs'
  21. ### Which dataset is used to train the model
  22. # self.graph_type = 'DD'
  23. # self.graph_type = 'caveman'
  24. # self.graph_type = 'caveman_small'
  25. # self.graph_type = 'caveman_small_single'
  26. # self.graph_type = 'community4'
  27. # self.graph_type = 'grid'
  28. self.graph_type = 'COLLAB'
  29. # self.graph_type = 'ladder_small'
  30. # self.graph_type = 'enzymes'
  31. # self.graph_type = 'enzymes_small'
  32. # self.graph_type = 'barabasi'
  33. # self.graph_type = 'barabasi_small'
  34. # self.graph_type = 'citeseer'
  35. # self.graph_type = 'citeseer_small'
  36. # self.graph_type = 'barabasi_noise'
  37. # self.noise = 10
  38. #
  39. # if self.graph_type == 'barabasi_noise':
  40. # self.graph_type = self.graph_type+str(self.noise)
  41. # if none, then auto calculate
  42. self.max_num_node = None # max number of nodes in a graph
  43. self.max_prev_node = None # max previous node that looks back
  44. ### network config
  45. ## GraphRNN
  46. if 'small' in self.graph_type:
  47. self.parameter_shrink = 2
  48. else:
  49. self.parameter_shrink = 1
  50. self.hidden_size_rnn = int(128/self.parameter_shrink) # hidden size for main RNN
  51. self.hidden_size_rnn_output = 16 # hidden size for output RNN
  52. self.embedding_size_rnn = int(64/self.parameter_shrink) # the size for LSTM input
  53. self.embedding_size_rnn_output = 8 # the embedding size for output rnn
  54. self.embedding_size_output = int(64/self.parameter_shrink) # the embedding size for output (VAE/MLP)
  55. self.batch_size = 32 # normal: 32, and the rest should be changed accordingly
  56. self.test_batch_size = 32
  57. self.test_total_size = 1000
  58. self.num_layers = 4
  59. ### training config
  60. self.num_workers = 4 # num workers to load data, default 4
  61. self.batch_ratio = 32 # how many batches of samples per epoch, default 32, e.g., 1 epoch = 32 batches
  62. self.epochs = 3000 # now one epoch means self.batch_ratio x batch_size
  63. self.epochs_test_start = 100
  64. self.epochs_test = 100
  65. self.epochs_log = 50
  66. self.epochs_save = 1000
  67. self.lr = 0.003
  68. self.milestones = [400, 1000]
  69. self.lr_rate = 0.3
  70. self.sample_time = 2 # sample time in each time step, when validating
  71. ### output config
  72. # self.dir_input = "/dfs/scratch0/jiaxuany0/"
  73. self.dir_input = "./"
  74. self.model_save_path = self.dir_input+'model_save/' # only for nll evaluation
  75. self.tensorboard_path = self.dir_input + 'tensorboard/'
  76. self.graph_save_path = self.dir_input+'graphs/'
  77. self.figure_save_path = self.dir_input+'figures/'
  78. self.timing_save_path = self.dir_input+'timing/'
  79. self.figure_prediction_save_path = self.dir_input+'figures_prediction/'
  80. self.nll_save_path = self.dir_input+'nll/'
  81. self.load = False # if load model, default lr is very low
  82. self.load_epoch = 3000
  83. self.save = True
  84. ### baseline config
  85. # self.generator_baseline = 'Gnp'
  86. self.generator_baseline = 'BA'
  87. # self.metric_baseline = 'general'
  88. # self.metric_baseline = 'degree'
  89. self.metric_baseline = 'clustering'
  90. self.graph_completion_string = self.training_mode + "_number_of_missing_nodes_" + str(self.number_of_missing_nodes)
  91. ### filenames to save intemediate and final outputs
  92. self.initial_fname = self.initial_note + '_' + self.graph_type + '_' + str(self.num_layers) + '_' + str(self.hidden_size_rnn) + '_'
  93. self.fname = self.note + '_' + self.graph_type + '_' + str(self.num_layers) + '_' + str(self.hidden_size_rnn) + '_'
  94. self.fname_pred = self.note+'_'+self.graph_type+'_'+str(self.num_layers)+'_'+ str(self.hidden_size_rnn)+'_pred_'
  95. self.fname_train = self.note+'_'+self.graph_type+'_'+str(self.num_layers)+'_'+ str(self.hidden_size_rnn)+'_train_'
  96. self.fname_test = self.note + '_' + self.graph_type + '_' + str(self.num_layers) + '_' + str(self.hidden_size_rnn) + '_test_'
  97. self.fname_baseline = self.graph_save_path + self.graph_type + self.generator_baseline+'_'+self.metric_baseline
  98. self.fname_GraphRNN = self.note + '_' + self.graph_type + '_' + str(self.num_layers) + '_' + str(self.hidden_size_rnn) + '_'