You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Layers.py 1.3KB

123456789101112131415161718192021222324252627282930313233
  1. ''' Define the Layers '''
  2. import torch.nn as nn
  3. from transformer.SubLayers import MultiHeadAttention, PositionwiseFeedForward
  4. __author__ = "Yu-Hsiang Huang"
  5. class EncoderLayer(nn.Module):
  6. ''' Compose with two layers '''
  7. def __init__(self, d_model, d_inner_hid, n_head, d_k, d_v, dropout=0.1):
  8. super(EncoderLayer, self).__init__()
  9. self.slf_attn = MultiHeadAttention(
  10. n_head, d_model, d_k, d_v, dropout=dropout)
  11. self.pos_ffn = PositionwiseFeedForward(d_model, d_inner_hid, dropout=dropout)
  12. def forward(self, enc_input, slf_attn_mask=None):
  13. enc_output, enc_slf_attn = self.slf_attn(
  14. enc_input, enc_input, enc_input, attn_mask=slf_attn_mask)
  15. enc_output = self.pos_ffn(enc_output)
  16. return enc_output, enc_slf_attn
  17. class DecoderLayer(nn.Module):
  18. ''' Compose with three layers '''
  19. def __init__(self, d_model, d_inner_hid, n_head, d_k, d_v, dropout=0.1):
  20. super(DecoderLayer, self).__init__()
  21. self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
  22. def forward(self, dec_input, slf_attn_mask=None, dec_enc_attn_mask=None):
  23. dec_output, dec_slf_attn = self.slf_attn(
  24. dec_input, dec_input, dec_input, attn_mask=slf_attn_mask)
  25. return dec_output, dec_slf_attn