from typing import Tuple from torchlap.configs.imagenet_configs import ImagenetConfigs from ...models.model import Model from ..entrypoint import BaseEntrypoint from ...models.imagenet.lap_resnet import ImagenetLAPResNet50 class EntryPoint(BaseEntrypoint): def _get_conf_model(self) -> Tuple[ImagenetConfigs, Model]: config = ImagenetConfigs('ImagenetFT', 2, 224, self.phase_type) model = ImagenetLAPResNet50(sigmoid_scale=0.1) config.freezing_regexes = [ r'(?!^layer4\..*$)(?!^fc\..*$)(^.*$)' ] return config, model