Sequential Recommendation for cold-start users with meta transitional learning
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer.py 5.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. from models import *
  2. import os
  3. import sys
  4. import torch
  5. import shutil
  6. import logging
  7. import numpy as np
  8. class Trainer:
  9. def __init__(self, data_loaders, itemnum, parameter):
  10. self.parameter = parameter
  11. # data loader
  12. self.train_data_loader = data_loaders[0]
  13. self.dev_data_loader = data_loaders[1]
  14. self.test_data_loader = data_loaders[2]
  15. # parameters
  16. self.batch_size = parameter['batch_size']
  17. self.learning_rate = parameter['learning_rate']
  18. self.epoch = parameter['epoch']
  19. self.print_epoch = parameter['print_epoch']
  20. self.eval_epoch = parameter['eval_epoch']
  21. self.device = parameter['device']
  22. self.MetaTL = MetaTL(itemnum, parameter)
  23. self.MetaTL.to(self.device)
  24. self.optimizer = torch.optim.Adam(self.MetaTL.parameters(), self.learning_rate)
  25. def rank_predict(self, data, x, ranks):
  26. # query_idx is the idx of positive score
  27. query_idx = x.shape[0] - 1
  28. # sort all scores with descending, because more plausible triple has higher score
  29. _, idx = torch.sort(x, descending=True)
  30. rank = list(idx.cpu().numpy()).index(query_idx) + 1
  31. ranks.append(rank)
  32. # update data
  33. if rank <= 10:
  34. data['Hits@10'] += 1
  35. data['NDCG@10'] += 1 / np.log2(rank + 1)
  36. if rank <= 5:
  37. data['Hits@5'] += 1
  38. data['NDCG@5'] += 1 / np.log2(rank + 1)
  39. if rank == 1:
  40. data['Hits@1'] += 1
  41. data['NDCG@1'] += 1 / np.log2(rank + 1)
  42. data['MRR'] += 1.0 / rank
  43. def do_one_step(self, task, iseval=False, curr_rel=''):
  44. loss, p_score, n_score = 0, 0, 0
  45. if not iseval:
  46. self.optimizer.zero_grad()
  47. p_score, n_score = self.MetaTL(task, iseval, curr_rel)
  48. y = torch.Tensor([1]).to(self.device)
  49. loss = self.MetaTL.loss_func(p_score, n_score, y)
  50. loss.backward()
  51. self.optimizer.step()
  52. elif curr_rel != '':
  53. p_score, n_score = self.MetaTL(task, iseval, curr_rel)
  54. y = torch.Tensor([1]).to(self.device)
  55. loss = self.MetaTL.loss_func(p_score, n_score, y)
  56. return loss, p_score, n_score
  57. def train(self):
  58. # initialization
  59. best_epoch = 0
  60. best_value = 0
  61. bad_counts = 0
  62. # training by epoch
  63. for e in range(self.epoch):
  64. # sample one batch from data_loader
  65. train_task, curr_rel = self.train_data_loader.next_batch()
  66. loss, _, _ = self.do_one_step(train_task, iseval=False, curr_rel=curr_rel)
  67. # print the loss on specific epoch
  68. if e % self.print_epoch == 0:
  69. loss_num = loss.item()
  70. print("Epoch: {}\tLoss: {:.4f}".format(e, loss_num))
  71. # do evaluation on specific epoch
  72. if e % self.eval_epoch == 0 and e != 0:
  73. print('Epoch {} Validating...'.format(e))
  74. valid_data = self.eval(istest=False, epoch=e)
  75. print('Epoch {} Testing...'.format(e))
  76. test_data = self.eval(istest=True, epoch=e)
  77. print('Finish')
  78. def eval(self, istest=False, epoch=None):
  79. self.MetaTL.eval()
  80. self.MetaTL.rel_q_sharing = dict()
  81. if istest:
  82. data_loader = self.test_data_loader
  83. else:
  84. data_loader = self.dev_data_loader
  85. data_loader.curr_tri_idx = 0
  86. # initial return data of validation
  87. data = {'MRR': 0, 'Hits@1': 0, 'Hits@5': 0, 'Hits@10': 0, 'NDCG@1': 0, 'NDCG@5': 0, 'NDCG@10': 0}
  88. ranks = []
  89. t = 0
  90. temp = dict()
  91. while True:
  92. # sample all the eval tasks
  93. eval_task, curr_rel = data_loader.next_one_on_eval()
  94. # at the end of sample tasks, a symbol 'EOT' will return
  95. if eval_task == 'EOT':
  96. break
  97. t += 1
  98. _, p_score, n_score = self.do_one_step(eval_task, iseval=True, curr_rel=curr_rel)
  99. x = torch.cat([n_score, p_score], 1).squeeze()
  100. self.rank_predict(data, x, ranks)
  101. # print current temp data dynamically
  102. for k in data.keys():
  103. temp[k] = data[k] / t
  104. sys.stdout.write("{}\tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
  105. t, temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1']))
  106. sys.stdout.flush()
  107. # print overall evaluation result and return it
  108. for k in data.keys():
  109. data[k] = round(data[k] / t, 3)
  110. if istest:
  111. print("TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
  112. temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1']))
  113. else:
  114. print("VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
  115. temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1']))
  116. return data