Browse Source

change hyper parameters

RNN
mohamad maheri 2 years ago
parent
commit
e9474085ef
2 changed files with 11 additions and 13 deletions
  1. 2
    2
      main.py
  2. 9
    11
      trainer.py

+ 2
- 2
main.py View File

args.add_argument("-seed", "--seed", default=7, type=int) args.add_argument("-seed", "--seed", default=7, type=int)
args.add_argument("-K", "--K", default=3, type=int) #NUMBER OF SHOT args.add_argument("-K", "--K", default=3, type=int) #NUMBER OF SHOT


args.add_argument("-dim", "--embed_dim", default=128, type=int)
args.add_argument("-dim", "--embed_dim", default=256, type=int)
args.add_argument("-bs", "--batch_size", default=1024, type=int) args.add_argument("-bs", "--batch_size", default=1024, type=int)
args.add_argument("-lr", "--learning_rate", default=0.001, type=float) args.add_argument("-lr", "--learning_rate", default=0.001, type=float)


args.add_argument("-p", "--dropout_p", default=0.5, type=float) args.add_argument("-p", "--dropout_p", default=0.5, type=float)


args.add_argument("-gpu", "--device", default=0, type=int) args.add_argument("-gpu", "--device", default=0, type=int)
args.add_argument("--number_of_neg",default=1,type=int)
args.add_argument("--number_of_neg",default=2,type=int)







+ 9
- 11
trainer.py View File





class Trainer: class Trainer:
def __init__(self, data_loaders, itemnum, parameter, user_train_num, user_train):
def __init__(self, data_loaders, itemnum, parameter,user_train_num,user_train):
# print(user_train) # print(user_train)
self.parameter = parameter self.parameter = parameter
# data loader # data loader
item_arg_all = np.argmax(variance_candidates_all, axis=1) item_arg_all = np.argmax(variance_candidates_all, axis=1)


# negitems = { user : self.candidate_cur[user][negitems_candidates_all[user][item_arg_all[index]]] for index,user in enumerate(users)} # negitems = { user : self.candidate_cur[user][negitems_candidates_all[user][item_arg_all[index]]] for index,user in enumerate(users)}
negitems0 = {user: self.candidate_cur[user][negitems_candidates_all[user][item_arg_all[index][0]]] for
index, user in enumerate(users)}
negitems1 = {user: self.candidate_cur[user][negitems_candidates_all[user][item_arg_all[index][1]]] for
index, user in enumerate(users)}
negitems0 = { user : self.candidate_cur[user][negitems_candidates_all[user][item_arg_all[index][0]]] for index,user in enumerate(users)}
negitems1 = { user : self.candidate_cur[user][negitems_candidates_all[user][item_arg_all[index][1]]] for index,user in enumerate(users)}


############################### ###############################
for i in users: for i in users:
self.final_negative_items[i] = [negitems0[i], negitems1[i]]
self.final_negative_items[i] = [negitems0[i],negitems1[i]]
############################### ###############################


# update Mu # update Mu
negitems_mu_candidates[i] = self.Mu_idx[i] negitems_mu_candidates[i] = self.Mu_idx[i]


negitems_mu = {} negitems_mu = {}
negitems_mu = {user: self.candidate_cur[user][negitems_mu_candidates[user]] for user in users}
negitems_mu = {user:self.candidate_cur[user][negitems_mu_candidates[user]] for user in users}


task = np.array(train_task[2]) task = np.array(train_task[2])
task = np.tile(task, reps=(1, self.S1 * (1 + self.S2_div_S1), 1)) task = np.tile(task, reps=(1, self.S1 * (1 + self.S2_div_S1), 1))
print("nan happend__2 in ratings_mu_candidates") print("nan happend__2 in ratings_mu_candidates")
ratings_mu_candidates = self.MetaTL.fast_forward(task, users) ratings_mu_candidates = self.MetaTL.fast_forward(task, users)
ratings_mu_candidates = ratings_mu_candidates / self.temperature ratings_mu_candidates = ratings_mu_candidates / self.temperature
ratings_mu_candidates = ratings_mu_candidates + 100
ratings_mu_candidates = ratings_mu_candidates + 10
ratings_mu_candidates = np.exp(ratings_mu_candidates) / np.reshape( ratings_mu_candidates = np.exp(ratings_mu_candidates) / np.reshape(
np.sum(np.exp(ratings_mu_candidates), axis=1), [-1, 1]) np.sum(np.exp(ratings_mu_candidates), axis=1), [-1, 1])


else: else:
user_set.add(i) user_set.add(i)
cache_arg = np.random.choice(self.S1 * (1 + self.S2_div_S1), self.S1, cache_arg = np.random.choice(self.S1 * (1 + self.S2_div_S1), self.S1,
p=ratings_mu_candidates[cnt], replace=False)
p=ratings_mu_candidates[cnt], replace=False)
self.Mu_idx[i] = np.array(self.Mu_idx[i])[cache_arg].tolist() self.Mu_idx[i] = np.array(self.Mu_idx[i])[cache_arg].tolist()
cnt += 1 cnt += 1


"TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format( "TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'],
temp['Hits@1'])) temp['Hits@1']))
with open('results.txt', 'a') as f:
with open('results2.txt', 'a') as f:
f.writelines( f.writelines(
"TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r\n\n".format( "TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r\n\n".format(
temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'],
"VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format( "VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'],
temp['Hits@1'])) temp['Hits@1']))
with open("results.txt", 'a') as f:
with open("results2.txt", 'a') as f:
f.writelines( f.writelines(
"VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format( "VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'],

Loading…
Cancel
Save