You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 4.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import argparse
  2. import matplotlib.pyplot as plt
  3. import networkx as nx
  4. import numpy as np
  5. import os
  6. from random import shuffle
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.init as init
  10. from torch.autograd import Variable
  11. import torch.nn.functional as F
  12. from torch import optim
  13. from torch.optim.lr_scheduler import MultiStepLR
  14. import data
  15. from baselines.graphvae.model import GraphVAE
  16. from baselines.graphvae.data import GraphAdjSampler
  17. CUDA = 2
  18. LR_milestones = [500, 1000]
  19. def build_model(args, max_num_nodes):
  20. out_dim = max_num_nodes * (max_num_nodes + 1) // 2
  21. if args.feature_type == 'id':
  22. input_dim = max_num_nodes
  23. elif args.feature_type == 'deg':
  24. input_dim = 1
  25. elif args.feature_type == 'struct':
  26. input_dim = 2
  27. model = GraphVAE(input_dim, 64, 256, max_num_nodes)
  28. return model
  29. def train(args, dataloader, model):
  30. epoch = 1
  31. optimizer = optim.Adam(list(model.parameters()), lr=args.lr)
  32. scheduler = MultiStepLR(optimizer, milestones=LR_milestones, gamma=args.lr)
  33. model.train()
  34. for epoch in range(5000):
  35. for batch_idx, data in enumerate(dataloader):
  36. model.zero_grad()
  37. features = data['features'].float()
  38. adj_input = data['adj'].float()
  39. features = Variable(features).cuda()
  40. adj_input = Variable(adj_input).cuda()
  41. loss = model(features, adj_input)
  42. print('Epoch: ', epoch, ', Iter: ', batch_idx, ', Loss: ', loss)
  43. loss.backward()
  44. optimizer.step()
  45. scheduler.step()
  46. break
  47. def arg_parse():
  48. parser = argparse.ArgumentParser(description='GraphVAE arguments.')
  49. io_parser = parser.add_mutually_exclusive_group(required=False)
  50. io_parser.add_argument('--dataset', dest='dataset',
  51. help='Input dataset.')
  52. parser.add_argument('--lr', dest='lr', type=float,
  53. help='Learning rate.')
  54. parser.add_argument('--batch_size', dest='batch_size', type=int,
  55. help='Batch size.')
  56. parser.add_argument('--num_workers', dest='num_workers', type=int,
  57. help='Number of workers to load data.')
  58. parser.add_argument('--max_num_nodes', dest='max_num_nodes', type=int,
  59. help='Predefined maximum number of nodes in train/test graphs. -1 if determined by \
  60. training data.')
  61. parser.add_argument('--feature', dest='feature_type',
  62. help='Feature used for encoder. Can be: id, deg')
  63. parser.set_defaults(dataset='grid',
  64. feature_type='id',
  65. lr=0.001,
  66. batch_size=1,
  67. num_workers=1,
  68. max_num_nodes=-1)
  69. return parser.parse_args()
  70. def main():
  71. prog_args = arg_parse()
  72. os.environ['CUDA_VISIBLE_DEVICES'] = str(CUDA)
  73. print('CUDA', CUDA)
  74. ### running log
  75. if prog_args.dataset == 'enzymes':
  76. graphs= data.Graph_load_batch(min_num_nodes=10, name='ENZYMES')
  77. num_graphs_raw = len(graphs)
  78. elif prog_args.dataset == 'grid':
  79. graphs = []
  80. for i in range(2,3):
  81. for j in range(2,3):
  82. graphs.append(nx.grid_2d_graph(i,j))
  83. num_graphs_raw = len(graphs)
  84. if prog_args.max_num_nodes == -1:
  85. max_num_nodes = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
  86. else:
  87. max_num_nodes = prog_args.max_num_nodes
  88. # remove graphs with number of nodes greater than max_num_nodes
  89. graphs = [g for g in graphs if g.number_of_nodes() <= max_num_nodes]
  90. graphs_len = len(graphs)
  91. print('Number of graphs removed due to upper-limit of number of nodes: ',
  92. num_graphs_raw - graphs_len)
  93. graphs_test = graphs[int(0.8 * graphs_len):]
  94. #graphs_train = graphs[0:int(0.8*graphs_len)]
  95. graphs_train = graphs
  96. print('total graph num: {}, training set: {}'.format(len(graphs),len(graphs_train)))
  97. print('max number node: {}'.format(max_num_nodes))
  98. dataset = GraphAdjSampler(graphs_train, max_num_nodes, features=prog_args.feature_type)
  99. #sample_strategy = torch.utils.data.sampler.WeightedRandomSampler(
  100. # [1.0 / len(dataset) for i in range(len(dataset))],
  101. # num_samples=prog_args.batch_size,
  102. # replacement=False)
  103. dataset_loader = torch.utils.data.DataLoader(
  104. dataset,
  105. batch_size=prog_args.batch_size,
  106. num_workers=prog_args.num_workers)
  107. model = build_model(prog_args, max_num_nodes).cuda()
  108. train(prog_args, dataset_loader, model)
  109. if __name__ == '__main__':
  110. main()