You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

README.md 3.7KB

1 year ago
1 year ago
1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # Domain-Adaptation
  2. This project contains code for different domain adaptation methods on Office31 dataset.
  3. Available methods include: DANN, CDAN, MCD, CORAL, MMD
  4. This project can be easily extended to use on other datasets or perform other adaptaion methods. (Check Code Structure to find out where you need to change.)
  5. ## Prepare Environment
  6. Install the requirements using conda from `requirements-conda.txt` (or using pip from `requirements.txt`).
  7. *My setting:* conda 22.11.1 with Python 3.8
  8. If you have trouble with installing torch check [this link](https://pytorch.org/get-started/previous-versions/)
  9. to find the currect version for your device.
  10. ## Run
  11. 1. Go to `src` directory.
  12. 2. Run `main.py` with appropriate arguments.
  13. Examples:
  14. - Perform MMD and CORAL adaptation on all 6 domain adaptations tasks from office31 dataset:
  15. ```bash
  16. python main.py --model_names MMD CORAL --batch_size 32
  17. ```
  18. - Tuning parameters of DANN model
  19. ```bash
  20. python main.py --max_epochs 10 --patience 3 --trials_count 1 --model_names DANN --num_workers 2 --batch_size 32 --source amazon --target webcam --hp_tune True
  21. ```
  22. Check "Available arguments" in "Code Structure" section for all the available arguments.
  23. ## Code Structure
  24. The main code for project is located in the `src/` directory.
  25. - `main.py`: The entry to program
  26. - **Available arguments:**
  27. - `max_epochs`: Maximum number of epochs to run the training
  28. - `patience`: Maximum number of epochs to continue if no improvement is seen (Early stopping parameter)
  29. - `batch_size`
  30. - `num_workers`
  31. - `trials_count`: Number of trials to run each of the tasks
  32. - `initial_trial`: The number to start indexing the trials from
  33. - `download`: Whether to download the dataset or not
  34. - `root`: Path to the root of project
  35. - `data_root`: Path to the data root
  36. - `results_root`: Path to the directory to store the results
  37. - `model_names`: Names of models to run separated by space - available options: DANN, CDAN, MMD, MCD, CORAL, SOURCE
  38. - `lr`: learning rate**
  39. - `gamma`: Gamma value for ExponentialLR**
  40. - `hp_tune`: Set true of you want to run for different hyperparameters, used for hyperparameter tuning
  41. - `source`: The source domain to run the training for, training will run for all the available domains if not specified - available options: amazon, dslr, webcam
  42. - `target`: The target domain to run the training for, training will run for all the available domains if not specified - available options: amazon, dslr, webcam
  43. - `vishook_frequency`: Number of epochs to wait before save a visualization
  44. - `source_checkpoint_base_dir`: Path to source-only trained model directory to use as base, set `None` to not use source-trained model***
  45. - `source_checkpoint_trial_number`: Trail number of source-only trained model to use
  46. - `models.py`: Contains models for adaptation
  47. - `train.py`: Contains base training iteration, dataset is also loaded here
  48. - `classifier_adapter.py`: Contains ClassifierAdapter class which is used for training a source-only model without adaptation
  49. - `load_source.py`: Load source-only trained model to use as base model for adaptation
  50. - `source.py`: Contains source model
  51. - `train_source.py`: Contains base source-only training iteration
  52. - `utils.py`: Contains utility classes
  53. - `vis_hook.py`: Contains VizHook class which is used for visualization
  54. ** Use can also set different lr and gammas for different models and tasks by changing `hp_map` in `main.py` directly.
  55. *** For perfoming domain adaptation on source-trained model, one must should train the model for source using option `--model_name SOURCE` first
  56. ## Acknowledgements
  57. [Pytorch Adapt](https://github.com/KevinMusgrave/pytorch-adapt/tree/0b0fb63b04c9bd7e2cc6cf45314c7ee9d6e391c0)