You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

SubLayers.py 3.3KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. ''' Define the sublayers in encoder/decoder layer '''
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.init as init
  5. from transformer.Modules import BottleLinear as Linear
  6. from transformer.Modules import ScaledDotProductAttention
  7. #from transformer.Modules import BottleLayerNormalization as LayerNormalization
  8. from transformer.Modules import LayerNormalization
  9. __author__ = "Yu-Hsiang Huang"
  10. class MultiHeadAttention(nn.Module):
  11. ''' Multi-Head Attention module '''
  12. def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
  13. super(MultiHeadAttention, self).__init__()
  14. self.n_head = n_head
  15. self.d_k = d_k
  16. self.d_v = d_v
  17. self.w_qs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
  18. self.w_ks = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
  19. self.w_vs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_v))
  20. self.attention = ScaledDotProductAttention(d_model)
  21. self.layer_norm = LayerNormalization(d_model)
  22. self.proj = Linear(n_head*d_v, d_model)
  23. self.dropout = nn.Dropout(dropout)
  24. init.xavier_normal(self.w_qs)
  25. init.xavier_normal(self.w_ks)
  26. init.xavier_normal(self.w_vs)
  27. def forward(self, q, k, v, attn_mask=None):
  28. d_k, d_v = self.d_k, self.d_v
  29. n_head = self.n_head
  30. residual = q
  31. mb_size, len_q, d_model = q.size()
  32. mb_size, len_k, d_model = k.size()
  33. mb_size, len_v, d_model = v.size()
  34. # treat as a (n_head) size batch
  35. q_s = q.repeat(n_head, 1, 1).view(n_head, -1, d_model) # n_head x (mb_size*len_q) x d_model
  36. k_s = k.repeat(n_head, 1, 1).view(n_head, -1, d_model) # n_head x (mb_size*len_k) x d_model
  37. v_s = v.repeat(n_head, 1, 1).view(n_head, -1, d_model) # n_head x (mb_size*len_v) x d_model
  38. # treat the result as a (n_head * mb_size) size batch
  39. q_s = torch.bmm(q_s, self.w_qs).view(-1, len_q, d_k) # (n_head*mb_size) x len_q x d_k
  40. k_s = torch.bmm(k_s, self.w_ks).view(-1, len_k, d_k) # (n_head*mb_size) x len_k x d_k
  41. v_s = torch.bmm(v_s, self.w_vs).view(-1, len_v, d_v) # (n_head*mb_size) x len_v x d_v
  42. # perform attention, result size = (n_head * mb_size) x len_q x d_v
  43. outputs, attns = self.attention(q_s, k_s, v_s, attn_mask=attn_mask.repeat(n_head, 1, 1))
  44. # back to original mb_size batch, result size = mb_size x len_q x (n_head*d_v)
  45. outputs = torch.cat(torch.split(outputs, mb_size, dim=0), dim=-1)
  46. # project back to residual size
  47. outputs = self.proj(outputs)
  48. outputs = self.dropout(outputs)
  49. return self.layer_norm(outputs + residual), attns
  50. class PositionwiseFeedForward(nn.Module):
  51. ''' A two-feed-forward-layer module '''
  52. def __init__(self, d_hid, d_inner_hid, dropout=0.1):
  53. super(PositionwiseFeedForward, self).__init__()
  54. self.w_1 = nn.Conv1d(d_hid, d_inner_hid, 1) # position-wise
  55. self.w_2 = nn.Conv1d(d_inner_hid, d_hid, 1) # position-wise
  56. self.layer_norm = LayerNormalization(d_hid)
  57. self.dropout = nn.Dropout(dropout)
  58. self.relu = nn.ReLU()
  59. def forward(self, x):
  60. residual = x
  61. output = self.relu(self.w_1(x.transpose(1, 2)))
  62. output = self.w_2(output).transpose(2, 1)
  63. output = self.dropout(output)
  64. return self.layer_norm(output + residual)