| @@ -27,7 +27,9 @@ class MetaLearner(nn.Module): | |||
| # self.hidden_size = out_size | |||
| self.out_size = embed_size | |||
| self.hidden_size = embed_size | |||
| self.rnn = nn.LSTM(embed_size,self.hidden_size,2,dropout=0.2) | |||
| # self.rnn = nn.LSTM(embed_size,self.hidden_size,2,dropout=0.2) | |||
| self.rnn = nn.GRU(input_size=embed_size,hidden_size=self.hidden_size, num_layers=1) | |||
| # nn.init.xavier_normal_(self.rnn.all_weights) | |||
| def forward(self, inputs): | |||
| @@ -35,7 +37,8 @@ class MetaLearner(nn.Module): | |||
| x = torch.stack([inputs[:,0,0,:],inputs[:,0,1,:],inputs[:,1,1,:]],dim=1) | |||
| x = x.transpose(0,1) | |||
| _,(x,c) = self.rnn(x) | |||
| # _,(x,c) = self.rnn(x) | |||
| x,c = self.rnn(x) | |||
| x = x[-1] | |||
| x = x.squeeze(0) | |||