You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Translator.py 3.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. ''' This module will handle the text generation with beam search. '''
  2. import torch
  3. import torch.nn as nn
  4. from torch.autograd import Variable
  5. from transformer.Models import Decoder
  6. CUDA =1
  7. class Translator(object):
  8. ''' Load with trained model and handle the beam search '''
  9. def __init__(self, opt):
  10. self.opt = opt
  11. self.tt = torch.cuda if opt.cuda else torch
  12. checkpoint = torch.load(opt.model, map_location=lambda storage, loc: storage)
  13. model_opt = checkpoint['settings']
  14. self.model_opt = model_opt
  15. print(model_opt)
  16. model = Decoder(
  17. model_opt.user_size,
  18. d_k=model_opt.d_k,
  19. d_v=model_opt.d_v,
  20. d_model=model_opt.d_model,
  21. d_word_vec=model_opt.d_word_vec,
  22. # kernel_size=model_opt.window_size,
  23. # finit=model_opt.finit,
  24. d_inner_hid=model_opt.d_inner_hid,
  25. n_head=model_opt.n_head,
  26. dropout=model_opt.dropout)
  27. prob_projection = nn.Softmax()
  28. model.load_state_dict(checkpoint['model'])
  29. print('[Info] Trained model state loaded.')
  30. h0 = torch.zeros(1, opt.beam_size, 4 * model_opt.d_word_vec)
  31. if opt.cuda:
  32. model.cuda(CUDA)
  33. prob_projection.cuda(CUDA)
  34. h0 = h0.cuda(CUDA)
  35. else:
  36. print('no cuda')
  37. model.cpu()
  38. prob_projection.cpu()
  39. h0 = h0.cpu()
  40. model.prob_projection = prob_projection
  41. self.h0 = h0
  42. self.model = model
  43. self.model.eval()
  44. def translate_batch(self, batch, pred_len):
  45. ''' Translation work in one batch '''
  46. # Batch size is in different location depending on data.
  47. tgt_seq = batch
  48. # print(batch.size())
  49. # print(tgt_seq.shape)
  50. batch_size = tgt_seq.size(0)
  51. max_len = min(tgt_seq.size(1), 100)
  52. max_len2 = max_len + pred_len
  53. beam_size = self.opt.beam_size
  54. # - Decode
  55. # print(tgt_seq.data[0,:])
  56. dec_partial_seq = torch.LongTensor(
  57. [[[tgt_seq.data[j, k] for k in range(max_len)] for i in range(beam_size)] for j in range(batch_size)])
  58. # size: (batch * beam) x seq
  59. # wrap into a Variable
  60. dec_partial_seq = Variable(dec_partial_seq, volatile=True)
  61. if self.opt.cuda:
  62. dec_partial_seq = dec_partial_seq.cuda(CUDA)
  63. for i in range(pred_len + 1):
  64. # print('here')
  65. len_dec_seq = max_len + i
  66. dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
  67. # print(dec_partial_seq.shape)
  68. # -- Decoding -- #
  69. dec_output, *_ = self.model(dec_partial_seq, self.h0, generate=True)
  70. # print(dec_output.shape)
  71. dec_output = dec_output.view(dec_partial_seq.size(0), -1, self.model_opt.user_size)
  72. # print(dec_output.shape)
  73. dec_output = dec_output[:, -1, :] # (batch * beam) * user_size
  74. # print(dec_output.shape)
  75. out = self.model.prob_projection(dec_output)
  76. # print(out.shape)
  77. sample = torch.multinomial(out, 1, replacement=True)
  78. # print(sample.shape)
  79. # batch x beam x 1
  80. # sample = sample.view(batch_size, beam_size, 1).contiguous()
  81. sample = sample.long()
  82. dec_partial_seq = torch.cat([dec_partial_seq, sample], dim=1)
  83. dec_partial_seq = dec_partial_seq.view(batch_size, beam_size, -1)
  84. # print(dec_partial_seq.size())
  85. # - Return useful information
  86. # print(dec_partial_seq.shape)
  87. return dec_partial_seq