Saba Hashemi 301d418879 Update 'README.md' | 1 year ago | |
---|---|---|
src | 1 year ago | |
utils | 1 year ago | |
visualizations | 1 year ago | |
.gitignore | 1 year ago | |
README.md | 1 year ago | |
requirements-conda.txt | 1 year ago | |
requirements.txt | 1 year ago |
This project contains code for different domain adaptation methods on Office31 dataset.
Available methods include: DANN, CDAN, MCD, CORAL, MMD
This project can be easily extended to use on other datasets or perform other adaptaion methods. (Check Code Structure to find out where you need to change.)
Install the requirements using conda from requirements-conda.txt
(or using pip from requirements.txt
).
My setting: conda 22.11.1 with Python 3.8
If you have trouble with installing torch check this link to find the currect version for your device.
src
directory.main.py
with appropriate arguments.Examples:
python main.py --model_names MMD CORAL --batch_size 32
python main.py --max_epochs 10 --patience 3 --trials_count 1 --model_names DANN --num_workers 2 --batch_size 32 --source amazon --target webcam --hp_tune True
Check “Available arguments” in “Code Structure” section for all the available arguments.
The main code for project is located in the src/
directory.
main.py
: The entry to program
Available arguments:
max_epochs
: Maximum number of epochs to run the training
patience
: Maximum number of epochs to continue if no improvement is seen (Early stopping parameter)
batch_size
num_workers
trials_count
: Number of trials to run each of the tasks
initial_trial
: The number to start indexing the trials from
download
: Whether to download the dataset or not
root
: Path to the root of project
data_root
: Path to the data root
results_root
: Path to the directory to store the results
model_names
: Names of models to run separated by space - available options: DANN, CDAN, MMD, MCD, CORAL, SOURCE
lr
: learning rate**
gamma
: Gamma value for ExponentialLR**
hp_tune
: Set true of you want to run for different hyperparameters, used for hyperparameter tuning
source
: The source domain to run the training for, training will run for all the available domains if not specified - available options: amazon, dslr, webcam
target
: The target domain to run the training for, training will run for all the available domains if not specified - available options: amazon, dslr, webcam
vishook_frequency
: Number of epochs to wait before save a visualization
source_checkpoint_base_dir
: Path to source-only trained model directory to use as base, set None
to not use source-trained model***
source_checkpoint_trial_number
: Trail number of source-only trained model to use
models.py
: Contains models for adaptation
train.py
: Contains base training iteration, dataset is also loaded here
classifier_adapter.py
: Contains ClassifierAdapter class which is used for training a source-only model without adaptation
load_source.py
: Load source-only trained model to use as base model for adaptation
source.py
: Contains source model
train_source.py
: Contains base source-only training iteration
utils.py
: Contains utility classes
vis_hook.py
: Contains VizHook class which is used for visualization
** Use can also set different lr and gammas for different models and tasks by changing hp_map
in main.py
directly.
*** For perfoming domain adaptation on source-trained model, one must should train the model for source using option --model_name SOURCE
first