Sequential Recommendation for cold-start users with meta transitional learning
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer.py 6.3KB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from models import *
  2. import os
  3. import sys
  4. import torch
  5. import shutil
  6. import logging
  7. import numpy as np
  8. class Trainer:
  9. def __init__(self, data_loaders, itemnum, parameter):
  10. self.parameter = parameter
  11. # data loader
  12. self.train_data_loader = data_loaders[0]
  13. self.dev_data_loader = data_loaders[1]
  14. self.test_data_loader = data_loaders[2]
  15. # parameters
  16. self.batch_size = parameter['batch_size']
  17. self.learning_rate = parameter['learning_rate']
  18. self.epoch = parameter['epoch']
  19. # self.print_epoch = parameter['print_epoch']
  20. # self.eval_epoch = parameter['eval_epoch']
  21. # self.device = torch.device(parameter['device'])
  22. self.device = parameter['device']
  23. self.MetaTL = MetaTL(itemnum, parameter)
  24. self.MetaTL.to(parameter['device'])
  25. self.optimizer = torch.optim.Adam(self.MetaTL.parameters(), self.learning_rate)
  26. if parameter['eval_epoch']:
  27. self.eval_epoch = parameter['eval_epoch']
  28. else:
  29. self.eval_epoch = 1000
  30. def rank_predict(self, data, x, ranks):
  31. # query_idx is the idx of positive score
  32. query_idx = x.shape[0] - 1
  33. # sort all scores with descending, because more plausible triple has higher score
  34. _, idx = torch.sort(x, descending=True)
  35. rank = list(idx.cpu().numpy()).index(query_idx) + 1
  36. ranks.append(rank)
  37. # update data
  38. if rank <= 10:
  39. data['Hits@10'] += 1
  40. data['NDCG@10'] += 1 / np.log2(rank + 1)
  41. if rank <= 5:
  42. data['Hits@5'] += 1
  43. data['NDCG@5'] += 1 / np.log2(rank + 1)
  44. if rank == 1:
  45. data['Hits@1'] += 1
  46. data['NDCG@1'] += 1 / np.log2(rank + 1)
  47. data['MRR'] += 1.0 / rank
  48. def do_one_step(self, task, iseval=False, curr_rel=''):
  49. loss, p_score, n_score = 0, 0, 0
  50. if not iseval:
  51. self.optimizer.zero_grad()
  52. p_score, n_score = self.MetaTL(task, iseval, curr_rel)
  53. y = torch.Tensor([1]).to(self.device)
  54. loss = self.MetaTL.loss_func(p_score, n_score,self.device)
  55. loss.backward()
  56. self.optimizer.step()
  57. elif curr_rel != '':
  58. p_score, n_score = self.MetaTL(task, iseval, curr_rel)
  59. y = torch.Tensor([1]).to(self.device)
  60. loss = self.MetaTL.loss_func(p_score, n_score,self.device)
  61. return loss, p_score, n_score
  62. def train(self):
  63. # initialization
  64. best_epoch = 0
  65. best_value = 0
  66. bad_counts = 0
  67. # training by epoch
  68. for e in range(self.epoch):
  69. # sample one batch from data_loader
  70. train_task, curr_rel = self.train_data_loader.next_batch()
  71. loss, _, _ = self.do_one_step(train_task, iseval=False, curr_rel=curr_rel)
  72. # do evaluation on specific epoch
  73. if e % self.eval_epoch == 0 and e != 0:
  74. loss_num = loss.detach().item()
  75. print("Epoch: {}\tLoss: {:.4f}".format(e, loss_num))
  76. print('Epoch {} Validating...'.format(e))
  77. valid_data = self.eval(istest=False, epoch=e)
  78. print('Epoch {} Testing...'.format(e))
  79. test_data = self.eval(istest=True, epoch=e)
  80. print('Finish')
  81. def eval(self, istest=False, epoch=None):
  82. # self.MetaTL.eval()
  83. self.MetaTL.rel_q_sharing = dict()
  84. if istest:
  85. data_loader = self.test_data_loader
  86. else:
  87. data_loader = self.dev_data_loader
  88. data_loader.curr_tri_idx = 0
  89. # initial return data of validation
  90. data = {'MRR': 0, 'Hits@1': 0, 'Hits@5': 0, 'Hits@10': 0, 'NDCG@1': 0, 'NDCG@5': 0, 'NDCG@10': 0}
  91. ranks = []
  92. t = 0
  93. temp = dict()
  94. total_loss = 0
  95. while True:
  96. # sample all the eval tasks
  97. eval_task, curr_rel = data_loader.next_one_on_eval()
  98. # at the end of sample tasks, a symbol 'EOT' will return
  99. if eval_task == 'EOT':
  100. break
  101. t += 1
  102. loss, p_score, n_score = self.do_one_step(eval_task, iseval=True, curr_rel=curr_rel)
  103. total_loss += loss
  104. x = torch.cat([n_score, p_score], 1).squeeze()
  105. self.rank_predict(data, x, ranks)
  106. # print current temp data dynamically
  107. for k in data.keys():
  108. temp[k] = data[k] / t
  109. # sys.stdout.write("{}\tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
  110. # t, temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1']))
  111. # sys.stdout.flush()
  112. # print overall evaluation result and return it
  113. for k in data.keys():
  114. data[k] = round(data[k] / t, 3)
  115. if istest:
  116. print("TEST: \t test_loss: ",total_loss.detach().item())
  117. print("TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
  118. temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1']),"\n")
  119. with open('results.txt', 'a') as f:
  120. f.writelines("TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r\n\n".format(
  121. temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1']))
  122. else:
  123. print("VALID: \t validation_loss: ", total_loss.detach().item() )
  124. print("VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
  125. temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1']))
  126. with open("results.txt",'a') as f:
  127. f.writelines("VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
  128. temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1']))
  129. return data