|
|
@@ -24,7 +24,9 @@ def get_params(): |
|
|
|
args.add_argument("-m", "--margin", default=1, type=float) |
|
|
|
args.add_argument("-p", "--dropout_p", default=0.5, type=float) |
|
|
|
|
|
|
|
args.add_argument("-gpu", "--device", default=1, type=int) |
|
|
|
args.add_argument("-gpu", "--device", default=0, type=int) |
|
|
|
args.add_argument("--number_of_neg",default=5,type=int) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args = args.parse_args() |
|
|
@@ -32,12 +34,16 @@ def get_params(): |
|
|
|
for k, v in vars(args).items(): |
|
|
|
params[k] = v |
|
|
|
|
|
|
|
params['device'] = torch.device('cuda:'+str(args.device)) |
|
|
|
|
|
|
|
|
|
|
|
# params['device'] = torch.device('cuda:'+str(args.device)) |
|
|
|
params['device'] = args.device |
|
|
|
# params['device'] = torch.device('cpu') |
|
|
|
|
|
|
|
return params, args |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
print(torch.cuda.is_available()) |
|
|
|
params, args = get_params() |
|
|
|
|
|
|
|
if params['seed'] is not None: |
|
|
@@ -48,17 +54,19 @@ if __name__ == '__main__': |
|
|
|
np.random.seed(SEED) |
|
|
|
random.seed(SEED) |
|
|
|
|
|
|
|
print("===============", torch.cuda.device_count(), "=======") |
|
|
|
user_train, usernum_train, itemnum, user_input_test, user_test, user_input_valid, user_valid = data_load(args.dataset, args.K) |
|
|
|
|
|
|
|
sampler = WarpSampler(user_train, usernum_train, itemnum, batch_size=args.batch_size, maxlen=args.K, n_workers=3) |
|
|
|
sampler = WarpSampler(user_train, usernum_train, itemnum, batch_size=args.batch_size, maxlen=args.K, n_workers=3,params=params) |
|
|
|
|
|
|
|
sampler_test = DataLoader(user_input_test, user_test, itemnum, params) |
|
|
|
|
|
|
|
sampler_valid = DataLoader(user_input_valid, user_valid, itemnum, params) |
|
|
|
|
|
|
|
|
|
|
|
trainer = Trainer([sampler, sampler_valid, sampler_test], itemnum, params) |
|
|
|
print("===============", torch.cuda.device_count(), "=======") |
|
|
|
|
|
|
|
trainer = Trainer([sampler, sampler_valid, sampler_test], itemnum, params) |
|
|
|
print("===============", torch.cuda.device_count(), "=======") |
|
|
|
trainer.train() |
|
|
|
|
|
|
|
sampler.close() |