You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Translator_new.py 4.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. ''' This module will handle the text generation with beam search. '''
  2. import torch
  3. import torch.nn as nn
  4. from torch.autograd import Variable
  5. from collections import OrderedDict
  6. import numpy as np
  7. from transformer.MyModel import EncoderDecoder, Encoder, Decoder, Generator, BahdanauAttention
  8. class Translator(object):
  9. ''' Load with trained model and handle the beam search '''
  10. def __init__(self, opt):
  11. self.opt = opt
  12. self.tt = torch.cuda if opt.cuda else torch
  13. checkpoint = torch.load(opt.model, map_location=lambda storage, loc: storage)
  14. model_opt = checkpoint['settings']
  15. self.model_opt = model_opt
  16. attention = BahdanauAttention(model_opt.d_inner_hid)
  17. model = EncoderDecoder(
  18. Encoder(model_opt.d_word_vec, model_opt.d_inner_hid, num_layers=1, dropout=model_opt.dropout),
  19. Decoder(model_opt.d_word_vec, model_opt.d_inner_hid, attention, num_layers=1, dropout=model_opt.dropout),
  20. nn.Embedding(model_opt.user_size, model_opt.d_word_vec),
  21. nn.Embedding(model_opt.user_size, model_opt.d_word_vec),
  22. Generator(model_opt.d_inner_hid, model_opt.user_size))
  23. prob_projection = nn.Softmax()
  24. model_dict = checkpoint['model']
  25. new_state_dict = OrderedDict()
  26. # for k, v in state_dict.items():
  27. # name = k[7:]
  28. # new_state_dict[name] = v
  29. print("\n\n\n")
  30. print(model_dict)
  31. print("\n\n\n")
  32. model.load_state_dict(model_dict)
  33. print('[Info] Trained model state loaded.')
  34. if opt.cuda:
  35. model.cuda(0)
  36. prob_projection.cuda(0)
  37. else:
  38. print('no cuda')
  39. model.cpu()
  40. prob_projection.cpu()
  41. model.prob_projection = prob_projection
  42. self.model = model
  43. self.model.eval()
  44. def translate_batch(self, src, src_mask, src_lengths, trg, trg_mask, trg_lengths):
  45. ''' Translation work in one batch '''
  46. # self.opt.beam_size=100
  47. # for i in range(35):
  48. # with torch.no_grad():
  49. # self.encoder_hidden, self.encoder_final = self.model.encode(src, src_mask, src_lengths)
  50. # prev_y = torch.ones(1, 1).fill_(0).type_as(trg)
  51. # trg_mask = torch.ones_like(prev_y)
  52. #
  53. # output = []
  54. # attention_scores = []
  55. # hidden = None
  56. # for j in range(self.opt.beam_size):
  57. #
  58. #
  59. output_f = []
  60. for j in range(self.opt.beam_size):
  61. with torch.no_grad():
  62. encoder_hidden, encoder_final = self.model.encode(src, src_mask, src_lengths)
  63. prev_y = torch.ones(1, 1).fill_(0).type_as(trg)
  64. trg_mask = torch.ones_like(prev_y)
  65. output = []
  66. attention_scores = []
  67. hidden = None
  68. for i in range(35):
  69. with torch.no_grad():
  70. out, hidden, pre_output = self.model.decode(encoder_hidden, encoder_final, src_mask,
  71. prev_y, trg_mask, hidden)
  72. # print(hidden,prev_y)
  73. prob = self.model.generator(pre_output[:, -1])
  74. # print(prob)
  75. # print(prob)
  76. _, next_word = torch.max(prob, dim=1)
  77. next_word = next_word.data.item()
  78. output.append(next_word)
  79. prev_y = torch.ones(1, 1).type_as(trg).fill_(next_word)
  80. attention_scores.append(self.model.decoder.attention.alphas.cpu().numpy())
  81. output = np.array(output)
  82. output_f.append(output)
  83. output_f = np.array(output_f)
  84. output_f2 = []
  85. output_f2.append(output_f)
  86. output_f2 = np.array(output_f2)
  87. print(output_f2)
  88. return output_f2
  89. def predict_next_user(self, num, output, hidden, attention_scores, src_mask, prev_y):
  90. with torch.no_grad():
  91. out, hidden, pre_output = self.model.decode(encoder_hidden, encoder_final, src_mask,
  92. prev_y, trg_mask, hidden)
  93. prob = self.model.generator(pre_output[:, -1])
  94. _, next_word = torch.max(prob, dim=1)
  95. next_word = next_word.data.item()
  96. output.append(next_word)
  97. prev_y = torch.ones(1, 1).type_as(trg).fill_(next_word)
  98. attention_scores.append(self.model.decoder.attention.alphas.cpu().numpy())
  99. return next_word, prev_y, attention_scores