|
|
@@ -21,9 +21,10 @@ class Args_DGMG(): |
|
|
|
### data config |
|
|
|
# self.graph_type = 'caveman_small' |
|
|
|
# self.graph_type = 'grid_small' |
|
|
|
self.graph_type = 'IMDBMULTI' |
|
|
|
# self.graph_type = 'ENZYMES' |
|
|
|
# self.graph_type = 'ladder_small' |
|
|
|
# self.graph_type = 'enzymes_small' |
|
|
|
self.graph_type = 'enzymes_small' |
|
|
|
# self.graph_type = 'protein' |
|
|
|
# self.graph_type = 'barabasi_small' |
|
|
|
# self.graph_type = 'citeseer_small' |
|
|
|
|
|
|
@@ -445,81 +446,81 @@ def train_DGMG_nll(args, dataset_train, dataset_test, model, max_iter=1000): |
|
|
|
|
|
|
|
def test_DGMG_2(args, model, test_graph, is_fast=False): |
|
|
|
model.eval() |
|
|
|
graph_num = args.test_graph_num |
|
|
|
|
|
|
|
graphs_generated = [] |
|
|
|
# graphs_generated = [] |
|
|
|
# for i in range(graph_num): |
|
|
|
# NOTE: when starting loop, we assume a node has already been generated |
|
|
|
node_neighbor = [[]] # list of lists (first node is zero) |
|
|
|
node_embedding = [ |
|
|
|
Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden |
|
|
|
try: |
|
|
|
node_embedding = [ |
|
|
|
Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden |
|
|
|
|
|
|
|
node_max = len(test_graph.nodes()) |
|
|
|
node_count = 1 |
|
|
|
while node_count <= node_max: |
|
|
|
# 1 message passing |
|
|
|
# do 2 times message passing |
|
|
|
node_embedding = message_passing(node_neighbor, node_embedding, model) |
|
|
|
node_max = len(test_graph.nodes()) |
|
|
|
node_count = 1 |
|
|
|
while node_count <= node_max: |
|
|
|
# 1 message passing |
|
|
|
# do 2 times message passing |
|
|
|
node_embedding = message_passing(node_neighbor, node_embedding, model) |
|
|
|
|
|
|
|
# 2 graph embedding and new node embedding |
|
|
|
node_embedding_cat = torch.cat(node_embedding, dim=0) |
|
|
|
graph_embedding = calc_graph_embedding(node_embedding_cat, model) |
|
|
|
init_embedding = calc_init_embedding(node_embedding_cat, model) |
|
|
|
# 2 graph embedding and new node embedding |
|
|
|
node_embedding_cat = torch.cat(node_embedding, dim=0) |
|
|
|
graph_embedding = calc_graph_embedding(node_embedding_cat, model) |
|
|
|
init_embedding = calc_init_embedding(node_embedding_cat, model) |
|
|
|
|
|
|
|
# 3 f_addnode |
|
|
|
p_addnode = model.f_an(graph_embedding) |
|
|
|
a_addnode = sample_tensor(p_addnode) |
|
|
|
# 3 f_addnode |
|
|
|
p_addnode = model.f_an(graph_embedding) |
|
|
|
a_addnode = sample_tensor(p_addnode) |
|
|
|
|
|
|
|
if a_addnode.data[0][0] == 1: |
|
|
|
# add node |
|
|
|
node_neighbor.append([]) |
|
|
|
node_embedding.append(init_embedding) |
|
|
|
if is_fast: |
|
|
|
node_embedding_cat = torch.cat(node_embedding, dim=0) |
|
|
|
else: |
|
|
|
break |
|
|
|
if a_addnode.data[0][0] == 1: |
|
|
|
# add node |
|
|
|
node_neighbor.append([]) |
|
|
|
node_embedding.append(init_embedding) |
|
|
|
if is_fast: |
|
|
|
node_embedding_cat = torch.cat(node_embedding, dim=0) |
|
|
|
else: |
|
|
|
break |
|
|
|
|
|
|
|
edge_count = 0 |
|
|
|
while edge_count < args.max_num_node: |
|
|
|
if not is_fast: |
|
|
|
node_embedding = message_passing(node_neighbor, node_embedding, model) |
|
|
|
node_embedding_cat = torch.cat(node_embedding, dim=0) |
|
|
|
graph_embedding = calc_graph_embedding(node_embedding_cat, model) |
|
|
|
edge_count = 0 |
|
|
|
while edge_count < args.max_num_node: |
|
|
|
if not is_fast: |
|
|
|
node_embedding = message_passing(node_neighbor, node_embedding, model) |
|
|
|
node_embedding_cat = torch.cat(node_embedding, dim=0) |
|
|
|
graph_embedding = calc_graph_embedding(node_embedding_cat, model) |
|
|
|
|
|
|
|
# 4 f_addedge |
|
|
|
p_addedge = model.f_ae(graph_embedding) |
|
|
|
a_addedge = sample_tensor(p_addedge) |
|
|
|
# 4 f_addedge |
|
|
|
p_addedge = model.f_ae(graph_embedding) |
|
|
|
a_addedge = sample_tensor(p_addedge) |
|
|
|
|
|
|
|
if a_addedge.data[0][0] == 1: |
|
|
|
# 5 f_nodes |
|
|
|
# excluding the last node (which is the new node) |
|
|
|
node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1, |
|
|
|
node_embedding_cat.size(1)) |
|
|
|
s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1)) |
|
|
|
p_node = F.softmax(s_node.permute(1, 0)) |
|
|
|
a_node = gumbel_softmax(p_node, temperature=0.01) |
|
|
|
_, a_node_id = a_node.topk(1) |
|
|
|
a_node_id = int(a_node_id.data[0][0]) |
|
|
|
# add edge |
|
|
|
if a_addedge.data[0][0] == 1: |
|
|
|
# 5 f_nodes |
|
|
|
# excluding the last node (which is the new node) |
|
|
|
node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1, |
|
|
|
node_embedding_cat.size(1)) |
|
|
|
s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1)) |
|
|
|
p_node = F.softmax(s_node.permute(1, 0)) |
|
|
|
a_node = gumbel_softmax(p_node, temperature=0.01) |
|
|
|
_, a_node_id = a_node.topk(1) |
|
|
|
a_node_id = int(a_node_id.data[0][0]) |
|
|
|
# add edge |
|
|
|
|
|
|
|
node_neighbor[-1].append(a_node_id) |
|
|
|
node_neighbor[a_node_id].append(len(node_neighbor) - 1) |
|
|
|
else: |
|
|
|
break |
|
|
|
node_neighbor[-1].append(a_node_id) |
|
|
|
node_neighbor[a_node_id].append(len(node_neighbor) - 1) |
|
|
|
else: |
|
|
|
break |
|
|
|
|
|
|
|
edge_count += 1 |
|
|
|
node_count += 1 |
|
|
|
edge_count += 1 |
|
|
|
node_count += 1 |
|
|
|
|
|
|
|
# clear node_neighbor and build it again |
|
|
|
node_neighbor = [] |
|
|
|
for n in range(node_max): |
|
|
|
temp_neighbor = [k for k in test_graph.edge[n]] |
|
|
|
node_neighbor.append(temp_neighbor) |
|
|
|
# clear node_neighbor and build it again |
|
|
|
node_neighbor = [] |
|
|
|
for n in range(node_max): |
|
|
|
temp_neighbor = [k for k in test_graph.edge[n]] |
|
|
|
node_neighbor.append(temp_neighbor) |
|
|
|
|
|
|
|
# now add the last node for real |
|
|
|
# 1 message passing |
|
|
|
# do 2 times message passing |
|
|
|
|
|
|
|
# now add the last node for real |
|
|
|
# 1 message passing |
|
|
|
# do 2 times message passing |
|
|
|
try: |
|
|
|
node_embedding = message_passing(node_neighbor, node_embedding, model) |
|
|
|
|
|
|
|
# 2 graph embedding and new node embedding |
|
|
@@ -703,33 +704,43 @@ if __name__ == '__main__': |
|
|
|
args.max_prev_node = 20 |
|
|
|
|
|
|
|
if args.graph_type == 'enzymes_small': |
|
|
|
graphs_raw = Graph_load_batch(min_num_nodes=10, name='ENZYMES') |
|
|
|
graphs_raw = Graph_load_batch(min_num_nodes=1, name='ENZYMES') |
|
|
|
graphs = [] |
|
|
|
for G in graphs_raw: |
|
|
|
if G.number_of_nodes() <= 20: |
|
|
|
graphs.append(G) |
|
|
|
args.max_prev_node = 15 |
|
|
|
args.max_prev_node = 10 |
|
|
|
|
|
|
|
if args.graph_type == 'citeseer_small': |
|
|
|
_, _, G = Graph_load(dataset='citeseer') |
|
|
|
G = max(nx.connected_component_subgraphs(G), key=len) |
|
|
|
G = nx.convert_node_labels_to_integers(G) |
|
|
|
graphs = [] |
|
|
|
for i in range(G.number_of_nodes()): |
|
|
|
G_ego = nx.ego_graph(G, i, radius=1) |
|
|
|
if (G_ego.number_of_nodes() >= 4) and (G_ego.number_of_nodes() <= 20): |
|
|
|
graphs.append(G_ego) |
|
|
|
shuffle(graphs) |
|
|
|
graphs = graphs[0:200] |
|
|
|
args.max_prev_node = 15 |
|
|
|
else: |
|
|
|
graphs, num_classes = load_data(args.graph_type, True) |
|
|
|
small_graphs = [] |
|
|
|
for i in range(len(graphs)): |
|
|
|
if graphs[i].number_of_nodes() < 13: |
|
|
|
small_graphs.append(graphs[i]) |
|
|
|
graphs = small_graphs |
|
|
|
args.max_prev_node = 12 |
|
|
|
if args.graph_type == 'protein': |
|
|
|
graphs_raw = Graph_load_batch(min_num_nodes=1, name='PROTEINS_full') |
|
|
|
graphs = [] |
|
|
|
for G in graphs_raw: |
|
|
|
if G.number_of_nodes() <= 15: |
|
|
|
graphs.append(G) |
|
|
|
args.max_prev_node = 10 |
|
|
|
|
|
|
|
else: |
|
|
|
if args.graph_type == 'citeseer_small': |
|
|
|
_, _, G = Graph_load(dataset='citeseer') |
|
|
|
G = max(nx.connected_component_subgraphs(G), key=len) |
|
|
|
G = nx.convert_node_labels_to_integers(G) |
|
|
|
graphs = [] |
|
|
|
for i in range(G.number_of_nodes()): |
|
|
|
G_ego = nx.ego_graph(G, i, radius=1) |
|
|
|
if (G_ego.number_of_nodes() >= 4) and (G_ego.number_of_nodes() <= 20): |
|
|
|
graphs.append(G_ego) |
|
|
|
shuffle(graphs) |
|
|
|
graphs = graphs[0:200] |
|
|
|
args.max_prev_node = 15 |
|
|
|
else: |
|
|
|
graphs, num_classes = load_data(args.graph_type, True) |
|
|
|
small_graphs = [] |
|
|
|
for i in range(len(graphs)): |
|
|
|
if graphs[i].number_of_nodes() < 16: |
|
|
|
small_graphs.append(graphs[i]) |
|
|
|
graphs = small_graphs |
|
|
|
args.max_prev_node = 21 |
|
|
|
|
|
|
|
# remove self loops |
|
|
|
for graph in graphs: |
|
|
@@ -741,6 +752,9 @@ if __name__ == '__main__': |
|
|
|
random.seed(123) |
|
|
|
shuffle(graphs) |
|
|
|
graphs_len = len(graphs) |
|
|
|
|
|
|
|
graph_statistics(graphs) |
|
|
|
|
|
|
|
graphs_test = graphs[int(0.8 * graphs_len):] |
|
|
|
graphs_train = graphs[0:int(0.8 * graphs_len)] |
|
|
|
|
|
|
@@ -751,19 +765,22 @@ if __name__ == '__main__': |
|
|
|
print('max previous node: {}'.format(args.max_prev_node)) |
|
|
|
|
|
|
|
### train |
|
|
|
# train_DGMG(args, graphs_train, model) |
|
|
|
|
|
|
|
fname = args.model_save_path + args.fname + 'model_' + str(args.load_epoch) + '.dat' |
|
|
|
model.load_state_dict(torch.load(fname)) |
|
|
|
|
|
|
|
all_tests = list() |
|
|
|
all_ret_test = list() |
|
|
|
for test_graph in graphs_test: |
|
|
|
test_graph = nx.convert_node_labels_to_integers(test_graph) |
|
|
|
original_graph = test_graph.copy() |
|
|
|
test_graph.remove_node(test_graph.nodes()[len(test_graph.nodes()) - 1]) |
|
|
|
ret_test = test_DGMG_2(args, model, test_graph) |
|
|
|
all_tests.append(original_graph) |
|
|
|
all_ret_test.append(ret_test) |
|
|
|
labels, results = calc_lable_result(original_graph, ret_test) |
|
|
|
evaluate(labels, results) |
|
|
|
train_DGMG(args, graphs_train, model) |
|
|
|
|
|
|
|
# fname = args.model_save_path + args.fname + 'model_' + str(9) + '.dat' |
|
|
|
# model.load_state_dict(torch.load(fname)) |
|
|
|
# |
|
|
|
# all_tests = list() |
|
|
|
# all_ret_test = list() |
|
|
|
# iter_count = 0 |
|
|
|
# for test_graph in graphs_test: |
|
|
|
# test_graph = nx.convert_node_labels_to_integers(test_graph) |
|
|
|
# original_graph = test_graph.copy() |
|
|
|
# test_graph.remove_node(test_graph.nodes()[len(test_graph.nodes()) - 1]) |
|
|
|
# ret_test = test_DGMG_2(args, model, test_graph) |
|
|
|
# all_tests.append(original_graph) |
|
|
|
# all_ret_test.append(ret_test) |
|
|
|
# iter_count += 1 |
|
|
|
# print(iter_count) |
|
|
|
# labels, results = calc_lable_result(all_tests, all_ret_test) |
|
|
|
# evaluate(labels, results) |