Sequential Recommendation for cold-start users with meta transitional learning
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

hyper_main.py 4.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. from ray.tune.schedulers import ASHAScheduler
  2. from ray.tune import CLIReporter
  3. from ray import tune
  4. from functools import partial
  5. from hyper_tunning import train_metatl
  6. import argparse
  7. import numpy as np
  8. import torch
  9. import random
  10. from trainer import *
  11. from utils import *
  12. from sampler import *
  13. import copy
  14. def get_params():
  15. args = argparse.ArgumentParser()
  16. args.add_argument("-data", "--dataset", default="electronics", type=str)
  17. args.add_argument("-seed", "--seed", default=None, type=int)
  18. args.add_argument("-K", "--K", default=3, type=int) #NUMBER OF SHOT
  19. # args.add_argument("-dim", "--embed_dim", default=100, type=int)
  20. args.add_argument("-bs", "--batch_size", default=1024, type=int)
  21. # args.add_argument("-lr", "--learning_rate", default=0.001, type=float)
  22. args.add_argument("-epo", "--epoch", default=1000, type=int)
  23. # args.add_argument("-prt_epo", "--print_epoch", default=100, type=int)
  24. # args.add_argument("-eval_epo", "--eval_epoch", default=1000, type=int)
  25. # args.add_argument("-b", "--beta", default=5, type=float)
  26. # args.add_argument("-m", "--margin", default=1, type=float)
  27. # args.add_argument("-p", "--dropout_p", default=0.5, type=float)
  28. # args.add_argument("-gpu", "--device", default=1, type=int)
  29. args = args.parse_args()
  30. params = {}
  31. for k, v in vars(args).items():
  32. params[k] = v
  33. params['device'] = torch.device('cuda:0')
  34. return params, args
  35. def main(num_samples, gpus_per_trial=2):
  36. params, args = get_params()
  37. if params['seed'] is not None:
  38. SEED = params['seed']
  39. torch.manual_seed(SEED)
  40. torch.cuda.manual_seed(SEED)
  41. torch.backends.cudnn.deterministic = True
  42. np.random.seed(SEED)
  43. random.seed(SEED)
  44. user_train, usernum_train, itemnum, user_input_test, user_test, user_input_valid, user_valid = data_load(args.dataset, args.K)
  45. batch_size = params['batch_size']
  46. # sampler = WarpSampler(user_train, usernum_train, itemnum, batch_size=batch_size, maxlen=args.K, n_workers=1)
  47. # sampler_test = DataLoader(user_input_test, user_test, itemnum, params)
  48. # sampler_valid = DataLoader(user_input_valid, user_valid, itemnum, params)
  49. config = {
  50. # "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
  51. # "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
  52. # "lr": tune.loguniform(1e-4, 1e-1),
  53. # "batch_size": tune.choice([2, 4, 8, 16])
  54. "embed_dim" : tune.choice([50,75,100,125,150,200]),
  55. # "batch_size" : tune.choice([128,256,512,1024,2048]),
  56. "learning_rate" : tune.choice([0.1,0.01,0.005,0.001,0.0001]),
  57. "beta" : tune.choice([0.1,1,5,10]),
  58. "margin" : tune.choice([1]),
  59. # "sampler":sampler,
  60. # "sampler_test":sampler_test,
  61. # "sampler_valid":sampler_valid,
  62. "itemnum":itemnum,
  63. "params":params,
  64. }
  65. scheduler = ASHAScheduler(
  66. metric="MRR",
  67. mode="max",
  68. max_t=params['epoch'],
  69. grace_period=200,
  70. reduction_factor=2)
  71. reporter = CLIReporter(
  72. # parameter_columns=["l1", "l2", "lr", "batch_size"],
  73. metric_columns=["MRR","NDCG10","NDCG5","NDCG1","Hits10","Hits5","Hits1","training_iteration"])
  74. result = tune.run(
  75. train_metatl,
  76. resources_per_trial={"cpu": 4, "gpu": gpus_per_trial},
  77. config=config,
  78. num_samples=num_samples,
  79. scheduler=scheduler,
  80. progress_reporter=reporter,
  81. log_to_file=True,
  82. # resume=True,
  83. local_dir="./ray_local_dir",
  84. name="metatl_rnn1",
  85. )
  86. best_trial = result.get_best_trial("MRR", "max", "last")
  87. print("Best trial config: {}".format(best_trial.config))
  88. print("Best trial final validation loss: {}".format(
  89. best_trial.last_result["loss"]))
  90. print("Best trial final validation MRR: {}".format(
  91. best_trial.last_result["MRR"]))
  92. print("Best trial final validation NDCG@1: {}".format(
  93. best_trial.last_result["NDCG@1"]))
  94. #
  95. print("=======================================================")
  96. print(result.results_df)
  97. print("=======================================================\n")
  98. # best_trained_model = Net(best_trial.config["l1"], best_trial.config["l2"])
  99. # device = "cpu"
  100. # if torch.cuda.is_available():
  101. # device = "cuda:0"
  102. # if gpus_per_trial > 1:
  103. # best_trained_model = nn.DataParallel(best_trained_model)
  104. # best_trained_model.to(device)
  105. #
  106. # best_checkpoint_dir = best_trial.checkpoint.value
  107. # model_state, optimizer_state = torch.load(os.path.join(
  108. # best_checkpoint_dir, "checkpoint"))
  109. # best_trained_model.load_state_dict(model_state)
  110. #
  111. # test_acc = test_accuracy(best_trained_model, device)
  112. # print("Best trial test set accuracy: {}".format(test_acc))
  113. if __name__ == "__main__":
  114. # You can change the number of GPUs per trial here:
  115. main(num_samples=150, gpus_per_trial=1)