Browse Source

train function

master
Yassaman Ommi 4 years ago
parent
commit
34fc982f1f
1 changed files with 50 additions and 0 deletions
  1. 50
    0
      GraphTransformer.py

+ 50
- 0
GraphTransformer.py View File

mlp_output = self.MLP(gat_output).reshape(1,-1) mlp_output = self.MLP(gat_output).reshape(1,-1)


return mlp_output return mlp_output

''''
hala train
''''
def build_model(gcn_input, model_dim, head):
model = Hydra(gcn_input, model_dim, head).cuda()
return model


def fn(batch):
return batch[0]


def train_model(model, trainloader, epoch, print_every=100):
optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)

model.train()
start = time.time()
temp = start
total_loss = 0

for i in range(epoch):
for batch, data in enumerate(trainloader, 0):
adj, features, true_links = data
adj, features, true_links = torch.tensor(adj).cuda(), torch.tensor(features).cuda(), torch.tensor(true_links).cuda()
# print(adj.shape)
# print(features.shape)
# print(true_links.shape)
preds = model(features, adj)
optim.zero_grad()
loss = F.binary_cross_entropy(preds.double(), true_links.double())
loss.backward()
optim.step()
total_loss += loss.item()
if (i + 1) % print_every == 0:
loss_avg = total_loss / print_every
print("time = %dm, epoch %d, iter = %d, loss = %.3f,\
%ds per %d iters" % ((time.time() - start) // 60,\
epoch + 1, i + 1, loss_avg, time.time() - temp,\
print_every))
total_loss = 0
temp = time.time()


# prepare data
# coop = sum([list(filter(lambda x: x is not None, [prepare_graph_data(g) for g in graphs])) for i in range(10)], [])
coop = list(filter(lambda x: x is not None, [prepare_graph_data(g) for g in graphs]))
trainloader = torch.utils.data.DataLoader(coop, collate_fn=fn, batch_size=1)
model = build_model(3, 243, 9)
train_model(model, trainloader, 100, 10)

Loading…
Cancel
Save