You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 6.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. from train import *
  2. if __name__ == '__main__':
  3. # All necessary arguments are defined in args.py
  4. args = Args()
  5. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
  6. print('CUDA', args.cuda)
  7. print('File name prefix',args.fname)
  8. # check if necessary directories exist
  9. if not os.path.isdir(args.model_save_path):
  10. os.makedirs(args.model_save_path)
  11. if not os.path.isdir(args.graph_save_path):
  12. os.makedirs(args.graph_save_path)
  13. if not os.path.isdir(args.figure_save_path):
  14. os.makedirs(args.figure_save_path)
  15. if not os.path.isdir(args.timing_save_path):
  16. os.makedirs(args.timing_save_path)
  17. if not os.path.isdir(args.figure_prediction_save_path):
  18. os.makedirs(args.figure_prediction_save_path)
  19. if not os.path.isdir(args.nll_save_path):
  20. os.makedirs(args.nll_save_path)
  21. time = strftime("%Y-%m-%d %H:%M:%S", gmtime())
  22. # logging.basicConfig(filename='logs/train' + time + '.log', level=logging.DEBUG)
  23. if args.clean_tensorboard:
  24. if os.path.isdir("tensorboard"):
  25. shutil.rmtree("tensorboard")
  26. configure("tensorboard/run"+time, flush_secs=5)
  27. graphs = create_graphs.create(args)
  28. # split datasets
  29. random.seed(123)
  30. shuffle(graphs)
  31. graphs_len = len(graphs)
  32. graphs_test = graphs[int(0.8 * graphs_len):]
  33. graphs_train = graphs[0:int(0.8*graphs_len)]
  34. graphs_validate = graphs[0:int(0.2*graphs_len)]
  35. # if use pre-saved graphs
  36. # dir_input = "/dfs/scratch0/jiaxuany0/graphs/"
  37. # fname_test = dir_input + args.note + '_' + args.graph_type + '_' + str(args.num_layers) + '_' + str(
  38. # args.hidden_size_rnn) + '_test_' + str(0) + '.dat'
  39. # graphs = load_graph_list(fname_test, is_real=True)
  40. # graphs_test = graphs[int(0.8 * graphs_len):]
  41. # graphs_train = graphs[0:int(0.8 * graphs_len)]
  42. # graphs_validate = graphs[int(0.2 * graphs_len):int(0.4 * graphs_len)]
  43. graph_validate_len = 0
  44. for graph in graphs_validate:
  45. graph_validate_len += graph.number_of_nodes()
  46. graph_validate_len /= len(graphs_validate)
  47. print('graph_validate_len', graph_validate_len)
  48. graph_test_len = 0
  49. for graph in graphs_test:
  50. graph_test_len += graph.number_of_nodes()
  51. graph_test_len /= len(graphs_test)
  52. print('graph_test_len', graph_test_len)
  53. args.max_num_node = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
  54. max_num_edge = max([graphs[i].number_of_edges() for i in range(len(graphs))])
  55. min_num_edge = min([graphs[i].number_of_edges() for i in range(len(graphs))])
  56. # args.max_num_node = 2000
  57. # show graphs statistics
  58. print('total graph num: {}, training set: {}'.format(len(graphs),len(graphs_train)))
  59. print('max number node: {}'.format(args.max_num_node))
  60. print('max/min number edge: {}; {}'.format(max_num_edge,min_num_edge))
  61. print('max previous node: {}'.format(args.max_prev_node))
  62. # save ground truth graphs
  63. ## To get train and test set, after loading you need to manually slice
  64. save_graph_list(graphs, args.graph_save_path + args.fname_train + '0.dat')
  65. save_graph_list(graphs, args.graph_save_path + args.fname_test + '0.dat')
  66. print('train and test graphs saved at: ', args.graph_save_path + args.fname_test + '0.dat')
  67. ### comment when normal training, for graph completion only
  68. # p = 0.5
  69. # for graph in graphs_train:
  70. # for node in list(graph.nodes()):
  71. # # print('node',node)
  72. # if np.random.rand()>p:
  73. # graph.remove_node(node)
  74. # for edge in list(graph.edges()):
  75. # # print('edge',edge)
  76. # if np.random.rand()>p:
  77. # graph.remove_edge(edge[0],edge[1])
  78. ### dataset initialization
  79. if 'nobfs' in args.note:
  80. print('nobfs')
  81. dataset = Graph_sequence_sampler_pytorch_nobfs(graphs_train, max_num_node=args.max_num_node)
  82. args.max_prev_node = args.max_num_node-1
  83. if 'barabasi_noise' in args.graph_type:
  84. print('barabasi_noise')
  85. dataset = Graph_sequence_sampler_pytorch_canonical(graphs_train,max_prev_node=args.max_prev_node)
  86. args.max_prev_node = args.max_num_node - 1
  87. else:
  88. dataset = Graph_sequence_sampler_pytorch(graphs_train,max_prev_node=args.max_prev_node,max_num_node=args.max_num_node)
  89. sample_strategy = torch.utils.data.sampler.WeightedRandomSampler([1.0 / len(dataset) for i in range(len(dataset))],
  90. num_samples=args.batch_size*args.batch_ratio, replacement=True)
  91. dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers,
  92. sampler=sample_strategy)
  93. ### model initialization
  94. ## Graph RNN VAE model
  95. # lstm = LSTM_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_lstm,
  96. # hidden_size=args.hidden_size, num_layers=args.num_layers).cuda()
  97. if 'GraphRNN_VAE_conditional' in args.note:
  98. rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
  99. hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
  100. has_output=False).cuda()
  101. output = MLP_VAE_conditional_plain(h_size=args.hidden_size_rnn, embedding_size=args.embedding_size_output, y_size=args.max_prev_node).cuda()
  102. elif 'GraphRNN_MLP' in args.note:
  103. rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
  104. hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
  105. has_output=False).cuda()
  106. output = MLP_plain(h_size=args.hidden_size_rnn, embedding_size=args.embedding_size_output, y_size=args.max_prev_node).cuda()
  107. elif 'GraphRNN_RNN' in args.note:
  108. rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
  109. hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
  110. has_output=True, output_size=args.hidden_size_rnn_output).cuda()
  111. output = GRU_plain(input_size=1, embedding_size=args.embedding_size_rnn_output,
  112. hidden_size=args.hidden_size_rnn_output, num_layers=args.num_layers, has_input=True,
  113. has_output=True, output_size=1).cuda()
  114. ### start training
  115. train(args, dataset_loader, rnn, output)
  116. ### graph completion
  117. # train_graph_completion(args,dataset_loader,rnn,output)
  118. ### nll evaluation
  119. # train_nll(args, dataset_loader, dataset_loader, rnn, output, max_iter = 200, graph_validate_len=graph_validate_len,graph_test_len=graph_test_len)