Browse Source

hyper parameter tuning including neg_size

RNN
mohamad maheri 2 years ago
parent
commit
372e75db43
2 changed files with 14 additions and 6 deletions
  1. 1
    1
      hyper_main.py
  2. 13
    5
      main.py

+ 1
- 1
hyper_main.py View File

@@ -105,7 +105,7 @@ def main(num_samples, gpus_per_trial=2):
scheduler=scheduler,
progress_reporter=reporter,
log_to_file=True,
# resume=True,
resume=True,
local_dir="/media/external_10TB/10TB/maheri/metaTL_ray/ray_local_dir",
name="bpr_rnn",
)

+ 13
- 5
main.py View File

@@ -24,7 +24,9 @@ def get_params():
args.add_argument("-m", "--margin", default=1, type=float)
args.add_argument("-p", "--dropout_p", default=0.5, type=float)

args.add_argument("-gpu", "--device", default=1, type=int)
args.add_argument("-gpu", "--device", default=0, type=int)
args.add_argument("--number_of_neg",default=5,type=int)



args = args.parse_args()
@@ -32,12 +34,16 @@ def get_params():
for k, v in vars(args).items():
params[k] = v

params['device'] = torch.device('cuda:'+str(args.device))


# params['device'] = torch.device('cuda:'+str(args.device))
params['device'] = args.device
# params['device'] = torch.device('cpu')

return params, args

if __name__ == '__main__':
print(torch.cuda.is_available())
params, args = get_params()

if params['seed'] is not None:
@@ -48,17 +54,19 @@ if __name__ == '__main__':
np.random.seed(SEED)
random.seed(SEED)

print("===============", torch.cuda.device_count(), "=======")
user_train, usernum_train, itemnum, user_input_test, user_test, user_input_valid, user_valid = data_load(args.dataset, args.K)

sampler = WarpSampler(user_train, usernum_train, itemnum, batch_size=args.batch_size, maxlen=args.K, n_workers=3)
sampler = WarpSampler(user_train, usernum_train, itemnum, batch_size=args.batch_size, maxlen=args.K, n_workers=3,params=params)

sampler_test = DataLoader(user_input_test, user_test, itemnum, params)

sampler_valid = DataLoader(user_input_valid, user_valid, itemnum, params)

trainer = Trainer([sampler, sampler_valid, sampler_test], itemnum, params)
print("===============", torch.cuda.device_count(), "=======")

trainer = Trainer([sampler, sampler_valid, sampler_test], itemnum, params)
print("===============", torch.cuda.device_count(), "=======")
trainer.train()

sampler.close()

Loading…
Cancel
Save