You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. from train import *
  2. from baselines.graphvae.graphvae_train import graph_statistics
  3. if __name__ == '__main__':
  4. print("SALAMMMMMMMM")
  5. # All necessary arguments are defined in args.py
  6. args = Args()
  7. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
  8. print('CUDA', args.cuda)
  9. print('File name prefix', args.fname)
  10. # check if necessary directories exist
  11. if not os.path.isdir(args.model_save_path):
  12. os.makedirs(args.model_save_path)
  13. if not os.path.isdir(args.graph_save_path):
  14. os.makedirs(args.graph_save_path)
  15. if not os.path.isdir(args.figure_save_path):
  16. os.makedirs(args.figure_save_path)
  17. if not os.path.isdir(args.timing_save_path):
  18. os.makedirs(args.timing_save_path)
  19. if not os.path.isdir(args.figure_prediction_save_path):
  20. os.makedirs(args.figure_prediction_save_path)
  21. if not os.path.isdir(args.nll_save_path):
  22. os.makedirs(args.nll_save_path)
  23. time = strftime("%Y-%m-%d %H:%M:%S", gmtime())
  24. # logging.basicConfig(filename='logs/train' + time + '.log', level=logging.DEBUG)
  25. if args.clean_tensorboard:
  26. if os.path.isdir("tensorboard"):
  27. shutil.rmtree("tensorboard")
  28. configure("tensorboard/run" + time, flush_secs=5)
  29. graphs = create_graphs.create(args)
  30. training_on_KronEM_data = True
  31. if training_on_KronEM_data:
  32. small_graphs = []
  33. for i in range(len(graphs)):
  34. if graphs[i].number_of_nodes() == 8 or graphs[i].number_of_nodes() == 16 or graphs[
  35. i].number_of_nodes() == 32 or \
  36. graphs[i].number_of_nodes() == 64 or graphs[i].number_of_nodes() == 128 or graphs[
  37. i].number_of_nodes() == 256:
  38. small_graphs.append(graphs[i])
  39. graphs = small_graphs
  40. else:
  41. if args.graph_type == 'IMDBBINARY' or args.graph_type == 'IMDBMULTI':
  42. small_graphs = []
  43. for i in range(len(graphs)):
  44. if graphs[i].number_of_nodes() < 41:
  45. small_graphs.append(graphs[i])
  46. graphs = small_graphs
  47. elif args.graph_type == 'COLLAB':
  48. small_graphs = []
  49. for i in range(len(graphs)):
  50. if graphs[i].number_of_nodes() < 52 and graphs[i].number_of_nodes() > 41:
  51. small_graphs.append(graphs[i])
  52. graphs = small_graphs
  53. # args.max_num_node = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
  54. graph_statistics(graphs)
  55. # split datasets
  56. random.seed(123)
  57. shuffle(graphs)
  58. graphs_len = len(graphs)
  59. graphs_test = graphs[int(0.8 * graphs_len):]
  60. graphs_train = graphs[0:int(0.8 * graphs_len)]
  61. graphs_validate = graphs[0:int(0.2 * graphs_len)]
  62. print("**** Test graphs size:")
  63. print(len(graphs_test))
  64. print("**** Train graphs size:")
  65. print(len(graphs_train))
  66. # if use pre-saved graphs
  67. # dir_input = "/dfs/scratch0/jiaxuany0/graphs/"
  68. # fname_test = dir_input + args.note + '_' + args.graph_type + '_' + str(args.num_layers) + '_' + str(
  69. # args.hidden_size_rnn) + '_test_' + str(0) + '.dat'
  70. # graphs = load_graph_list(fname_test, is_real=True)
  71. # graphs_test = graphs[int(0.8 * graphs_len):]
  72. # graphs_train = graphs[0:int(0.8 * graphs_len)]
  73. # graphs_validate = graphs[int(0.2 * graphs_len):int(0.4 * graphs_len)]
  74. graph_validate_len = 0
  75. for graph in graphs_validate:
  76. graph_validate_len += graph.number_of_nodes()
  77. graph_validate_len /= len(graphs_validate)
  78. print('graph_validate_len', graph_validate_len)
  79. graph_test_len = 0
  80. for graph in graphs_test:
  81. graph_test_len += graph.number_of_nodes()
  82. graph_test_len /= len(graphs_test)
  83. print('graph_test_len', graph_test_len)
  84. args.max_num_node = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
  85. max_num_edge = max([graphs[i].number_of_edges() for i in range(len(graphs))])
  86. min_num_edge = min([graphs[i].number_of_edges() for i in range(len(graphs))])
  87. # args.max_num_node = 2000
  88. # show graphs statistics
  89. print('total graph num: {}, training set: {}'.format(len(graphs), len(graphs_train)))
  90. print('max number node: {}'.format(args.max_num_node))
  91. print('max/min number edge: {}; {}'.format(max_num_edge, min_num_edge))
  92. print('max previous node: {}'.format(args.max_prev_node))
  93. # save ground truth graphs
  94. ## To get train and test set, after loading you need to manually slice
  95. save_graph_list(graphs, args.graph_save_path + args.fname_train + '0.dat')
  96. save_graph_list(graphs, args.graph_save_path + args.fname_test + '0.dat')
  97. print('train and test graphs saved at: ', args.graph_save_path + args.fname_test + '0.dat')
  98. ### comment when normal training, for graph completion only
  99. # p = 0.5
  100. # for graph in graphs_train:
  101. # for node in list(graph.nodes()):
  102. # # print('node',node)
  103. # if np.random.rand()>p:
  104. # graph.remove_node(node)
  105. # for edge in list(graph.edges()):
  106. # # print('edge',edge)
  107. # if np.random.rand()>p:
  108. # graph.remove_edge(edge[0],edge[1])
  109. ### dataset initialization
  110. if 'nobfs' in args.note:
  111. print('nobfs')
  112. dataset = Graph_sequence_sampler_pytorch_nobfs(graphs_train, max_num_node=args.max_num_node)
  113. args.max_prev_node = args.max_num_node - 1
  114. if 'barabasi_noise' in args.graph_type:
  115. print('barabasi_noise')
  116. dataset = Graph_sequence_sampler_pytorch_canonical(graphs_train, max_prev_node=args.max_prev_node)
  117. args.max_prev_node = args.max_num_node - 1
  118. else:
  119. dataset = Graph_sequence_sampler_pytorch(graphs_train, max_prev_node=args.max_prev_node,
  120. max_num_node=args.max_num_node)
  121. sample_strategy = torch.utils.data.sampler.WeightedRandomSampler([1.0 / len(dataset) for i in range(len(dataset))],
  122. num_samples=args.batch_size * args.batch_ratio,
  123. replacement=True)
  124. dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers,
  125. sampler=sample_strategy)
  126. ### model initialization
  127. ## Graph RNN VAE model
  128. # lstm = LSTM_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_lstm,
  129. # hidden_size=args.hidden_size, num_layers=args.num_layers).cuda()
  130. if 'GraphRNN_VAE_conditional' in args.note:
  131. rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
  132. hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
  133. has_output=False).cuda()
  134. output = MLP_VAE_conditional_plain(h_size=args.hidden_size_rnn, embedding_size=args.embedding_size_output,
  135. y_size=args.max_prev_node).cuda()
  136. elif 'GraphRNN_MLP' in args.note:
  137. rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
  138. hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
  139. has_output=False).cuda()
  140. output = MLP_plain(h_size=args.hidden_size_rnn, embedding_size=args.embedding_size_output,
  141. y_size=args.max_prev_node).cuda()
  142. elif 'GraphRNN_RNN' in args.note:
  143. rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
  144. hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
  145. has_output=True, output_size=args.hidden_size_rnn_output).cuda()
  146. output = GRU_plain(input_size=1, embedding_size=args.embedding_size_rnn_output,
  147. hidden_size=args.hidden_size_rnn_output, num_layers=args.num_layers, has_input=True,
  148. has_output=True, output_size=1).cuda()
  149. ### start training
  150. train(args, dataset_loader, rnn, output)
  151. ### graph completion
  152. # train_graph_completion(args,dataset_loader,rnn,output)
  153. ### nll evaluation
  154. # train_nll(args, dataset_loader, dataset_loader, rnn, output, max_iter = 200, graph_validate_len=graph_validate_len,graph_test_len=graph_test_len)