|
|
@@ -101,18 +101,18 @@ def prepare_graph_data(graph, max_size=40, min_size=10): |
|
|
|
gets a graph as an input |
|
|
|
returns a graph with a randomly removed node adj matrix [0], its feature matrix [1], the removed node true links [2] |
|
|
|
''' |
|
|
|
if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size: |
|
|
|
return None |
|
|
|
relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph) |
|
|
|
choice = np.random.choice(list(relabeled_graph.nodes())) |
|
|
|
remaining_graph = relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes()))) |
|
|
|
remaining_graph_adj = nx.to_numpy_matrix(remaining_graph) |
|
|
|
graph_length = len(remaining_graph) |
|
|
|
remaining_graph_adj = np.pad(remaining_graph_adj, [(0, max_size - graph_length), (0, max_size - graph_length)]) |
|
|
|
removed_node = nx.to_numpy_matrix(relabeled_graph)[choice] |
|
|
|
removed_node_row = np.asarray(removed_node)[0] |
|
|
|
removed_node_row = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))]) |
|
|
|
return remaining_graph_adj, feature_matrix(remaining_graph), removed_node_row |
|
|
|
if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size: |
|
|
|
return None |
|
|
|
relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph) |
|
|
|
choice = np.random.choice(list(relabeled_graph.nodes())) |
|
|
|
remaining_graph = relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes()))) |
|
|
|
remaining_graph_adj = nx.to_numpy_matrix(remaining_graph) |
|
|
|
graph_length = len(remaining_graph) |
|
|
|
remaining_graph_adj = np.pad(remaining_graph_adj, [(0, max_size - graph_length), (0, max_size - graph_length)]) |
|
|
|
removed_node = nx.to_numpy_matrix(relabeled_graph)[choice] |
|
|
|
removed_node_row = np.asarray(removed_node)[0] |
|
|
|
removed_node_row = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))]) |
|
|
|
return remaining_graph_adj, feature_matrix(remaining_graph), removed_node_row |
|
|
|
|
|
|
|
"""" |
|
|
|
Layers: |
|
|
@@ -210,9 +210,8 @@ class Hydra(nn.Module): |
|
|
|
|
|
|
|
return mlp_output |
|
|
|
|
|
|
|
"""" |
|
|
|
hala train |
|
|
|
"""" |
|
|
|
# train |
|
|
|
|
|
|
|
def build_model(gcn_input, model_dim, head): |
|
|
|
model = Hydra(gcn_input, model_dim, head).cuda() |
|
|
|
return model |
|
|
@@ -223,34 +222,34 @@ def fn(batch): |
|
|
|
|
|
|
|
|
|
|
|
def train_model(model, trainloader, epoch, print_every=100): |
|
|
|
optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9) |
|
|
|
|
|
|
|
model.train() |
|
|
|
start = time.time() |
|
|
|
temp = start |
|
|
|
total_loss = 0 |
|
|
|
|
|
|
|
for i in range(epoch): |
|
|
|
for batch, data in enumerate(trainloader, 0): |
|
|
|
adj, features, true_links = data |
|
|
|
adj, features, true_links = torch.tensor(adj).cuda(), torch.tensor(features).cuda(), torch.tensor(true_links).cuda() |
|
|
|
# print(adj.shape) |
|
|
|
# print(features.shape) |
|
|
|
# print(true_links.shape) |
|
|
|
preds = model(features, adj) |
|
|
|
optim.zero_grad() |
|
|
|
loss = F.binary_cross_entropy(preds.double(), true_links.double()) |
|
|
|
loss.backward() |
|
|
|
optim.step() |
|
|
|
total_loss += loss.item() |
|
|
|
if (i + 1) % print_every == 0: |
|
|
|
loss_avg = total_loss / print_every |
|
|
|
print("time = %dm, epoch %d, iter = %d, loss = %.3f,\ |
|
|
|
%ds per %d iters" % ((time.time() - start) // 60,\ |
|
|
|
epoch + 1, i + 1, loss_avg, time.time() - temp,\ |
|
|
|
print_every)) |
|
|
|
optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9) |
|
|
|
|
|
|
|
model.train() |
|
|
|
start = time.time() |
|
|
|
temp = start |
|
|
|
total_loss = 0 |
|
|
|
temp = time.time() |
|
|
|
|
|
|
|
for i in range(epoch): |
|
|
|
for batch, data in enumerate(trainloader, 0): |
|
|
|
adj, features, true_links = data |
|
|
|
adj, features, true_links = torch.tensor(adj).cuda(), torch.tensor(features).cuda(), torch.tensor(true_links).cuda() |
|
|
|
# print(adj.shape) |
|
|
|
# print(features.shape) |
|
|
|
# print(true_links.shape) |
|
|
|
preds = model(features, adj) |
|
|
|
optim.zero_grad() |
|
|
|
loss = F.binary_cross_entropy(preds.double(), true_links.double()) |
|
|
|
loss.backward() |
|
|
|
optim.step() |
|
|
|
total_loss += loss.item() |
|
|
|
if (i + 1) % print_every == 0: |
|
|
|
loss_avg = total_loss / print_every |
|
|
|
print("time = %dm, epoch %d, iter = %d, loss = %.3f,\ |
|
|
|
%ds per %d iters" % ((time.time() - start) // 60,\ |
|
|
|
epoch + 1, i + 1, loss_avg, time.time() - temp,\ |
|
|
|
print_every)) |
|
|
|
total_loss = 0 |
|
|
|
temp = time.time() |
|
|
|
|
|
|
|
|
|
|
|
# prepare data |