|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- ''' Define the sublayers in encoder/decoder layer '''
-
- import torch
- import torch.nn as nn
- import torch.nn.init as init
- from transformer.Modules import BottleLinear as Linear
- from transformer.Modules import ScaledDotProductAttention
- #from transformer.Modules import BottleLayerNormalization as LayerNormalization
- from transformer.Modules import LayerNormalization
-
- __author__ = "Yu-Hsiang Huang"
-
- class MultiHeadAttention(nn.Module):
- ''' Multi-Head Attention module '''
-
- def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
- super(MultiHeadAttention, self).__init__()
-
- self.n_head = n_head
- self.d_k = d_k
- self.d_v = d_v
-
- self.w_qs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
- self.w_ks = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
- self.w_vs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_v))
-
- self.attention = ScaledDotProductAttention(d_model)
- self.layer_norm = LayerNormalization(d_model)
- self.proj = Linear(n_head*d_v, d_model)
-
- self.dropout = nn.Dropout(dropout)
-
- init.xavier_normal(self.w_qs)
- init.xavier_normal(self.w_ks)
- init.xavier_normal(self.w_vs)
-
- def forward(self, q, k, v, attn_mask=None):
-
- d_k, d_v = self.d_k, self.d_v
- n_head = self.n_head
-
- residual = q
-
- mb_size, len_q, d_model = q.size()
- mb_size, len_k, d_model = k.size()
- mb_size, len_v, d_model = v.size()
-
- # treat as a (n_head) size batch
- q_s = q.repeat(n_head, 1, 1).view(n_head, -1, d_model) # n_head x (mb_size*len_q) x d_model
- k_s = k.repeat(n_head, 1, 1).view(n_head, -1, d_model) # n_head x (mb_size*len_k) x d_model
- v_s = v.repeat(n_head, 1, 1).view(n_head, -1, d_model) # n_head x (mb_size*len_v) x d_model
-
- # treat the result as a (n_head * mb_size) size batch
- q_s = torch.bmm(q_s, self.w_qs).view(-1, len_q, d_k) # (n_head*mb_size) x len_q x d_k
- k_s = torch.bmm(k_s, self.w_ks).view(-1, len_k, d_k) # (n_head*mb_size) x len_k x d_k
- v_s = torch.bmm(v_s, self.w_vs).view(-1, len_v, d_v) # (n_head*mb_size) x len_v x d_v
-
- # perform attention, result size = (n_head * mb_size) x len_q x d_v
- outputs, attns = self.attention(q_s, k_s, v_s, attn_mask=attn_mask.repeat(n_head, 1, 1))
-
- # back to original mb_size batch, result size = mb_size x len_q x (n_head*d_v)
- outputs = torch.cat(torch.split(outputs, mb_size, dim=0), dim=-1)
-
- # project back to residual size
- outputs = self.proj(outputs)
- outputs = self.dropout(outputs)
-
- return self.layer_norm(outputs + residual), attns
-
- class PositionwiseFeedForward(nn.Module):
- ''' A two-feed-forward-layer module '''
-
- def __init__(self, d_hid, d_inner_hid, dropout=0.1):
- super(PositionwiseFeedForward, self).__init__()
- self.w_1 = nn.Conv1d(d_hid, d_inner_hid, 1) # position-wise
- self.w_2 = nn.Conv1d(d_inner_hid, d_hid, 1) # position-wise
- self.layer_norm = LayerNormalization(d_hid)
- self.dropout = nn.Dropout(dropout)
- self.relu = nn.ReLU()
-
- def forward(self, x):
- residual = x
- output = self.relu(self.w_1(x.transpose(1, 2)))
- output = self.w_2(output).transpose(2, 1)
- output = self.dropout(output)
- return self.layer_norm(output + residual)
|