123456789101112131415161718192021222324252627282930313233343536373839404142434445 |
- from __future__ import annotations
-
- from dataclasses import dataclass
-
-
- @dataclass
- class CommonParams:
- """
- Common parameters for the application.
- """
-
- model_name: str
- project_name: str
- run_name: str
- save_dpath: str
- save_every: int = 1000
- load_path: str | None = ""
- log_dir: str = "./runs"
- log_every: int = 10
- log_console_every: int = 100
- evaluate_only: bool = False
-
-
- @dataclass
- class TrainingParams:
- """
- Parameters for training.
- """
-
- lr: float = 0.001
- min_lr: float = 0.0001
- weight_decay: float = 0.0
- t0: int = 50
- lr_step: int = 10
- gamma: float = 1.0
- batch_size: int = 128
- num_workers: int = 0
- seed: int = 42
-
- num_train_steps: int = 100
- num_epochs: int = 100
- eval_every: int = 5
-
- validation_sample_size: int = 1000
- validation_batch_size: int = 128
|