| # self.hidden_size = out_size | # self.hidden_size = out_size | ||||
| self.out_size = embed_size | self.out_size = embed_size | ||||
| self.hidden_size = embed_size | self.hidden_size = embed_size | ||||
| self.rnn = nn.LSTM(embed_size,self.hidden_size,2,dropout=0.2) | # self.rnn = nn.LSTM(embed_size,self.hidden_size,2,dropout=0.2) | ||||
| self.rnn = nn.GRU(input_size=embed_size,hidden_size=self.hidden_size, num_layers=1) | |||||
| # nn.init.xavier_normal_(self.rnn.all_weights) | # nn.init.xavier_normal_(self.rnn.all_weights) | ||||
| def forward(self, inputs): | def forward(self, inputs): | ||||
| x = torch.stack([inputs[:,0,0,:],inputs[:,0,1,:],inputs[:,1,1,:]],dim=1) | x = torch.stack([inputs[:,0,0,:],inputs[:,0,1,:],inputs[:,1,1,:]],dim=1) | ||||
| x = x.transpose(0,1) | x = x.transpose(0,1) | ||||
| _,(x,c) = self.rnn(x) | # _,(x,c) = self.rnn(x) | ||||
| x,c = self.rnn(x) | |||||
| x = x[-1] | x = x[-1] | ||||
| x = x.squeeze(0) | x = x.squeeze(0) | ||||