You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphvae_train.py 6.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import argparse
  2. import matplotlib.pyplot as plt
  3. import networkx as nx
  4. import numpy as np
  5. import os
  6. from random import shuffle
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.init as init
  10. from torch.autograd import Variable
  11. import torch.nn.functional as F
  12. from torch import optim
  13. from torch.optim.lr_scheduler import MultiStepLR
  14. import data
  15. from baselines.graphvae.util import load_data
  16. from main_baselines.graphvae.graphvae_model import GraphVAE
  17. from main_baselines.graphvae.graphvae_data import GraphAdjSampler
  18. from baselines.graphvae.graphvae_train import graph_statistics
  19. CUDA = 0
  20. LR_milestones = [500, 1000]
  21. def build_model(args, max_num_nodes):
  22. out_dim = max_num_nodes * (max_num_nodes + 1) // 2
  23. if args.feature_type == 'id':
  24. input_dim = max_num_nodes
  25. elif args.feature_type == 'deg':
  26. input_dim = 1
  27. elif args.feature_type == 'struct':
  28. input_dim = 2
  29. model = GraphVAE(input_dim, 64, 256, max_num_nodes)
  30. return model
  31. def train(args, dataloader, model):
  32. epoch = 1
  33. optimizer = optim.Adam(list(model.parameters()), lr=args.lr)
  34. scheduler = MultiStepLR(optimizer, milestones=LR_milestones, gamma=args.lr)
  35. model.train()
  36. for epoch in range(5000):
  37. for batch_idx, data in enumerate(dataloader):
  38. model.zero_grad()
  39. features = data['features'].float()
  40. adj_input = data['adj'].float()
  41. features = Variable(features).cuda()
  42. adj_input = Variable(adj_input).cuda()
  43. loss = model(features, adj_input)
  44. print('Epoch: ', epoch, ', Iter: ', batch_idx, ', Loss: ', loss)
  45. loss.backward()
  46. optimizer.step()
  47. scheduler.step()
  48. # break
  49. if epoch % 2 == 0:
  50. fname = model_save_path + "GraphVAE" + str(epoch) + '.dat'
  51. torch.save(model.state_dict(), fname)
  52. def arg_parse():
  53. parser = argparse.ArgumentParser(description='GraphVAE arguments.')
  54. io_parser = parser.add_mutually_exclusive_group(required=False)
  55. io_parser.add_argument('--dataset', dest='dataset',
  56. help='Input dataset.')
  57. parser.add_argument('--lr', dest='lr', type=float,
  58. help='Learning rate.')
  59. parser.add_argument('--batch_size', dest='batch_size', type=int,
  60. help='Batch size.')
  61. parser.add_argument('--num_workers', dest='num_workers', type=int,
  62. help='Number of workers to load data.')
  63. parser.add_argument('--max_num_nodes', dest='max_num_nodes', type=int,
  64. help='Predefined maximum number of nodes in train/test graphs. -1 if determined by \
  65. training data.')
  66. parser.add_argument('--feature', dest='feature_type',
  67. help='Feature used for encoder. Can be: id, deg')
  68. parser.set_defaults(dataset='IMDBBINARY',
  69. feature_type='id',
  70. lr=0.001,
  71. batch_size=1,
  72. num_workers=1,
  73. max_num_nodes=-1)
  74. return parser.parse_args()
  75. def main():
  76. prog_args = arg_parse()
  77. os.environ['CUDA_VISIBLE_DEVICES'] = str(CUDA)
  78. print('CUDA', CUDA)
  79. ### running log
  80. if prog_args.dataset == 'enzymes':
  81. graphs = data.Graph_load_batch(min_num_nodes=10, name='ENZYMES')
  82. num_graphs_raw = len(graphs)
  83. elif prog_args.dataset == 'grid':
  84. graphs = []
  85. for z in range(1):
  86. for i in range(2, 4):
  87. for j in range(2, 4):
  88. graphs.append(nx.grid_2d_graph(i, j))
  89. num_graphs_raw = len(graphs)
  90. else:
  91. graphs, num_classes = load_data(prog_args.dataset, True)
  92. num_graphs_raw = len(graphs)
  93. if prog_args.max_num_nodes == -1:
  94. max_num_nodes = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
  95. small_graphs_size = 0
  96. if prog_args.dataset != 'grid_small' and prog_args.dataset != 'grid':
  97. small_graphs = []
  98. for i in range(len(graphs)):
  99. if graphs[i].number_of_nodes() < 13:
  100. # if graphs[i].number_of_nodes() == 8 or graphs[i].number_of_nodes() == 16 or graphs[
  101. # i].number_of_nodes() == 32 or \
  102. # graphs[i].number_of_nodes() == 64 or graphs[i].number_of_nodes() == 128 or graphs[
  103. # i].number_of_nodes() == 256:
  104. small_graphs_size += 1
  105. small_graphs.append(graphs[i])
  106. graphs = small_graphs
  107. print(len(graphs))
  108. max_num_nodes = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
  109. graph_statistics(graphs)
  110. else:
  111. max_num_nodes = prog_args.max_num_nodes
  112. # remove graphs with number of nodes greater than max_num_nodes
  113. graphs = [g for g in graphs if g.number_of_nodes() <= max_num_nodes]
  114. graphs_len = len(graphs)
  115. print('Number of graphs removed due to upper-limit of number of nodes: ',
  116. num_graphs_raw - graphs_len)
  117. graphs_test = graphs[int(0.8 * graphs_len):]
  118. # graphs_train = graphs[0:int(0.8*graphs_len)]
  119. # graphs_train = graphs
  120. graphs_train = graphs[0:int(0.8 * graphs_len)]
  121. print('total graph num: {}, training set: {}'.format(len(graphs), len(graphs_train)))
  122. print('max number node: {}'.format(max_num_nodes))
  123. dataset = GraphAdjSampler(graphs_train, max_num_nodes, features=prog_args.feature_type)
  124. # sample_strategy = torch.utils.data.sampler.WeightedRandomSampler(
  125. # [1.0 / len(dataset) for i in range(len(dataset))],
  126. # num_samples=prog_args.batch_size,
  127. # replacement=False)
  128. dataset_loader = torch.utils.data.DataLoader(
  129. dataset,
  130. batch_size=prog_args.batch_size,
  131. num_workers=prog_args.num_workers)
  132. model = build_model(prog_args, max_num_nodes).cuda()
  133. train(prog_args, dataset_loader, model)
  134. if __name__ == '__main__':
  135. model_save_path = './model_save/'
  136. if not os.path.isdir(model_save_path):
  137. os.makedirs(model_save_path)
  138. main()