extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

hyper_main.py 3.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. from hyper_tunning import load_data
  2. import os
  3. from ray.tune.schedulers import ASHAScheduler
  4. from ray.tune import CLIReporter
  5. from ray import tune
  6. from functools import partial
  7. from hyper_tunning import train_melu
  8. import numpy as np
  9. def main(num_samples, max_num_epochs=20, gpus_per_trial=2):
  10. data_dir = os.path.abspath("/media/external_10TB/10TB/maheri/define_task_melu_data")
  11. load_data(data_dir)
  12. config = {
  13. # "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
  14. # "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
  15. # "lr": tune.loguniform(1e-4, 1e-1),
  16. # "batch_size": tune.choice([2, 4, 8, 16])
  17. "transformer": tune.choice(['kronoker']),
  18. "meta_algo": tune.choice(['gbml']),
  19. "first_order": tune.choice([False]),
  20. "adapt_transform": tune.choice([True, False]),
  21. # "local_lr":tune.choice([5e-6,5e-4,5e-3]),
  22. # "lr":tune.choice([5e-5,5e-4]),
  23. "local_lr": tune.loguniform(5e-6, 5e-3),
  24. "lr": tune.loguniform(5e-5, 5e-3),
  25. "batch_size": tune.choice([16, 32, 64]),
  26. "inner": tune.choice([7, 5, 4, 3, 1]),
  27. "test_state": tune.choice(["user_and_item_cold_state"]),
  28. "embedding_dim": tune.choice([16, 32, 64]),
  29. "first_fc_hidden_dim": tune.choice([32, 64, 128]),
  30. "second_fc_hidden_dim": tune.choice([32, 64]),
  31. 'cluster_h1_dim': tune.choice([256, 128, 64]),
  32. 'cluster_h2_dim': tune.choice([128, 64, 32]),
  33. 'cluster_final_dim': tune.choice([64, 32]),
  34. 'cluster_dropout_rate': tune.choice([0, 0.01, 0.1]),
  35. 'cluster_k': tune.choice([3, 5, 7, 9, 11]),
  36. 'temperature': tune.choice([0.1, 0.5, 1.0, 2.0, 10.0]),
  37. 'trainer_dropout_rate': tune.choice([0, 0.01, 0.1]),
  38. }
  39. scheduler = ASHAScheduler(
  40. metric="loss",
  41. mode="min",
  42. max_t=30,
  43. grace_period=10,
  44. reduction_factor=2)
  45. reporter = CLIReporter(
  46. # parameter_columns=["l1", "l2", "lr", "batch_size"],
  47. metric_columns=["loss", "ndcg1", "ndcg3", "training_iteration"])
  48. result = tune.run(
  49. partial(train_melu, data_dir=data_dir),
  50. resources_per_trial={"cpu": 4, "gpu": gpus_per_trial},
  51. config=config,
  52. num_samples=num_samples,
  53. scheduler=scheduler,
  54. progress_reporter=reporter,
  55. log_to_file=True,
  56. # resume=True,
  57. local_dir="./hyper_tunning_all_cold",
  58. name="melu_all_cold_clustered",
  59. )
  60. best_trial = result.get_best_trial("loss", "min", "last")
  61. print("Best trial config: {}".format(best_trial.config))
  62. print("Best trial final validation loss: {}".format(
  63. best_trial.last_result["loss"]))
  64. print("Best trial final validation ndcg1: {}".format(
  65. best_trial.last_result["ndcg1"]))
  66. print("Best trial final validation ndcg3: {}".format(
  67. best_trial.last_result["ndcg3"]))
  68. #
  69. print("=======================================================")
  70. print(result.results_df)
  71. print("=======================================================\n")
  72. # best_trained_model = Net(best_trial.config["l1"], best_trial.config["l2"])
  73. # device = "cpu"
  74. # if torch.cuda.is_available():
  75. # device = "cuda:0"
  76. # if gpus_per_trial > 1:
  77. # best_trained_model = nn.DataParallel(best_trained_model)
  78. # best_trained_model.to(device)
  79. #
  80. # best_checkpoint_dir = best_trial.checkpoint.value
  81. # model_state, optimizer_state = torch.load(os.path.join(
  82. # best_checkpoint_dir, "checkpoint"))
  83. # best_trained_model.load_state_dict(model_state)
  84. #
  85. # test_acc = test_accuracy(best_trained_model, device)
  86. # print("Best trial test set accuracy: {}".format(test_acc))
  87. if __name__ == "__main__":
  88. # You can change the number of GPUs per trial here:
  89. main(num_samples=150, max_num_epochs=25, gpus_per_trial=1)