|
|
@@ -209,3 +209,53 @@ class Hydra(nn.Module): |
|
|
|
mlp_output = self.MLP(gat_output).reshape(1,-1) |
|
|
|
|
|
|
|
return mlp_output |
|
|
|
|
|
|
|
'''' |
|
|
|
hala train |
|
|
|
'''' |
|
|
|
def build_model(gcn_input, model_dim, head): |
|
|
|
model = Hydra(gcn_input, model_dim, head).cuda() |
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
def fn(batch): |
|
|
|
return batch[0] |
|
|
|
|
|
|
|
|
|
|
|
def train_model(model, trainloader, epoch, print_every=100): |
|
|
|
optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9) |
|
|
|
|
|
|
|
model.train() |
|
|
|
start = time.time() |
|
|
|
temp = start |
|
|
|
total_loss = 0 |
|
|
|
|
|
|
|
for i in range(epoch): |
|
|
|
for batch, data in enumerate(trainloader, 0): |
|
|
|
adj, features, true_links = data |
|
|
|
adj, features, true_links = torch.tensor(adj).cuda(), torch.tensor(features).cuda(), torch.tensor(true_links).cuda() |
|
|
|
# print(adj.shape) |
|
|
|
# print(features.shape) |
|
|
|
# print(true_links.shape) |
|
|
|
preds = model(features, adj) |
|
|
|
optim.zero_grad() |
|
|
|
loss = F.binary_cross_entropy(preds.double(), true_links.double()) |
|
|
|
loss.backward() |
|
|
|
optim.step() |
|
|
|
total_loss += loss.item() |
|
|
|
if (i + 1) % print_every == 0: |
|
|
|
loss_avg = total_loss / print_every |
|
|
|
print("time = %dm, epoch %d, iter = %d, loss = %.3f,\ |
|
|
|
%ds per %d iters" % ((time.time() - start) // 60,\ |
|
|
|
epoch + 1, i + 1, loss_avg, time.time() - temp,\ |
|
|
|
print_every)) |
|
|
|
total_loss = 0 |
|
|
|
temp = time.time() |
|
|
|
|
|
|
|
|
|
|
|
# prepare data |
|
|
|
# coop = sum([list(filter(lambda x: x is not None, [prepare_graph_data(g) for g in graphs])) for i in range(10)], []) |
|
|
|
coop = list(filter(lambda x: x is not None, [prepare_graph_data(g) for g in graphs])) |
|
|
|
trainloader = torch.utils.data.DataLoader(coop, collate_fn=fn, batch_size=1) |
|
|
|
model = build_model(3, 243, 9) |
|
|
|
train_model(model, trainloader, 100, 10) |