| def forward(self, inputs): | def forward(self, inputs): | ||||
| size = inputs.shape | size = inputs.shape | ||||
| x = inputs.contiguous().view(size[0]*2, size[2], size[3]) | |||||
| x = x.transpose(1,0) | |||||
| x = torch.stack([inputs[:,0,0,:],inputs[:,0,1,:],inputs[:,1,1,:]],dim=1) | |||||
| x = x.transpose(0,1) | |||||
| _,(x,c) = self.rnn(x) | _,(x,c) = self.rnn(x) | ||||
| x = x[-1] | x = x[-1] | ||||
| x = x.squeeze(0) | x = x.squeeze(0) | ||||
| x = x.reshape(size[0],2,100) | |||||
| x = torch.mean(x, 1) | |||||
| return x.view(size[0], 1, 1, self.out_size) | return x.view(size[0], 1, 1, self.out_size) | ||||