123456789101112131415 |
- from typing import Dict
-
- import torch
-
- from ..tv_resnet import BasicBlock, ResNet
-
-
- class RSNAORGResNet18(ResNet):
-
- def __init__(self):
- super().__init__(BasicBlock, [2, 2, 2, 2], binary=True)
-
- def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> Dict[str, torch.Tensor]:
- x = x.repeat_interleave(3, dim=1) # B 3 224 224
- return super().forward(x, y)
|