Browse Source

change hyper parameters

RNN
mohamad maheri 2 years ago
parent
commit
e9474085ef
2 changed files with 11 additions and 13 deletions
  1. 2
    2
      main.py
  2. 9
    11
      trainer.py

+ 2
- 2
main.py View File

@@ -13,7 +13,7 @@ def get_params():
args.add_argument("-seed", "--seed", default=7, type=int)
args.add_argument("-K", "--K", default=3, type=int) #NUMBER OF SHOT

args.add_argument("-dim", "--embed_dim", default=128, type=int)
args.add_argument("-dim", "--embed_dim", default=256, type=int)
args.add_argument("-bs", "--batch_size", default=1024, type=int)
args.add_argument("-lr", "--learning_rate", default=0.001, type=float)

@@ -26,7 +26,7 @@ def get_params():
args.add_argument("-p", "--dropout_p", default=0.5, type=float)

args.add_argument("-gpu", "--device", default=0, type=int)
args.add_argument("--number_of_neg",default=1,type=int)
args.add_argument("--number_of_neg",default=2,type=int)




+ 9
- 11
trainer.py View File

@@ -11,7 +11,7 @@ import gc


class Trainer:
def __init__(self, data_loaders, itemnum, parameter, user_train_num, user_train):
def __init__(self, data_loaders, itemnum, parameter,user_train_num,user_train):
# print(user_train)
self.parameter = parameter
# data loader
@@ -124,14 +124,12 @@ class Trainer:
item_arg_all = np.argmax(variance_candidates_all, axis=1)

# negitems = { user : self.candidate_cur[user][negitems_candidates_all[user][item_arg_all[index]]] for index,user in enumerate(users)}
negitems0 = {user: self.candidate_cur[user][negitems_candidates_all[user][item_arg_all[index][0]]] for
index, user in enumerate(users)}
negitems1 = {user: self.candidate_cur[user][negitems_candidates_all[user][item_arg_all[index][1]]] for
index, user in enumerate(users)}
negitems0 = { user : self.candidate_cur[user][negitems_candidates_all[user][item_arg_all[index][0]]] for index,user in enumerate(users)}
negitems1 = { user : self.candidate_cur[user][negitems_candidates_all[user][item_arg_all[index][1]]] for index,user in enumerate(users)}

###############################
for i in users:
self.final_negative_items[i] = [negitems0[i], negitems1[i]]
self.final_negative_items[i] = [negitems0[i],negitems1[i]]
###############################

# update Mu
@@ -149,7 +147,7 @@ class Trainer:
negitems_mu_candidates[i] = self.Mu_idx[i]

negitems_mu = {}
negitems_mu = {user: self.candidate_cur[user][negitems_mu_candidates[user]] for user in users}
negitems_mu = {user:self.candidate_cur[user][negitems_mu_candidates[user]] for user in users}

task = np.array(train_task[2])
task = np.tile(task, reps=(1, self.S1 * (1 + self.S2_div_S1), 1))
@@ -169,7 +167,7 @@ class Trainer:
print("nan happend__2 in ratings_mu_candidates")
ratings_mu_candidates = self.MetaTL.fast_forward(task, users)
ratings_mu_candidates = ratings_mu_candidates / self.temperature
ratings_mu_candidates = ratings_mu_candidates + 100
ratings_mu_candidates = ratings_mu_candidates + 10
ratings_mu_candidates = np.exp(ratings_mu_candidates) / np.reshape(
np.sum(np.exp(ratings_mu_candidates), axis=1), [-1, 1])

@@ -181,7 +179,7 @@ class Trainer:
else:
user_set.add(i)
cache_arg = np.random.choice(self.S1 * (1 + self.S2_div_S1), self.S1,
p=ratings_mu_candidates[cnt], replace=False)
p=ratings_mu_candidates[cnt], replace=False)
self.Mu_idx[i] = np.array(self.Mu_idx[i])[cache_arg].tolist()
cnt += 1

@@ -404,7 +402,7 @@ class Trainer:
"TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'],
temp['Hits@1']))
with open('results.txt', 'a') as f:
with open('results2.txt', 'a') as f:
f.writelines(
"TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r\n\n".format(
temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'],
@@ -415,7 +413,7 @@ class Trainer:
"VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'],
temp['Hits@1']))
with open("results.txt", 'a') as f:
with open("results2.txt", 'a') as f:
f.writelines(
"VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'],

Loading…
Cancel
Save