|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 |
- ''' This module will handle the text generation with beam search. '''
-
- import torch
- import torch.nn as nn
- from torch.autograd import Variable
-
- from transformer.Models import Decoder
-
- CUDA =1
-
- class Translator(object):
- ''' Load with trained model and handle the beam search '''
-
- def __init__(self, opt):
- self.opt = opt
- self.tt = torch.cuda if opt.cuda else torch
- checkpoint = torch.load(opt.model, map_location=lambda storage, loc: storage)
- model_opt = checkpoint['settings']
- self.model_opt = model_opt
- print(model_opt)
-
- model = Decoder(
- model_opt.user_size,
- d_k=model_opt.d_k,
- d_v=model_opt.d_v,
- d_model=model_opt.d_model,
- d_word_vec=model_opt.d_word_vec,
- # kernel_size=model_opt.window_size,
- # finit=model_opt.finit,
- d_inner_hid=model_opt.d_inner_hid,
- n_head=model_opt.n_head,
- dropout=model_opt.dropout)
-
- prob_projection = nn.Softmax()
-
- model.load_state_dict(checkpoint['model'])
- print('[Info] Trained model state loaded.')
- h0 = torch.zeros(1, opt.beam_size, 4 * model_opt.d_word_vec)
-
- if opt.cuda:
- model.cuda(CUDA)
- prob_projection.cuda(CUDA)
- h0 = h0.cuda(CUDA)
- else:
- print('no cuda')
- model.cpu()
- prob_projection.cpu()
- h0 = h0.cpu()
-
- model.prob_projection = prob_projection
-
- self.h0 = h0
- self.model = model
- self.model.eval()
-
- def translate_batch(self, batch, pred_len):
- ''' Translation work in one batch '''
-
- # Batch size is in different location depending on data.
- tgt_seq = batch
- # print(batch.size())
- # print(tgt_seq.shape)
-
- batch_size = tgt_seq.size(0)
- max_len = min(tgt_seq.size(1), 100)
- max_len2 = max_len + pred_len
- beam_size = self.opt.beam_size
-
- # - Decode
- # print(tgt_seq.data[0,:])
-
- dec_partial_seq = torch.LongTensor(
- [[[tgt_seq.data[j, k] for k in range(max_len)] for i in range(beam_size)] for j in range(batch_size)])
- # size: (batch * beam) x seq
-
- # wrap into a Variable
- dec_partial_seq = Variable(dec_partial_seq, volatile=True)
- if self.opt.cuda:
- dec_partial_seq = dec_partial_seq.cuda(CUDA)
- for i in range(pred_len + 1):
- # print('here')
- len_dec_seq = max_len + i
- dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
- # print(dec_partial_seq.shape)
- # -- Decoding -- #
- dec_output, *_ = self.model(dec_partial_seq, self.h0, generate=True)
- # print(dec_output.shape)
- dec_output = dec_output.view(dec_partial_seq.size(0), -1, self.model_opt.user_size)
- # print(dec_output.shape)
- dec_output = dec_output[:, -1, :] # (batch * beam) * user_size
- # print(dec_output.shape)
-
- out = self.model.prob_projection(dec_output)
- # print(out.shape)
- sample = torch.multinomial(out, 1, replacement=True)
- # print(sample.shape)
-
- # batch x beam x 1
- # sample = sample.view(batch_size, beam_size, 1).contiguous()
- sample = sample.long()
- dec_partial_seq = torch.cat([dec_partial_seq, sample], dim=1)
- dec_partial_seq = dec_partial_seq.view(batch_size, beam_size, -1)
- # print(dec_partial_seq.size())
-
- # - Return useful information
- # print(dec_partial_seq.shape)
- return dec_partial_seq
|