|
|
@@ -27,7 +27,9 @@ class MetaLearner(nn.Module): |
|
|
|
# self.hidden_size = out_size |
|
|
|
self.out_size = embed_size |
|
|
|
self.hidden_size = embed_size |
|
|
|
self.rnn = nn.LSTM(embed_size,self.hidden_size,2,dropout=0.2) |
|
|
|
# self.rnn = nn.LSTM(embed_size,self.hidden_size,2,dropout=0.2) |
|
|
|
self.rnn = nn.GRU(input_size=embed_size,hidden_size=self.hidden_size, num_layers=1) |
|
|
|
|
|
|
|
# nn.init.xavier_normal_(self.rnn.all_weights) |
|
|
|
|
|
|
|
def forward(self, inputs): |
|
|
@@ -35,7 +37,8 @@ class MetaLearner(nn.Module): |
|
|
|
x = torch.stack([inputs[:,0,0,:],inputs[:,0,1,:],inputs[:,1,1,:]],dim=1) |
|
|
|
x = x.transpose(0,1) |
|
|
|
|
|
|
|
_,(x,c) = self.rnn(x) |
|
|
|
# _,(x,c) = self.rnn(x) |
|
|
|
x,c = self.rnn(x) |
|
|
|
x = x[-1] |
|
|
|
x = x.squeeze(0) |
|
|
|
|