meta-learning approach for solving cold start problem
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 3.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import os
  2. import torch
  3. import pickle
  4. from MeLU import MeLU
  5. from options import config
  6. from model_training import training
  7. from model_test import test
  8. from data_generation import generate
  9. from evidence_candidate import selection
  10. if __name__ == "__main__":
  11. # master_path= "./ml"
  12. master_path = "/media/external_10TB/10TB/maheri/melu_data"
  13. if not os.path.exists("{}/".format(master_path)):
  14. print("generating data phase started")
  15. os.mkdir("{}/".format(master_path))
  16. # preparing dataset. It needs about 22GB of your hard disk space.
  17. generate(master_path)
  18. # training model.
  19. melu = MeLU(config)
  20. model_filename = "{}/models2.pkl".format(master_path)
  21. if not os.path.exists(model_filename):
  22. print("training phase started")
  23. # Load training dataset.
  24. training_set_size = int(len(os.listdir("{}/warm_state".format(master_path))) / 4)
  25. supp_xs_s = []
  26. supp_ys_s = []
  27. query_xs_s = []
  28. query_ys_s = []
  29. for idx in range(training_set_size):
  30. supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, idx), "rb")))
  31. supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, idx), "rb")))
  32. query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, idx), "rb")))
  33. query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, idx), "rb")))
  34. total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
  35. del(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
  36. training(melu, total_dataset, batch_size=config['batch_size'], num_epoch=config['num_epoch'], model_save=True, model_filename=model_filename)
  37. else:
  38. trained_state_dict = torch.load(model_filename)
  39. melu.load_state_dict(trained_state_dict)
  40. print("training finished")
  41. # selecting evidence candidates.
  42. # evidence_candidate_list = selection(melu, master_path, config['num_candidate'])
  43. # for movie, score in evidence_candidate_list:
  44. # print(movie, score)
  45. print("start of test phase")
  46. test_state = 'user_and_item_cold_state'
  47. test_dataset = None
  48. test_set_size = int(len(os.listdir("{}/{}".format(master_path,test_state))) / 4)
  49. supp_xs_s = []
  50. supp_ys_s = []
  51. query_xs_s = []
  52. query_ys_s = []
  53. for idx in range(test_set_size):
  54. supp_xs_s.append(pickle.load(open("{}/{}/supp_x_{}.pkl".format(master_path,test_state, idx), "rb")))
  55. supp_ys_s.append(pickle.load(open("{}/{}/supp_y_{}.pkl".format(master_path,test_state, idx), "rb")))
  56. query_xs_s.append(pickle.load(open("{}/{}/query_x_{}.pkl".format(master_path,test_state, idx), "rb")))
  57. query_ys_s.append(pickle.load(open("{}/{}/query_y_{}.pkl".format(master_path,test_state, idx), "rb")))
  58. test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
  59. del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
  60. test(melu, test_dataset, batch_size=config['batch_size'], num_epoch=config['num_epoch'])