from torch import nn from ._common_types import size_2_t class ActivatedConv2d(nn.Sequential): def __init__(self, in_channels: int, out_channels: int, kernel_size: size_2_t, stride: size_2_t = 1, padding: size_2_t = 0, bn: bool = False, activation: nn.Module = nn.ReLU): super().__init__( nn.BatchNorm2d(in_channels) if bn else nn.Identity(), nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding), activation(), )