Sequential Recommendation for cold-start users with meta transitional learning
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

hyper_tunning.py 2.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import os
  2. import torch
  3. import torch.nn as nn
  4. from ray import tune
  5. import pickle
  6. import random
  7. import gc
  8. from trainer import Trainer
  9. import numpy as np
  10. from utils import *
  11. from sampler import *
  12. import os
  13. def train_metatl(conf,checkpoint_dir=None):
  14. SEED = conf["params"]['seed']
  15. torch.manual_seed(SEED)
  16. torch.cuda.manual_seed(SEED)
  17. torch.backends.cudnn.deterministic = True
  18. np.random.seed(SEED)
  19. random.seed(SEED)
  20. params = conf['params']
  21. params['batch_size'] = conf['batch_size']
  22. params['number_of_neg'] = conf['number_of_neg']
  23. user_train, usernum_train, itemnum, user_input_test, user_test, user_input_valid, user_valid = data_load(params['dataset'], params['K'])
  24. sampler = WarpSampler(user_train, usernum_train, itemnum, batch_size=params['batch_size'], maxlen=params['K'], n_workers=1,params=params)
  25. sampler_test = DataLoader(user_input_test, user_test, itemnum, params)
  26. sampler_valid = DataLoader(user_input_valid, user_valid, itemnum, params)
  27. ps = {
  28. "batch_size" : conf['batch_size'],
  29. "learning_rate" : conf['learning_rate'],
  30. "epoch" : conf["params"]['epoch'],
  31. "beta" : conf['beta'],
  32. "embed_dim" : conf['embed_dim'],
  33. "margin" : conf['margin'],
  34. "K" : conf["params"]['K'],
  35. "number_of_neg" : conf["number_of_neg"],
  36. "loss_function" : conf["loss_function"],
  37. "eval_epoch" : conf["eval_epoch"],
  38. "device" : params['device']
  39. }
  40. trainer = Trainer([sampler, sampler_valid, sampler_test], conf["itemnum"], ps)
  41. if checkpoint_dir:
  42. print("===================== using checkpoint =====================")
  43. model_state, optimizer_state = torch.load(
  44. os.path.join(checkpoint_dir, "checkpoint"))
  45. trainer.MetaTL.load_state_dict(model_state)
  46. trainer.optimizer.load_state_dict(optimizer_state)
  47. for epoch in range(int(ps['epoch']/ps['eval_epoch'])):
  48. for e in range(ps['eval_epoch']):
  49. # sample one batch from data_loader
  50. train_task, curr_rel = trainer.train_data_loader.next_batch()
  51. loss, _, _ = trainer.do_one_step(train_task, iseval=False, curr_rel=curr_rel)
  52. # do evaluation on specific epoch
  53. valid_data = trainer.eval(istest=False, epoch=(-1))
  54. with tune.checkpoint_dir(epoch) as checkpoint_dir:
  55. path = os.path.join(checkpoint_dir, "checkpoint")
  56. torch.save((trainer.MetaTL.state_dict(), trainer.optimizer.state_dict()), path)
  57. tune.report(
  58. MRR=valid_data["MRR"], NDCG10=valid_data['NDCG@10'], NDCG5=valid_data["NDCG@5"], NDCG1=valid_data["NDCG@1"],
  59. Hits10=valid_data["Hits@10"], Hits5=valid_data["Hits@5"], Hits1=valid_data["Hits@1"],
  60. training_iteration=epoch*ps['eval_epoch']
  61. )