| gets a graph as an input | gets a graph as an input | ||||
| returns a graph with a randomly removed node adj matrix [0], its feature matrix [1], the removed node true links [2] | returns a graph with a randomly removed node adj matrix [0], its feature matrix [1], the removed node true links [2] | ||||
| ''' | ''' | ||||
| if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size: | |||||
| return None | |||||
| relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph) | |||||
| choice = np.random.choice(list(relabeled_graph.nodes())) | |||||
| remaining_graph = relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes()))) | |||||
| remaining_graph_adj = nx.to_numpy_matrix(remaining_graph) | |||||
| graph_length = len(remaining_graph) | |||||
| remaining_graph_adj = np.pad(remaining_graph_adj, [(0, max_size - graph_length), (0, max_size - graph_length)]) | |||||
| removed_node = nx.to_numpy_matrix(relabeled_graph)[choice] | |||||
| removed_node_row = np.asarray(removed_node)[0] | |||||
| removed_node_row = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))]) | |||||
| return remaining_graph_adj, feature_matrix(remaining_graph), removed_node_row | |||||
| if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size: | |||||
| return None | |||||
| relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph) | |||||
| choice = np.random.choice(list(relabeled_graph.nodes())) | |||||
| remaining_graph = relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes()))) | |||||
| remaining_graph_adj = nx.to_numpy_matrix(remaining_graph) | |||||
| graph_length = len(remaining_graph) | |||||
| remaining_graph_adj = np.pad(remaining_graph_adj, [(0, max_size - graph_length), (0, max_size - graph_length)]) | |||||
| removed_node = nx.to_numpy_matrix(relabeled_graph)[choice] | |||||
| removed_node_row = np.asarray(removed_node)[0] | |||||
| removed_node_row = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))]) | |||||
| return remaining_graph_adj, feature_matrix(remaining_graph), removed_node_row | |||||
| """" | """" | ||||
| Layers: | Layers: | ||||
| return mlp_output | return mlp_output | ||||
| """" | |||||
| hala train | |||||
| """" | |||||
| # train | |||||
| def build_model(gcn_input, model_dim, head): | def build_model(gcn_input, model_dim, head): | ||||
| model = Hydra(gcn_input, model_dim, head).cuda() | model = Hydra(gcn_input, model_dim, head).cuda() | ||||
| return model | return model | ||||
| def train_model(model, trainloader, epoch, print_every=100): | def train_model(model, trainloader, epoch, print_every=100): | ||||
| optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9) | |||||
| model.train() | |||||
| start = time.time() | |||||
| temp = start | |||||
| total_loss = 0 | |||||
| for i in range(epoch): | |||||
| for batch, data in enumerate(trainloader, 0): | |||||
| adj, features, true_links = data | |||||
| adj, features, true_links = torch.tensor(adj).cuda(), torch.tensor(features).cuda(), torch.tensor(true_links).cuda() | |||||
| # print(adj.shape) | |||||
| # print(features.shape) | |||||
| # print(true_links.shape) | |||||
| preds = model(features, adj) | |||||
| optim.zero_grad() | |||||
| loss = F.binary_cross_entropy(preds.double(), true_links.double()) | |||||
| loss.backward() | |||||
| optim.step() | |||||
| total_loss += loss.item() | |||||
| if (i + 1) % print_every == 0: | |||||
| loss_avg = total_loss / print_every | |||||
| print("time = %dm, epoch %d, iter = %d, loss = %.3f,\ | |||||
| %ds per %d iters" % ((time.time() - start) // 60,\ | |||||
| epoch + 1, i + 1, loss_avg, time.time() - temp,\ | |||||
| print_every)) | |||||
| optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9) | |||||
| model.train() | |||||
| start = time.time() | |||||
| temp = start | |||||
| total_loss = 0 | total_loss = 0 | ||||
| temp = time.time() | |||||
| for i in range(epoch): | |||||
| for batch, data in enumerate(trainloader, 0): | |||||
| adj, features, true_links = data | |||||
| adj, features, true_links = torch.tensor(adj).cuda(), torch.tensor(features).cuda(), torch.tensor(true_links).cuda() | |||||
| # print(adj.shape) | |||||
| # print(features.shape) | |||||
| # print(true_links.shape) | |||||
| preds = model(features, adj) | |||||
| optim.zero_grad() | |||||
| loss = F.binary_cross_entropy(preds.double(), true_links.double()) | |||||
| loss.backward() | |||||
| optim.step() | |||||
| total_loss += loss.item() | |||||
| if (i + 1) % print_every == 0: | |||||
| loss_avg = total_loss / print_every | |||||
| print("time = %dm, epoch %d, iter = %d, loss = %.3f,\ | |||||
| %ds per %d iters" % ((time.time() - start) // 60,\ | |||||
| epoch + 1, i + 1, loss_avg, time.time() - temp,\ | |||||
| print_every)) | |||||
| total_loss = 0 | |||||
| temp = time.time() | |||||
| # prepare data | # prepare data |