|
|
|
|
|
|
|
|
loss = loss.sum() |
|
|
loss = loss.sum() |
|
|
return loss |
|
|
return loss |
|
|
|
|
|
|
|
|
def top_loss(p_scores, n_values): |
|
|
|
|
|
p1 = p_scores[:, 0, None] |
|
|
|
|
|
p2 = p_scores[:, 1, None] |
|
|
|
|
|
|
|
|
|
|
|
num_neg = n_values.shape[1] |
|
|
|
|
|
half_index = int(num_neg / 2) |
|
|
|
|
|
|
|
|
|
|
|
d1 = torch.sub(p1, n_values[:, 0:half_index]) |
|
|
|
|
|
d2 = torch.sub(p2, n_values[:, half_index:]) |
|
|
|
|
|
# print("d1 shape:",d1.shape) |
|
|
|
|
|
|
|
|
def top_loss(p_scores, n_values,device): |
|
|
|
|
|
ratio = int(n_values.shape[1] / p_scores.shape[1]) |
|
|
|
|
|
temp_pvalues = torch.tensor([]).cuda(device=device) |
|
|
|
|
|
for i in range(p_scores.shape[1]): |
|
|
|
|
|
temp_pvalues = torch.cat((temp_pvalues, p_scores[:, i, None].expand(-1, ratio)), dim=1) |
|
|
|
|
|
|
|
|
# print("add shape:",torch.cat((d1,d2),dim=1).shape) |
|
|
|
|
|
t1 = torch.sigmoid(torch.cat((d1,d2),dim=1)) |
|
|
|
|
|
# print("t1 shape:",t1.shape) |
|
|
|
|
|
|
|
|
t1 = torch.sigmoid(torch.sub(n_values , temp_pvalues)) |
|
|
t2 = torch.sigmoid(torch.pow(n_values,2)) |
|
|
t2 = torch.sigmoid(torch.pow(n_values,2)) |
|
|
# print("t2 shape:",t2.shape) |
|
|
|
|
|
|
|
|
|
|
|
t3 = torch.add(t1,t2) |
|
|
|
|
|
# print("t3 shape:",t3.shape) |
|
|
|
|
|
|
|
|
|
|
|
loss = t3.sum() |
|
|
|
|
|
# print(loss.shape) |
|
|
|
|
|
# loss /= (n_values.shape[1] * p_scores.shape[0]) |
|
|
|
|
|
loss /= n_values.shape[1] |
|
|
|
|
|
|
|
|
t = torch.add(t1,t2) |
|
|
|
|
|
t = t.sum(dim=1) |
|
|
|
|
|
loss = t / n_values.shape[1] |
|
|
|
|
|
loss = loss.sum(dim=0) |
|
|
return loss |
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.embedding_learner = EmbeddingLearner() |
|
|
self.embedding_learner = EmbeddingLearner() |
|
|
# self.loss_func = nn.MarginRankingLoss(self.margin) |
|
|
# self.loss_func = nn.MarginRankingLoss(self.margin) |
|
|
self.loss_func = bpr_loss |
|
|
|
|
|
|
|
|
self.loss_func = top_loss |
|
|
|
|
|
|
|
|
self.rel_q_sharing = dict() |
|
|
self.rel_q_sharing = dict() |
|
|
|
|
|
|