Sequential Recommendation for cold-start users with meta transitional learning
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer.py 6.0KB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. from models import *
  2. import os
  3. import sys
  4. import torch
  5. import shutil
  6. import logging
  7. import numpy as np
  8. class Trainer:
  9. def __init__(self, data_loaders, itemnum, parameter):
  10. self.parameter = parameter
  11. # data loader
  12. self.train_data_loader = data_loaders[0]
  13. self.dev_data_loader = data_loaders[1]
  14. self.test_data_loader = data_loaders[2]
  15. # parameters
  16. self.batch_size = parameter['batch_size']
  17. self.learning_rate = parameter['learning_rate']
  18. self.epoch = parameter['epoch']
  19. # self.print_epoch = parameter['print_epoch']
  20. # self.eval_epoch = parameter['eval_epoch']
  21. self.eval_epoch = 1000
  22. self.device = torch.device(parameter['device'])
  23. self.MetaTL = MetaTL(itemnum, parameter)
  24. self.MetaTL.to(self.device)
  25. self.optimizer = torch.optim.Adam(self.MetaTL.parameters(), self.learning_rate)
  26. def rank_predict(self, data, x, ranks):
  27. # query_idx is the idx of positive score
  28. query_idx = x.shape[0] - 1
  29. # sort all scores with descending, because more plausible triple has higher score
  30. _, idx = torch.sort(x, descending=True)
  31. rank = list(idx.cpu().numpy()).index(query_idx) + 1
  32. ranks.append(rank)
  33. # update data
  34. if rank <= 10:
  35. data['Hits@10'] += 1
  36. data['NDCG@10'] += 1 / np.log2(rank + 1)
  37. if rank <= 5:
  38. data['Hits@5'] += 1
  39. data['NDCG@5'] += 1 / np.log2(rank + 1)
  40. if rank == 1:
  41. data['Hits@1'] += 1
  42. data['NDCG@1'] += 1 / np.log2(rank + 1)
  43. data['MRR'] += 1.0 / rank
  44. def do_one_step(self, task, iseval=False, curr_rel=''):
  45. loss, p_score, n_score = 0, 0, 0
  46. if not iseval:
  47. self.optimizer.zero_grad()
  48. p_score, n_score = self.MetaTL(task, iseval, curr_rel)
  49. y = torch.Tensor([1]).to(self.device)
  50. loss = self.MetaTL.loss_func(p_score, n_score, y)
  51. loss.backward()
  52. self.optimizer.step()
  53. elif curr_rel != '':
  54. p_score, n_score = self.MetaTL(task, iseval, curr_rel)
  55. y = torch.Tensor([1]).to(self.device)
  56. loss = self.MetaTL.loss_func(p_score, n_score, y)
  57. return loss, p_score, n_score
  58. def train(self):
  59. # initialization
  60. best_epoch = 0
  61. best_value = 0
  62. bad_counts = 0
  63. # training by epoch
  64. for e in range(self.epoch):
  65. # sample one batch from data_loader
  66. train_task, curr_rel = self.train_data_loader.next_batch()
  67. loss, _, _ = self.do_one_step(train_task, iseval=False, curr_rel=curr_rel)
  68. # print the loss on specific epoch
  69. # if e % self.print_epoch == 0:
  70. # loss_num = loss.item()
  71. # print("Epoch: {}\tLoss: {:.4f}".format(e, loss_num))
  72. # do evaluation on specific epoch
  73. if e % self.eval_epoch == 0 and e != 0:
  74. print('Epoch {} Validating...'.format(e))
  75. valid_data = self.eval(istest=False, epoch=e)
  76. print('Epoch {} Testing...'.format(e))
  77. test_data = self.eval(istest=True, epoch=e)
  78. print('Finish')
  79. def eval(self, istest=False, epoch=None):
  80. # self.MetaTL.eval()
  81. self.MetaTL.rel_q_sharing = dict()
  82. if istest:
  83. data_loader = self.test_data_loader
  84. else:
  85. data_loader = self.dev_data_loader
  86. data_loader.curr_tri_idx = 0
  87. # initial return data of validation
  88. data = {'MRR': 0, 'Hits@1': 0, 'Hits@5': 0, 'Hits@10': 0, 'NDCG@1': 0, 'NDCG@5': 0, 'NDCG@10': 0}
  89. ranks = []
  90. t = 0
  91. temp = dict()
  92. while True:
  93. # sample all the eval tasks
  94. eval_task, curr_rel = data_loader.next_one_on_eval()
  95. # at the end of sample tasks, a symbol 'EOT' will return
  96. if eval_task == 'EOT':
  97. break
  98. t += 1
  99. _, p_score, n_score = self.do_one_step(eval_task, iseval=True, curr_rel=curr_rel)
  100. x = torch.cat([n_score, p_score], 1).squeeze()
  101. self.rank_predict(data, x, ranks)
  102. # print current temp data dynamically
  103. for k in data.keys():
  104. temp[k] = data[k] / t
  105. # sys.stdout.write("{}\tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
  106. # t, temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1']))
  107. # sys.stdout.flush()
  108. # print overall evaluation result and return it
  109. for k in data.keys():
  110. data[k] = round(data[k] / t, 3)
  111. if istest:
  112. print("TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
  113. temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1']),"\n")
  114. with open('results.txt', 'a') as f:
  115. f.writelines("TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r\n\n".format(
  116. temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1']))
  117. else:
  118. print("VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
  119. temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1']))
  120. with open("results.txt",'a') as f:
  121. f.writelines("VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
  122. temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1']))
  123. return data