| @@ -209,3 +209,53 @@ class Hydra(nn.Module): | |||
| mlp_output = self.MLP(gat_output).reshape(1,-1) | |||
| return mlp_output | |||
| '''' | |||
| hala train | |||
| '''' | |||
| def build_model(gcn_input, model_dim, head): | |||
| model = Hydra(gcn_input, model_dim, head).cuda() | |||
| return model | |||
| def fn(batch): | |||
| return batch[0] | |||
| def train_model(model, trainloader, epoch, print_every=100): | |||
| optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9) | |||
| model.train() | |||
| start = time.time() | |||
| temp = start | |||
| total_loss = 0 | |||
| for i in range(epoch): | |||
| for batch, data in enumerate(trainloader, 0): | |||
| adj, features, true_links = data | |||
| adj, features, true_links = torch.tensor(adj).cuda(), torch.tensor(features).cuda(), torch.tensor(true_links).cuda() | |||
| # print(adj.shape) | |||
| # print(features.shape) | |||
| # print(true_links.shape) | |||
| preds = model(features, adj) | |||
| optim.zero_grad() | |||
| loss = F.binary_cross_entropy(preds.double(), true_links.double()) | |||
| loss.backward() | |||
| optim.step() | |||
| total_loss += loss.item() | |||
| if (i + 1) % print_every == 0: | |||
| loss_avg = total_loss / print_every | |||
| print("time = %dm, epoch %d, iter = %d, loss = %.3f,\ | |||
| %ds per %d iters" % ((time.time() - start) // 60,\ | |||
| epoch + 1, i + 1, loss_avg, time.time() - temp,\ | |||
| print_every)) | |||
| total_loss = 0 | |||
| temp = time.time() | |||
| # prepare data | |||
| # coop = sum([list(filter(lambda x: x is not None, [prepare_graph_data(g) for g in graphs])) for i in range(10)], []) | |||
| coop = list(filter(lambda x: x is not None, [prepare_graph_data(g) for g in graphs])) | |||
| trainloader = torch.utils.data.DataLoader(coop, collate_fn=fn, batch_size=1) | |||
| model = build_model(3, 243, 9) | |||
| train_model(model, trainloader, 100, 10) | |||