You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

modules.py 355B

123456789101112131415161718192021
  1. from torch import nn
  2. import torch
  3. class Add(nn.Module):
  4. def forward(self, inputs):
  5. return torch.add(*inputs)
  6. class Multiply(nn.Module):
  7. def forward(self, inputs):
  8. return torch.mul(*inputs)
  9. class Cat(nn.Module):
  10. def forward(self, inputs, dim):
  11. self.__setattr__('dim', dim)
  12. return torch.cat(inputs, dim)