Browse Source

train function

master
Yassaman Ommi 3 years ago
parent
commit
80fdf6aef4
1 changed files with 41 additions and 42 deletions
  1. 41
    42
      GraphTransformer.py

+ 41
- 42
GraphTransformer.py View File

@@ -101,18 +101,18 @@ def prepare_graph_data(graph, max_size=40, min_size=10):
gets a graph as an input
returns a graph with a randomly removed node adj matrix [0], its feature matrix [1], the removed node true links [2]
'''
if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:
return None
relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph)
choice = np.random.choice(list(relabeled_graph.nodes()))
remaining_graph = relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes())))
remaining_graph_adj = nx.to_numpy_matrix(remaining_graph)
graph_length = len(remaining_graph)
remaining_graph_adj = np.pad(remaining_graph_adj, [(0, max_size - graph_length), (0, max_size - graph_length)])
removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]
removed_node_row = np.asarray(removed_node)[0]
removed_node_row = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))])
return remaining_graph_adj, feature_matrix(remaining_graph), removed_node_row
if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:
return None
relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph)
choice = np.random.choice(list(relabeled_graph.nodes()))
remaining_graph = relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes())))
remaining_graph_adj = nx.to_numpy_matrix(remaining_graph)
graph_length = len(remaining_graph)
remaining_graph_adj = np.pad(remaining_graph_adj, [(0, max_size - graph_length), (0, max_size - graph_length)])
removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]
removed_node_row = np.asarray(removed_node)[0]
removed_node_row = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))])
return remaining_graph_adj, feature_matrix(remaining_graph), removed_node_row

""""
Layers:
@@ -210,9 +210,8 @@ class Hydra(nn.Module):

return mlp_output

""""
hala train
""""
# train

def build_model(gcn_input, model_dim, head):
model = Hydra(gcn_input, model_dim, head).cuda()
return model
@@ -223,34 +222,34 @@ def fn(batch):


def train_model(model, trainloader, epoch, print_every=100):
optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)

model.train()
start = time.time()
temp = start
total_loss = 0

for i in range(epoch):
for batch, data in enumerate(trainloader, 0):
adj, features, true_links = data
adj, features, true_links = torch.tensor(adj).cuda(), torch.tensor(features).cuda(), torch.tensor(true_links).cuda()
# print(adj.shape)
# print(features.shape)
# print(true_links.shape)
preds = model(features, adj)
optim.zero_grad()
loss = F.binary_cross_entropy(preds.double(), true_links.double())
loss.backward()
optim.step()
total_loss += loss.item()
if (i + 1) % print_every == 0:
loss_avg = total_loss / print_every
print("time = %dm, epoch %d, iter = %d, loss = %.3f,\
%ds per %d iters" % ((time.time() - start) // 60,\
epoch + 1, i + 1, loss_avg, time.time() - temp,\
print_every))
optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)

model.train()
start = time.time()
temp = start
total_loss = 0
temp = time.time()

for i in range(epoch):
for batch, data in enumerate(trainloader, 0):
adj, features, true_links = data
adj, features, true_links = torch.tensor(adj).cuda(), torch.tensor(features).cuda(), torch.tensor(true_links).cuda()
# print(adj.shape)
# print(features.shape)
# print(true_links.shape)
preds = model(features, adj)
optim.zero_grad()
loss = F.binary_cross_entropy(preds.double(), true_links.double())
loss.backward()
optim.step()
total_loss += loss.item()
if (i + 1) % print_every == 0:
loss_avg = total_loss / print_every
print("time = %dm, epoch %d, iter = %d, loss = %.3f,\
%ds per %d iters" % ((time.time() - start) // 60,\
epoch + 1, i + 1, loss_avg, time.time() - temp,\
print_every))
total_loss = 0
temp = time.time()


# prepare data

Loading…
Cancel
Save