Browse Source

fix bug in evaluation

master
Ali Amiri 2 years ago
parent
commit
a14ad2caa8
1 changed files with 116 additions and 99 deletions
  1. 116
    99
      main_DeepGMG.py

+ 116
- 99
main_DeepGMG.py View File

@@ -21,9 +21,10 @@ class Args_DGMG():
### data config
# self.graph_type = 'caveman_small'
# self.graph_type = 'grid_small'
self.graph_type = 'IMDBMULTI'
# self.graph_type = 'ENZYMES'
# self.graph_type = 'ladder_small'
# self.graph_type = 'enzymes_small'
self.graph_type = 'enzymes_small'
# self.graph_type = 'protein'
# self.graph_type = 'barabasi_small'
# self.graph_type = 'citeseer_small'

@@ -445,81 +446,81 @@ def train_DGMG_nll(args, dataset_train, dataset_test, model, max_iter=1000):

def test_DGMG_2(args, model, test_graph, is_fast=False):
model.eval()
graph_num = args.test_graph_num

graphs_generated = []
# graphs_generated = []
# for i in range(graph_num):
# NOTE: when starting loop, we assume a node has already been generated
node_neighbor = [[]] # list of lists (first node is zero)
node_embedding = [
Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden
try:
node_embedding = [
Variable(torch.ones(1, args.node_embedding_size)).cuda()] # list of torch tensors, each size: 1*hidden

node_max = len(test_graph.nodes())
node_count = 1
while node_count <= node_max:
# 1 message passing
# do 2 times message passing
node_embedding = message_passing(node_neighbor, node_embedding, model)
node_max = len(test_graph.nodes())
node_count = 1
while node_count <= node_max:
# 1 message passing
# do 2 times message passing
node_embedding = message_passing(node_neighbor, node_embedding, model)

# 2 graph embedding and new node embedding
node_embedding_cat = torch.cat(node_embedding, dim=0)
graph_embedding = calc_graph_embedding(node_embedding_cat, model)
init_embedding = calc_init_embedding(node_embedding_cat, model)
# 2 graph embedding and new node embedding
node_embedding_cat = torch.cat(node_embedding, dim=0)
graph_embedding = calc_graph_embedding(node_embedding_cat, model)
init_embedding = calc_init_embedding(node_embedding_cat, model)

# 3 f_addnode
p_addnode = model.f_an(graph_embedding)
a_addnode = sample_tensor(p_addnode)
# 3 f_addnode
p_addnode = model.f_an(graph_embedding)
a_addnode = sample_tensor(p_addnode)

if a_addnode.data[0][0] == 1:
# add node
node_neighbor.append([])
node_embedding.append(init_embedding)
if is_fast:
node_embedding_cat = torch.cat(node_embedding, dim=0)
else:
break
if a_addnode.data[0][0] == 1:
# add node
node_neighbor.append([])
node_embedding.append(init_embedding)
if is_fast:
node_embedding_cat = torch.cat(node_embedding, dim=0)
else:
break

edge_count = 0
while edge_count < args.max_num_node:
if not is_fast:
node_embedding = message_passing(node_neighbor, node_embedding, model)
node_embedding_cat = torch.cat(node_embedding, dim=0)
graph_embedding = calc_graph_embedding(node_embedding_cat, model)
edge_count = 0
while edge_count < args.max_num_node:
if not is_fast:
node_embedding = message_passing(node_neighbor, node_embedding, model)
node_embedding_cat = torch.cat(node_embedding, dim=0)
graph_embedding = calc_graph_embedding(node_embedding_cat, model)

# 4 f_addedge
p_addedge = model.f_ae(graph_embedding)
a_addedge = sample_tensor(p_addedge)
# 4 f_addedge
p_addedge = model.f_ae(graph_embedding)
a_addedge = sample_tensor(p_addedge)

if a_addedge.data[0][0] == 1:
# 5 f_nodes
# excluding the last node (which is the new node)
node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1,
node_embedding_cat.size(1))
s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1))
p_node = F.softmax(s_node.permute(1, 0))
a_node = gumbel_softmax(p_node, temperature=0.01)
_, a_node_id = a_node.topk(1)
a_node_id = int(a_node_id.data[0][0])
# add edge
if a_addedge.data[0][0] == 1:
# 5 f_nodes
# excluding the last node (which is the new node)
node_new_embedding_cat = node_embedding_cat[-1, :].expand(node_embedding_cat.size(0) - 1,
node_embedding_cat.size(1))
s_node = model.f_s(torch.cat((node_embedding_cat[0:-1, :], node_new_embedding_cat), dim=1))
p_node = F.softmax(s_node.permute(1, 0))
a_node = gumbel_softmax(p_node, temperature=0.01)
_, a_node_id = a_node.topk(1)
a_node_id = int(a_node_id.data[0][0])
# add edge

node_neighbor[-1].append(a_node_id)
node_neighbor[a_node_id].append(len(node_neighbor) - 1)
else:
break
node_neighbor[-1].append(a_node_id)
node_neighbor[a_node_id].append(len(node_neighbor) - 1)
else:
break

edge_count += 1
node_count += 1
edge_count += 1
node_count += 1

# clear node_neighbor and build it again
node_neighbor = []
for n in range(node_max):
temp_neighbor = [k for k in test_graph.edge[n]]
node_neighbor.append(temp_neighbor)
# clear node_neighbor and build it again
node_neighbor = []
for n in range(node_max):
temp_neighbor = [k for k in test_graph.edge[n]]
node_neighbor.append(temp_neighbor)

# now add the last node for real
# 1 message passing
# do 2 times message passing

# now add the last node for real
# 1 message passing
# do 2 times message passing
try:
node_embedding = message_passing(node_neighbor, node_embedding, model)

# 2 graph embedding and new node embedding
@@ -703,33 +704,43 @@ if __name__ == '__main__':
args.max_prev_node = 20

if args.graph_type == 'enzymes_small':
graphs_raw = Graph_load_batch(min_num_nodes=10, name='ENZYMES')
graphs_raw = Graph_load_batch(min_num_nodes=1, name='ENZYMES')
graphs = []
for G in graphs_raw:
if G.number_of_nodes() <= 20:
graphs.append(G)
args.max_prev_node = 15
args.max_prev_node = 10

if args.graph_type == 'citeseer_small':
_, _, G = Graph_load(dataset='citeseer')
G = max(nx.connected_component_subgraphs(G), key=len)
G = nx.convert_node_labels_to_integers(G)
graphs = []
for i in range(G.number_of_nodes()):
G_ego = nx.ego_graph(G, i, radius=1)
if (G_ego.number_of_nodes() >= 4) and (G_ego.number_of_nodes() <= 20):
graphs.append(G_ego)
shuffle(graphs)
graphs = graphs[0:200]
args.max_prev_node = 15
else:
graphs, num_classes = load_data(args.graph_type, True)
small_graphs = []
for i in range(len(graphs)):
if graphs[i].number_of_nodes() < 13:
small_graphs.append(graphs[i])
graphs = small_graphs
args.max_prev_node = 12
if args.graph_type == 'protein':
graphs_raw = Graph_load_batch(min_num_nodes=1, name='PROTEINS_full')
graphs = []
for G in graphs_raw:
if G.number_of_nodes() <= 15:
graphs.append(G)
args.max_prev_node = 10

else:
if args.graph_type == 'citeseer_small':
_, _, G = Graph_load(dataset='citeseer')
G = max(nx.connected_component_subgraphs(G), key=len)
G = nx.convert_node_labels_to_integers(G)
graphs = []
for i in range(G.number_of_nodes()):
G_ego = nx.ego_graph(G, i, radius=1)
if (G_ego.number_of_nodes() >= 4) and (G_ego.number_of_nodes() <= 20):
graphs.append(G_ego)
shuffle(graphs)
graphs = graphs[0:200]
args.max_prev_node = 15
else:
graphs, num_classes = load_data(args.graph_type, True)
small_graphs = []
for i in range(len(graphs)):
if graphs[i].number_of_nodes() < 16:
small_graphs.append(graphs[i])
graphs = small_graphs
args.max_prev_node = 21

# remove self loops
for graph in graphs:
@@ -741,6 +752,9 @@ if __name__ == '__main__':
random.seed(123)
shuffle(graphs)
graphs_len = len(graphs)

graph_statistics(graphs)

graphs_test = graphs[int(0.8 * graphs_len):]
graphs_train = graphs[0:int(0.8 * graphs_len)]

@@ -751,19 +765,22 @@ if __name__ == '__main__':
print('max previous node: {}'.format(args.max_prev_node))

### train
# train_DGMG(args, graphs_train, model)

fname = args.model_save_path + args.fname + 'model_' + str(args.load_epoch) + '.dat'
model.load_state_dict(torch.load(fname))

all_tests = list()
all_ret_test = list()
for test_graph in graphs_test:
test_graph = nx.convert_node_labels_to_integers(test_graph)
original_graph = test_graph.copy()
test_graph.remove_node(test_graph.nodes()[len(test_graph.nodes()) - 1])
ret_test = test_DGMG_2(args, model, test_graph)
all_tests.append(original_graph)
all_ret_test.append(ret_test)
labels, results = calc_lable_result(original_graph, ret_test)
evaluate(labels, results)
train_DGMG(args, graphs_train, model)

# fname = args.model_save_path + args.fname + 'model_' + str(9) + '.dat'
# model.load_state_dict(torch.load(fname))
#
# all_tests = list()
# all_ret_test = list()
# iter_count = 0
# for test_graph in graphs_test:
# test_graph = nx.convert_node_labels_to_integers(test_graph)
# original_graph = test_graph.copy()
# test_graph.remove_node(test_graph.nodes()[len(test_graph.nodes()) - 1])
# ret_test = test_DGMG_2(args, model, test_graph)
# all_tests.append(original_graph)
# all_ret_test.append(ret_test)
# iter_count += 1
# print(iter_count)
# labels, results = calc_lable_result(all_tests, all_ret_test)
# evaluate(labels, results)

Loading…
Cancel
Save