| @@ -101,18 +101,18 @@ def prepare_graph_data(graph, max_size=40, min_size=10): | |||
| gets a graph as an input | |||
| returns a graph with a randomly removed node adj matrix [0], its feature matrix [1], the removed node true links [2] | |||
| ''' | |||
| if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size: | |||
| return None | |||
| relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph) | |||
| choice = np.random.choice(list(relabeled_graph.nodes())) | |||
| remaining_graph = relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes()))) | |||
| remaining_graph_adj = nx.to_numpy_matrix(remaining_graph) | |||
| graph_length = len(remaining_graph) | |||
| remaining_graph_adj = np.pad(remaining_graph_adj, [(0, max_size - graph_length), (0, max_size - graph_length)]) | |||
| removed_node = nx.to_numpy_matrix(relabeled_graph)[choice] | |||
| removed_node_row = np.asarray(removed_node)[0] | |||
| removed_node_row = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))]) | |||
| return remaining_graph_adj, feature_matrix(remaining_graph), removed_node_row | |||
| if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size: | |||
| return None | |||
| relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph) | |||
| choice = np.random.choice(list(relabeled_graph.nodes())) | |||
| remaining_graph = relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes()))) | |||
| remaining_graph_adj = nx.to_numpy_matrix(remaining_graph) | |||
| graph_length = len(remaining_graph) | |||
| remaining_graph_adj = np.pad(remaining_graph_adj, [(0, max_size - graph_length), (0, max_size - graph_length)]) | |||
| removed_node = nx.to_numpy_matrix(relabeled_graph)[choice] | |||
| removed_node_row = np.asarray(removed_node)[0] | |||
| removed_node_row = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))]) | |||
| return remaining_graph_adj, feature_matrix(remaining_graph), removed_node_row | |||
| """" | |||
| Layers: | |||
| @@ -210,9 +210,8 @@ class Hydra(nn.Module): | |||
| return mlp_output | |||
| """" | |||
| hala train | |||
| """" | |||
| # train | |||
| def build_model(gcn_input, model_dim, head): | |||
| model = Hydra(gcn_input, model_dim, head).cuda() | |||
| return model | |||
| @@ -223,34 +222,34 @@ def fn(batch): | |||
| def train_model(model, trainloader, epoch, print_every=100): | |||
| optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9) | |||
| model.train() | |||
| start = time.time() | |||
| temp = start | |||
| total_loss = 0 | |||
| for i in range(epoch): | |||
| for batch, data in enumerate(trainloader, 0): | |||
| adj, features, true_links = data | |||
| adj, features, true_links = torch.tensor(adj).cuda(), torch.tensor(features).cuda(), torch.tensor(true_links).cuda() | |||
| # print(adj.shape) | |||
| # print(features.shape) | |||
| # print(true_links.shape) | |||
| preds = model(features, adj) | |||
| optim.zero_grad() | |||
| loss = F.binary_cross_entropy(preds.double(), true_links.double()) | |||
| loss.backward() | |||
| optim.step() | |||
| total_loss += loss.item() | |||
| if (i + 1) % print_every == 0: | |||
| loss_avg = total_loss / print_every | |||
| print("time = %dm, epoch %d, iter = %d, loss = %.3f,\ | |||
| %ds per %d iters" % ((time.time() - start) // 60,\ | |||
| epoch + 1, i + 1, loss_avg, time.time() - temp,\ | |||
| print_every)) | |||
| optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9) | |||
| model.train() | |||
| start = time.time() | |||
| temp = start | |||
| total_loss = 0 | |||
| temp = time.time() | |||
| for i in range(epoch): | |||
| for batch, data in enumerate(trainloader, 0): | |||
| adj, features, true_links = data | |||
| adj, features, true_links = torch.tensor(adj).cuda(), torch.tensor(features).cuda(), torch.tensor(true_links).cuda() | |||
| # print(adj.shape) | |||
| # print(features.shape) | |||
| # print(true_links.shape) | |||
| preds = model(features, adj) | |||
| optim.zero_grad() | |||
| loss = F.binary_cross_entropy(preds.double(), true_links.double()) | |||
| loss.backward() | |||
| optim.step() | |||
| total_loss += loss.item() | |||
| if (i + 1) % print_every == 0: | |||
| loss_avg = total_loss / print_every | |||
| print("time = %dm, epoch %d, iter = %d, loss = %.3f,\ | |||
| %ds per %d iters" % ((time.time() - start) // 60,\ | |||
| epoch + 1, i + 1, loss_avg, time.time() - temp,\ | |||
| print_every)) | |||
| total_loss = 0 | |||
| temp = time.time() | |||
| # prepare data | |||