|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, inputs): |
|
|
def forward(self, inputs): |
|
|
size = inputs.shape |
|
|
size = inputs.shape |
|
|
x = inputs.contiguous().view(size[0]*2, size[2], size[3]) |
|
|
|
|
|
x = x.transpose(1,0) |
|
|
|
|
|
|
|
|
x = torch.stack([inputs[:,0,0,:],inputs[:,0,1,:],inputs[:,1,1,:]],dim=1) |
|
|
|
|
|
x = x.transpose(0,1) |
|
|
|
|
|
|
|
|
_,(x,c) = self.rnn(x) |
|
|
_,(x,c) = self.rnn(x) |
|
|
x = x[-1] |
|
|
x = x[-1] |
|
|
|
|
|
|
|
|
x = x.squeeze(0) |
|
|
x = x.squeeze(0) |
|
|
x = x.reshape(size[0],2,100) |
|
|
|
|
|
x = torch.mean(x, 1) |
|
|
|
|
|
|
|
|
|
|
|
return x.view(size[0], 1, 1, self.out_size) |
|
|
return x.view(size[0], 1, 1, self.out_size) |
|
|
|
|
|
|
|
|
|
|
|
|