| mlp_output = self.MLP(gat_output).reshape(1,-1) | mlp_output = self.MLP(gat_output).reshape(1,-1) | ||||
| return mlp_output | return mlp_output | ||||
| '''' | |||||
| hala train | |||||
| '''' | |||||
| def build_model(gcn_input, model_dim, head): | |||||
| model = Hydra(gcn_input, model_dim, head).cuda() | |||||
| return model | |||||
| def fn(batch): | |||||
| return batch[0] | |||||
| def train_model(model, trainloader, epoch, print_every=100): | |||||
| optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9) | |||||
| model.train() | |||||
| start = time.time() | |||||
| temp = start | |||||
| total_loss = 0 | |||||
| for i in range(epoch): | |||||
| for batch, data in enumerate(trainloader, 0): | |||||
| adj, features, true_links = data | |||||
| adj, features, true_links = torch.tensor(adj).cuda(), torch.tensor(features).cuda(), torch.tensor(true_links).cuda() | |||||
| # print(adj.shape) | |||||
| # print(features.shape) | |||||
| # print(true_links.shape) | |||||
| preds = model(features, adj) | |||||
| optim.zero_grad() | |||||
| loss = F.binary_cross_entropy(preds.double(), true_links.double()) | |||||
| loss.backward() | |||||
| optim.step() | |||||
| total_loss += loss.item() | |||||
| if (i + 1) % print_every == 0: | |||||
| loss_avg = total_loss / print_every | |||||
| print("time = %dm, epoch %d, iter = %d, loss = %.3f,\ | |||||
| %ds per %d iters" % ((time.time() - start) // 60,\ | |||||
| epoch + 1, i + 1, loss_avg, time.time() - temp,\ | |||||
| print_every)) | |||||
| total_loss = 0 | |||||
| temp = time.time() | |||||
| # prepare data | |||||
| # coop = sum([list(filter(lambda x: x is not None, [prepare_graph_data(g) for g in graphs])) for i in range(10)], []) | |||||
| coop = list(filter(lambda x: x is not None, [prepare_graph_data(g) for g in graphs])) | |||||
| trainloader = torch.utils.data.DataLoader(coop, collate_fn=fn, batch_size=1) | |||||
| model = build_model(3, 243, 9) | |||||
| train_model(model, trainloader, 100, 10) |