You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Models.py 6.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. ''' Define the Transformer model '''
  2. import torch
  3. import torch.nn as nn
  4. import numpy as np
  5. import transformer.Constants as Constants
  6. from transformer.Modules import BottleLinear as Linear
  7. from transformer.Layers import EncoderLayer, DecoderLayer
  8. import json
  9. import pickle
  10. CUDA = 0
  11. idx2vec_addr = '/media/external_3TB/3TB/ramezani/pmoini/Trial/data3/idx2vec.pickle'
  12. def get_attn_padding_mask(seq_q, seq_k):
  13. ''' Indicate the padding-related part to mask '''
  14. assert seq_q.dim() == 2 and seq_k.dim() == 2
  15. mb_size, len_q = seq_q.size()
  16. mb_size, len_k = seq_k.size()
  17. pad_attn_mask = seq_k.data.eq(Constants.PAD).unsqueeze(1) # bx1xsk
  18. pad_attn_mask = pad_attn_mask.expand(mb_size, len_q, len_k) # bxsqxsk
  19. if seq_q.is_cuda:
  20. pad_attn_mask = pad_attn_mask.cuda(CUDA)
  21. return pad_attn_mask
  22. def get_attn_subsequent_mask(seq):
  23. ''' Get an attention mask to avoid using the subsequent info.'''
  24. assert seq.dim() == 2
  25. attn_shape = (seq.size(0), seq.size(1), seq.size(1))
  26. subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
  27. subsequent_mask = torch.from_numpy(subsequent_mask)
  28. if seq.is_cuda:
  29. subsequent_mask = subsequent_mask.cuda(CUDA)
  30. return subsequent_mask
  31. def get_previous_user_mask(seq, user_size):
  32. ''' Mask previous activated users.'''
  33. assert seq.dim() == 2
  34. prev_shape = (seq.size(0), seq.size(1), seq.size(1))
  35. seqs = seq.repeat(1, 1, seq.size(1)).view(seq.size(0), seq.size(1), seq.size(1))
  36. previous_mask = np.tril(np.ones(prev_shape)).astype('float32')
  37. previous_mask = torch.from_numpy(previous_mask)
  38. if seq.is_cuda:
  39. previous_mask = previous_mask.cuda(CUDA)
  40. # print(previous_mask)
  41. # print(seqs)
  42. masked_seq = previous_mask * seqs.data.float()
  43. # print(masked_seq.size())
  44. # force the 0th dimension (PAD) to be masked
  45. PAD_tmp = torch.zeros(seq.size(0), seq.size(1), 1)
  46. if seq.is_cuda:
  47. PAD_tmp = PAD_tmp.cuda(CUDA)
  48. masked_seq = torch.cat([masked_seq, PAD_tmp], dim=2)
  49. ans_tmp = torch.zeros(seq.size(0), seq.size(1), user_size)
  50. if seq.is_cuda:
  51. ans_tmp = ans_tmp.cuda(CUDA)
  52. masked_seq = ans_tmp.scatter_(2, masked_seq.long(), float('-inf'))
  53. return masked_seq
  54. class Decoder(nn.Module):
  55. ''' A decoder model with self attention mechanism. '''
  56. def __init__(
  57. self, user_size, kernel_size=3, n_layers=1, n_head=1, d_k=32, d_v=32,
  58. d_word_vec=32, d_model=32, d_inner_hid=32, dropout=0.1, finit=0):
  59. super(Decoder, self).__init__()
  60. self.d_model = d_model
  61. self.user_size = user_size
  62. self.user_emb = nn.Embedding(
  63. user_size, d_word_vec, padding_idx=Constants.PAD)
  64. self.tgt_user_proj = Linear(d_model, user_size, bias=False)
  65. with open(idx2vec_addr, 'rb') as handle:
  66. self.idx2vec = pickle.load(handle)
  67. self.gru = nn.GRU(input_size=d_word_vec, hidden_size=8 * d_word_vec, num_layers=1, batch_first=True)
  68. self.dropout = nn.Dropout(dropout)
  69. self.conv = nn.Conv1d(8 * d_model, user_size, kernel_size, padding=kernel_size - 1, bias=True)
  70. # self.conv = nn.Linear(4*d_model,user_size)
  71. self.padding = kernel_size - 1
  72. self.finit = finit
  73. self.layer_stack = nn.ModuleList([
  74. DecoderLayer(8 * d_model, d_inner_hid, n_head, d_k, d_v, dropout=dropout)
  75. for _ in range(n_layers)])
  76. def forward(self, tgt_seq, h0, return_attns=False, generate=False):
  77. if not generate:
  78. tgt_seq = tgt_seq[:, :-1]
  79. # print("******************")
  80. # print(tgt_seq.shape)
  81. # print(self.user_emb.num_embeddings)
  82. # print(self.user_emb.embedding_dim)
  83. # print("*******************")
  84. # # Word embedding look up
  85. # print("HERE: ****************************")
  86. # print("num_embeddings: ", self.user_emb.num_embeddings)
  87. # print("embedding_dim: ", self.user_emb.embedding_dim)
  88. # print("max input: ", torch.max(tgt_seq))
  89. # print("min input: ", torch.min(tgt_seq))
  90. # print("DONE: ****************************")
  91. dec_input = self.user_emb(tgt_seq)
  92. # for i in range(tgt_seq.size(0)):
  93. # for idx in tgt_seq[i]:
  94. # print(self.idx2vec[int(idx.data.cpu().numpy())])
  95. dec_new_input = torch.FloatTensor([[ self.idx2vec[int(idx.data.cpu().numpy())] for idx in tgt_seq[i]] for i in range(tgt_seq.size(0))]).cuda(CUDA)
  96. # print(dec_new_input)
  97. dec_input = dec_new_input
  98. # dec_input, h_n = self.gru(dec_input, h0)
  99. # Decode
  100. dec_slf_attn_pad_mask = get_attn_padding_mask(tgt_seq, tgt_seq)
  101. dec_slf_attn_sub_mask = get_attn_subsequent_mask(tgt_seq)
  102. # 1 means masked
  103. dec_slf_attn_mask = torch.gt(dec_slf_attn_pad_mask + dec_slf_attn_sub_mask, 0)
  104. # print('##########')
  105. # print(tgt_seq.shape)
  106. # print(dec_input.shape)
  107. # print(dec_slf_attn_pad_mask.shape)
  108. # print(dec_slf_attn_sub_mask.shape)
  109. # print(dec_slf_attn_mask.shape)
  110. # print(dec_slf_attn_pad_mask)
  111. # print(dec_slf_attn_sub_mask)
  112. # print(dec_slf_attn_mask)
  113. if return_attns:
  114. dec_slf_attns = [[] for _ in tgt_seq.size(0)]
  115. dec_output = dec_input
  116. # for dec_layer in self.layer_stack:
  117. # dec_output, dec_slf_attn = dec_layer(
  118. # dec_output, slf_attn_mask=dec_slf_attn_mask)
  119. # if return_attns:
  120. # dec_slf_attns += [dec_slf_attn]
  121. # print(dec_output.size())
  122. dec_output = dec_output.transpose(1, 2)
  123. # print('***',dec_output.shape)
  124. dec_output = self.conv(dec_output)
  125. # print(dec_output.shape)
  126. dec_output = dec_output[:, :, 0:-self.padding]
  127. dec_output = dec_output.transpose(1, 2).contiguous()
  128. if self.finit > 0:
  129. dec_output += self.tgt_user_proj(self.user_emb(tgt_seq[:, 0])).repeat(dec_input.size(1), 1, 1).transpose(0,
  130. 1).contiguous()
  131. seq_logit = dec_output + torch.autograd.Variable(get_previous_user_mask(tgt_seq, self.user_size),
  132. requires_grad=False)
  133. # print(seq_logit.size()) batch*seqlen*n_word
  134. if return_attns:
  135. return seq_logit.view(-1, seq_logit.size(2)), dec_slf_attns
  136. else:
  137. return seq_logit.view(-1, seq_logit.size(2)),