|
123456789101112131415 |
- from torch import nn
-
- from ._common_types import size_2_t
-
-
- class ActivatedConv2d(nn.Sequential):
-
- def __init__(self, in_channels: int, out_channels: int, kernel_size: size_2_t,
- stride: size_2_t = 1, padding: size_2_t = 0, bn: bool = False,
- activation: nn.Module = nn.ReLU):
- super().__init__(
- nn.BatchNorm2d(in_channels) if bn else nn.Identity(),
- nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
- activation(),
- )
|