meta-learning approach for solving cold start problem
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 2.1KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import os
  2. import torch
  3. import pickle
  4. from MeLU import MeLU
  5. from options import config
  6. from model_training import training
  7. from data_generation import generate
  8. from evidence_candidate import selection
  9. if __name__ == "__main__":
  10. # master_path= "./ml"
  11. master_path = "/media/external_3TB/3TB/rafie/maheri/melr"
  12. # master_path = "/media/external_10TB/10TB/pourmand/ml"
  13. if not os.path.exists("{}/".format(master_path)):
  14. print("inajm")
  15. os.mkdir("{}/".format(master_path))
  16. # preparing dataset. It needs about 22GB of your hard disk space.
  17. generate(master_path)
  18. # # training model.
  19. # melu = MeLU(config)
  20. # model_filename = "{}/models.pkl".format(master_path)
  21. # if not os.path.exists(model_filename):
  22. # # Load training dataset.
  23. # training_set_size = int(len(os.listdir("{}/warm_state".format(master_path))) / 4)
  24. # supp_xs_s = []
  25. # supp_ys_s = []
  26. # query_xs_s = []
  27. # query_ys_s = []
  28. # for idx in range(training_set_size):
  29. # supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, idx), "rb")))
  30. # supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, idx), "rb")))
  31. # query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, idx), "rb")))
  32. # query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, idx), "rb")))
  33. # total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
  34. # del(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
  35. # training(melu, total_dataset, batch_size=config['batch_size'], num_epoch=config['num_epoch'], model_save=True, model_filename=model_filename)
  36. # else:
  37. # trained_state_dict = torch.load(model_filename)
  38. # melu.load_state_dict(trained_state_dict)
  39. #
  40. # # selecting evidence candidates.
  41. # evidence_candidate_list = selection(melu, master_path, config['num_candidate'])
  42. # for movie, score in evidence_candidate_list:
  43. # print(movie, score)