|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189 |
-
- import argparse
- from time import strftime, gmtime
-
- import matplotlib.pyplot as plt
- import networkx as nx
- import numpy as np
- import os
- import shutil
- from random import shuffle
- import torch
- import torch.nn as nn
- import torch.nn.init as init
- from tensorboard_logger import configure, log_value
- from torch.autograd import Variable
- import torch.nn.functional as F
- from torch import optim
- from torch.optim.lr_scheduler import MultiStepLR
-
- import data
- # from baselines.graphvae.graphvae_args import Graph_VAE_Args
- from baselines.graphvae.graphvae_model import GraphVAE
- from baselines.graphvae.graphvae_data import GraphAdjSampler
- from baselines.graphvae.args import GraphVAE_Args
-
- CUDA = 0
- vae_args = GraphVAE_Args()
- LR_milestones = [100, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000]
-
- def build_model(args, max_num_nodes):
- out_dim = max_num_nodes * (max_num_nodes + 1) // 2
- if args.feature_type == 'id':
- if vae_args.completion_mode:
- input_dim = max_num_nodes-vae_args.number_of_missing_nodes
- else:
- input_dim = max_num_nodes
- elif args.feature_type == 'deg':
- input_dim = 1
- elif args.feature_type == 'struct':
- input_dim = 2
- model = GraphVAE(input_dim, 16, 256, max_num_nodes, vae_args.number_of_missing_nodes,
- vae_args.completion_mode)
- return model
-
- def train(args, dataloader, model):
- epoch = 1
- optimizer = optim.Adam(list(model.parameters()), lr=args.lr)
- scheduler = MultiStepLR(optimizer, milestones=LR_milestones, gamma=0.1)
-
- model.train()
- # A = Graph_VAE_Args()
- for epoch in range(5000):
- for batch_idx, data in enumerate(dataloader):
- model.zero_grad()
- features = data['features'].float()
- # print("features")
- # print(features)
- # print("************************************************")
- adj_input = data['adj'].float()
- # print("adj_input")
- # print(adj_input)
- # print("************************************************")
-
- features = Variable(features).cuda()
- adj_input = Variable(adj_input).cuda()
-
- loss = model(features, adj_input)
- print('Epoch: ', epoch, ', Iter: ', batch_idx, ', Loss: ', loss)
- loss.backward()
-
- optimizer.step()
- scheduler.step()
- if epoch % 2 == 0:
- fname = vae_args.model_save_path + "GraphVAE" + str(epoch) + '.dat'
- torch.save(model.state_dict(), fname)
- log_value("Training Loss", loss.data, epoch * 32 + batch_idx)
- # break
-
- def arg_parse():
- parser = argparse.ArgumentParser(description='GraphVAE arguments.')
- io_parser = parser.add_mutually_exclusive_group(required=False)
- io_parser.add_argument('--dataset', dest='dataset',
- help='Input dataset.')
-
- parser.add_argument('--lr', dest='lr', type=float,
- help='Learning rate.')
- parser.add_argument('--batch_size', dest='batch_size', type=int,
- help='Batch size.')
- parser.add_argument('--batch_ratio', dest='batch_ratio', type=int,
- help='Batch ratio.')
- parser.add_argument('--num_workers', dest='num_workers', type=int,
- help='Number of workers to load data.')
- parser.add_argument('--max_num_nodes', dest='max_num_nodes', type=int,
- help='Predefined maximum number of nodes in train/test graphs. -1 if determined by \
- training data.')
- parser.add_argument('--feature', dest='feature_type',
- help='Feature used for encoder. Can be: id, deg')
-
- parser.set_defaults(dataset='grid',
- feature_type='id',
- lr=0.01,
- batch_size=6,
- batch_ratio=6,
- num_workers=4,
- max_num_nodes=-1)
- return parser.parse_args()
-
- def main():
- prog_args = arg_parse()
-
- os.environ['CUDA_VISIBLE_DEVICES'] = str(CUDA)
- print('CUDA', CUDA)
- ### running log
-
- if prog_args.dataset == 'enzymes':
- graphs= data.Graph_load_batch(min_num_nodes=10, name='ENZYMES')
- num_graphs_raw = len(graphs)
- elif prog_args.dataset == 'grid':
- graphs = []
- # for i in range(2,6):
- # for j in range(2,6):
- # graphs.append(nx.grid_2d_graph(i,j))
- # *********************************
- # graphs.append(nx.grid_2d_graph(2, 2))
- # graphs.append(nx.grid_2d_graph(2, 3))
- graphs.append(nx.grid_2d_graph(1, 12))
- graphs.append(nx.grid_2d_graph(2, 6))
- graphs.append(nx.grid_2d_graph(3, 4))
- graphs.append(nx.grid_2d_graph(4, 3))
- graphs.append(nx.grid_2d_graph(6, 2))
- graphs.append(nx.grid_2d_graph(12, 1))
- # *********************************
- # graphs.append(nx.grid_2d_graph(1, 24))
- # graphs.append(nx.grid_2d_graph(2, 12))
- # graphs.append(nx.grid_2d_graph(3, 8))
- # graphs.append(nx.grid_2d_graph(4, 6))
- # graphs.append(nx.grid_2d_graph(6, 4))
- # graphs.append(nx.grid_2d_graph(8, 3))
- # graphs.append(nx.grid_2d_graph(12, 2))
- # graphs.append(nx.grid_2d_graph(24, 1))
-
- num_graphs_raw = len(graphs)
-
- if prog_args.max_num_nodes == -1:
- max_num_nodes = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
- else:
- max_num_nodes = prog_args.max_num_nodes
- # remove graphs with number of nodes greater than max_num_nodes
- graphs = [g for g in graphs if g.number_of_nodes() <= max_num_nodes]
-
- graphs_len = len(graphs)
- print('Number of graphs removed due to upper-limit of number of nodes: ',
- num_graphs_raw - graphs_len)
- graphs_test = graphs[int(0.8 * graphs_len):]
- #graphs_train = graphs[0:int(0.8*graphs_len)]
- graphs_train = graphs
-
- print('total graph num: {}, training set: {}'.format(len(graphs),len(graphs_train)))
- print('max number node: {}'.format(max_num_nodes))
-
- dataset = GraphAdjSampler(graphs_train, max_num_nodes,vae_args.permutation_mode, vae_args.bfs_mode,
- features=prog_args.feature_type)
- #sample_strategy = torch.utils.data.sampler.WeightedRandomSampler(
- # [1.0 / len(dataset) for i in range(len(dataset))],
- # num_samples=prog_args.batch_size,
- # replacement=False)
- sample_strategy = torch.utils.data.sampler.WeightedRandomSampler([1.0 / len(dataset) for i in range(len(dataset))],
- num_samples=prog_args.batch_size*prog_args.batch_ratio,
- replacement=True)
-
- dataset_loader = torch.utils.data.DataLoader(
- dataset,
- batch_size=prog_args.batch_size,
- num_workers=prog_args.num_workers,
- sampler=sample_strategy)
- model = build_model(prog_args, max_num_nodes).cuda()
- train(prog_args, dataset_loader, model)
-
-
- if __name__ == '__main__':
- if not os.path.isdir(vae_args.model_save_path):
- os.makedirs(vae_args.model_save_path)
- # configure(vae_args.tensorboard_path, flush_secs=5)
- time = strftime("%Y-%m-%d %H:%M:%S", gmtime())
- if vae_args.clean_tensorboard:
- if os.path.isdir("tensorboard"):
- shutil.rmtree("tensorboard")
- configure("tensorboard/run"+time, flush_secs=5)
- main()
|