Sequential Recommendation for cold-start users with meta transitional learning
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

hyper_main.py 5.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. from ray.tune.schedulers import ASHAScheduler
  2. from ray.tune import CLIReporter
  3. from ray import tune
  4. from functools import partial
  5. from hyper_tunning import train_metatl
  6. import argparse
  7. import numpy as np
  8. import torch
  9. import random
  10. from trainer import *
  11. from utils import *
  12. from sampler import *
  13. import copy
  14. def get_params():
  15. args = argparse.ArgumentParser()
  16. args.add_argument("-data", "--dataset", default="electronics", type=str)
  17. args.add_argument("-seed", "--seed", default=None, type=int)
  18. args.add_argument("-K", "--K", default=3, type=int) #NUMBER OF SHOT
  19. # args.add_argument("-dim", "--embed_dim", default=100, type=int)
  20. args.add_argument("-bs", "--batch_size", default=1024, type=int)
  21. # args.add_argument("-lr", "--learning_rate", default=0.001, type=float)
  22. args.add_argument("-epo", "--epoch", default=100000, type=int)
  23. # args.add_argument("-prt_epo", "--print_epoch", default=100, type=int)
  24. # args.add_argument("-eval_epo", "--eval_epoch", default=1000, type=int)
  25. # args.add_argument("-b", "--beta", default=5, type=float)
  26. # args.add_argument("-m", "--margin", default=1, type=float)
  27. # args.add_argument("-p", "--dropout_p", default=0.5, type=float)
  28. # args.add_argument("-gpu", "--device", default=1, type=int)
  29. args = args.parse_args()
  30. params = {}
  31. for k, v in vars(args).items():
  32. params[k] = v
  33. params['device'] = torch.device('cuda:0')
  34. return params, args
  35. def main(num_samples, gpus_per_trial=2):
  36. params, args = get_params()
  37. if params['seed'] is not None:
  38. SEED = params['seed']
  39. torch.manual_seed(SEED)
  40. torch.cuda.manual_seed(SEED)
  41. torch.backends.cudnn.deterministic = True
  42. np.random.seed(SEED)
  43. random.seed(SEED)
  44. user_train, usernum_train, itemnum, user_input_test, user_test, user_input_valid, user_valid = data_load(args.dataset, args.K)
  45. batch_size = params['batch_size']
  46. # sampler = WarpSampler(user_train, usernum_train, itemnum, batch_size=batch_size, maxlen=args.K, n_workers=1)
  47. # sampler_test = DataLoader(user_input_test, user_test, itemnum, params)
  48. # sampler_valid = DataLoader(user_input_valid, user_valid, itemnum, params)
  49. config = {
  50. # "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
  51. # "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
  52. # "lr": tune.loguniform(1e-4, 1e-1),
  53. "embed_dim" : tune.choice([50,75,100,125,150,200,300]),
  54. # "batch_size" : tune.choice([128,256,512,1024,2048]),
  55. "learning_rate" : tune.choice([0.1,0.01,0.004,0.005,0.007,0.001,0.0001]),
  56. "beta" : tune.choice([0.05,0.1,1,4,4.5,5,5.5,6,10]),
  57. "margin" : tune.choice([1,0.9,0.8,1.1,1.2]),
  58. # "sampler":sampler,
  59. # "sampler_test":sampler_test,
  60. # "sampler_valid":sampler_valid,
  61. "itemnum":itemnum,
  62. "params":params,
  63. }
  64. scheduler = ASHAScheduler(
  65. metric="MRR",
  66. mode="max",
  67. max_t=params['epoch'],
  68. grace_period=200,
  69. reduction_factor=2)
  70. reporter = CLIReporter(
  71. # parameter_columns=["l1", "l2", "lr", "batch_size"],
  72. metric_columns=["MRR","NDCG10","NDCG5","NDCG1","Hits10","Hits5","Hits1","training_iteration"])
  73. result = tune.run(
  74. train_metatl,
  75. resources_per_trial={"cpu": 4, "gpu": gpus_per_trial},
  76. config=config,
  77. num_samples=num_samples,
  78. scheduler=scheduler,
  79. progress_reporter=reporter,
  80. log_to_file=True,
  81. # resume=True,
  82. local_dir="/media/external_10TB/10TB/maheri/metaTL_ray/ray_local_dir",
  83. name="metatl_rnn1",
  84. )
  85. best_trial = result.get_best_trial("MRR", "max", "last")
  86. print("Best trial config: {}".format(best_trial.config))
  87. print("Best trial final validation loss: {}".format(
  88. best_trial.last_result["loss"]))
  89. print("Best trial final validation MRR: {}".format(
  90. best_trial.last_result["MRR"]))
  91. print("Best trial final validation NDCG@1: {}".format(
  92. best_trial.last_result["NDCG@1"]))
  93. #
  94. print("=======================================================")
  95. print(result.results_df)
  96. print("=======================================================\n")
  97. # best_trained_model = Net(best_trial.config["l1"], best_trial.config["l2"])
  98. # device = "cpu"
  99. # if torch.cuda.is_available():
  100. # device = "cuda:0"
  101. # if gpus_per_trial > 1:
  102. # best_trained_model = nn.DataParallel(best_trained_model)
  103. # best_trained_model.to(device)
  104. #
  105. # best_checkpoint_dir = best_trial.checkpoint.value
  106. # model_state, optimizer_state = torch.load(os.path.join(
  107. # best_checkpoint_dir, "checkpoint"))
  108. # best_trained_model.load_state_dict(model_state)
  109. #
  110. # test_acc = test_accuracy(best_trained_model, device)
  111. # print("Best trial test set accuracy: {}".format(test_acc))
  112. if __name__ == "__main__":
  113. # You can change the number of GPUs per trial here:
  114. main(num_samples=150, gpus_per_trial=1)