Adapted to Movie lens dataset
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_TaNP.py 9.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import os
  2. from datetime import datetime
  3. import time
  4. import numpy as np
  5. import random
  6. import argparse
  7. import pickle
  8. import torch
  9. import torch.nn as nn
  10. import torch.optim as optim
  11. from torch.autograd import Variable
  12. import json
  13. from utils.loader import Preprocess
  14. from TaNP import Trainer
  15. from TaNP_training import training
  16. from utils import helper
  17. from eval import testing
  18. parser = argparse.ArgumentParser()
  19. # parser.add_argument('--data_dir', type=str, default='data/lastfm_20')#1
  20. # parser.add_argument('--model_save_dir', type=str, default='save_model_dir')#1
  21. parser.add_argument('--data_dir', type=str, default='/media/external_10TB/10TB/maheri/melu_data')#1
  22. parser.add_argument('--model_save_dir', type=str, default='/media/external_10TB/10TB/maheri/tanp_data/tanp_models')#1
  23. parser.add_argument('--id', type=str, default='1', help='used for save hyper-parameters.')#1
  24. parser.add_argument('--first_embedding_dim', type=int, default=32, help='Embedding dimension for item and user.')#1
  25. parser.add_argument('--second_embedding_dim', type=int, default=16, help='Embedding dimension for item and user.')#1
  26. parser.add_argument('--z1_dim', type=int, default=32, help='The dimension of z1 in latent path.')
  27. parser.add_argument('--z2_dim', type=int, default=32, help='The dimension of z2 in latent path.')
  28. parser.add_argument('--z_dim', type=int, default=32, help='The dimension of z in latent path.')
  29. parser.add_argument('--enc_h1_dim', type=int, default=64, help='The hidden first dimension of encoder.')
  30. parser.add_argument('--enc_h2_dim', type=int, default=64, help='The hidden second dimension of encoder.')
  31. parser.add_argument('--taskenc_h1_dim', type=int, default=128, help='The hidden first dimension of task encoder.')
  32. parser.add_argument('--taskenc_h2_dim', type=int, default=64, help='The hidden second dimension of task encoder.')
  33. parser.add_argument('--taskenc_final_dim', type=int, default=64, help='The hidden second dimension of task encoder.')
  34. parser.add_argument('--clusters_k', type=int, default=7, help='Cluster numbers of tasks.')
  35. parser.add_argument('--temperature', type=float, default=1.0, help='used for student-t distribution.')
  36. parser.add_argument('--lambda', type=float, default=0.1, help='used to balance the clustering loss and NP loss.')
  37. parser.add_argument('--dec_h1_dim', type=int, default=128, help='The hidden first dimension of encoder.')
  38. parser.add_argument('--dec_h2_dim', type=int, default=128, help='The hidden second dimension of encoder.')
  39. parser.add_argument('--dec_h3_dim', type=int, default=128, help='The hidden third dimension of encoder.')
  40. # used for movie datasets
  41. parser.add_argument('--num_gender', type=int, default=2, help='User information.')#1
  42. parser.add_argument('--num_age', type=int, default=7, help='User information.')#1
  43. parser.add_argument('--num_occupation', type=int, default=21, help='User information.')#1
  44. parser.add_argument('--num_zipcode', type=int, default=3402, help='User information.')#1
  45. parser.add_argument('--num_rate', type=int, default=6, help='Item information.')#1
  46. parser.add_argument('--num_genre', type=int, default=25, help='Item information.')#1
  47. parser.add_argument('--num_director', type=int, default=2186, help='Item information.')#1
  48. parser.add_argument('--num_actor', type=int, default=8030, help='Item information.')#1
  49. parser.add_argument('--dropout_rate', type=float, default=0, help='used in encoder and decoder.')
  50. parser.add_argument('--lr', type=float, default=1e-4, help='Applies to SGD and Adagrad.')#1
  51. parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam or adamax.')
  52. parser.add_argument('--num_epoch', type=int, default=150)#1
  53. parser.add_argument('--batch_size', type=int, default=32)#1
  54. parser.add_argument('--train_ratio', type=float, default=0.7, help='Warm user ratio for training.')#1
  55. parser.add_argument('--valid_ratio', type=float, default=0.1, help='Cold user ratio for validation.')#1
  56. parser.add_argument('--seed', type=int, default=2020)#1
  57. parser.add_argument('--save', type=int, default=0)#1
  58. parser.add_argument('--use_cuda', type=bool, default=torch.cuda.is_available())#1
  59. parser.add_argument('--cpu', action='store_true', help='Ignore CUDA.')#1
  60. parser.add_argument('--support_size', type=int, default=20)#1
  61. parser.add_argument('--query_size', type=int, default=10)#1
  62. parser.add_argument('--max_len', type=int, default=200, help='The max length of interactions for each user.')
  63. parser.add_argument('--context_min', type=int, default=20, help='Minimum size of context range.')
  64. # change for Movie lens
  65. parser.add_argument('--embedding_dim', type=int, default=32, help='embedding dimension for each item/user feature of Movie lens')
  66. parser.add_argument('--first_fc_hidden_dim', type=int, default=64, help='embedding dimension for each item/user feature of Movie lens')
  67. parser.add_argument('--second_fc_hidden_dim', type=int, default=64, help='embedding dimension for each item/user feature of Movie lens')
  68. args = parser.parse_args()
  69. def seed_everything(seed=1023):
  70. random.seed(seed)
  71. torch.manual_seed(seed)
  72. torch.cuda.manual_seed_all(seed)
  73. np.random.seed(seed)
  74. os.environ['PYTHONHASHSEED'] = str(seed)
  75. torch.backends.cudnn.deterministic = True
  76. torch.backends.cudnn.benchmark = False
  77. seed = args.seed
  78. seed_everything(seed)
  79. if args.cpu:
  80. args.use_cuda = False
  81. elif args.use_cuda:
  82. torch.cuda.manual_seed(args.seed)
  83. opt = vars(args)
  84. # print model info
  85. helper.print_config(opt)
  86. helper.ensure_dir(opt["model_save_dir"], verbose=True)
  87. # save model config
  88. helper.save_config(opt, opt["model_save_dir"] + "/" +opt["id"] + '.config', verbose=True)
  89. # record training log
  90. file_logger = helper.FileLogger(opt["model_save_dir"] + '/' + opt['id'] + ".log",
  91. header="# epoch\ttrain_loss\tprecision5\tNDCG5\tMAP5\tprecision7"
  92. "\tNDCG7\tMAP7\tprecision10\tNDCG10\tMAP10")
  93. # change for Movie Lens
  94. # preprocess = Preprocess(opt)
  95. print("Preprocess is done.")
  96. print("Create model TaNP...")
  97. # opt['uf_dim'] = preprocess.uf_dim
  98. # opt['if_dim'] = preprocess.if_dim
  99. trainer = Trainer(opt)
  100. if opt['use_cuda']:
  101. trainer.cuda()
  102. model_filename = "{}/{}.pt".format(opt['model_save_dir'], opt["id"])
  103. # /4 since sup_x, sup_y, query_x, query_y
  104. # change for Movie lens
  105. # training_set_size = int(len(os.listdir("{}/{}/{}".format(opt["data_dir"], "training", "log"))) / 4)
  106. training_set_size = int(len(os.listdir("{}/{}".format(opt["data_dir"], "warm_state"))) / 4)
  107. supp_xs_s = []
  108. supp_ys_s = []
  109. query_xs_s = []
  110. query_ys_s = []
  111. for idx in range(training_set_size):
  112. # supp_xs_s.append(pickle.load(open("{}/{}/{}/supp_x_{}.pkl".format(opt["data_dir"], "training", "log", idx), "rb")))
  113. # supp_ys_s.append(pickle.load(open("{}/{}/{}/supp_y_{}.pkl".format(opt["data_dir"], "training", "log", idx), "rb")))
  114. # query_xs_s.append(pickle.load(open("{}/{}/{}/query_x_{}.pkl".format(opt["data_dir"], "training", "log", idx), "rb")))
  115. # query_ys_s.append(pickle.load(open("{}/{}/{}/query_y_{}.pkl".format(opt["data_dir"], "training", "log", idx), "rb")))
  116. supp_xs_s.append(pickle.load(open("{}/{}/supp_x_{}.pkl".format(opt["data_dir"], "warm_state", idx), "rb")))
  117. supp_ys_s.append(pickle.load(open("{}/{}/supp_y_{}.pkl".format(opt["data_dir"], "warm_state", idx), "rb")))
  118. query_xs_s.append(pickle.load(open("{}/{}/query_x_{}.pkl".format(opt["data_dir"], "warm_state", idx), "rb")))
  119. query_ys_s.append(pickle.load(open("{}/{}/query_y_{}.pkl".format(opt["data_dir"], "warm_state", idx), "rb")))
  120. train_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
  121. del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
  122. # change for Movie lens
  123. # testing_set_size = int(len(os.listdir("{}/{}/{}".format(opt["data_dir"], "testing", "log"))) / 4)
  124. testing_set_size = int(len(os.listdir("{}/{}".format(opt["data_dir"], "user_cold_state"))) / 4)
  125. supp_xs_s = []
  126. supp_ys_s = []
  127. query_xs_s = []
  128. query_ys_s = []
  129. for idx in range(testing_set_size):
  130. # change for Movie lens
  131. # supp_xs_s.append(
  132. # pickle.load(open("{}/{}/{}/supp_x_{}.pkl".format(opt["data_dir"], "testing", "log", idx), "rb")))
  133. # supp_ys_s.append(
  134. # pickle.load(open("{}/{}/{}/supp_y_{}.pkl".format(opt["data_dir"], "testing", "log", idx), "rb")))
  135. # query_xs_s.append(
  136. # pickle.load(open("{}/{}/{}/query_x_{}.pkl".format(opt["data_dir"], "testing", "log", idx), "rb")))
  137. # query_ys_s.append(
  138. # pickle.load(open("{}/{}/{}/query_y_{}.pkl".format(opt["data_dir"], "testing", "log", idx), "rb")))
  139. supp_xs_s.append(
  140. pickle.load(open("{}/{}/supp_x_{}.pkl".format(opt["data_dir"], "user_cold_state", idx), "rb")))
  141. supp_ys_s.append(
  142. pickle.load(open("{}/{}/supp_y_{}.pkl".format(opt["data_dir"], "user_cold_state", idx), "rb")))
  143. query_xs_s.append(
  144. pickle.load(open("{}/{}/query_x_{}.pkl".format(opt["data_dir"], "user_cold_state", idx), "rb")))
  145. query_ys_s.append(
  146. pickle.load(open("{}/{}/query_y_{}.pkl".format(opt["data_dir"], "user_cold_state", idx), "rb")))
  147. test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
  148. del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
  149. print("# epoch\ttrain_loss\tprecision5\tNDCG5\tMAP5\tprecision7\tNDCG7\tMAP7\tprecision10\tNDCG10\tMAP10")
  150. if not os.path.exists(model_filename):
  151. print("Start training...")
  152. training(trainer, opt, train_dataset, test_dataset, batch_size=opt['batch_size'], num_epoch=opt['num_epoch'],
  153. model_save=opt["save"], model_filename=model_filename, logger=file_logger)
  154. else:
  155. print("Load pre-trained model...")
  156. opt = helper.load_config(model_filename[:-2]+"config")
  157. helper.print_config(opt)
  158. trained_state_dict = torch.load(model_filename)
  159. trainer.load_state_dict(trained_state_dict)