from torch import nn import torch class Add(nn.Module): def forward(self, inputs): return torch.add(*inputs) class Multiply(nn.Module): def forward(self, inputs): return torch.mul(*inputs) class Cat(nn.Module): def forward(self, inputs, dim): self.__setattr__('dim', dim) return torch.cat(inputs, dim)