Browse Source

using GRU instead of LSTM

RNN
mohamad maheri 2 years ago
parent
commit
52f474670f
1 changed files with 5 additions and 2 deletions
  1. 5
    2
      models.py

+ 5
- 2
models.py View File

@@ -27,7 +27,9 @@ class MetaLearner(nn.Module):
# self.hidden_size = out_size
self.out_size = embed_size
self.hidden_size = embed_size
self.rnn = nn.LSTM(embed_size,self.hidden_size,2,dropout=0.2)
# self.rnn = nn.LSTM(embed_size,self.hidden_size,2,dropout=0.2)
self.rnn = nn.GRU(input_size=embed_size,hidden_size=self.hidden_size, num_layers=1)

# nn.init.xavier_normal_(self.rnn.all_weights)

def forward(self, inputs):
@@ -35,7 +37,8 @@ class MetaLearner(nn.Module):
x = torch.stack([inputs[:,0,0,:],inputs[:,0,1,:],inputs[:,1,1,:]],dim=1)
x = x.transpose(0,1)

_,(x,c) = self.rnn(x)
# _,(x,c) = self.rnn(x)
x,c = self.rnn(x)
x = x[-1]
x = x.squeeze(0)


Loading…
Cancel
Save