@@ -29,15 +29,13 @@ class MetaLearner(nn.Module): | |||
def forward(self, inputs): | |||
size = inputs.shape | |||
x = inputs.contiguous().view(size[0]*2, size[2], size[3]) | |||
x = x.transpose(1,0) | |||
x = torch.stack([inputs[:,0,0,:],inputs[:,0,1,:],inputs[:,1,1,:]],dim=1) | |||
x = x.transpose(0,1) | |||
_,(x,c) = self.rnn(x) | |||
x = x[-1] | |||
x = x.squeeze(0) | |||
x = x.reshape(size[0],2,100) | |||
x = torch.mean(x, 1) | |||
return x.view(size[0], 1, 1, self.out_size) | |||