|
|
@@ -11,7 +11,7 @@ import gc |
|
|
|
|
|
|
|
|
|
|
|
class Trainer: |
|
|
|
def __init__(self, data_loaders, itemnum, parameter, user_train_num, user_train): |
|
|
|
def __init__(self, data_loaders, itemnum, parameter,user_train_num,user_train): |
|
|
|
# print(user_train) |
|
|
|
self.parameter = parameter |
|
|
|
# data loader |
|
|
@@ -124,14 +124,12 @@ class Trainer: |
|
|
|
item_arg_all = np.argmax(variance_candidates_all, axis=1) |
|
|
|
|
|
|
|
# negitems = { user : self.candidate_cur[user][negitems_candidates_all[user][item_arg_all[index]]] for index,user in enumerate(users)} |
|
|
|
negitems0 = {user: self.candidate_cur[user][negitems_candidates_all[user][item_arg_all[index][0]]] for |
|
|
|
index, user in enumerate(users)} |
|
|
|
negitems1 = {user: self.candidate_cur[user][negitems_candidates_all[user][item_arg_all[index][1]]] for |
|
|
|
index, user in enumerate(users)} |
|
|
|
negitems0 = { user : self.candidate_cur[user][negitems_candidates_all[user][item_arg_all[index][0]]] for index,user in enumerate(users)} |
|
|
|
negitems1 = { user : self.candidate_cur[user][negitems_candidates_all[user][item_arg_all[index][1]]] for index,user in enumerate(users)} |
|
|
|
|
|
|
|
############################### |
|
|
|
for i in users: |
|
|
|
self.final_negative_items[i] = [negitems0[i], negitems1[i]] |
|
|
|
self.final_negative_items[i] = [negitems0[i],negitems1[i]] |
|
|
|
############################### |
|
|
|
|
|
|
|
# update Mu |
|
|
@@ -149,7 +147,7 @@ class Trainer: |
|
|
|
negitems_mu_candidates[i] = self.Mu_idx[i] |
|
|
|
|
|
|
|
negitems_mu = {} |
|
|
|
negitems_mu = {user: self.candidate_cur[user][negitems_mu_candidates[user]] for user in users} |
|
|
|
negitems_mu = {user:self.candidate_cur[user][negitems_mu_candidates[user]] for user in users} |
|
|
|
|
|
|
|
task = np.array(train_task[2]) |
|
|
|
task = np.tile(task, reps=(1, self.S1 * (1 + self.S2_div_S1), 1)) |
|
|
@@ -169,7 +167,7 @@ class Trainer: |
|
|
|
print("nan happend__2 in ratings_mu_candidates") |
|
|
|
ratings_mu_candidates = self.MetaTL.fast_forward(task, users) |
|
|
|
ratings_mu_candidates = ratings_mu_candidates / self.temperature |
|
|
|
ratings_mu_candidates = ratings_mu_candidates + 100 |
|
|
|
ratings_mu_candidates = ratings_mu_candidates + 10 |
|
|
|
ratings_mu_candidates = np.exp(ratings_mu_candidates) / np.reshape( |
|
|
|
np.sum(np.exp(ratings_mu_candidates), axis=1), [-1, 1]) |
|
|
|
|
|
|
@@ -181,7 +179,7 @@ class Trainer: |
|
|
|
else: |
|
|
|
user_set.add(i) |
|
|
|
cache_arg = np.random.choice(self.S1 * (1 + self.S2_div_S1), self.S1, |
|
|
|
p=ratings_mu_candidates[cnt], replace=False) |
|
|
|
p=ratings_mu_candidates[cnt], replace=False) |
|
|
|
self.Mu_idx[i] = np.array(self.Mu_idx[i])[cache_arg].tolist() |
|
|
|
cnt += 1 |
|
|
|
|
|
|
@@ -404,7 +402,7 @@ class Trainer: |
|
|
|
"TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format( |
|
|
|
temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], |
|
|
|
temp['Hits@1'])) |
|
|
|
with open('results.txt', 'a') as f: |
|
|
|
with open('results2.txt', 'a') as f: |
|
|
|
f.writelines( |
|
|
|
"TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r\n\n".format( |
|
|
|
temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], |
|
|
@@ -415,7 +413,7 @@ class Trainer: |
|
|
|
"VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format( |
|
|
|
temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], |
|
|
|
temp['Hits@1'])) |
|
|
|
with open("results.txt", 'a') as f: |
|
|
|
with open("results2.txt", 'a') as f: |
|
|
|
f.writelines( |
|
|
|
"VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format( |
|
|
|
temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], |