Browse Source

hyper parameter tuning including neg_size

RNN
mohamad maheri 2 years ago
parent
commit
372e75db43
2 changed files with 14 additions and 6 deletions
  1. 1
    1
      hyper_main.py
  2. 13
    5
      main.py

+ 1
- 1
hyper_main.py View File

scheduler=scheduler, scheduler=scheduler,
progress_reporter=reporter, progress_reporter=reporter,
log_to_file=True, log_to_file=True,
# resume=True,
resume=True,
local_dir="/media/external_10TB/10TB/maheri/metaTL_ray/ray_local_dir", local_dir="/media/external_10TB/10TB/maheri/metaTL_ray/ray_local_dir",
name="bpr_rnn", name="bpr_rnn",
) )

+ 13
- 5
main.py View File

args.add_argument("-m", "--margin", default=1, type=float) args.add_argument("-m", "--margin", default=1, type=float)
args.add_argument("-p", "--dropout_p", default=0.5, type=float) args.add_argument("-p", "--dropout_p", default=0.5, type=float)


args.add_argument("-gpu", "--device", default=1, type=int)
args.add_argument("-gpu", "--device", default=0, type=int)
args.add_argument("--number_of_neg",default=5,type=int)





args = args.parse_args() args = args.parse_args()
for k, v in vars(args).items(): for k, v in vars(args).items():
params[k] = v params[k] = v


params['device'] = torch.device('cuda:'+str(args.device))


# params['device'] = torch.device('cuda:'+str(args.device))
params['device'] = args.device
# params['device'] = torch.device('cpu') # params['device'] = torch.device('cpu')


return params, args return params, args


if __name__ == '__main__': if __name__ == '__main__':
print(torch.cuda.is_available())
params, args = get_params() params, args = get_params()


if params['seed'] is not None: if params['seed'] is not None:
np.random.seed(SEED) np.random.seed(SEED)
random.seed(SEED) random.seed(SEED)


print("===============", torch.cuda.device_count(), "=======")
user_train, usernum_train, itemnum, user_input_test, user_test, user_input_valid, user_valid = data_load(args.dataset, args.K) user_train, usernum_train, itemnum, user_input_test, user_test, user_input_valid, user_valid = data_load(args.dataset, args.K)


sampler = WarpSampler(user_train, usernum_train, itemnum, batch_size=args.batch_size, maxlen=args.K, n_workers=3)
sampler = WarpSampler(user_train, usernum_train, itemnum, batch_size=args.batch_size, maxlen=args.K, n_workers=3,params=params)


sampler_test = DataLoader(user_input_test, user_test, itemnum, params) sampler_test = DataLoader(user_input_test, user_test, itemnum, params)


sampler_valid = DataLoader(user_input_valid, user_valid, itemnum, params) sampler_valid = DataLoader(user_input_valid, user_valid, itemnum, params)


trainer = Trainer([sampler, sampler_valid, sampler_test], itemnum, params)
print("===============", torch.cuda.device_count(), "=======")


trainer = Trainer([sampler, sampler_valid, sampler_test], itemnum, params)
print("===============", torch.cuda.device_count(), "=======")
trainer.train() trainer.train()


sampler.close() sampler.close()

Loading…
Cancel
Save