You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

activated_conv.py 548B

1 year ago
123456789101112131415
  1. from torch import nn
  2. from ._common_types import size_2_t
  3. class ActivatedConv2d(nn.Sequential):
  4. def __init__(self, in_channels: int, out_channels: int, kernel_size: size_2_t,
  5. stride: size_2_t = 1, padding: size_2_t = 0, bn: bool = False,
  6. activation: nn.Module = nn.ReLU):
  7. super().__init__(
  8. nn.BatchNorm2d(in_channels) if bn else nn.Identity(),
  9. nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
  10. activation(),
  11. )