You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

celeba_configs.py 1.4KB

1234567891011121314151617181920212223242526272829303132333435
  1. from .base_config import BaseConfig, PhaseType
  2. from typing import Union
  3. from torch import nn
  4. from torchvision import transforms
  5. from ..data.content_loaders.celeba_loader import CelebALoader, CelebATag
  6. from ..model_evaluation.binary_evaluator import BinaryEvaluator
  7. class CelebAConfigs(BaseConfig):
  8. def __init__(self,
  9. try_name: str, try_num: int, input_size, phase_type: PhaseType) -> None:
  10. super().__init__(try_name, try_num, 'default', input_size, phase_type, CelebALoader, BinaryEvaluator)
  11. # replaced configs!
  12. self.batch_size = 64
  13. self.iters_per_epoch = 200
  14. self.tags = [tag.name for tag in CelebATag if tag.name.endswith('Tag')]
  15. self.main_tag: Union[CelebATag, None] = None
  16. self.data_root: str = 'data/celeba'
  17. self.dataset_metadata: str = 'dataset_metadata/celeba.tsv'
  18. self.title_of_reference_metric_to_choose_best_epoch = 'AvgSS'
  19. self.operator_to_decide_on_improvement_of_val_reference_metric = '>='
  20. self.keep_best_and_last_epochs_only = True
  21. # augmentation
  22. self.augmentations_dict = {
  23. 'x': nn.Sequential(
  24. transforms.RandomRotation(45),
  25. transforms.RandomAffine(0, shear=0.2, scale=(.8, 1.2)),
  26. transforms.RandomHorizontalFlip(),
  27. transforms.ColorJitter(brightness=.5, contrast=.5, saturation=.5),
  28. transforms.RandomPerspective())
  29. }