from typing import Dict import torch from ..tv_resnet import BasicBlock, ResNet class RSNAORGResNet18(ResNet): def __init__(self): super().__init__(BasicBlock, [2, 2, 2, 2], binary=True) def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]: x = x.repeat_interleave(3, dim=1) # B 3 224 224 return super().forward(x, y)