You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphvae_train.py 7.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import argparse
  2. from time import strftime, gmtime
  3. import matplotlib.pyplot as plt
  4. import networkx as nx
  5. import numpy as np
  6. import os
  7. import shutil
  8. from random import shuffle
  9. import torch
  10. import torch.nn as nn
  11. import torch.nn.init as init
  12. from tensorboard_logger import configure, log_value
  13. from torch.autograd import Variable
  14. import torch.nn.functional as F
  15. from torch import optim
  16. from torch.optim.lr_scheduler import MultiStepLR
  17. import data
  18. # from baselines.graphvae.graphvae_args import Graph_VAE_Args
  19. from baselines.graphvae.graphvae_model import GraphVAE
  20. from baselines.graphvae.graphvae_data import GraphAdjSampler
  21. from baselines.graphvae.args import GraphVAE_Args
  22. CUDA = 0
  23. vae_args = GraphVAE_Args()
  24. LR_milestones = [100, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000]
  25. def build_model(args, max_num_nodes):
  26. out_dim = max_num_nodes * (max_num_nodes + 1) // 2
  27. if args.feature_type == 'id':
  28. if vae_args.completion_mode:
  29. input_dim = max_num_nodes-vae_args.number_of_missing_nodes
  30. else:
  31. input_dim = max_num_nodes
  32. elif args.feature_type == 'deg':
  33. input_dim = 1
  34. elif args.feature_type == 'struct':
  35. input_dim = 2
  36. model = GraphVAE(input_dim, 16, 256, max_num_nodes, vae_args.number_of_missing_nodes,
  37. vae_args.completion_mode)
  38. return model
  39. def train(args, dataloader, model):
  40. epoch = 1
  41. optimizer = optim.Adam(list(model.parameters()), lr=args.lr)
  42. scheduler = MultiStepLR(optimizer, milestones=LR_milestones, gamma=0.1)
  43. model.train()
  44. # A = Graph_VAE_Args()
  45. for epoch in range(5000):
  46. for batch_idx, data in enumerate(dataloader):
  47. model.zero_grad()
  48. features = data['features'].float()
  49. # print("features")
  50. # print(features)
  51. # print("************************************************")
  52. adj_input = data['adj'].float()
  53. # print("adj_input")
  54. # print(adj_input)
  55. # print("************************************************")
  56. features = Variable(features).cuda()
  57. adj_input = Variable(adj_input).cuda()
  58. loss = model(features, adj_input)
  59. print('Epoch: ', epoch, ', Iter: ', batch_idx, ', Loss: ', loss)
  60. loss.backward()
  61. optimizer.step()
  62. scheduler.step()
  63. if epoch % 2 == 0:
  64. fname = vae_args.model_save_path + "GraphVAE" + str(epoch) + '.dat'
  65. torch.save(model.state_dict(), fname)
  66. log_value("Training Loss", loss.data, epoch * 32 + batch_idx)
  67. # break
  68. def arg_parse():
  69. parser = argparse.ArgumentParser(description='GraphVAE arguments.')
  70. io_parser = parser.add_mutually_exclusive_group(required=False)
  71. io_parser.add_argument('--dataset', dest='dataset',
  72. help='Input dataset.')
  73. parser.add_argument('--lr', dest='lr', type=float,
  74. help='Learning rate.')
  75. parser.add_argument('--batch_size', dest='batch_size', type=int,
  76. help='Batch size.')
  77. parser.add_argument('--batch_ratio', dest='batch_ratio', type=int,
  78. help='Batch ratio.')
  79. parser.add_argument('--num_workers', dest='num_workers', type=int,
  80. help='Number of workers to load data.')
  81. parser.add_argument('--max_num_nodes', dest='max_num_nodes', type=int,
  82. help='Predefined maximum number of nodes in train/test graphs. -1 if determined by \
  83. training data.')
  84. parser.add_argument('--feature', dest='feature_type',
  85. help='Feature used for encoder. Can be: id, deg')
  86. parser.set_defaults(dataset='grid',
  87. feature_type='id',
  88. lr=0.01,
  89. batch_size=6,
  90. batch_ratio=6,
  91. num_workers=4,
  92. max_num_nodes=-1)
  93. return parser.parse_args()
  94. def main():
  95. prog_args = arg_parse()
  96. os.environ['CUDA_VISIBLE_DEVICES'] = str(CUDA)
  97. print('CUDA', CUDA)
  98. ### running log
  99. if prog_args.dataset == 'enzymes':
  100. graphs= data.Graph_load_batch(min_num_nodes=10, name='ENZYMES')
  101. num_graphs_raw = len(graphs)
  102. elif prog_args.dataset == 'grid':
  103. graphs = []
  104. # for i in range(2,6):
  105. # for j in range(2,6):
  106. # graphs.append(nx.grid_2d_graph(i,j))
  107. # *********************************
  108. # graphs.append(nx.grid_2d_graph(2, 2))
  109. # graphs.append(nx.grid_2d_graph(2, 3))
  110. graphs.append(nx.grid_2d_graph(1, 12))
  111. graphs.append(nx.grid_2d_graph(2, 6))
  112. graphs.append(nx.grid_2d_graph(3, 4))
  113. graphs.append(nx.grid_2d_graph(4, 3))
  114. graphs.append(nx.grid_2d_graph(6, 2))
  115. graphs.append(nx.grid_2d_graph(12, 1))
  116. # *********************************
  117. # graphs.append(nx.grid_2d_graph(1, 24))
  118. # graphs.append(nx.grid_2d_graph(2, 12))
  119. # graphs.append(nx.grid_2d_graph(3, 8))
  120. # graphs.append(nx.grid_2d_graph(4, 6))
  121. # graphs.append(nx.grid_2d_graph(6, 4))
  122. # graphs.append(nx.grid_2d_graph(8, 3))
  123. # graphs.append(nx.grid_2d_graph(12, 2))
  124. # graphs.append(nx.grid_2d_graph(24, 1))
  125. num_graphs_raw = len(graphs)
  126. if prog_args.max_num_nodes == -1:
  127. max_num_nodes = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
  128. else:
  129. max_num_nodes = prog_args.max_num_nodes
  130. # remove graphs with number of nodes greater than max_num_nodes
  131. graphs = [g for g in graphs if g.number_of_nodes() <= max_num_nodes]
  132. graphs_len = len(graphs)
  133. print('Number of graphs removed due to upper-limit of number of nodes: ',
  134. num_graphs_raw - graphs_len)
  135. graphs_test = graphs[int(0.8 * graphs_len):]
  136. #graphs_train = graphs[0:int(0.8*graphs_len)]
  137. graphs_train = graphs
  138. print('total graph num: {}, training set: {}'.format(len(graphs),len(graphs_train)))
  139. print('max number node: {}'.format(max_num_nodes))
  140. dataset = GraphAdjSampler(graphs_train, max_num_nodes,vae_args.permutation_mode, vae_args.bfs_mode,
  141. features=prog_args.feature_type)
  142. #sample_strategy = torch.utils.data.sampler.WeightedRandomSampler(
  143. # [1.0 / len(dataset) for i in range(len(dataset))],
  144. # num_samples=prog_args.batch_size,
  145. # replacement=False)
  146. sample_strategy = torch.utils.data.sampler.WeightedRandomSampler([1.0 / len(dataset) for i in range(len(dataset))],
  147. num_samples=prog_args.batch_size*prog_args.batch_ratio,
  148. replacement=True)
  149. dataset_loader = torch.utils.data.DataLoader(
  150. dataset,
  151. batch_size=prog_args.batch_size,
  152. num_workers=prog_args.num_workers,
  153. sampler=sample_strategy)
  154. model = build_model(prog_args, max_num_nodes).cuda()
  155. train(prog_args, dataset_loader, model)
  156. if __name__ == '__main__':
  157. if not os.path.isdir(vae_args.model_save_path):
  158. os.makedirs(vae_args.model_save_path)
  159. # configure(vae_args.tensorboard_path, flush_secs=5)
  160. time = strftime("%Y-%m-%d %H:%M:%S", gmtime())
  161. if vae_args.clean_tensorboard:
  162. if os.path.isdir("tensorboard"):
  163. shutil.rmtree("tensorboard")
  164. configure("tensorboard/run"+time, flush_secs=5)
  165. main()