Browse Source

make DeepGMG more suitbale for our testing

master
Ali Amiri 4 years ago
parent
commit
e83a57354d
1 changed files with 154 additions and 267 deletions
  1. 154
    267
      main_DeepGMG.py

+ 154
- 267
main_DeepGMG.py View File

# an implementation for "Learning Deep Generative Models of Graphs" # an implementation for "Learning Deep Generative Models of Graphs"
import os

from main import * from main import *



class Args_DGMG(): class Args_DGMG():
def __init__(self): def __init__(self):
### CUDA ### CUDA
self.cuda = 2
self.cuda = 0


### model type ### model type
self.note = 'Baseline_DGMG' # do GCN after adding each edge
self.note = 'Baseline_DGMG' # do GCN after adding each edge
# self.note = 'Baseline_DGMG_fast' # do GCN only after adding each node # self.note = 'Baseline_DGMG_fast' # do GCN only after adding each node


### data config ### data config
self.graph_type = 'caveman_small'
# self.graph_type = 'grid_small'
# self.graph_type = 'caveman_small'
self.graph_type = 'grid_small'
# self.graph_type = 'ladder_small' # self.graph_type = 'ladder_small'
# self.graph_type = 'enzymes_small' # self.graph_type = 'enzymes_small'
# self.graph_type = 'barabasi_small' # self.graph_type = 'barabasi_small'
self.node_embedding_size = 64 self.node_embedding_size = 64
self.test_graph_num = 200 self.test_graph_num = 200



### training config ### training config
self.epochs = 2000 # now one epoch means self.batch_ratio x batch_size
self.load_epoch = 2000
self.epochs_test_start = 100
self.epochs_test = 100
self.epochs_log = 100
self.epochs_save = 100
self.epochs = 100 # now one epoch means self.batch_ratio x batch_size
self.load_epoch = 100
self.epochs_test_start = 10
self.epochs_test = 10
self.epochs_log = 10
self.epochs_save = 10
if 'fast' in self.note: if 'fast' in self.note:
self.is_fast = True self.is_fast = True
else: else:
self.figure_prediction_save_path = 'figures_prediction/' self.figure_prediction_save_path = 'figures_prediction/'
self.nll_save_path = 'nll/' self.nll_save_path = 'nll/'



self.fname = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size) self.fname = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size)
self.fname_pred = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size) + '_pred_' self.fname_pred = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size) + '_pred_'
self.fname_train = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size) + '_train_' self.fname_train = self.note + '_' + self.graph_type + '_' + str(self.node_embedding_size) + '_train_'
self.save = True self.save = True




def train_DGMG_epoch(epoch, args, model, dataset, optimizer, scheduler, is_fast = False):
def train_DGMG_epoch(epoch, args, model, dataset, optimizer, scheduler, is_fast=False):
model.train() model.train()
graph_num = len(dataset) graph_num = len(dataset)
order = list(range(graph_num)) order = list(range(graph_num))
shuffle(order) shuffle(order)



loss_addnode = 0 loss_addnode = 0
loss_addedge = 0 loss_addedge = 0
loss_node = 0 loss_node = 0
order_mapping = dict(zip(graph.nodes(), node_order)) order_mapping = dict(zip(graph.nodes(), node_order))
graph = nx.relabel_nodes(graph, order_mapping, copy=True) graph = nx.relabel_nodes(graph, order_mapping, copy=True)



# NOTE: when starting loop, we assume a node has already been generated # NOTE: when starting loop, we assume a node has already been generated
node_count = 1 node_count = 1
node_embedding = [Variable(torch.ones(1,args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden
node_embedding = [
Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden


loss = 0 loss = 0
while node_count<=graph.number_of_nodes():
node_neighbor = graph.subgraph(list(range(node_count))).adjacency_list() # list of lists (first node is zero)
node_neighbor_new = graph.subgraph(list(range(node_count+1))).adjacency_list()[-1] # list of new node's neighbors
while node_count <= graph.number_of_nodes():
node_neighbor = graph.subgraph(
list(range(node_count))).adjacency_list() # list of lists (first node is zero)
node_neighbor_new = graph.subgraph(list(range(node_count + 1))).adjacency_list()[
-1] # list of new node's neighbors


# 1 message passing # 1 message passing
# do 2 times message passing # do 2 times message passing
if is_fast: if is_fast:
node_embedding_cat = torch.cat(node_embedding, dim=0) node_embedding_cat = torch.cat(node_embedding, dim=0)
# calc loss # calc loss
loss_addnode_step = F.binary_cross_entropy(p_addnode,Variable(torch.ones((1,1))).cuda())
loss_addnode_step = F.binary_cross_entropy(p_addnode, Variable(torch.ones((1, 1))).cuda())
# loss_addnode_step.backward(retain_graph=True) # loss_addnode_step.backward(retain_graph=True)
loss += loss_addnode_step loss += loss_addnode_step
loss_addnode += loss_addnode_step.data loss_addnode += loss_addnode_step.data
loss_addnode += loss_addnode_step.data loss_addnode += loss_addnode_step.data
break break



edge_count = 0 edge_count = 0
while edge_count<=len(node_neighbor_new):
while edge_count <= len(node_neighbor_new):
if not is_fast: if not is_fast:
node_embedding = message_passing(node_neighbor, node_embedding, model) node_embedding = message_passing(node_neighbor, node_embedding, model)
node_embedding_cat = torch.cat(node_embedding, dim=0) node_embedding_cat = torch.cat(node_embedding, dim=0)


# 5 f_nodes # 5 f_nodes
# excluding the last node (which is the new node) # excluding the last node (which is the new node)
node_new_embedding_cat = node_embedding_cat[-1,:].expand(node_embedding_cat.size(0)-1,node_embedding_cat.size(1))
s_node = model.f_s(torch.cat((node_embedding_cat[0:-1,:],node_new_embedding_cat),dim=1))
p_node = F.softmax(s_node.permute(1,0))
node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1,
node_embedding_cat.size(1))
s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1))
p_node = F.softmax(s_node.permute(1, 0))
# get ground truth # get ground truth
a_node = torch.zeros((1,p_node.size(1)))
a_node = torch.zeros((1, p_node.size(1)))
# print('node_neighbor_new',node_neighbor_new, edge_count) # print('node_neighbor_new',node_neighbor_new, edge_count)
a_node[0,node_neighbor_new[edge_count]] = 1
a_node[0, node_neighbor_new[edge_count]] = 1
a_node = Variable(a_node).cuda() a_node = Variable(a_node).cuda()
# add edge # add edge
node_neighbor[-1].append(node_neighbor_new[edge_count]) node_neighbor[-1].append(node_neighbor_new[edge_count])
node_neighbor[node_neighbor_new[edge_count]].append(len(node_neighbor)-1)
node_neighbor[node_neighbor_new[edge_count]].append(len(node_neighbor) - 1)
# calc loss # calc loss
loss_node_step = F.binary_cross_entropy(p_node,a_node)
loss_node_step = F.binary_cross_entropy(p_node, a_node)
# loss_node_step.backward(retain_graph=True) # loss_node_step.backward(retain_graph=True)
loss += loss_node_step loss += loss_node_step
loss_node += loss_node_step.data loss_node += loss_node_step.data


loss_all = loss_addnode + loss_addedge + loss_node loss_all = loss_addnode + loss_addedge + loss_node


if epoch % args.epochs_log==0:
if epoch % args.epochs_log == 0:
print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, hidden: {}'.format( print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, hidden: {}'.format(
epoch, args.epochs,loss_all[0], args.graph_type, args.node_embedding_size))

epoch, args.epochs, loss_all, args.graph_type, args.node_embedding_size))


# loss_sum += loss.data[0]*x.size(0) # loss_sum += loss.data[0]*x.size(0)
# return loss_sum # return loss_sum




def test_DGMG_epoch(args, model, is_fast=False):
model.eval()
graph_num = args.test_graph_num



def train_DGMG_forward_epoch(args, model, dataset, is_fast = False):
model.train()
graph_num = len(dataset)
order = list(range(graph_num))
shuffle(order)


loss_addnode = 0
loss_addedge = 0
loss_node = 0
for i in order:
model.zero_grad()

graph = dataset[i]
# do random ordering: relabel nodes
node_order = list(range(graph.number_of_nodes()))
shuffle(node_order)
order_mapping = dict(zip(graph.nodes(), node_order))
graph = nx.relabel_nodes(graph, order_mapping, copy=True)


graphs_generated = []
for i in range(graph_num):
# NOTE: when starting loop, we assume a node has already been generated # NOTE: when starting loop, we assume a node has already been generated
node_count = 1
node_embedding = [Variable(torch.ones(1,args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden


loss = 0
while node_count<=graph.number_of_nodes():
node_neighbor = graph.subgraph(list(range(node_count))).adjacency_list() # list of lists (first node is zero)
node_neighbor_new = graph.subgraph(list(range(node_count+1))).adjacency_list()[-1] # list of new node's neighbors
node_neighbor = [[]] # list of lists (first node is zero)
node_embedding = [
Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden


node_count = 1
while node_count <= args.max_num_node:
# 1 message passing # 1 message passing
# do 2 times message passing # do 2 times message passing
node_embedding = message_passing(node_neighbor, node_embedding, model) node_embedding = message_passing(node_neighbor, node_embedding, model)


# 3 f_addnode # 3 f_addnode
p_addnode = model.f_an(graph_embedding) p_addnode = model.f_an(graph_embedding)
if node_count < graph.number_of_nodes():
a_addnode = sample_tensor(p_addnode)
# print(a_addnode.data[0][0])
if a_addnode.data[0][0] == 1:
# print('add node')
# add node # add node
node_neighbor.append([]) node_neighbor.append([])
node_embedding.append(init_embedding) node_embedding.append(init_embedding)
if is_fast: if is_fast:
node_embedding_cat = torch.cat(node_embedding, dim=0) node_embedding_cat = torch.cat(node_embedding, dim=0)
# calc loss
loss_addnode_step = F.binary_cross_entropy(p_addnode,Variable(torch.ones((1,1))).cuda())
# loss_addnode_step.backward(retain_graph=True)
loss += loss_addnode_step
loss_addnode += loss_addnode_step.data
else: else:
# calc loss
loss_addnode_step = F.binary_cross_entropy(p_addnode, Variable(torch.zeros((1, 1))).cuda())
# loss_addnode_step.backward(retain_graph=True)
loss += loss_addnode_step
loss_addnode += loss_addnode_step.data
break break



edge_count = 0 edge_count = 0
while edge_count<=len(node_neighbor_new):
while edge_count < args.max_num_node:
if not is_fast: if not is_fast:
node_embedding = message_passing(node_neighbor, node_embedding, model) node_embedding = message_passing(node_neighbor, node_embedding, model)
node_embedding_cat = torch.cat(node_embedding, dim=0) node_embedding_cat = torch.cat(node_embedding, dim=0)


# 4 f_addedge # 4 f_addedge
p_addedge = model.f_ae(graph_embedding) p_addedge = model.f_ae(graph_embedding)
a_addedge = sample_tensor(p_addedge)
# print(a_addedge.data[0][0])


if edge_count < len(node_neighbor_new):
# calc loss
loss_addedge_step = F.binary_cross_entropy(p_addedge, Variable(torch.ones((1, 1))).cuda())
# loss_addedge_step.backward(retain_graph=True)
loss += loss_addedge_step
loss_addedge += loss_addedge_step.data

if a_addedge.data[0][0] == 1:
# print('add edge')
# 5 f_nodes # 5 f_nodes
# excluding the last node (which is the new node) # excluding the last node (which is the new node)
node_new_embedding_cat = node_embedding_cat[-1,:].expand(node_embedding_cat.size(0)-1,node_embedding_cat.size(1))
s_node = model.f_s(torch.cat((node_embedding_cat[0:-1,:],node_new_embedding_cat),dim=1))
p_node = F.softmax(s_node.permute(1,0))
# get ground truth
a_node = torch.zeros((1,p_node.size(1)))
# print('node_neighbor_new',node_neighbor_new, edge_count)
a_node[0,node_neighbor_new[edge_count]] = 1
a_node = Variable(a_node).cuda()
node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1,
node_embedding_cat.size(1))
s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1))
p_node = F.softmax(s_node.permute(1, 0))
a_node = gumbel_softmax(p_node, temperature=0.01)
_, a_node_id = a_node.topk(1)
a_node_id = int(a_node_id.data[0][0])
# add edge # add edge
node_neighbor[-1].append(node_neighbor_new[edge_count])
node_neighbor[node_neighbor_new[edge_count]].append(len(node_neighbor)-1)
# calc loss
loss_node_step = F.binary_cross_entropy(p_node,a_node)
# loss_node_step.backward(retain_graph=True)
loss += loss_node_step
loss_node += loss_node_step.data*p_node.size(1)

node_neighbor[-1].append(a_node_id)
node_neighbor[a_node_id].append(len(node_neighbor) - 1)
else: else:
# calc loss
loss_addedge_step = F.binary_cross_entropy(p_addedge, Variable(torch.zeros((1, 1))).cuda())
# loss_addedge_step.backward(retain_graph=True)
loss += loss_addedge_step
loss_addedge += loss_addedge_step.data
break break


edge_count += 1 edge_count += 1
node_count += 1 node_count += 1
# save graph
node_neighbor_dict = dict(zip(list(range(len(node_neighbor))), node_neighbor))
graph = nx.from_dict_of_lists(node_neighbor_dict)
graphs_generated.append(graph)



loss_all = loss_addnode + loss_addedge + loss_node

# if epoch % args.epochs_log==0:
# print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, hidden: {}'.format(
# epoch, args.epochs,loss_all[0], args.graph_type, args.node_embedding_size))


return loss_all[0]/len(dataset)





return graphs_generated




def test_DGMG_epoch(args, model, is_fast=False):
def test_DGMG_2(args, model, test_graph, is_fast=False):
model.eval() model.eval()
graph_num = args.test_graph_num graph_num = args.test_graph_num


for i in range(graph_num): for i in range(graph_num):
# NOTE: when starting loop, we assume a node has already been generated # NOTE: when starting loop, we assume a node has already been generated
node_neighbor = [[]] # list of lists (first node is zero) node_neighbor = [[]] # list of lists (first node is zero)
node_embedding = [Variable(torch.ones(1,args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden
node_embedding = [
Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden


node_max = len(test_graph.nodes())
node_count = 1 node_count = 1
while node_count<=args.max_num_node:
while node_count <= node_max:
# 1 message passing # 1 message passing
# do 2 times message passing # do 2 times message passing
node_embedding = message_passing(node_neighbor, node_embedding, model) node_embedding = message_passing(node_neighbor, node_embedding, model)
# 3 f_addnode # 3 f_addnode
p_addnode = model.f_an(graph_embedding) p_addnode = model.f_an(graph_embedding)
a_addnode = sample_tensor(p_addnode) a_addnode = sample_tensor(p_addnode)
# print(a_addnode.data[0][0])
if a_addnode.data[0][0]==1:
# print('add node')

if a_addnode.data[0][0] == 1:
# add node # add node
node_neighbor.append([]) node_neighbor.append([])
node_embedding.append(init_embedding) node_embedding.append(init_embedding)
break break


edge_count = 0 edge_count = 0
while edge_count<args.max_num_node:
while edge_count < args.max_num_node:
if not is_fast: if not is_fast:
node_embedding = message_passing(node_neighbor, node_embedding, model) node_embedding = message_passing(node_neighbor, node_embedding, model)
node_embedding_cat = torch.cat(node_embedding, dim=0) node_embedding_cat = torch.cat(node_embedding, dim=0)
# 4 f_addedge # 4 f_addedge
p_addedge = model.f_ae(graph_embedding) p_addedge = model.f_ae(graph_embedding)
a_addedge = sample_tensor(p_addedge) a_addedge = sample_tensor(p_addedge)
# print(a_addedge.data[0][0])


if a_addedge.data[0][0]==1:
# print('add edge')
if a_addedge.data[0][0] == 1:
# 5 f_nodes # 5 f_nodes
# excluding the last node (which is the new node) # excluding the last node (which is the new node)
node_new_embedding_cat = node_embedding_cat[-1,:].expand(node_embedding_cat.size(0)-1,node_embedding_cat.size(1))
s_node = model.f_s(torch.cat((node_embedding_cat[0:-1,:],node_new_embedding_cat),dim=1))
p_node = F.softmax(s_node.permute(1,0))
node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1,
node_embedding_cat.size(1))
s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1))
p_node = F.softmax(s_node.permute(1, 0))
a_node = gumbel_softmax(p_node, temperature=0.01) a_node = gumbel_softmax(p_node, temperature=0.01)
_, a_node_id = a_node.topk(1) _, a_node_id = a_node.topk(1)
a_node_id = int(a_node_id.data[0][0]) a_node_id = int(a_node_id.data[0][0])
# add edge # add edge

node_neighbor[-1].append(a_node_id) node_neighbor[-1].append(a_node_id)
node_neighbor[a_node_id].append(len(node_neighbor)-1)
node_neighbor[a_node_id].append(len(node_neighbor) - 1)
else: else:
break break


edge_count += 1 edge_count += 1
node_count += 1 node_count += 1
# save graph
node_neighbor_dict = dict(zip(list(range(len(node_neighbor))), node_neighbor))
graph = nx.from_dict_of_lists(node_neighbor_dict)
graphs_generated.append(graph)


return graphs_generated
# clear node_neighbor and build it again
node_neighbor = []
for n in range(node_max):
temp_neighbor = [k for k in test_graph.edge[n]]
node_neighbor.append(temp_neighbor)


# now add the last node for real
# 1 message passing
# do 2 times message passing
try:
node_embedding = message_passing(node_neighbor, node_embedding, model)


# 2 graph embedding and new node embedding
node_embedding_cat = torch.cat(node_embedding, dim=0)
graph_embedding = calc_graph_embedding(node_embedding_cat, model)
init_embedding = calc_init_embedding(node_embedding_cat, model)


# 3 f_addnode
p_addnode = model.f_an(graph_embedding)
a_addnode = sample_tensor(p_addnode)


if a_addnode.data[0][0] == 1:
# add node
node_neighbor.append([])
node_embedding.append(init_embedding)
if is_fast:
node_embedding_cat = torch.cat(node_embedding, dim=0)


edge_count = 0
while edge_count < args.max_num_node:
if not is_fast:
node_embedding = message_passing(node_neighbor, node_embedding, model)
node_embedding_cat = torch.cat(node_embedding, dim=0)
graph_embedding = calc_graph_embedding(node_embedding_cat, model)


# 4 f_addedge
p_addedge = model.f_ae(graph_embedding)
a_addedge = sample_tensor(p_addedge)


if a_addedge.data[0][0] == 1:
# 5 f_nodes
# excluding the last node (which is the new node)
node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1,
node_embedding_cat.size(1))
s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1))
p_node = F.softmax(s_node.permute(1, 0))
a_node = gumbel_softmax(p_node, temperature=0.01)
_, a_node_id = a_node.topk(1)
a_node_id = int(a_node_id.data[0][0])
# add edge


node_neighbor[-1].append(a_node_id)
node_neighbor[a_node_id].append(len(node_neighbor) - 1)
else:
break


edge_count += 1
node_count += 1
except:
print('error')
# save graph
node_neighbor_dict = dict(zip(list(range(len(node_neighbor))), node_neighbor))
graph = nx.from_dict_of_lists(node_neighbor_dict)
graphs_generated.append(graph)


return graphs_generated




########### train function for LSTM + VAE ########### train function for LSTM + VAE
train_DGMG_epoch(epoch, args, model, dataset_train, optimizer, scheduler, is_fast=args.is_fast) train_DGMG_epoch(epoch, args, model, dataset_train, optimizer, scheduler, is_fast=args.is_fast)
time_end = tm.time() time_end = tm.time()
time_all[epoch - 1] = time_end - time_start time_all[epoch - 1] = time_end - time_start
# print('time used',time_all[epoch - 1])
print('time used', time_all[epoch - 1])
# test # test
if epoch % args.epochs_test == 0 and epoch >= args.epochs_test_start: if epoch % args.epochs_test == 0 and epoch >= args.epochs_test_start:
graphs = test_DGMG_epoch(args,model, is_fast=args.is_fast)
graphs = test_DGMG_epoch(args, model, is_fast=args.is_fast)
fname = args.graph_save_path + args.fname_pred + str(epoch) + '.dat' fname = args.graph_save_path + args.fname_pred + str(epoch) + '.dat'
save_graph_list(graphs, fname) save_graph_list(graphs, fname)
# print('test done, graphs saved') # print('test done, graphs saved')
np.save(args.timing_save_path + args.fname, time_all) np.save(args.timing_save_path + args.fname, time_all)









########### train function for LSTM + VAE
def train_DGMG_nll(args, dataset_train,dataset_test, model,max_iter=1000):
# check if load existing model
fname = args.model_save_path + args.fname + 'model_' + str(args.load_epoch) + '.dat'
model.load_state_dict(torch.load(fname))

fname_output = args.nll_save_path + args.note + '_' + args.graph_type + '.csv'
with open(fname_output, 'w+') as f:
f.write('train,test\n')
# start main loop
for iter in range(max_iter):
nll_train = train_DGMG_forward_epoch(args, model, dataset_train, is_fast=args.is_fast)
nll_test = train_DGMG_forward_epoch(args, model, dataset_test, is_fast=args.is_fast)
print('train', nll_train, 'test', nll_test)
f.write(str(nll_train) + ',' + str(nll_test) + '\n')





if __name__ == '__main__': if __name__ == '__main__':
args = Args_DGMG() args = Args_DGMG()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
print('CUDA', args.cuda) print('CUDA', args.cuda)
print('File name prefix',args.fname)

print('File name prefix', args.fname)


graphs = [] graphs = []
for i in range(4, 10): for i in range(4, 10):
graphs.append(nx.ladder_graph(i)) graphs.append(nx.ladder_graph(i))
model = DGM_graphs(h_size = args.node_embedding_size).cuda()
model = DGM_graphs(h_size=args.node_embedding_size).cuda()


if args.graph_type == 'ladder_small':
graphs = []
for i in range(2, 11):
graphs.append(nx.ladder_graph(i))
args.max_prev_node = 10
# if args.graph_type == 'caveman_small':
# graphs = []
# for i in range(2, 5):
# for j in range(2, 6):
# for k in range(10):
# graphs.append(nx.relaxed_caveman_graph(i, j, p=0.1))
# args.max_prev_node = 20
if args.graph_type=='caveman_small':
graphs = []
for i in range(2, 3):
for j in range(6, 11):
for k in range(20):
graphs.append(caveman_special(i, j, p_edge=0.8))
args.max_prev_node = 20
if args.graph_type == 'grid_small': if args.graph_type == 'grid_small':
graphs = [] graphs = []
for i in range(2, 5):
for j in range(2, 6):
for i in range(2, 3):
for j in range(2, 4):
graphs.append(nx.grid_2d_graph(i, j)) graphs.append(nx.grid_2d_graph(i, j))
args.max_prev_node = 15
if args.graph_type == 'barabasi_small':
graphs = []
for i in range(4, 21):
for j in range(3, 4):
for k in range(10):
graphs.append(nx.barabasi_albert_graph(i, j))
args.max_prev_node = 20

if args.graph_type == 'enzymes_small':
graphs_raw = Graph_load_batch(min_num_nodes=10, name='ENZYMES')
graphs = []
for G in graphs_raw:
if G.number_of_nodes()<=20:
graphs.append(G)
args.max_prev_node = 15

if args.graph_type == 'citeseer_small':
_, _, G = Graph_load(dataset='citeseer')
G = max(nx.connected_component_subgraphs(G), key=len)
G = nx.convert_node_labels_to_integers(G)
graphs = []
for i in range(G.number_of_nodes()):
G_ego = nx.ego_graph(G, i, radius=1)
if (G_ego.number_of_nodes() >= 4) and (G_ego.number_of_nodes() <= 20):
graphs.append(G_ego)
shuffle(graphs)
graphs = graphs[0:200]
args.max_prev_node = 15
args.max_prev_node = 6


# remove self loops # remove self loops
for graph in graphs: for graph in graphs:
# split datasets # split datasets
random.seed(123) random.seed(123)
shuffle(graphs) shuffle(graphs)
graphs_len = len(graphs)
graphs_test = graphs[int(0.8 * graphs_len):]
graphs_train = graphs[0:int(0.8 * graphs_len)]
# graphs_len = len(graphs)
# graphs_test = graphs[int(0.8 * graphs_len):]
# graphs_validate = graphs[int(0.7 * graphs_len):int(0.8 * graphs_len)]
# graphs_train = graphs[0:int(0.7 * graphs_len)]


args.max_num_node = max([graphs[i].number_of_nodes() for i in range(len(graphs))]) args.max_num_node = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
# args.max_num_node = 2000
# show graphs statistics
print('total graph num: {}, training set: {}'.format(len(graphs), len(graphs_train)))

print('max number node: {}'.format(args.max_num_node)) print('max number node: {}'.format(args.max_num_node))
print('max previous node: {}'.format(args.max_prev_node)) print('max previous node: {}'.format(args.max_prev_node))
test_graph = nx.grid_2d_graph(2, 3)
test_graph.remove_node(test_graph.nodes()[5])
train_DGMG(args, graphs, model)


# save ground truth graphs
# save_graph_list(graphs, args.graph_save_path + args.fname_train + '0.dat')
# save_graph_list(graphs, args.graph_save_path + args.fname_test + '0.dat')
# print('train and test graphs saved')

## if use pre-saved graphs
# dir_input = "graphs/"
# fname_test = args.graph_save_path + args.fname_test + '0.dat'
# graphs = load_graph_list(fname_test, is_real=True)
# graphs_test = graphs[int(0.8 * graphs_len):]
# graphs_train = graphs[0:int(0.8 * graphs_len)]
# graphs_validate = graphs[0:int(0.2 * graphs_len)]

# print('train')
# for graph in graphs_validate:
# print(graph.number_of_nodes())
# print('test')
# for graph in graphs_test:
# print(graph.number_of_nodes())



### train
train_DGMG(args,graphs,model)

### calc nll
# train_DGMG_nll(args, graphs_validate,graphs_test, model,max_iter=1000)









# for j in range(1000):
# graph = graphs[0]
# # do random ordering: relabel nodes
# node_order = list(range(graph.number_of_nodes()))
# shuffle(node_order)
# order_mapping = dict(zip(graph.nodes(), node_order))
# graph = nx.relabel_nodes(graph, order_mapping, copy=True)
# print(graph.nodes())
test_graph = nx.convert_node_labels_to_integers(test_graph)
test_DGMG_2(args, model, test_graph)

Loading…
Cancel
Save