Browse Source

input all sequence of items to a LSTM (instead of averaging pairs)

RNN
mohamad maheri 2 years ago
parent
commit
9cbfd2aff6
1 changed files with 3 additions and 5 deletions
  1. 3
    5
      models.py

+ 3
- 5
models.py View File

@@ -29,15 +29,13 @@ class MetaLearner(nn.Module):

def forward(self, inputs):
size = inputs.shape
x = inputs.contiguous().view(size[0]*2, size[2], size[3])
x = x.transpose(1,0)
x = torch.stack([inputs[:,0,0,:],inputs[:,0,1,:],inputs[:,1,1,:]],dim=1)
x = x.transpose(0,1)

_,(x,c) = self.rnn(x)
x = x[-1]

x = x.squeeze(0)
x = x.reshape(size[0],2,100)
x = torch.mean(x, 1)

return x.view(size[0], 1, 1, self.out_size)



Loading…
Cancel
Save