You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Translator_beam.py 4.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. ''' This module will handle the text generation with beam search. '''
  2. import torch
  3. import torch.nn as nn
  4. from torch.autograd import Variable
  5. from transformer.Models import Decoder
  6. CUDA =1
  7. class Translator(object):
  8. ''' Load with trained model and handle the beam search '''
  9. def __init__(self, opt):
  10. self.opt = opt
  11. self.tt = torch.cuda if opt.cuda else torch
  12. checkpoint = torch.load(opt.model, map_location=lambda storage, loc: storage)
  13. model_opt = checkpoint['settings']
  14. self.model_opt = model_opt
  15. model = Decoder(
  16. model_opt.user_size,
  17. d_k=model_opt.d_k,
  18. d_v=model_opt.d_v,
  19. d_model=model_opt.d_model,
  20. d_word_vec=model_opt.d_word_vec,
  21. # kernel_size=model_opt.window_size,
  22. # finit=model_opt.finit,
  23. d_inner_hid=model_opt.d_inner_hid,
  24. n_head=model_opt.n_head,
  25. dropout=model_opt.dropout)
  26. prob_projection = nn.Softmax()
  27. model.load_state_dict(checkpoint['model'])
  28. print('[Info] Trained model state loaded.')
  29. h0 = torch.zeros(1, opt.beam_size, 8 * model_opt.d_word_vec)
  30. if opt.cuda:
  31. model.cuda(CUDA)
  32. prob_projection.cuda(CUDA)
  33. h0 = h0.cuda(CUDA)
  34. else:
  35. print('no cuda')
  36. model.cpu()
  37. prob_projection.cpu()
  38. h0 = h0.cpu()
  39. model.prob_projection = prob_projection
  40. self.h0 = h0
  41. self.model = model
  42. self.model.eval()
  43. def translate_batch(self, batch, pred_len):
  44. ''' Translation work in one batch '''
  45. # Batch size is in different location depending on data.
  46. tgt_seq = batch
  47. # print(batch.size())
  48. # print(tgt_seq.shape)
  49. batch_size = tgt_seq.size(0)
  50. max_len = min(tgt_seq.size(1), 100)
  51. max_len2 = max_len + pred_len
  52. beam_size = self.opt.beam_size
  53. user_size = self.model_opt.user_size
  54. # - Decode
  55. # print(tgt_seq.data[0,:])
  56. dec_partial_seq = torch.LongTensor(
  57. [[[tgt_seq.data[j, k] for k in range(max_len)] for i in range(beam_size)] for j in range(batch_size)])
  58. # size: (batch * beam) x seq
  59. # wrap into a Variable
  60. dec_partial_seq = Variable(dec_partial_seq, volatile=True)
  61. if self.opt.cuda:
  62. dec_partial_seq = dec_partial_seq.cuda(CUDA)
  63. for i in range(pred_len + 1):
  64. # print('here')
  65. len_dec_seq = max_len + i
  66. dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
  67. # print(dec_partial_seq.shape)
  68. # -- Decoding -- #
  69. dec_output, *_ = self.model(dec_partial_seq, self.h0, generate=True)
  70. # print(dec_output.shape)
  71. dec_output = dec_output.view(dec_partial_seq.size(0), -1, self.model_opt.user_size)
  72. # print(dec_output.shape)
  73. dec_output = dec_output[:, -1, :] # (batch * beam) * user_size
  74. # print(dec_output.shape)
  75. out = self.model.prob_projection(dec_output)
  76. # print(out.shape)
  77. sample = torch.multinomial(out.view(-1),beam_size,replacement=True)
  78. sample = [[x/user_size , x%user_size] for x in sample]
  79. # print(sample)
  80. new_sample = []
  81. for row,id in sample:
  82. new_smaple_row = torch.cat([dec_partial_seq[row],torch.LongTensor([id]).cuda(CUDA)],dim=0)
  83. new_sample.append(new_smaple_row)
  84. new_sample = torch.stack(new_sample)
  85. # sample = torch.multinomial(out, 1, replacement=True)
  86. # print(sample)
  87. # batch x beam x 1
  88. # sample = sample.view(batch_size, beam_size, 1).contiguous()
  89. # sample = sample.long()
  90. # dec_partial_seq = torch.cat([dec_partial_seq, sample], dim=1)
  91. # print(dec_partial_seq.shape)
  92. dec_partial_seq = new_sample
  93. dec_partial_seq = dec_partial_seq.view(batch_size, beam_size, -1)
  94. # print(dec_partial_seq)
  95. # print(dec_partial_seq.size())
  96. # - Return useful information
  97. print(dec_partial_seq.shape)
  98. return dec_partial_seq