123456789101112131415161718192021 |
- from torch import nn
- import torch
-
-
- class Add(nn.Module):
-
- def forward(self, inputs):
- return torch.add(*inputs)
-
-
- class Multiply(nn.Module):
-
- def forward(self, inputs):
- return torch.mul(*inputs)
-
-
- class Cat(nn.Module):
-
- def forward(self, inputs, dim):
- self.__setattr__('dim', dim)
- return torch.cat(inputs, dim)
|