123456789101112131415161718192021222324252627282930313233 |
- ''' Define the Layers '''
- import torch.nn as nn
- from transformer.SubLayers import MultiHeadAttention, PositionwiseFeedForward
-
- __author__ = "Yu-Hsiang Huang"
-
- class EncoderLayer(nn.Module):
- ''' Compose with two layers '''
-
- def __init__(self, d_model, d_inner_hid, n_head, d_k, d_v, dropout=0.1):
- super(EncoderLayer, self).__init__()
- self.slf_attn = MultiHeadAttention(
- n_head, d_model, d_k, d_v, dropout=dropout)
- self.pos_ffn = PositionwiseFeedForward(d_model, d_inner_hid, dropout=dropout)
-
- def forward(self, enc_input, slf_attn_mask=None):
- enc_output, enc_slf_attn = self.slf_attn(
- enc_input, enc_input, enc_input, attn_mask=slf_attn_mask)
- enc_output = self.pos_ffn(enc_output)
- return enc_output, enc_slf_attn
-
- class DecoderLayer(nn.Module):
- ''' Compose with three layers '''
-
- def __init__(self, d_model, d_inner_hid, n_head, d_k, d_v, dropout=0.1):
- super(DecoderLayer, self).__init__()
- self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
-
- def forward(self, dec_input, slf_attn_mask=None, dec_enc_attn_mask=None):
- dec_output, dec_slf_attn = self.slf_attn(
- dec_input, dec_input, dec_input, attn_mask=slf_attn_mask)
-
- return dec_output, dec_slf_attn
|