Sequential Recommendation for cold-start users with meta transitional learning
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer.py 19KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. from models import *
  2. import os
  3. import sys
  4. import shutil
  5. import logging
  6. import numpy as np
  7. import random
  8. import copy
  9. from operator import itemgetter
  10. import gc
  11. class Trainer:
  12. def __init__(self, data_loaders, itemnum, parameter,user_train_num,user_train):
  13. # print(user_train)
  14. self.parameter = parameter
  15. # data loader
  16. self.train_data_loader = data_loaders[0]
  17. self.dev_data_loader = data_loaders[1]
  18. self.test_data_loader = data_loaders[2]
  19. # parameters
  20. self.batch_size = parameter['batch_size']
  21. self.learning_rate = parameter['learning_rate']
  22. self.epoch = parameter['epoch']
  23. self.device = parameter['device']
  24. self.MetaTL = MetaTL(itemnum, parameter)
  25. self.MetaTL.to(parameter['device'])
  26. self.optimizer = torch.optim.Adam(self.MetaTL.parameters(), self.learning_rate)
  27. if parameter['eval_epoch']:
  28. self.eval_epoch = parameter['eval_epoch']
  29. else:
  30. self.eval_epoch = 1000
  31. self.varset_size = parameter['varset_size']
  32. self.user_train = user_train
  33. self.warmup = parameter['warmup']
  34. self.alpha = parameter['alpha']
  35. self.S1 = parameter['S1']
  36. self.S2_div_S1 = parameter['S2_div_S1']
  37. self.temperature = parameter['temperature']
  38. self.itemnum = itemnum
  39. self.user_train_num = user_train_num
  40. # init the two candidate sets for monitoring variance
  41. self.candidate_cur = np.random.choice(itemnum, [user_train_num + 1, self.varset_size])
  42. # for i in range(1,user_train_num+1):
  43. # for j in range(self.varset_size):
  44. # while self.candidate_cur[i, j] in user_train[i]:
  45. # self.candidate_cur[i, j] = random.randint(1, itemnum)
  46. # self.candidate_nxt = [np.random.choice(itemnum, [user_train_num+1, self.varset_size]) for _ in range(5)]
  47. # for c in range(5):
  48. # for i in range(1,user_train_num+1):
  49. # for j in range(self.varset_size):
  50. # while self.candidate_nxt[c][i, j] in user_train[i]:
  51. # self.candidate_nxt[c][i, j] = random.randint(1, itemnum)
  52. self.Mu_idx = {}
  53. for i in range(user_train_num + 1):
  54. Mu_idx_tmp = random.sample(list(range(self.varset_size)), self.S1)
  55. self.Mu_idx[i] = Mu_idx_tmp
  56. # todo : calculate score of positive items
  57. self.score_cand_cur = {}
  58. self.score_pos_cur = {}
  59. # final candidate after execution of change_mu (after one_step) (for later epochs)
  60. self.final_negative_items = {}
  61. def change_mu(self, p_score, n_score, epoch_cur, users, train_task):
  62. negitems = {}
  63. negitems_candidates_all = {}
  64. # for i in users:
  65. # negitems_candidates_all[i] = self.Mu_idx[i]
  66. negitems_candidates_all = self.Mu_idx.copy()
  67. ratings_positems = p_score.cpu().detach().numpy()
  68. ratings_positems = np.reshape(ratings_positems, [-1])
  69. # added
  70. cnt = 0
  71. for i in users:
  72. self.score_pos_cur[i] = ratings_positems[cnt]
  73. cnt += 1
  74. Mu_items_all = {index: value[negitems_candidates_all[i]] for index, value in enumerate(self.candidate_cur)}
  75. task = np.array(train_task[2])
  76. task = np.tile(task, reps=(1, self.S1, 1))
  77. task[:, :, 2] = np.array(itemgetter(*users)(Mu_items_all))
  78. ratings_candidates_all = self.MetaTL.fast_forward(task, users)
  79. hisscore_candidates_all = [self.score_cand_cur[i][:, negitems_candidates_all[i]] for user in users]
  80. hisscore_pos_all = ratings_positems.copy()
  81. hisscore_candidates_all = np.array(hisscore_candidates_all).transpose((1, 0, 2))
  82. hisscore_pos_all = np.array(hisscore_pos_all)
  83. hisscore_pos_all = hisscore_pos_all[:, np.newaxis]
  84. hisscore_pos_all = np.tile(hisscore_pos_all, (hisscore_candidates_all.shape[0], 1, 1))
  85. hislikelihood_candidates_all = 1 / (1 + np.exp(hisscore_pos_all - hisscore_candidates_all))
  86. mean_candidates_all = np.mean(hislikelihood_candidates_all[:, :], axis=0)
  87. variance_candidates_all = np.zeros(mean_candidates_all.shape)
  88. for i in range(hislikelihood_candidates_all.shape[0]):
  89. variance_candidates_all += (hislikelihood_candidates_all[i, :, :] - mean_candidates_all) ** 2
  90. variance_candidates_all = np.sqrt(variance_candidates_all / hislikelihood_candidates_all.shape[0])
  91. likelihood_candidates_all = \
  92. 1 / (1 + np.exp(np.expand_dims(ratings_positems, -1) - ratings_candidates_all))
  93. # Top sampling strategy by score + alpha * std
  94. item_arg_all = None
  95. if self.alpha >= 0:
  96. # item_arg_all = np.argmax(likelihood_candidates_all +
  97. # self.alpha * min(1, epoch_cur / self.warmup)
  98. # * variance_candidates_all, axis=1)
  99. a = likelihood_candidates_all + self.alpha * min(1, epoch_cur / self.warmup) * variance_candidates_all
  100. item_arg_all = np.argpartition(a, kth=(-2), axis=1)
  101. item_arg_all = np.array(item_arg_all)[:, -2:]
  102. else:
  103. item_arg_all = np.argmax(variance_candidates_all, axis=1)
  104. # negitems = { user : self.candidate_cur[user][negitems_candidates_all[user][item_arg_all[index]]] for index,user in enumerate(users)}
  105. negitems0 = { user : self.candidate_cur[user][negitems_candidates_all[user][item_arg_all[index][0]]] for index,user in enumerate(users)}
  106. negitems1 = { user : self.candidate_cur[user][negitems_candidates_all[user][item_arg_all[index][1]]] for index,user in enumerate(users)}
  107. ###############################
  108. for i in users:
  109. self.final_negative_items[i] = [negitems0[i],negitems1[i]]
  110. ###############################
  111. # update Mu
  112. negitems_mu_candidates = {}
  113. for i in users:
  114. Mu_set = set(self.Mu_idx[i])
  115. while len(self.Mu_idx[i]) < self.S1 * (1 + self.S2_div_S1):
  116. random_item = random.randint(0, self.candidate_cur.shape[1] - 1)
  117. while random_item in Mu_set:
  118. random_item = random.randint(0, self.candidate_cur.shape[1] - 1)
  119. self.Mu_idx[i].append(random_item)
  120. negitems_mu_candidates[i] = self.Mu_idx[i]
  121. negitems_mu = {}
  122. negitems_mu = {user:self.candidate_cur[user][negitems_mu_candidates[user]] for user in users}
  123. task = np.array(train_task[2])
  124. task = np.tile(task, reps=(1, self.S1 * (1 + self.S2_div_S1), 1))
  125. task[:, :, 2] = np.array(itemgetter(*users)(negitems_mu))
  126. ratings_mu_candidates = self.MetaTL.fast_forward(task, users)
  127. ratings_mu_candidates = ratings_mu_candidates / self.temperature
  128. if np.any(np.isnan(ratings_mu_candidates)):
  129. print("nan happend in ratings_mu_candidates")
  130. ratings_mu_candidates = np.nan_to_num(ratings_mu_candidates)
  131. ratings_mu_candidates = np.exp(ratings_mu_candidates) / np.reshape(
  132. np.sum(np.exp(ratings_mu_candidates), axis=1), [-1, 1])
  133. if np.any(np.isnan(ratings_mu_candidates)):
  134. print("nan happend__2 in ratings_mu_candidates")
  135. ratings_mu_candidates = self.MetaTL.fast_forward(task, users)
  136. ratings_mu_candidates = ratings_mu_candidates / self.temperature
  137. ratings_mu_candidates = ratings_mu_candidates + 10
  138. ratings_mu_candidates = np.exp(ratings_mu_candidates) / np.reshape(
  139. np.sum(np.exp(ratings_mu_candidates), axis=1), [-1, 1])
  140. user_set = set()
  141. cnt = 0
  142. for i in users:
  143. if i in user_set:
  144. continue
  145. else:
  146. user_set.add(i)
  147. cache_arg = np.random.choice(self.S1 * (1 + self.S2_div_S1), self.S1,
  148. p=ratings_mu_candidates[cnt], replace=False)
  149. self.Mu_idx[i] = np.array(self.Mu_idx[i])[cache_arg].tolist()
  150. cnt += 1
  151. second_cand = 0
  152. del negitems, ratings_positems, Mu_items_all, task, ratings_candidates_all, hisscore_candidates_all, hisscore_pos_all
  153. del hislikelihood_candidates_all, mean_candidates_all, variance_candidates_all, likelihood_candidates_all, second_cand
  154. del negitems_mu, ratings_mu_candidates, user_set
  155. gc.collect()
  156. def change_candidate(self, epoch_count):
  157. score_1epoch_nxt = []
  158. for c in range(5):
  159. # todo: implement proper funciton
  160. pred = self.MetaTL(self.MetaTL.rel_q_sharing.keys(), self.candidate_nxt[c])
  161. score_1epoch_nxt.append(np.array(pred))
  162. # score_1epoch_nxt.append(np.array(/
  163. # [EvalUser.predict_fast(model, sess, num_user, num_item, parallel_users=100,
  164. # predict_data=candidate_nxt[c])]))
  165. # score_1epoch_pos = np.array(
  166. # [EvalUser.predict_pos(model, sess, num_user, max_posid, parallel_users=100, predict_data=train_pos)])
  167. # todo: implement proper function
  168. score_1epoch_pos = self.MetaTL(user_train, train_data)
  169. # delete the score_cand_cur[0,:,:] at the earlist timestamp
  170. if epoch_count >= 5 or epoch_count == 0:
  171. self.score_pos_cur = np.delete(self.score_pos_cur, 0, 0)
  172. for c in range(5):
  173. self.score_cand_nxt[c] = np.concatenate([self.score_cand_nxt[c], score_1epoch_nxt[c]], axis=0)
  174. self.score_pos_cur = np.concatenate([self.score_pos_cur, score_1epoch_pos], axis=0)
  175. score_cand_cur = np.copy(self.score_cand_nxt[0])
  176. candidate_cur = np.copy(self.candidate_nxt[0])
  177. for c in range(4):
  178. self.candidate_nxt[c] = np.copy(self.candidate_nxt[c + 1])
  179. self.score_cand_nxt[c] = np.copy(self.score_cand_nxt[c + 1])
  180. self.candidate_nxt[4] = np.random.choice(list(range(1, self.itemnum)), [self.user_train_num, self.varset_size])
  181. for i in range(self.user_train_num):
  182. for j in range(self.varset_size):
  183. while self.candidate_nxt[4][i, j] in self.user_train[i]:
  184. self.candidate_nxt[4][i, j] = random.randint(0, self.itemnum - 1)
  185. self.score_cand_nxt[4] = np.delete(self.score_cand_nxt[4], list(range(5)), 0)
  186. def rank_predict(self, data, x, ranks):
  187. # query_idx is the idx of positive score
  188. query_idx = x.shape[0] - 1
  189. # sort all scores with descending, because more plausible triple has higher score
  190. _, idx = torch.sort(x, descending=True)
  191. rank = list(idx.cpu().numpy()).index(query_idx) + 1
  192. ranks.append(rank)
  193. # update data
  194. if rank <= 10:
  195. data['Hits@10'] += 1
  196. data['NDCG@10'] += 1 / np.log2(rank + 1)
  197. if rank <= 5:
  198. data['Hits@5'] += 1
  199. data['NDCG@5'] += 1 / np.log2(rank + 1)
  200. if rank == 1:
  201. data['Hits@1'] += 1
  202. data['NDCG@1'] += 1 / np.log2(rank + 1)
  203. data['MRR'] += 1.0 / rank
  204. def do_one_step(self, task, iseval=False, curr_rel='', epoch=None, train_task=None, epoch_count=None):
  205. loss, p_score, n_score = 0, 0, 0
  206. if not iseval:
  207. task_new = copy.deepcopy(np.array(task[2]))
  208. cnt = 0
  209. for user in curr_rel:
  210. if user in self.final_negative_items:
  211. for index, t in enumerate(task[1][cnt]):
  212. if index % 2 == 0:
  213. t[2] = self.final_negative_items[user][0]
  214. else:
  215. t[2] = self.final_negative_items[user][1]
  216. cnt += 1
  217. self.optimizer.zero_grad()
  218. p_score, n_score = self.MetaTL(task, iseval, curr_rel)
  219. y = torch.Tensor([1]).to(self.device)
  220. loss = self.MetaTL.loss_func(p_score, n_score, y)
  221. loss.backward()
  222. self.optimizer.step()
  223. # task_new = np.array(task[2])
  224. task_new = np.tile(task_new, reps=(1, self.varset_size, 1))
  225. task_new[:, :, 2] = np.array(itemgetter(*curr_rel)(self.candidate_cur))
  226. data = self.MetaTL.fast_forward(task_new, curr_rel)
  227. # prepare score_cand_cur (make all users to have the same number of history scores)
  228. temp = min(epoch_count, 4)
  229. for index, user in enumerate(curr_rel):
  230. if (not user in self.score_cand_cur):
  231. self.score_cand_cur[user] = np.array([data[index]])
  232. elif len(self.score_cand_cur[user]) <= temp:
  233. self.score_cand_cur[user] = np.concatenate(
  234. [self.score_cand_cur[user], np.array([data[index]])], axis=0)
  235. self.change_mu(p_score, n_score, epoch_count, curr_rel, task)
  236. elif curr_rel != '':
  237. p_score, n_score = self.MetaTL(task, iseval, curr_rel)
  238. y = torch.Tensor([1]).to(self.device)
  239. loss = self.MetaTL.loss_func(p_score, n_score, y)
  240. return loss, p_score, n_score
  241. def train(self):
  242. # initialization
  243. best_epoch = 0
  244. best_value = 0
  245. bad_counts = 0
  246. epoch_count = 0
  247. # training by epoch
  248. for e in range(self.epoch):
  249. if e % 10 == 0: print("epoch:", e)
  250. # sample one batch from data_loader
  251. train_task, curr_rel = self.train_data_loader.next_batch()
  252. # change task negative samples using mu_idx
  253. loss, _, _ = self.do_one_step(train_task, iseval=False, curr_rel=curr_rel, epoch=e, train_task=train_task,
  254. epoch_count=epoch_count)
  255. # after ten epoch epoch
  256. if (e % 2500 == 0) and e != 0:
  257. # init the two candidate sets for monitoring variance
  258. self.candidate_cur = np.random.choice(self.itemnum, [self.user_train_num + 1, self.varset_size])
  259. for i in range(1, self.user_train_num + 1):
  260. for j in range(self.varset_size):
  261. while self.candidate_cur[i, j] in self.user_train[i]:
  262. self.candidate_cur[i, j] = random.randint(1, self.itemnum)
  263. self.Mu_idx = {}
  264. for i in range(self.user_train_num + 1):
  265. Mu_idx_tmp = random.sample(list(range(self.varset_size)), self.S1)
  266. self.Mu_idx[i] = Mu_idx_tmp
  267. self.score_cand_cur = {}
  268. self.score_pos_cur = {}
  269. self.final_negative_items = {}
  270. # reset epoch_count has many effects on the chnage_mu and one_step and train function
  271. epoch_count = 0
  272. # after one epoch
  273. elif e % 25 == 0 and e != 0:
  274. self.check_complenteness(epoch_count)
  275. print("epoch_count:", epoch_count)
  276. print("=========================\n\n")
  277. epoch_count += 1
  278. # do evaluation on specific epoch
  279. if e % self.eval_epoch == 0 and e != 0:
  280. loss_num = loss.detach().item()
  281. print("Epoch: {}\tLoss: {:.4f}".format(e, loss_num))
  282. print('Epoch {} Validating...'.format(e))
  283. valid_data = self.eval(istest=False, epoch=e)
  284. print('Epoch {} Testing...'.format(e))
  285. test_data = self.eval(istest=True, epoch=e)
  286. # original = r'/content/results.txt'
  287. # target = r'/content/drive/MyDrive/MetaTL/MetaTL_v3/results.txt'
  288. # shutil.copyfile(original, target)
  289. # print(self.candidate_cur[curr_rel[0]],self.score_cand_cur[curr_rel[0]])
  290. print('Finish')
  291. def eval(self, istest=False, epoch=None):
  292. torch.backends.cudnn.enabled = False
  293. self.MetaTL.eval()
  294. self.MetaTL.rel_q_sharing = dict()
  295. if istest:
  296. data_loader = self.test_data_loader
  297. else:
  298. data_loader = self.dev_data_loader
  299. data_loader.curr_tri_idx = 0
  300. # initial return data of validation
  301. data = {'MRR': 0, 'Hits@1': 0, 'Hits@5': 0, 'Hits@10': 0, 'NDCG@1': 0, 'NDCG@5': 0, 'NDCG@10': 0}
  302. ranks = []
  303. t = 0
  304. temp = dict()
  305. total_loss = 0
  306. while True:
  307. # sample all the eval tasks
  308. eval_task, curr_rel = data_loader.next_one_on_eval()
  309. # at the end of sample tasks, a symbol 'EOT' will return
  310. if eval_task == 'EOT':
  311. break
  312. t += 1
  313. loss, p_score, n_score = self.do_one_step(eval_task, iseval=True, curr_rel=curr_rel)
  314. total_loss += loss
  315. x = torch.cat([n_score, p_score], 1).squeeze()
  316. self.rank_predict(data, x, ranks)
  317. # print current temp data dynamically
  318. for k in data.keys():
  319. temp[k] = data[k] / t
  320. # print overall evaluation result and return it
  321. for k in data.keys():
  322. data[k] = round(data[k] / t, 3)
  323. print("\n")
  324. if istest:
  325. print("TEST: \t test_loss: ", total_loss.detach().item())
  326. print(
  327. "TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
  328. temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'],
  329. temp['Hits@1']))
  330. with open('results2.txt', 'a') as f:
  331. f.writelines(
  332. "TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r\n\n".format(
  333. temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'],
  334. temp['Hits@1']))
  335. else:
  336. print("VALID: \t validation_loss: ", total_loss.detach().item())
  337. print(
  338. "VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
  339. temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'],
  340. temp['Hits@1']))
  341. with open("results2.txt", 'a') as f:
  342. f.writelines(
  343. "VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
  344. temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'],
  345. temp['Hits@1']))
  346. print("\n")
  347. del total_loss, p_score, n_score
  348. gc.collect()
  349. self.MetaTL.train()
  350. torch.backends.cudnn.enabled = True
  351. return data
  352. def check_complenteness(self, epoch_count):
  353. # un_users = set()
  354. for user in list(self.user_train.keys()):
  355. if not user in self.score_cand_cur:
  356. self.score_cand_cur[user] = np.array([np.zeros(self.varset_size)])
  357. num = epoch_count - len(self.score_cand_cur[user]) + 1
  358. if num > 0 and len(self.score_cand_cur[user]) < 5:
  359. # if num!=1 : print("bug happend1")
  360. # un_users.add(user)
  361. self.score_cand_cur[user] = np.concatenate(
  362. [self.score_cand_cur[user], np.array([self.score_cand_cur[user][-1]])], axis=0)
  363. if epoch_count >= 4:
  364. t = 0
  365. for user in list(self.score_cand_cur.keys()):
  366. t = user
  367. # self.score_cand_cur[user] = np.delete(self.score_cand_cur[user], 0, 0)
  368. self.score_cand_cur[user] = self.score_cand_cur[user][-4:]