Browse Source

using GRU instead of LSTM

RNN
mohamad maheri 2 years ago
parent
commit
52f474670f
1 changed files with 5 additions and 2 deletions
  1. 5
    2
      models.py

+ 5
- 2
models.py View File

# self.hidden_size = out_size # self.hidden_size = out_size
self.out_size = embed_size self.out_size = embed_size
self.hidden_size = embed_size self.hidden_size = embed_size
self.rnn = nn.LSTM(embed_size,self.hidden_size,2,dropout=0.2)
# self.rnn = nn.LSTM(embed_size,self.hidden_size,2,dropout=0.2)
self.rnn = nn.GRU(input_size=embed_size,hidden_size=self.hidden_size, num_layers=1)

# nn.init.xavier_normal_(self.rnn.all_weights) # nn.init.xavier_normal_(self.rnn.all_weights)


def forward(self, inputs): def forward(self, inputs):
x = torch.stack([inputs[:,0,0,:],inputs[:,0,1,:],inputs[:,1,1,:]],dim=1) x = torch.stack([inputs[:,0,0,:],inputs[:,0,1,:],inputs[:,1,1,:]],dim=1)
x = x.transpose(0,1) x = x.transpose(0,1)


_,(x,c) = self.rnn(x)
# _,(x,c) = self.rnn(x)
x,c = self.rnn(x)
x = x[-1] x = x[-1]
x = x.squeeze(0) x = x.squeeze(0)



Loading…
Cancel
Save