1234567891011121314151617181920212223242526272829303132333435 |
- from .base_config import BaseConfig, PhaseType
- from typing import Union
- from torch import nn
- from torchvision import transforms
- from ..data.content_loaders.celeba_loader import CelebALoader, CelebATag
- from ..model_evaluation.binary_evaluator import BinaryEvaluator
-
- class CelebAConfigs(BaseConfig):
-
- def __init__(self,
- try_name: str, try_num: int, input_size, phase_type: PhaseType) -> None:
- super().__init__(try_name, try_num, 'default', input_size, phase_type, CelebALoader, BinaryEvaluator)
-
- # replaced configs!
- self.batch_size = 64
- self.iters_per_epoch = 200
- self.tags = [tag.name for tag in CelebATag if tag.name.endswith('Tag')]
- self.main_tag: Union[CelebATag, None] = None
- self.data_root: str = 'data/celeba'
- self.dataset_metadata: str = 'dataset_metadata/celeba.tsv'
-
- self.title_of_reference_metric_to_choose_best_epoch = 'AvgSS'
- self.operator_to_decide_on_improvement_of_val_reference_metric = '>='
- self.keep_best_and_last_epochs_only = True
-
- # augmentation
- self.augmentations_dict = {
- 'x': nn.Sequential(
- transforms.RandomRotation(45),
- transforms.RandomAffine(0, shear=0.2, scale=(.8, 1.2)),
- transforms.RandomHorizontalFlip(),
- transforms.ColorJitter(brightness=.5, contrast=.5, saturation=.5),
- transforms.RandomPerspective())
- }
|