| @@ -41,7 +41,6 @@ class MetaLearner(nn.Module): | |||
| return x.view(size[0], 1, 1, self.out_size) | |||
| class EmbeddingLearner(nn.Module): | |||
| def __init__(self): | |||
| super(EmbeddingLearner, self).__init__() | |||
| @@ -52,21 +51,58 @@ class EmbeddingLearner(nn.Module): | |||
| n_score = score[:, pos_num:] | |||
| return p_score, n_score | |||
| def bpr_loss(p_scores, n_values): | |||
| p1 = p_scores[:,0,None] | |||
| p2 = p_scores[:,1,None] | |||
| def bpr_loss(p_scores, n_values,device): | |||
| ratio = int(n_values.shape[1] / p_scores.shape[1]) | |||
| temp_pvalues = torch.tensor([]).cuda(device=device) | |||
| for i in range(p_scores.shape[1]): | |||
| temp_pvalues = torch.cat((temp_pvalues, p_scores[:, i, None].expand(-1, ratio)), dim=1) | |||
| d = torch.sub(temp_pvalues,n_values) | |||
| t = F.logsigmoid(d) | |||
| loss = -1 * (1.0/n_values.shape[1]) * t.sum(dim=1) | |||
| loss = loss.sum(dim=0) | |||
| return loss | |||
| def bpr_max_loss(p_scores, n_values,device): | |||
| s = F.softmax(n_values,dim=1) | |||
| ratio = int(n_values.shape[1] / p_scores.shape[1]) | |||
| temp_pvalues = torch.tensor([]).cuda(device=device) | |||
| for i in range(p_scores.shape[1]): | |||
| temp_pvalues = torch.cat((temp_pvalues,p_scores[:,i,None].expand(-1,ratio)),dim=1) | |||
| d = torch.sigmoid(torch.sub(temp_pvalues,n_values)) | |||
| t = torch.mul(s,d) | |||
| loss = -1 * torch.log(t.sum(dim=1)) | |||
| loss = loss.sum() | |||
| return loss | |||
| def top_loss(p_scores, n_values): | |||
| p1 = p_scores[:, 0, None] | |||
| p2 = p_scores[:, 1, None] | |||
| num_neg = n_values.shape[1] | |||
| half_index = int(num_neg/2) | |||
| half_index = int(num_neg / 2) | |||
| d1 = torch.sub(p1, n_values[:, 0:half_index]) | |||
| d2 = torch.sub(p2, n_values[:, half_index:]) | |||
| # print("d1 shape:",d1.shape) | |||
| d1 = torch.sub(p1,n_values[:,0:half_index]) | |||
| d2 = torch.sub(p2,n_values[:,half_index:]) | |||
| # print("add shape:",torch.cat((d1,d2),dim=1).shape) | |||
| t1 = torch.sigmoid(torch.cat((d1,d2),dim=1)) | |||
| # print("t1 shape:",t1.shape) | |||
| t2 = torch.sigmoid(torch.pow(n_values,2)) | |||
| # print("t2 shape:",t2.shape) | |||
| t = F.logsigmoid(torch.add(d1,d2)) | |||
| t3 = torch.add(t1,t2) | |||
| # print("t3 shape:",t3.shape) | |||
| loss = (-1) * t.sum() / n_values.shape[1] | |||
| loss = t3.sum() | |||
| # print(loss.shape) | |||
| # loss /= (n_values.shape[1] * p_scores.shape[0]) | |||
| loss /= n_values.shape[1] | |||
| return loss | |||
| class MetaTL(nn.Module): | |||
| def __init__(self, itemnum, parameter): | |||
| super(MetaTL, self).__init__() | |||
| @@ -81,7 +117,8 @@ class MetaTL(nn.Module): | |||
| num_hidden2=200, out_size=100, dropout_p=0) | |||
| self.embedding_learner = EmbeddingLearner() | |||
| self.loss_func = nn.MarginRankingLoss(self.margin) | |||
| # self.loss_func = nn.MarginRankingLoss(self.margin) | |||
| self.loss_func = bpr_loss | |||
| self.rel_q_sharing = dict() | |||
| @@ -118,9 +155,9 @@ class MetaTL(nn.Module): | |||
| # sorted,indecies = torch.sort(n_score, descending=True,dim=1) | |||
| # n_values = sorted[:,0:p_score.shape[1]] | |||
| n_values = n_score | |||
| loss = bpr_loss(p_score,n_values) | |||
| # loss = self.loss_func(p_score, n_values, y) | |||
| loss = self.loss_func(p_score,n_score,device=self.device) | |||
| loss.backward(retain_graph=True) | |||
| grad_meta = rel.grad | |||
| @@ -28,7 +28,7 @@ class Trainer: | |||
| self.optimizer = torch.optim.Adam(self.MetaTL.parameters(), self.learning_rate) | |||
| def rank_predict(self, data, x, ranks): | |||
| # query_idx is the idx of positive score | |||
| query_idx = x.shape[0] - 1 | |||
| @@ -54,13 +54,13 @@ class Trainer: | |||
| self.optimizer.zero_grad() | |||
| p_score, n_score = self.MetaTL(task, iseval, curr_rel) | |||
| y = torch.Tensor([1]).to(self.device) | |||
| loss = self.MetaTL.loss_func(p_score, n_score, y) | |||
| loss = self.MetaTL.loss_func(p_score, n_score,self.device) | |||
| loss.backward() | |||
| self.optimizer.step() | |||
| elif curr_rel != '': | |||
| p_score, n_score = self.MetaTL(task, iseval, curr_rel) | |||
| y = torch.Tensor([1]).to(self.device) | |||
| loss = self.MetaTL.loss_func(p_score, n_score, y) | |||
| loss = self.MetaTL.loss_func(p_score, n_score,self.device) | |||
| return loss, p_score, n_score | |||
| def train(self): | |||
| @@ -74,12 +74,12 @@ class Trainer: | |||
| # sample one batch from data_loader | |||
| train_task, curr_rel = self.train_data_loader.next_batch() | |||
| loss, _, _ = self.do_one_step(train_task, iseval=False, curr_rel=curr_rel) | |||
| # print the loss on specific epoch | |||
| # if e % self.print_epoch == 0: | |||
| # loss_num = loss.item() | |||
| # print("Epoch: {}\tLoss: {:.4f}".format(e, loss_num)) | |||
| # do evaluation on specific epoch | |||
| if e % self.eval_epoch == 0 and e != 0: | |||
| loss_num = loss.item() | |||
| print("Epoch: {}\tLoss: {:.4f}".format(e, loss_num)) | |||
| print('Epoch {} Validating...'.format(e)) | |||
| valid_data = self.eval(istest=False, epoch=e) | |||
| @@ -105,6 +105,7 @@ class Trainer: | |||
| t = 0 | |||
| temp = dict() | |||
| total_loss = 0 | |||
| while True: | |||
| # sample all the eval tasks | |||
| eval_task, curr_rel = data_loader.next_one_on_eval() | |||
| @@ -113,7 +114,8 @@ class Trainer: | |||
| break | |||
| t += 1 | |||
| _, p_score, n_score = self.do_one_step(eval_task, iseval=True, curr_rel=curr_rel) | |||
| loss, p_score, n_score = self.do_one_step(eval_task, iseval=True, curr_rel=curr_rel) | |||
| total_loss += loss | |||
| x = torch.cat([n_score, p_score], 1).squeeze() | |||
| @@ -132,12 +134,14 @@ class Trainer: | |||
| if istest: | |||
| print("TEST: \t test_loss: ",total_loss.item()) | |||
| print("TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format( | |||
| temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1']),"\n") | |||
| with open('results.txt', 'a') as f: | |||
| f.writelines("TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r\n\n".format( | |||
| temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1'])) | |||
| else: | |||
| print("VALID: \t validation_loss: ", total_loss.item()) | |||
| print("VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format( | |||
| temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1'])) | |||
| with open("results.txt",'a') as f: | |||