123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- ''' Define the Transformer model '''
- import torch
- import torch.nn as nn
- import numpy as np
- import transformer.Constants as Constants
- from transformer.Modules import BottleLinear as Linear
- from transformer.Layers import EncoderLayer, DecoderLayer
- import json
- import pickle
-
- CUDA = 0
-
- idx2vec_addr = '/media/external_3TB/3TB/ramezani/pmoini/Trial/data3/idx2vec.pickle'
-
-
- def get_attn_padding_mask(seq_q, seq_k):
- ''' Indicate the padding-related part to mask '''
- assert seq_q.dim() == 2 and seq_k.dim() == 2
- mb_size, len_q = seq_q.size()
- mb_size, len_k = seq_k.size()
- pad_attn_mask = seq_k.data.eq(Constants.PAD).unsqueeze(1) # bx1xsk
- pad_attn_mask = pad_attn_mask.expand(mb_size, len_q, len_k) # bxsqxsk
- if seq_q.is_cuda:
- pad_attn_mask = pad_attn_mask.cuda(CUDA)
- return pad_attn_mask
-
-
- def get_attn_subsequent_mask(seq):
- ''' Get an attention mask to avoid using the subsequent info.'''
- assert seq.dim() == 2
- attn_shape = (seq.size(0), seq.size(1), seq.size(1))
- subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
- subsequent_mask = torch.from_numpy(subsequent_mask)
- if seq.is_cuda:
- subsequent_mask = subsequent_mask.cuda(CUDA)
- return subsequent_mask
-
-
- def get_previous_user_mask(seq, user_size):
- ''' Mask previous activated users.'''
- assert seq.dim() == 2
- prev_shape = (seq.size(0), seq.size(1), seq.size(1))
- seqs = seq.repeat(1, 1, seq.size(1)).view(seq.size(0), seq.size(1), seq.size(1))
- previous_mask = np.tril(np.ones(prev_shape)).astype('float32')
- previous_mask = torch.from_numpy(previous_mask)
- if seq.is_cuda:
- previous_mask = previous_mask.cuda(CUDA)
- # print(previous_mask)
- # print(seqs)
- masked_seq = previous_mask * seqs.data.float()
- # print(masked_seq.size())
-
- # force the 0th dimension (PAD) to be masked
- PAD_tmp = torch.zeros(seq.size(0), seq.size(1), 1)
- if seq.is_cuda:
- PAD_tmp = PAD_tmp.cuda(CUDA)
- masked_seq = torch.cat([masked_seq, PAD_tmp], dim=2)
- ans_tmp = torch.zeros(seq.size(0), seq.size(1), user_size)
- if seq.is_cuda:
- ans_tmp = ans_tmp.cuda(CUDA)
- masked_seq = ans_tmp.scatter_(2, masked_seq.long(), float('-inf'))
-
- return masked_seq
-
-
- class Decoder(nn.Module):
- ''' A decoder model with self attention mechanism. '''
-
- def __init__(
- self, user_size, kernel_size=3, n_layers=1, n_head=1, d_k=32, d_v=32,
- d_word_vec=32, d_model=32, d_inner_hid=32, dropout=0.1, finit=0):
-
- super(Decoder, self).__init__()
- self.d_model = d_model
- self.user_size = user_size
-
- self.user_emb = nn.Embedding(
- user_size, d_word_vec, padding_idx=Constants.PAD)
- self.tgt_user_proj = Linear(d_model, user_size, bias=False)
-
- with open(idx2vec_addr, 'rb') as handle:
- self.idx2vec = pickle.load(handle)
-
- self.gru = nn.GRU(input_size=d_word_vec, hidden_size=8 * d_word_vec, num_layers=1, batch_first=True)
-
- self.dropout = nn.Dropout(dropout)
- self.conv = nn.Conv1d(8 * d_model, user_size, kernel_size, padding=kernel_size - 1, bias=True)
- # self.conv = nn.Linear(4*d_model,user_size)
- self.padding = kernel_size - 1
- self.finit = finit
-
- self.layer_stack = nn.ModuleList([
- DecoderLayer(8 * d_model, d_inner_hid, n_head, d_k, d_v, dropout=dropout)
- for _ in range(n_layers)])
-
- def forward(self, tgt_seq, h0, return_attns=False, generate=False):
- if not generate:
- tgt_seq = tgt_seq[:, :-1]
- # print("******************")
- # print(tgt_seq.shape)
- # print(self.user_emb.num_embeddings)
- # print(self.user_emb.embedding_dim)
- # print("*******************")
-
- # # Word embedding look up
- # print("HERE: ****************************")
- # print("num_embeddings: ", self.user_emb.num_embeddings)
- # print("embedding_dim: ", self.user_emb.embedding_dim)
- # print("max input: ", torch.max(tgt_seq))
- # print("min input: ", torch.min(tgt_seq))
- # print("DONE: ****************************")
- dec_input = self.user_emb(tgt_seq)
-
- # for i in range(tgt_seq.size(0)):
- # for idx in tgt_seq[i]:
- # print(self.idx2vec[int(idx.data.cpu().numpy())])
-
- dec_new_input = torch.FloatTensor([[ self.idx2vec[int(idx.data.cpu().numpy())] for idx in tgt_seq[i]] for i in range(tgt_seq.size(0))]).cuda(CUDA)
- # print(dec_new_input)
- dec_input = dec_new_input
-
- # dec_input, h_n = self.gru(dec_input, h0)
-
- # Decode
- dec_slf_attn_pad_mask = get_attn_padding_mask(tgt_seq, tgt_seq)
- dec_slf_attn_sub_mask = get_attn_subsequent_mask(tgt_seq)
- # 1 means masked
- dec_slf_attn_mask = torch.gt(dec_slf_attn_pad_mask + dec_slf_attn_sub_mask, 0)
-
- # print('##########')
- # print(tgt_seq.shape)
- # print(dec_input.shape)
- # print(dec_slf_attn_pad_mask.shape)
- # print(dec_slf_attn_sub_mask.shape)
- # print(dec_slf_attn_mask.shape)
- # print(dec_slf_attn_pad_mask)
- # print(dec_slf_attn_sub_mask)
- # print(dec_slf_attn_mask)
-
- if return_attns:
- dec_slf_attns = [[] for _ in tgt_seq.size(0)]
-
- dec_output = dec_input
- # for dec_layer in self.layer_stack:
- # dec_output, dec_slf_attn = dec_layer(
- # dec_output, slf_attn_mask=dec_slf_attn_mask)
-
- # if return_attns:
- # dec_slf_attns += [dec_slf_attn]
-
- # print(dec_output.size())
- dec_output = dec_output.transpose(1, 2)
- # print('***',dec_output.shape)
- dec_output = self.conv(dec_output)
- # print(dec_output.shape)
- dec_output = dec_output[:, :, 0:-self.padding]
- dec_output = dec_output.transpose(1, 2).contiguous()
- if self.finit > 0:
- dec_output += self.tgt_user_proj(self.user_emb(tgt_seq[:, 0])).repeat(dec_input.size(1), 1, 1).transpose(0,
- 1).contiguous()
-
- seq_logit = dec_output + torch.autograd.Variable(get_previous_user_mask(tgt_seq, self.user_size),
- requires_grad=False)
- # print(seq_logit.size()) batch*seqlen*n_word
- if return_attns:
- return seq_logit.view(-1, seq_logit.size(2)), dec_slf_attns
- else:
- return seq_logit.view(-1, seq_logit.size(2)),
|