Saba Hashemi 0f008a29aa Add project | 1 year ago | |
---|---|---|
src | 1 year ago | |
utils | 1 year ago | |
visualizations | 1 year ago | |
.gitignore | 1 year ago | |
README.md | 1 year ago | |
requirements-conda.txt | 1 year ago | |
requirements.txt | 1 year ago |
Install the requirements using conda from requirements-conda.txt
(or using pip from requirements.txt
).
My setting: conda 22.11.1 with Python 3.8
If you have trouble with installing torch check this link to find the currect version for your device.
src
directory.main.py
with appropriate arguments.Examples:
python main.py --model_names MMD CORAL --batch_size 32
python main.py --max_epochs 10 --patience 3 --trials_count 1 --model_names DANN --num_workers 2 --batch_size 32 --source amazon --target webcam --hp_tune True
Check “Available arguments” in “Code Structure” section for all the available arguments.
The main code for project is located in the src/
directory.
main.py
: The entry to program
Available arguments:
max_epochs
: Maximum number of epochs to run the training
patience
: Maximum number of epochs to continue if no improvement is seen (Early stopping parameter)
batch_size
num_workers
trials_count
: Number of trials to run each of the tasks
initial_trial
: The number to start indexing the trials from
download
: Whether to download the dataset or not
root
: Path to the root of project
data_root
: Path to the data root
results_root
: Path to the directory to store the results
model_names
: Names of models to run separated by space - available options: DANN, CDAN, MMD, MCD, CORAL, SOURCE
lr
: learning rate**
gamma
: Gamma value for ExponentialLR**
hp_tune
: Set true of you want to run for different hyperparameters, used for hyperparameter tuning
source
: The source domain to run the training for, training will run for all the available domains if not specified - available options: amazon, dslr, webcam
target
: The target domain to run the training for, training will run for all the available domains if not specified - available options: amazon, dslr, webcam
vishook_frequency
: Number of epochs to wait before save a visualization
source_checkpoint_base_dir
: Path to source-only trained model directory to use as base, set None
to not use source-trained model***
source_checkpoint_trial_number
: Trail number of source-only trained model to use
models.py
: Contains models for adaptation
train.py
: Contains base training iteration
classifier_adapter.py
: Contains ClassifierAdapter class which is used for training a source-only model without adaptation
load_source.py
: Load source-only trained model to use as base model for adaptation
source.py
: Contains source model
train_source.py
: Contains base source-only training iteration
utils.py
: Contains utility classes
vis_hook.py
: Contains VizHook class which is used for visualization
** Use can also set different lr and gammas for different models and tasks by changing hp_map
in main.py
directly.
*** For perfoming domain adaptation on source-trained model, one must should train the model for source using option --model_name SOURCE
first