You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Modules.py 3.1KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.init as init
  4. import numpy as np
  5. __author__ = "Yu-Hsiang Huang"
  6. class Linear(nn.Module):
  7. ''' Simple Linear layer with xavier init '''
  8. def __init__(self, d_in, d_out, bias=True):
  9. super(Linear, self).__init__()
  10. self.linear = nn.Linear(d_in, d_out, bias=bias)
  11. init.xavier_normal(self.linear.weight)
  12. def forward(self, x):
  13. return self.linear(x)
  14. class Bottle(nn.Module):
  15. ''' Perform the reshape routine before and after an operation '''
  16. def forward(self, input):
  17. if len(input.size()) <= 2:
  18. return super(Bottle, self).forward(input)
  19. size = input.size()[:2]
  20. out = super(Bottle, self).forward(input.view(size[0]*size[1], -1))
  21. return out.view(size[0], size[1], -1)
  22. class BottleLinear(Bottle, Linear):
  23. ''' Perform the reshape routine before and after a linear projection '''
  24. pass
  25. class BottleSoftmax(Bottle, nn.Softmax):
  26. ''' Perform the reshape routine before and after a softmax operation'''
  27. pass
  28. class LayerNormalization(nn.Module):
  29. ''' Layer normalization module '''
  30. def __init__(self, d_hid, eps=1e-3):
  31. super(LayerNormalization, self).__init__()
  32. self.eps = eps
  33. self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True)
  34. self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True)
  35. def forward(self, z):
  36. if z.size(1) == 1:
  37. return z
  38. mu = torch.mean(z, keepdim=True, dim=-1)
  39. sigma = torch.std(z, keepdim=True, dim=-1)
  40. ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
  41. ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)
  42. return ln_out
  43. class BatchBottle(nn.Module):
  44. ''' Perform the reshape routine before and after an operation '''
  45. def forward(self, input):
  46. if len(input.size()) <= 2:
  47. return super(BatchBottle, self).forward(input)
  48. size = input.size()[1:]
  49. out = super(BatchBottle, self).forward(input.view(-1, size[0]*size[1]))
  50. return out.view(-1, size[0], size[1])
  51. class BottleLayerNormalization(BatchBottle, LayerNormalization):
  52. ''' Perform the reshape routine before and after a layer normalization'''
  53. pass
  54. class ScaledDotProductAttention(nn.Module):
  55. ''' Scaled Dot-Product Attention '''
  56. def __init__(self, d_model, attn_dropout=0.1):
  57. super(ScaledDotProductAttention, self).__init__()
  58. self.temper = np.power(d_model, 0.5)
  59. self.dropout = nn.Dropout(attn_dropout)
  60. self.softmax = BottleSoftmax()
  61. def forward(self, q, k, v, attn_mask=None):
  62. attn = torch.bmm(q, k.transpose(1, 2)) / self.temper
  63. if attn_mask is not None:
  64. assert attn_mask.size() == attn.size(), \
  65. 'Attention mask shape {} mismatch ' \
  66. 'with Attention logit tensor shape ' \
  67. '{}.'.format(attn_mask.size(), attn.size())
  68. attn.data.masked_fill_(attn_mask, -float('inf'))
  69. attn = self.softmax(attn)
  70. attn = self.dropout(attn)
  71. output = torch.bmm(attn, v)
  72. return output, attn