You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

MyModel.py 8.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import math, copy, time
  6. import matplotlib.pyplot as plt
  7. from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
  8. from IPython.core.debugger import set_trace
  9. # we will use CUDA if it is available
  10. USE_CUDA = torch.cuda.is_available()
  11. DEVICE=torch.device('cpu') # or set to 'cpu'
  12. print("CUDA:", USE_CUDA)
  13. print(DEVICE)
  14. #seed = 42
  15. #np.random.seed(seed)
  16. #torch.manual_seed(seed)
  17. #torch.cuda.manual_seed(seed)
  18. class EncoderDecoder(nn.Module):
  19. """
  20. A standard Encoder-Decoder architecture. Base for this and many
  21. other models.
  22. """
  23. def __init__(self, encoder, decoder, src_embed, trg_embed, generator):
  24. super(EncoderDecoder, self).__init__()
  25. self.encoder = encoder
  26. self.decoder = decoder
  27. self.src_embed = src_embed
  28. self.trg_embed = trg_embed
  29. self.generator = generator
  30. def forward(self, src, trg, src_mask, trg_mask, src_lengths, trg_lengths):
  31. """Take in and process masked src and target sequences."""
  32. encoder_hidden, encoder_final = self.encode(src, src_mask, src_lengths)
  33. return self.decode(encoder_hidden, encoder_final, src_mask, trg, trg_mask)
  34. def encode(self, src, src_mask, src_lengths):
  35. return self.encoder(self.src_embed(src), src_mask, src_lengths)
  36. # return self.encoder(src, src_mask, src_lengths)
  37. def decode(self, encoder_hidden, encoder_final, src_mask, trg, trg_mask,
  38. decoder_hidden=None):
  39. return self.decoder(self.trg_embed(trg), encoder_hidden, encoder_final,
  40. src_mask, trg_mask, hidden=decoder_hidden)
  41. class Generator(nn.Module):
  42. """Define standard linear + softmax generation step."""
  43. def __init__(self, hidden_size, vocab_size):
  44. super(Generator, self).__init__()
  45. self.proj = nn.Linear(hidden_size, vocab_size, bias=False)
  46. def forward(self, x):
  47. return F.log_softmax(self.proj(x), dim=-1)
  48. class Encoder(nn.Module):
  49. """Encodes a sequence of word embeddings"""
  50. def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.):
  51. super(Encoder, self).__init__()
  52. self.num_layers = num_layers
  53. self.rnn = nn.GRU(input_size, hidden_size, num_layers,
  54. batch_first=True, bidirectional=True, dropout=dropout)
  55. def forward(self, x, mask, lengths):
  56. """
  57. Applies a bidirectional GRU to sequence of embeddings x.
  58. The input mini-batch x needs to be sorted by length.
  59. x should have dimensions [batch, time, dim].
  60. """
  61. packed = pack_padded_sequence(x, lengths, batch_first=True)
  62. output, final = self.rnn(packed)
  63. output, _ = pad_packed_sequence(output, batch_first=True)
  64. # we need to manually concatenate the final states for both directions
  65. fwd_final = final[0:final.size(0):2]
  66. bwd_final = final[1:final.size(0):2]
  67. final = torch.cat([fwd_final, bwd_final], dim=2) # [num_layers, batch, 2*dim]
  68. return output, final
  69. class Decoder(nn.Module):
  70. """A conditional RNN decoder with attention."""
  71. def __init__(self, emb_size, hidden_size, attention, num_layers=1, dropout=0.5,
  72. bridge=True):
  73. super(Decoder, self).__init__()
  74. self.hidden_size = hidden_size
  75. self.num_layers = num_layers
  76. self.attention = attention
  77. self.dropout = dropout
  78. self.rnn = nn.GRU(emb_size + 2*hidden_size, hidden_size, num_layers,
  79. batch_first=True, dropout=dropout)
  80. # to initialize from the final encoder state
  81. self.bridge = nn.Linear(2*hidden_size, hidden_size, bias=True) if bridge else None
  82. self.dropout_layer = nn.Dropout(p=dropout)
  83. self.pre_output_layer = nn.Linear(hidden_size + 2*hidden_size + emb_size,
  84. hidden_size, bias=False)
  85. def forward_step(self, prev_embed, encoder_hidden, src_mask, proj_key, hidden):
  86. """Perform a single decoder step (1 word)"""
  87. # compute context vector using attention mechanism
  88. query = hidden[-1].unsqueeze(1) # [#layers, B, D] -> [B, 1, D]
  89. context, attn_probs = self.attention(
  90. query=query, proj_key=proj_key,
  91. value=encoder_hidden, mask=src_mask)
  92. # update rnn hidden state
  93. rnn_input = torch.cat([prev_embed, context], dim=2)
  94. output, hidden = self.rnn(rnn_input, hidden)
  95. pre_output = torch.cat([prev_embed, output, context], dim=2)
  96. pre_output = self.dropout_layer(pre_output)
  97. pre_output = self.pre_output_layer(pre_output)
  98. return output, hidden, pre_output
  99. def forward(self, trg_embed, encoder_hidden, encoder_final,
  100. src_mask, trg_mask, hidden=None, max_len=None):
  101. """Unroll the decoder one step at a time."""
  102. # the maximum number of steps to unroll the RNN
  103. if max_len is None:
  104. max_len = trg_mask.size(-1)
  105. # initialize decoder hidden state
  106. if hidden is None:
  107. hidden = self.init_hidden(encoder_final)
  108. # pre-compute projected encoder hidden states
  109. # (the "keys" for the attention mechanism)
  110. # this is only done for efficiency
  111. proj_key = self.attention.key_layer(encoder_hidden)
  112. # here we store all intermediate hidden states and pre-output vectors
  113. decoder_states = []
  114. pre_output_vectors = []
  115. # unroll the decoder RNN for max_len steps
  116. for i in range(max_len):
  117. prev_embed = trg_embed[:, i].unsqueeze(1)
  118. output, hidden, pre_output = self.forward_step(
  119. prev_embed, encoder_hidden, src_mask, proj_key, hidden)
  120. decoder_states.append(output)
  121. pre_output_vectors.append(pre_output)
  122. decoder_states = torch.cat(decoder_states, dim=1)
  123. pre_output_vectors = torch.cat(pre_output_vectors, dim=1)
  124. return decoder_states, hidden, pre_output_vectors # [B, N, D]
  125. def init_hidden(self, encoder_final):
  126. """Returns the initial decoder state,
  127. conditioned on the final encoder state."""
  128. if encoder_final is None:
  129. return None # start with zeros
  130. return torch.tanh(self.bridge(encoder_final))
  131. class BahdanauAttention(nn.Module):
  132. """Implements Bahdanau (MLP) attention"""
  133. def __init__(self, hidden_size, key_size=None, query_size=None):
  134. super(BahdanauAttention, self).__init__()
  135. # We assume a bi-directional encoder so key_size is 2*hidden_size
  136. key_size = 2 * hidden_size if key_size is None else key_size
  137. query_size = hidden_size if query_size is None else query_size
  138. self.key_layer = nn.Linear(key_size, hidden_size, bias=False)
  139. self.query_layer = nn.Linear(query_size, hidden_size, bias=False)
  140. self.energy_layer = nn.Linear(hidden_size, 1, bias=False)
  141. # to store attention scores
  142. self.alphas = None
  143. def forward(self, query=None, proj_key=None, value=None, mask=None):
  144. assert mask is not None, "mask is required"
  145. # We first project the query (the decoder state).
  146. # The projected keys (the encoder states) were already pre-computated.
  147. query = self.query_layer(query)
  148. # Calculate scores.
  149. scores = self.energy_layer(torch.tanh(query + proj_key))
  150. scores = scores.squeeze(2).unsqueeze(1)
  151. # Mask out invalid positions.
  152. # The mask marks valid positions so we invert it using `mask & 0`.
  153. scores.data.masked_fill_(mask == 0, -float('inf'))
  154. # Turn scores to probabilities.
  155. alphas = F.softmax(scores, dim=-1)
  156. self.alphas = alphas
  157. # The context vector is the weighted sum of the values.
  158. context = torch.bmm(alphas, value)
  159. # context shape: [B, 1, 2D], alphas shape: [B, 1, M]
  160. return context, alphas