|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.optimizer = torch.optim.Adam(self.MetaTL.parameters(), self.learning_rate) |
|
|
self.optimizer = torch.optim.Adam(self.MetaTL.parameters(), self.learning_rate) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rank_predict(self, data, x, ranks): |
|
|
def rank_predict(self, data, x, ranks): |
|
|
# query_idx is the idx of positive score |
|
|
# query_idx is the idx of positive score |
|
|
query_idx = x.shape[0] - 1 |
|
|
query_idx = x.shape[0] - 1 |
|
|
|
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
self.optimizer.zero_grad() |
|
|
p_score, n_score = self.MetaTL(task, iseval, curr_rel) |
|
|
p_score, n_score = self.MetaTL(task, iseval, curr_rel) |
|
|
y = torch.Tensor([1]).to(self.device) |
|
|
y = torch.Tensor([1]).to(self.device) |
|
|
loss = self.MetaTL.loss_func(p_score, n_score, y) |
|
|
|
|
|
|
|
|
loss = self.MetaTL.loss_func(p_score, n_score,self.device) |
|
|
loss.backward() |
|
|
loss.backward() |
|
|
self.optimizer.step() |
|
|
self.optimizer.step() |
|
|
elif curr_rel != '': |
|
|
elif curr_rel != '': |
|
|
p_score, n_score = self.MetaTL(task, iseval, curr_rel) |
|
|
p_score, n_score = self.MetaTL(task, iseval, curr_rel) |
|
|
y = torch.Tensor([1]).to(self.device) |
|
|
y = torch.Tensor([1]).to(self.device) |
|
|
loss = self.MetaTL.loss_func(p_score, n_score, y) |
|
|
|
|
|
|
|
|
loss = self.MetaTL.loss_func(p_score, n_score,self.device) |
|
|
return loss, p_score, n_score |
|
|
return loss, p_score, n_score |
|
|
|
|
|
|
|
|
def train(self): |
|
|
def train(self): |
|
|
|
|
|
|
|
|
# sample one batch from data_loader |
|
|
# sample one batch from data_loader |
|
|
train_task, curr_rel = self.train_data_loader.next_batch() |
|
|
train_task, curr_rel = self.train_data_loader.next_batch() |
|
|
loss, _, _ = self.do_one_step(train_task, iseval=False, curr_rel=curr_rel) |
|
|
loss, _, _ = self.do_one_step(train_task, iseval=False, curr_rel=curr_rel) |
|
|
# print the loss on specific epoch |
|
|
|
|
|
# if e % self.print_epoch == 0: |
|
|
|
|
|
# loss_num = loss.item() |
|
|
|
|
|
# print("Epoch: {}\tLoss: {:.4f}".format(e, loss_num)) |
|
|
|
|
|
|
|
|
|
|
|
# do evaluation on specific epoch |
|
|
# do evaluation on specific epoch |
|
|
if e % self.eval_epoch == 0 and e != 0: |
|
|
if e % self.eval_epoch == 0 and e != 0: |
|
|
|
|
|
loss_num = loss.item() |
|
|
|
|
|
print("Epoch: {}\tLoss: {:.4f}".format(e, loss_num)) |
|
|
|
|
|
|
|
|
print('Epoch {} Validating...'.format(e)) |
|
|
print('Epoch {} Validating...'.format(e)) |
|
|
valid_data = self.eval(istest=False, epoch=e) |
|
|
valid_data = self.eval(istest=False, epoch=e) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t = 0 |
|
|
t = 0 |
|
|
temp = dict() |
|
|
temp = dict() |
|
|
|
|
|
total_loss = 0 |
|
|
while True: |
|
|
while True: |
|
|
# sample all the eval tasks |
|
|
# sample all the eval tasks |
|
|
eval_task, curr_rel = data_loader.next_one_on_eval() |
|
|
eval_task, curr_rel = data_loader.next_one_on_eval() |
|
|
|
|
|
|
|
|
break |
|
|
break |
|
|
t += 1 |
|
|
t += 1 |
|
|
|
|
|
|
|
|
_, p_score, n_score = self.do_one_step(eval_task, iseval=True, curr_rel=curr_rel) |
|
|
|
|
|
|
|
|
loss, p_score, n_score = self.do_one_step(eval_task, iseval=True, curr_rel=curr_rel) |
|
|
|
|
|
total_loss += loss |
|
|
|
|
|
|
|
|
x = torch.cat([n_score, p_score], 1).squeeze() |
|
|
x = torch.cat([n_score, p_score], 1).squeeze() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if istest: |
|
|
if istest: |
|
|
|
|
|
print("TEST: \t test_loss: ",total_loss.item()) |
|
|
print("TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format( |
|
|
print("TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format( |
|
|
temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1']),"\n") |
|
|
temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1']),"\n") |
|
|
with open('results.txt', 'a') as f: |
|
|
with open('results.txt', 'a') as f: |
|
|
f.writelines("TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r\n\n".format( |
|
|
f.writelines("TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r\n\n".format( |
|
|
temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1'])) |
|
|
temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1'])) |
|
|
else: |
|
|
else: |
|
|
|
|
|
print("VALID: \t validation_loss: ", total_loss.item()) |
|
|
print("VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format( |
|
|
print("VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format( |
|
|
temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1'])) |
|
|
temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1'])) |
|
|
with open("results.txt",'a') as f: |
|
|
with open("results.txt",'a') as f: |