|
|
2 years ago | |
|---|---|---|
| src | 2 years ago | |
| utils | 2 years ago | |
| visualizations | 2 years ago | |
| .gitignore | 2 years ago | |
| README.md | 2 years ago | |
| requirements-conda.txt | 2 years ago | |
| requirements.txt | 2 years ago | |
Install the requirements using conda from requirements-conda.txt (or using pip from requirements.txt).
My setting: conda 22.11.1 with Python 3.8
If you have trouble with installing torch check this link to find the currect version for your device.
src directory.main.py with appropriate arguments.Examples:
python main.py --model_names MMD CORAL --batch_size 32
python main.py --max_epochs 10 --patience 3 --trials_count 1 --model_names DANN --num_workers 2 --batch_size 32 --source amazon --target webcam --hp_tune True
Check “Available arguments” in “Code Structure” section for all the available arguments.
The main code for project is located in the src/ directory.
main.py: The entry to program
Available arguments:
max_epochs: Maximum number of epochs to run the training
patience: Maximum number of epochs to continue if no improvement is seen (Early stopping parameter)
batch_size
num_workers
trials_count: Number of trials to run each of the tasks
initial_trial: The number to start indexing the trials from
download: Whether to download the dataset or not
root: Path to the root of project
data_root: Path to the data root
results_root: Path to the directory to store the results
model_names: Names of models to run separated by space - available options: DANN, CDAN, MMD, MCD, CORAL, SOURCE
lr: learning rate**
gamma: Gamma value for ExponentialLR**
hp_tune: Set true of you want to run for different hyperparameters, used for hyperparameter tuning
source: The source domain to run the training for, training will run for all the available domains if not specified - available options: amazon, dslr, webcam
target: The target domain to run the training for, training will run for all the available domains if not specified - available options: amazon, dslr, webcam
vishook_frequency: Number of epochs to wait before save a visualization
source_checkpoint_base_dir: Path to source-only trained model directory to use as base, set None to not use source-trained model***
source_checkpoint_trial_number: Trail number of source-only trained model to use
models.py: Contains models for adaptation
train.py: Contains base training iteration
classifier_adapter.py: Contains ClassifierAdapter class which is used for training a source-only model without adaptation
load_source.py: Load source-only trained model to use as base model for adaptation
source.py: Contains source model
train_source.py: Contains base source-only training iteration
utils.py: Contains utility classes
vis_hook.py: Contains VizHook class which is used for visualization
** Use can also set different lr and gammas for different models and tasks by changing hp_map in main.py directly.
*** For perfoming domain adaptation on source-trained model, one must should train the model for source using option --model_name SOURCE first