from .base_config import BaseConfig, PhaseType from typing import Union from torch import nn from torchvision import transforms from ..data.content_loaders.celeba_loader import CelebALoader, CelebATag from ..model_evaluation.binary_evaluator import BinaryEvaluator class CelebAConfigs(BaseConfig): def __init__(self, try_name: str, try_num: int, input_size, phase_type: PhaseType) -> None: super().__init__(try_name, try_num, 'default', input_size, phase_type, CelebALoader, BinaryEvaluator) # replaced configs! self.batch_size = 64 self.iters_per_epoch = 200 self.tags = [tag.name for tag in CelebATag if tag.name.endswith('Tag')] self.main_tag: Union[CelebATag, None] = None self.data_root: str = 'data/celeba' self.dataset_metadata: str = 'dataset_metadata/celeba.tsv' self.title_of_reference_metric_to_choose_best_epoch = 'AvgSS' self.operator_to_decide_on_improvement_of_val_reference_metric = '>=' self.keep_best_and_last_epochs_only = True # augmentation self.augmentations_dict = { 'x': nn.Sequential( transforms.RandomRotation(45), transforms.RandomAffine(0, shear=0.2, scale=(.8, 1.2)), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=.5, contrast=.5, saturation=.5), transforms.RandomPerspective()) }