| @@ -0,0 +1,9 @@ | |||
| datasets | |||
| results | |||
| model-weights | |||
| other | |||
| saved_models | |||
| __pycache__ | |||
| .vscode | |||
| venv | |||
| @@ -0,0 +1,100 @@ | |||
| # Domain-Adaptation | |||
| ## Prepare Environment | |||
| Install the requirements using conda from `requirements-conda.txt` (or using pip from `requirements.txt`). | |||
| *My setting:* conda 22.11.1 with Python 3.8 | |||
| If you have trouble with installing torch check [this link](https://pytorch.org/get-started/previous-versions/) | |||
| to find the currect version for your device. | |||
| ## Run | |||
| 1. Go to `src` directory. | |||
| 2. Run `main.py` with appropriate arguments. | |||
| Examples: | |||
| - Perform MMD and CORAL adaptation on all 6 domain adaptations tasks from office31 dataset: | |||
| ```bash | |||
| python main.py --model_names MMD CORAL --batch_size 32 | |||
| ``` | |||
| - Tuning parameters of DANN model | |||
| ```bash | |||
| python main.py --max_epochs 10 --patience 3 --trials_count 1 --model_names DANN --num_workers 2 --batch_size 32 --source amazon --target webcam --hp_tune True | |||
| ``` | |||
| Check "Available arguments" in "Code Structure" section for all the available arguments. | |||
| ## Code Structure | |||
| The main code for project is located in the `src/` directory. | |||
| - `main.py`: The entry to program | |||
| - **Available arguments:** | |||
| - `max_epochs`: Maximum number of epochs to run the training | |||
| - `patience`: Maximum number of epochs to continue if no improvement is seen (Early stopping parameter) | |||
| - `batch_size` | |||
| - `num_workers` | |||
| - `trials_count`: Number of trials to run each of the tasks | |||
| - `initial_trial`: The number to start indexing the trials from | |||
| - `download`: Whether to download the dataset or not | |||
| - `root`: Path to the root of project | |||
| - `data_root`: Path to the data root | |||
| - `results_root`: Path to the directory to store the results | |||
| - `model_names`: Names of models to run separated by space - available options: DANN, CDAN, MMD, MCD, CORAL, SOURCE | |||
| - `lr`: learning rate** | |||
| - `gamma`: Gamma value for ExponentialLR** | |||
| - `hp_tune`: Set true of you want to run for different hyperparameters, used for hyperparameter tuning | |||
| - `source`: The source domain to run the training for, training will run for all the available domains if not specified - available options: amazon, dslr, webcam | |||
| - `target`: The target domain to run the training for, training will run for all the available domains if not specified - available options: amazon, dslr, webcam | |||
| - `vishook_frequency`: Number of epochs to wait before save a visualization | |||
| - `source_checkpoint_base_dir`: Path to source-only trained model directory to use as base, set `None` to not use source-trained model*** | |||
| - `source_checkpoint_trial_number`: Trail number of source-only trained model to use | |||
| - `models.py`: Contains models for adaptation | |||
| - `train.py`: Contains base training iteration | |||
| - `classifier_adapter.py`: Contains ClassifierAdapter class which is used for training a source-only model without adaptation | |||
| - `load_source.py`: Load source-only trained model to use as base model for adaptation | |||
| - `source.py`: Contains source model | |||
| - `train_source.py`: Contains base source-only training iteration | |||
| - `utils.py`: Contains utility classes | |||
| - `vis_hook.py`: Contains VizHook class which is used for visualization | |||
| ** Use can also set different lr and gammas for different models and tasks by changing `hp_map` in `main.py` directly. | |||
| *** For perfoming domain adaptation on source-trained model, one must should train the model for source using option `--model_name SOURCE` first | |||
| ## Acknowledgements | |||
| [Pytorch Adapt](https://github.com/KevinMusgrave/pytorch-adapt/tree/0b0fb63b04c9bd7e2cc6cf45314c7ee9d6e391c0) | |||
| @@ -0,0 +1,211 @@ | |||
| # This file may be used to create an environment using: | |||
| # $ conda create --name <env> --file <this file> | |||
| # platform: linux-64 | |||
| _libgcc_mutex=0.1=main | |||
| _openmp_mutex=5.1=1_gnu | |||
| _pytorch_select=0.1=cpu_0 | |||
| anyio=3.5.0=py38h06a4308_0 | |||
| argon2-cffi=21.3.0=pyhd3eb1b0_0 | |||
| argon2-cffi-bindings=21.2.0=py38h7f8727e_0 | |||
| asttokens=2.0.5=pyhd3eb1b0_0 | |||
| attrs=22.1.0=py38h06a4308_0 | |||
| babel=2.11.0=py38h06a4308_0 | |||
| backcall=0.2.0=pyhd3eb1b0_0 | |||
| beautifulsoup4=4.11.1=py38h06a4308_0 | |||
| blas=1.0=mkl | |||
| bleach=4.1.0=pyhd3eb1b0_0 | |||
| brotli=1.0.9=h5eee18b_7 | |||
| brotli-bin=1.0.9=h5eee18b_7 | |||
| brotlipy=0.7.0=py38h27cfd23_1003 | |||
| ca-certificates=2022.10.11=h06a4308_0 | |||
| certifi=2022.12.7=py38h06a4308_0 | |||
| cffi=1.15.1=py38h5eee18b_3 | |||
| charset-normalizer=2.0.4=pyhd3eb1b0_0 | |||
| comm=0.1.2=py38h06a4308_0 | |||
| contourpy=1.0.5=py38hdb19cb5_0 | |||
| cryptography=38.0.1=py38h9ce1e76_0 | |||
| cudatoolkit=10.1.243=h8cb64d8_10 | |||
| cudnn=7.6.5.32=hc0a50b0_1 | |||
| cycler=0.11.0=pyhd3eb1b0_0 | |||
| dbus=1.13.18=hb2f20db_0 | |||
| debugpy=1.5.1=py38h295c915_0 | |||
| decorator=5.1.1=pyhd3eb1b0_0 | |||
| defusedxml=0.7.1=pyhd3eb1b0_0 | |||
| entrypoints=0.4=py38h06a4308_0 | |||
| executing=0.8.3=pyhd3eb1b0_0 | |||
| expat=2.4.9=h6a678d5_0 | |||
| filelock=3.9.0=pypi_0 | |||
| flit-core=3.6.0=pyhd3eb1b0_0 | |||
| fontconfig=2.14.1=h52c9d5c_1 | |||
| fonttools=4.25.0=pyhd3eb1b0_0 | |||
| freetype=2.12.1=h4a9f257_0 | |||
| gdown=4.6.0=pypi_0 | |||
| giflib=5.2.1=h7b6447c_0 | |||
| glib=2.69.1=he621ea3_2 | |||
| gst-plugins-base=1.14.0=h8213a91_2 | |||
| gstreamer=1.14.0=h28cd5cc_2 | |||
| h5py=3.2.1=pypi_0 | |||
| icu=58.2=he6710b0_3 | |||
| idna=3.4=py38h06a4308_0 | |||
| importlib-metadata=4.11.3=py38h06a4308_0 | |||
| importlib_resources=5.2.0=pyhd3eb1b0_1 | |||
| intel-openmp=2021.4.0=h06a4308_3561 | |||
| ipykernel=6.19.2=py38hb070fc8_0 | |||
| ipython=8.7.0=py38h06a4308_0 | |||
| ipython_genutils=0.2.0=pyhd3eb1b0_1 | |||
| ipywidgets=7.6.5=pyhd3eb1b0_1 | |||
| jedi=0.18.1=py38h06a4308_1 | |||
| jinja2=3.1.2=py38h06a4308_0 | |||
| joblib=1.2.0=pypi_0 | |||
| jpeg=9e=h7f8727e_0 | |||
| json5=0.9.6=pyhd3eb1b0_0 | |||
| jsonschema=4.16.0=py38h06a4308_0 | |||
| jupyter=1.0.0=py38h06a4308_8 | |||
| jupyter_client=7.4.8=py38h06a4308_0 | |||
| jupyter_console=6.4.4=py38h06a4308_0 | |||
| jupyter_core=5.1.1=py38h06a4308_0 | |||
| jupyter_server=1.23.4=py38h06a4308_0 | |||
| jupyterlab=3.5.2=py38h06a4308_0 | |||
| jupyterlab_pygments=0.1.2=py_0 | |||
| jupyterlab_server=2.16.5=py38h06a4308_0 | |||
| jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 | |||
| kiwisolver=1.4.4=py38h6a678d5_0 | |||
| krb5=1.19.2=hac12032_0 | |||
| lcms2=2.12=h3be6417_0 | |||
| ld_impl_linux-64=2.38=h1181459_1 | |||
| lerc=3.0=h295c915_0 | |||
| libbrotlicommon=1.0.9=h5eee18b_7 | |||
| libbrotlidec=1.0.9=h5eee18b_7 | |||
| libbrotlienc=1.0.9=h5eee18b_7 | |||
| libclang=10.0.1=default_hb85057a_2 | |||
| libdeflate=1.8=h7f8727e_5 | |||
| libedit=3.1.20221030=h5eee18b_0 | |||
| libevent=2.1.12=h8f2d780_0 | |||
| libffi=3.4.2=h6a678d5_6 | |||
| libgcc-ng=11.2.0=h1234567_1 | |||
| libgomp=11.2.0=h1234567_1 | |||
| libllvm10=10.0.1=hbcb73fb_5 | |||
| libpng=1.6.37=hbc83047_0 | |||
| libpq=12.9=h16c4e8d_3 | |||
| libsodium=1.0.18=h7b6447c_0 | |||
| libstdcxx-ng=11.2.0=h1234567_1 | |||
| libtiff=4.4.0=hecacb30_2 | |||
| libuuid=1.41.5=h5eee18b_0 | |||
| libuv=1.40.0=h7b6447c_0 | |||
| libwebp=1.2.4=h11a3e52_0 | |||
| libwebp-base=1.2.4=h5eee18b_0 | |||
| libxcb=1.15=h7f8727e_0 | |||
| libxkbcommon=1.0.1=hfa300c1_0 | |||
| libxml2=2.9.14=h74e7548_0 | |||
| libxslt=1.1.35=h4e12654_0 | |||
| llvmlite=0.39.1=pypi_0 | |||
| lxml=4.9.1=py38h1edc446_0 | |||
| lz4-c=1.9.4=h6a678d5_0 | |||
| markupsafe=2.1.1=py38h7f8727e_0 | |||
| matplotlib=3.4.2=pypi_0 | |||
| matplotlib-inline=0.1.6=py38h06a4308_0 | |||
| mistune=0.8.4=py38h7b6447c_1000 | |||
| mkl=2021.4.0=h06a4308_640 | |||
| mkl-service=2.4.0=py38h7f8727e_0 | |||
| mkl_fft=1.3.1=py38hd3c417c_0 | |||
| mkl_random=1.2.2=py38h51133e4_0 | |||
| munkres=1.1.4=py_0 | |||
| nbclassic=0.4.8=py38h06a4308_0 | |||
| nbclient=0.5.13=py38h06a4308_0 | |||
| nbconvert=6.5.4=py38h06a4308_0 | |||
| nbformat=5.7.0=py38h06a4308_0 | |||
| ncurses=6.3=h5eee18b_3 | |||
| nest-asyncio=1.5.6=py38h06a4308_0 | |||
| ninja=1.10.2=h06a4308_5 | |||
| ninja-base=1.10.2=hd09550d_5 | |||
| notebook=6.5.2=py38h06a4308_0 | |||
| notebook-shim=0.2.2=py38h06a4308_0 | |||
| nspr=4.33=h295c915_0 | |||
| nss=3.74=h0370c37_0 | |||
| numba=0.56.4=pypi_0 | |||
| numpy=1.22.4=pypi_0 | |||
| opencv-python=4.5.2.54=pypi_0 | |||
| openssl=1.1.1s=h7f8727e_0 | |||
| packaging=22.0=py38h06a4308_0 | |||
| pandas=1.5.2=pypi_0 | |||
| pandocfilters=1.5.0=pyhd3eb1b0_0 | |||
| parso=0.8.3=pyhd3eb1b0_0 | |||
| pcre=8.45=h295c915_0 | |||
| pexpect=4.8.0=pyhd3eb1b0_3 | |||
| pickleshare=0.7.5=pyhd3eb1b0_1003 | |||
| pillow=8.4.0=pypi_0 | |||
| pip=22.3.1=py38h06a4308_0 | |||
| pkgutil-resolve-name=1.3.10=py38h06a4308_0 | |||
| platformdirs=2.5.2=py38h06a4308_0 | |||
| ply=3.11=py38_0 | |||
| prometheus_client=0.14.1=py38h06a4308_0 | |||
| prompt-toolkit=3.0.36=py38h06a4308_0 | |||
| prompt_toolkit=3.0.36=hd3eb1b0_0 | |||
| psutil=5.9.0=py38h5eee18b_0 | |||
| ptyprocess=0.7.0=pyhd3eb1b0_2 | |||
| pure_eval=0.2.2=pyhd3eb1b0_0 | |||
| pycparser=2.21=pyhd3eb1b0_0 | |||
| pygments=2.11.2=pyhd3eb1b0_0 | |||
| pynndescent=0.5.8=pypi_0 | |||
| pyopenssl=22.0.0=pyhd3eb1b0_0 | |||
| pyparsing=3.0.9=py38h06a4308_0 | |||
| pyqt=5.15.7=py38h6a678d5_1 | |||
| pyqt5-sip=12.11.0=py38h6a678d5_1 | |||
| pyrsistent=0.18.0=py38heee7806_0 | |||
| pysocks=1.7.1=py38h06a4308_0 | |||
| python=3.8.15=h7a1cb2a_2 | |||
| python-dateutil=2.8.2=pyhd3eb1b0_0 | |||
| python-fastjsonschema=2.16.2=py38h06a4308_0 | |||
| pytorch-adapt=0.0.82=pypi_0 | |||
| pytorch-ignite=0.4.9=pypi_0 | |||
| pytorch-metric-learning=1.6.3=pypi_0 | |||
| pytz=2022.7=py38h06a4308_0 | |||
| pyyaml=6.0=pypi_0 | |||
| pyzmq=23.2.0=py38h6a678d5_0 | |||
| qt-main=5.15.2=h327a75a_7 | |||
| qt-webengine=5.15.9=hd2b0992_4 | |||
| qtconsole=5.3.2=py38h06a4308_0 | |||
| qtpy=2.2.0=py38h06a4308_0 | |||
| qtwebkit=5.212=h4eab89a_4 | |||
| readline=8.2=h5eee18b_0 | |||
| requests=2.28.1=py38h06a4308_0 | |||
| scikit-learn=1.0=pypi_0 | |||
| scipy=1.6.3=pypi_0 | |||
| seaborn=0.12.2=pypi_0 | |||
| send2trash=1.8.0=pyhd3eb1b0_1 | |||
| setuptools=65.5.0=py38h06a4308_0 | |||
| sip=6.6.2=py38h6a678d5_0 | |||
| six=1.16.0=pyhd3eb1b0_1 | |||
| sniffio=1.2.0=py38h06a4308_1 | |||
| soupsieve=2.3.2.post1=py38h06a4308_0 | |||
| sqlite=3.40.0=h5082296_0 | |||
| stack_data=0.2.0=pyhd3eb1b0_0 | |||
| terminado=0.17.1=py38h06a4308_0 | |||
| threadpoolctl=3.1.0=pypi_0 | |||
| timm=0.4.9=pypi_0 | |||
| tinycss2=1.2.1=py38h06a4308_0 | |||
| tk=8.6.12=h1ccaba5_0 | |||
| toml=0.10.2=pyhd3eb1b0_0 | |||
| tomli=2.0.1=py38h06a4308_0 | |||
| torch=1.8.1+cu101=pypi_0 | |||
| torchaudio=0.8.1=pypi_0 | |||
| torchmetrics=0.9.3=pypi_0 | |||
| torchvision=0.9.1+cu101=pypi_0 | |||
| tornado=6.2=py38h5eee18b_0 | |||
| tqdm=4.61.2=pypi_0 | |||
| traitlets=5.7.1=py38h06a4308_0 | |||
| typing-extensions=4.4.0=py38h06a4308_0 | |||
| typing_extensions=4.4.0=py38h06a4308_0 | |||
| umap-learn=0.5.3=pypi_0 | |||
| urllib3=1.26.13=py38h06a4308_0 | |||
| wcwidth=0.2.5=pyhd3eb1b0_0 | |||
| webencodings=0.5.1=py38_1 | |||
| websocket-client=0.58.0=py38h06a4308_4 | |||
| wheel=0.37.1=pyhd3eb1b0_0 | |||
| widgetsnbextension=3.5.2=py38h06a4308_0 | |||
| xz=5.2.8=h5eee18b_0 | |||
| yacs=0.1.8=pypi_0 | |||
| zeromq=4.3.4=h2531618_0 | |||
| zipp=3.11.0=py38h06a4308_0 | |||
| zlib=1.2.13=h5eee18b_0 | |||
| zstd=1.5.2=ha4553b6_0 | |||
| @@ -0,0 +1,134 @@ | |||
| anyio @ file:///tmp/build/80754af9/anyio_1644481698350/work/dist | |||
| argon2-cffi @ file:///opt/conda/conda-bld/argon2-cffi_1645000214183/work | |||
| argon2-cffi-bindings @ file:///tmp/build/80754af9/argon2-cffi-bindings_1644569684262/work | |||
| asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work | |||
| attrs @ file:///croot/attrs_1668696182826/work | |||
| Babel @ file:///croot/babel_1671781930836/work | |||
| backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work | |||
| beautifulsoup4 @ file:///opt/conda/conda-bld/beautifulsoup4_1650462163268/work | |||
| bleach @ file:///opt/conda/conda-bld/bleach_1641577558959/work | |||
| brotlipy==0.7.0 | |||
| certifi @ file:///croot/certifi_1671487769961/work/certifi | |||
| cffi @ file:///croot/cffi_1670423208954/work | |||
| charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work | |||
| comm @ file:///croot/comm_1671231121260/work | |||
| contourpy @ file:///opt/conda/conda-bld/contourpy_1663827406301/work | |||
| cryptography @ file:///croot/cryptography_1665612644927/work | |||
| cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work | |||
| debugpy @ file:///tmp/build/80754af9/debugpy_1637091796427/work | |||
| decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work | |||
| defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work | |||
| entrypoints @ file:///tmp/build/80754af9/entrypoints_1649926445639/work | |||
| executing @ file:///opt/conda/conda-bld/executing_1646925071911/work | |||
| fastjsonschema @ file:///opt/conda/conda-bld/python-fastjsonschema_1661371079312/work | |||
| filelock==3.9.0 | |||
| flit_core @ file:///opt/conda/conda-bld/flit-core_1644941570762/work/source/flit_core | |||
| fonttools==4.25.0 | |||
| gdown==4.6.0 | |||
| h5py==3.2.1 | |||
| idna @ file:///croot/idna_1666125576474/work | |||
| importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1648562408398/work | |||
| importlib-resources @ file:///tmp/build/80754af9/importlib_resources_1625135880749/work | |||
| ipykernel @ file:///croot/ipykernel_1671488378391/work | |||
| ipython @ file:///croot/ipython_1670919316550/work | |||
| ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work | |||
| ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1634143127070/work | |||
| jedi @ file:///tmp/build/80754af9/jedi_1644315233700/work | |||
| Jinja2 @ file:///croot/jinja2_1666908132255/work | |||
| joblib==1.2.0 | |||
| json5 @ file:///tmp/build/80754af9/json5_1624432770122/work | |||
| jsonschema @ file:///opt/conda/conda-bld/jsonschema_1663375472438/work | |||
| jupyter @ file:///tmp/abs_33h4eoipez/croots/recipe/jupyter_1659349046347/work | |||
| jupyter-console @ file:///croot/jupyter_console_1671541909316/work | |||
| jupyter-server @ file:///croot/jupyter_server_1671707632269/work | |||
| jupyter_client @ file:///croot/jupyter_client_1671703053786/work | |||
| jupyter_core @ file:///croot/jupyter_core_1672332224593/work | |||
| jupyterlab @ file:///croot/jupyterlab_1672132689850/work | |||
| jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work | |||
| jupyterlab-widgets @ file:///tmp/build/80754af9/jupyterlab_widgets_1609884341231/work | |||
| jupyterlab_server @ file:///croot/jupyterlab_server_1672127357747/work | |||
| kiwisolver @ file:///croot/kiwisolver_1672387140495/work | |||
| llvmlite==0.39.1 | |||
| lxml @ file:///opt/conda/conda-bld/lxml_1657545139709/work | |||
| MarkupSafe @ file:///opt/conda/conda-bld/markupsafe_1654597864307/work | |||
| matplotlib==3.4.2 | |||
| matplotlib-inline @ file:///opt/conda/conda-bld/matplotlib-inline_1662014470464/work | |||
| mistune==0.8.4 | |||
| mkl-fft==1.3.1 | |||
| mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186064646/work | |||
| mkl-service==2.4.0 | |||
| munkres==1.1.4 | |||
| nbclassic @ file:///croot/nbclassic_1668174957779/work | |||
| nbclient @ file:///tmp/build/80754af9/nbclient_1650308366712/work | |||
| nbconvert @ file:///croot/nbconvert_1668450669124/work | |||
| nbformat @ file:///croot/nbformat_1670352325207/work | |||
| nest-asyncio @ file:///croot/nest-asyncio_1672387112409/work | |||
| notebook @ file:///croot/notebook_1668179881751/work | |||
| notebook_shim @ file:///croot/notebook-shim_1668160579331/work | |||
| numba==0.56.4 | |||
| numpy==1.22.4 | |||
| opencv-python==4.5.2.54 | |||
| packaging @ file:///croot/packaging_1671697413597/work | |||
| pandas==1.5.2 | |||
| pandocfilters @ file:///opt/conda/conda-bld/pandocfilters_1643405455980/work | |||
| parso @ file:///opt/conda/conda-bld/parso_1641458642106/work | |||
| pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work | |||
| pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work | |||
| Pillow==8.4.0 | |||
| pkgutil_resolve_name @ file:///opt/conda/conda-bld/pkgutil-resolve-name_1661463321198/work | |||
| platformdirs @ file:///opt/conda/conda-bld/platformdirs_1662711380096/work | |||
| ply==3.11 | |||
| prometheus-client @ file:///tmp/abs_d3zeliano1/croots/recipe/prometheus_client_1659455100375/work | |||
| prompt-toolkit @ file:///croot/prompt-toolkit_1672387306916/work | |||
| psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work | |||
| ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl | |||
| pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work | |||
| pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work | |||
| Pygments @ file:///opt/conda/conda-bld/pygments_1644249106324/work | |||
| pynndescent==0.5.8 | |||
| pyOpenSSL @ file:///opt/conda/conda-bld/pyopenssl_1643788558760/work | |||
| pyparsing @ file:///opt/conda/conda-bld/pyparsing_1661452539315/work | |||
| PyQt5-sip==12.11.0 | |||
| pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1636110947380/work | |||
| PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work | |||
| python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work | |||
| pytorch-adapt==0.0.82 | |||
| pytorch-ignite==0.4.9 | |||
| pytorch-metric-learning==1.6.3 | |||
| pytz @ file:///croot/pytz_1671697431263/work | |||
| PyYAML==6.0 | |||
| pyzmq @ file:///opt/conda/conda-bld/pyzmq_1657724186960/work | |||
| qtconsole @ file:///opt/conda/conda-bld/qtconsole_1662018252641/work | |||
| QtPy @ file:///opt/conda/conda-bld/qtpy_1662014892439/work | |||
| requests @ file:///opt/conda/conda-bld/requests_1657734628632/work | |||
| scikit-learn==1.0 | |||
| scipy==1.6.3 | |||
| seaborn==0.12.2 | |||
| Send2Trash @ file:///tmp/build/80754af9/send2trash_1632406701022/work | |||
| sip @ file:///tmp/abs_44cd77b_pu/croots/recipe/sip_1659012365470/work | |||
| six @ file:///tmp/build/80754af9/six_1644875935023/work | |||
| sniffio @ file:///tmp/build/80754af9/sniffio_1614030475067/work | |||
| soupsieve @ file:///croot/soupsieve_1666296392845/work | |||
| stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work | |||
| terminado @ file:///croot/terminado_1671751832461/work | |||
| threadpoolctl==3.1.0 | |||
| timm==0.4.9 | |||
| tinycss2 @ file:///croot/tinycss2_1668168815555/work | |||
| toml @ file:///tmp/build/80754af9/toml_1616166611790/work | |||
| tomli @ file:///opt/conda/conda-bld/tomli_1657175507142/work | |||
| torch==1.8.1+cu101 | |||
| torchaudio==0.8.1 | |||
| torchmetrics==0.9.3 | |||
| torchvision==0.9.1+cu101 | |||
| tornado @ file:///opt/conda/conda-bld/tornado_1662061693373/work | |||
| tqdm==4.61.2 | |||
| traitlets @ file:///croot/traitlets_1671143879854/work | |||
| typing_extensions @ file:///croot/typing_extensions_1669924550328/work | |||
| umap-learn==0.5.3 | |||
| urllib3 @ file:///croot/urllib3_1670526988650/work | |||
| wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work | |||
| webencodings==0.5.1 | |||
| websocket-client @ file:///tmp/build/80754af9/websocket-client_1614804261064/work | |||
| widgetsnbextension @ file:///tmp/build/80754af9/widgetsnbextension_1645009353553/work | |||
| yacs==0.1.8 | |||
| zipp @ file:///croot/zipp_1672387121353/work | |||
| @@ -0,0 +1,21 @@ | |||
| from pytorch_adapt.adapters.base_adapter import BaseGCAdapter | |||
| from pytorch_adapt.adapters.utils import with_opt | |||
| from pytorch_adapt.hooks import ClassifierHook | |||
| class ClassifierAdapter(BaseGCAdapter): | |||
| """ | |||
| Wraps [AlignerPlusCHook][pytorch_adapt.hooks.AlignerPlusCHook]. | |||
| |Container|Required keys| | |||
| |---|---| | |||
| |models|```["G", "C"]```| | |||
| |optimizers|```["G", "C"]```| | |||
| """ | |||
| def init_hook(self, hook_kwargs): | |||
| opts = with_opt(list(self.optimizers.keys())) | |||
| self.hook = self.hook_cls(opts, **hook_kwargs) | |||
| @property | |||
| def hook_cls(self): | |||
| return ClassifierHook | |||
| @@ -0,0 +1,67 @@ | |||
| import torch | |||
| import os | |||
| from pytorch_adapt.adapters import DANN, MCD, VADA, CDAN, RTN, ADDA, Aligner, SymNets | |||
| from pytorch_adapt.containers import Models, Optimizers, LRSchedulers | |||
| from pytorch_adapt.models import Discriminator, office31C, office31G | |||
| from pytorch_adapt.containers import Misc | |||
| from pytorch_adapt.containers import LRSchedulers | |||
| from classifier_adapter import ClassifierAdapter | |||
| from utils import HP, DAModels | |||
| import copy | |||
| import matplotlib.pyplot as plt | |||
| import torch | |||
| import os | |||
| import gc | |||
| from datetime import datetime | |||
| from pytorch_adapt.datasets import DataloaderCreator, get_office31 | |||
| from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite | |||
| from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory | |||
| from pytorch_adapt.frameworks.ignite import ( | |||
| CheckpointFnCreator, | |||
| IgniteValHookWrapper, | |||
| checkpoint_utils, | |||
| ) | |||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
| def get_source_trainer(checkpoint_dir): | |||
| G = office31G(pretrained=False).to(device) | |||
| C = office31C(pretrained=False).to(device) | |||
| optimizers = Optimizers((torch.optim.Adam, {"lr": 1e-4})) | |||
| lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {"gamma": 0.99})) | |||
| models = Models({"G": G, "C": C}) | |||
| adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers) | |||
| checkpoint_fn = CheckpointFnCreator(dirname=checkpoint_dir, require_empty=False) | |||
| sourceAccuracyValidator = AccuracyValidator() | |||
| val_hooks = [ScoreHistory(sourceAccuracyValidator)] | |||
| new_trainer = Ignite( | |||
| adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn, device=device | |||
| ) | |||
| objs = [ | |||
| { | |||
| "engine": new_trainer.trainer, | |||
| **checkpoint_utils.adapter_to_dict(new_trainer.adapter), | |||
| } | |||
| ] | |||
| for to_load in objs: | |||
| checkpoint_fn.load_best_checkpoint(to_load) | |||
| return new_trainer | |||
| @@ -0,0 +1,172 @@ | |||
| import argparse | |||
| import logging | |||
| import os | |||
| from datetime import datetime | |||
| from train import train | |||
| from train_source import train_source | |||
| from utils import HP, DAModels | |||
| import tracemalloc | |||
| logging.basicConfig() | |||
| logging.getLogger("pytorch-adapt").setLevel(logging.WARNING) | |||
| DATASET_PAIRS = [("amazon", "webcam"), ("amazon", "dslr"), | |||
| ("webcam", "amazon"), ("webcam", "dslr"), | |||
| ("dslr", "amazon"), ("dslr", "webcam")] | |||
| def run_experiment_on_model(args, model_name, trial_number): | |||
| train_fn = train | |||
| if model_name == DAModels.SOURCE: | |||
| train_fn = train_source | |||
| base_output_dir = f"{args.results_root}/{model_name}/{trial_number}" | |||
| os.makedirs(base_output_dir, exist_ok=True) | |||
| hp_map = { | |||
| 'DANN': { | |||
| 'a2d': (5e-05, 0.99), | |||
| 'a2w': (5e-05, 0.99), | |||
| 'd2a': (0.0001, 0.9), | |||
| 'd2w': (5e-05, 0.99), | |||
| 'w2a': (0.0001, 0.99), | |||
| 'w2d': (0.0001, 0.99), | |||
| }, | |||
| 'CDAN': { | |||
| 'a2d': (1e-05, 1), | |||
| 'a2w': (1e-05, 1), | |||
| 'd2a': (1e-05, 0.99), | |||
| 'd2w': (1e-05, 0.99), | |||
| 'w2a': (0.0001, 0.99), | |||
| 'w2d': (5e-05, 0.99), | |||
| }, | |||
| 'MMD': { | |||
| 'a2d': (5e-05, 1), | |||
| 'a2w': (5e-05, 0.99), | |||
| 'd2a': (0.0001, 0.99), | |||
| 'd2w': (5e-05, 0.9), | |||
| 'w2a': (0.0001, 0.99), | |||
| 'w2d': (1e-05, 0.99), | |||
| }, | |||
| 'MCD': { | |||
| 'a2d': (1e-05, 0.9), | |||
| 'a2w': (0.0001, 1), | |||
| 'd2a': (1e-05, 0.9), | |||
| 'd2w': (1e-05, 0.99), | |||
| 'w2a': (1e-05, 0.9), | |||
| 'w2d': (5*1e-6, 0.99), | |||
| }, | |||
| 'CORAL': { | |||
| 'a2d': (1e-05, 0.99), | |||
| 'a2w': (1e-05, 1), | |||
| 'd2a': (5*1e-6, 0.99), | |||
| 'd2w': (0.0001, 0.99), | |||
| 'w2a': (1e-5, 0.99), | |||
| 'w2d': (0.0001, 0.99), | |||
| }, | |||
| } | |||
| if not args.hp_tune: | |||
| d = datetime.now() | |||
| results_file = f"{base_output_dir}/e{args.max_epochs}_p{args.patience}_{d.strftime('%Y%m%d-%H:%M:%S')}.txt" | |||
| with open(results_file, "w") as myfile: | |||
| myfile.write("pair, source_acc, target_acc, best_epoch, best_score, time, cur, peak, lr, gamma\n") | |||
| for source_domain, target_domain in DATASET_PAIRS: | |||
| pair_name = f"{source_domain[0]}2{target_domain[0]}" | |||
| hp_parmas = hp_map[DAModels.CDAN.name][pair_name] | |||
| lr = args.lr if args.lr else hp_parmas[0] | |||
| gamma = args.gamma if args.gamma else hp_parmas[1] | |||
| hp = HP(lr=lr, gamma=gamma) | |||
| tracemalloc.start() | |||
| train_fn(args, model_name, hp, base_output_dir, results_file, source_domain, target_domain) | |||
| cur, peak = tracemalloc.get_traced_memory() | |||
| tracemalloc.stop() | |||
| with open(results_file, "a") as myfile: | |||
| myfile.write(f"{cur}, {peak}\n") | |||
| else: | |||
| # gamma_list = [1, 0.99, 0.9, 0.8] | |||
| # lr_list = [1e-5, 3*1e-5, 1e-4, 3*1e-4] | |||
| hp_values = { | |||
| 1e-4: [0.99, 0.9], | |||
| 1e-5: [0.99, 0.9], | |||
| 5*1e-5: [0.99, 0.9], | |||
| 5*1e-6: [0.99] | |||
| # 5*1e-4: [1, 0.99], | |||
| # 1e-3: [0.99, 0.9, 0.8], | |||
| } | |||
| d = datetime.now() | |||
| hp_file = f"{base_output_dir}/hp_e{args.max_epochs}_p{args.patience}_{d.strftime('%Y%m%d-%H:%M:%S')}.txt" | |||
| with open(hp_file, "w") as myfile: | |||
| myfile.write("lr, gamma, pair, source_acc, target_acc, best_epoch, best_score\n") | |||
| hp_best = None | |||
| hp_best_score = None | |||
| for lr, gamma_list in hp_values.items(): | |||
| # for lr in lr_list: | |||
| for gamma in gamma_list: | |||
| hp = HP(lr=lr, gamma=gamma) | |||
| print("HP:", hp) | |||
| for source_domain, target_domain in DATASET_PAIRS: | |||
| _, _, best_score = \ | |||
| train_fn(args, model_name, hp, base_output_dir, hp_file, source_domain, target_domain) | |||
| if best_score is not None and (hp_best_score is None or hp_best_score < best_score): | |||
| hp_best = hp | |||
| hp_best_score = best_score | |||
| with open(hp_file, "a") as myfile: | |||
| myfile.write(f"\nbest: {hp_best.lr}, {hp_best.gamma}\n") | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--max_epochs', default=60, type=int) | |||
| parser.add_argument('--patience', default=10, type=int) | |||
| parser.add_argument('--batch_size', default=32, type=int) | |||
| parser.add_argument('--num_workers', default=2, type=int) | |||
| parser.add_argument('--trials_count', default=3, type=int) | |||
| parser.add_argument('--initial_trial', default=0, type=int) | |||
| parser.add_argument('--download', default=False, type=bool) | |||
| parser.add_argument('--root', default="../") | |||
| parser.add_argument('--data_root', default="../datasets/pytorch-adapt/") | |||
| parser.add_argument('--results_root', default="../results/") | |||
| parser.add_argument('--model_names', default=["DANN"], nargs='+') | |||
| parser.add_argument('--lr', default=None, type=float) | |||
| parser.add_argument('--gamma', default=None, type=float) | |||
| parser.add_argument('--hp_tune', default=False, type=bool) | |||
| parser.add_argument('--source', default=None) | |||
| parser.add_argument('--target', default=None) | |||
| parser.add_argument('--vishook_frequency', default=5, type=int) | |||
| parser.add_argument('--source_checkpoint_base_dir', default=None) # default='../results/DAModels.SOURCE/' | |||
| parser.add_argument('--source_checkpoint_trial_number', default=-1, type=int) | |||
| args = parser.parse_args() | |||
| print(args) | |||
| for trial_number in range(args.initial_trial, args.initial_trial + args.trials_count): | |||
| if args.source_checkpoint_trial_number == -1: | |||
| args.source_checkpoint_trial_number = trial_number | |||
| for model_name in args.model_names: | |||
| try: | |||
| model_enum = DAModels(model_name) | |||
| except ValueError: | |||
| logging.warning(f"Model {model_name} not found. skipping...") | |||
| continue | |||
| run_experiment_on_model(args, model_enum, trial_number) | |||
| @@ -0,0 +1,85 @@ | |||
| import torch | |||
| import os | |||
| from pytorch_adapt.adapters import DANN, MCD, VADA, CDAN, RTN, ADDA, Aligner, SymNets | |||
| from pytorch_adapt.containers import Models, Optimizers, LRSchedulers | |||
| from pytorch_adapt.models import Discriminator, office31C, office31G | |||
| from pytorch_adapt.containers import Misc | |||
| from pytorch_adapt.layers import RandomizedDotProduct | |||
| from pytorch_adapt.layers import MultipleModels, CORALLoss, MMDLoss | |||
| from pytorch_adapt.utils import common_functions | |||
| from pytorch_adapt.containers import LRSchedulers | |||
| from classifier_adapter import ClassifierAdapter | |||
| from load_source import get_source_trainer | |||
| from utils import HP, DAModels | |||
| import copy | |||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||
| def get_model(model_name, hp: HP, source_checkpoint_dir, data_root, source_domain): | |||
| if source_checkpoint_dir: | |||
| source_trainer = get_source_trainer(source_checkpoint_dir) | |||
| G = copy.deepcopy(source_trainer.adapter.models["G"]) | |||
| C = copy.deepcopy(source_trainer.adapter.models["C"]) | |||
| D = Discriminator(in_size=2048, h=1024).to(device) | |||
| else: | |||
| weights_root = os.path.join(data_root, "weights") | |||
| G = office31G(pretrained=True, model_dir=weights_root).to(device) | |||
| C = office31C(domain=source_domain, pretrained=True, | |||
| model_dir=weights_root).to(device) | |||
| D = Discriminator(in_size=2048, h=1024).to(device) | |||
| optimizers = Optimizers((torch.optim.Adam, {"lr": hp.lr})) | |||
| lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {"gamma": hp.gamma})) | |||
| if model_name == DAModels.DANN: | |||
| models = Models({"G": G, "C": C, "D": D}) | |||
| adapter = DANN(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers) | |||
| elif model_name == DAModels.CDAN: | |||
| models = Models({"G": G, "C": C, "D": D}) | |||
| misc = Misc({"feature_combiner": RandomizedDotProduct([2048, 31], 2048)}) | |||
| adapter = CDAN(models=models, misc=misc, optimizers=optimizers, lr_schedulers=lr_schedulers) | |||
| elif model_name == DAModels.MCD: | |||
| C1 = common_functions.reinit(copy.deepcopy(C)) | |||
| C_combined = MultipleModels(C, C1) | |||
| models = Models({"G": G, "C": C_combined}) | |||
| adapter= MCD(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers) | |||
| elif model_name == DAModels.SYMNET: | |||
| C1 = common_functions.reinit(copy.deepcopy(C)) | |||
| C_combined = MultipleModels(C, C1) | |||
| models = Models({"G": G, "C": C_combined}) | |||
| adapter= SymNets(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers) | |||
| elif model_name == DAModels.MMD: | |||
| models = Models({"G": G, "C": C}) | |||
| hook_kwargs = {"loss_fn": MMDLoss()} | |||
| adapter= Aligner(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers, hook_kwargs=hook_kwargs) | |||
| elif model_name == DAModels.CORAL: | |||
| models = Models({"G": G, "C": C}) | |||
| hook_kwargs = {"loss_fn": CORALLoss()} | |||
| adapter= Aligner(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers, hook_kwargs=hook_kwargs) | |||
| elif model_name == DAModels.SOURCE: | |||
| models = Models({"G": G, "C": C}) | |||
| adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers) | |||
| return adapter | |||
| @@ -0,0 +1,167 @@ | |||
| import logging | |||
| import matplotlib.pyplot as plt | |||
| import pandas as pd | |||
| import seaborn as sns | |||
| import torch | |||
| import umap | |||
| from tqdm import tqdm | |||
| import os | |||
| import gc | |||
| from datetime import datetime | |||
| from pytorch_adapt.adapters import DANN, MCD, VADA, CDAN, RTN, ADDA, Aligner | |||
| from pytorch_adapt.adapters.base_adapter import BaseGCAdapter | |||
| from pytorch_adapt.adapters.utils import with_opt | |||
| from pytorch_adapt.hooks import ClassifierHook | |||
| from pytorch_adapt.containers import Models, Optimizers, LRSchedulers | |||
| from pytorch_adapt.datasets import DataloaderCreator, get_mnist_mnistm, get_office31 | |||
| from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite | |||
| from pytorch_adapt.models import Discriminator, mnistC, mnistG, office31C, office31G | |||
| from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory | |||
| from pytorch_adapt.containers import Misc | |||
| from pytorch_adapt.layers import RandomizedDotProduct | |||
| from pytorch_adapt.layers import MultipleModels | |||
| from pytorch_adapt.utils import common_functions | |||
| from pytorch_adapt.containers import LRSchedulers | |||
| import copy | |||
| from pprint import pprint | |||
| PATIENCE = 5 | |||
| EPOCHS = 50 | |||
| BATCH_SIZE = 32 | |||
| NUM_WORKERS = 2 | |||
| TRIAL_COUNT = 5 | |||
| logging.basicConfig() | |||
| logging.getLogger("pytorch-adapt").setLevel(logging.WARNING) | |||
| class ClassifierAdapter(BaseGCAdapter): | |||
| """ | |||
| Wraps [AlignerPlusCHook][pytorch_adapt.hooks.AlignerPlusCHook]. | |||
| |Container|Required keys| | |||
| |---|---| | |||
| |models|```["G", "C"]```| | |||
| |optimizers|```["G", "C"]```| | |||
| """ | |||
| def init_hook(self, hook_kwargs): | |||
| opts = with_opt(list(self.optimizers.keys())) | |||
| self.hook = self.hook_cls(opts, **hook_kwargs) | |||
| @property | |||
| def hook_cls(self): | |||
| return ClassifierHook | |||
| root='/content/drive/MyDrive/Shared with Sabas/Bsc/' | |||
| # root="datasets/pytorch-adapt/" | |||
| data_root = os.path.join(root,'data') | |||
| batch_size=BATCH_SIZE | |||
| num_workers=NUM_WORKERS | |||
| device = torch.device("cuda") | |||
| model_dir = os.path.join(data_root, "weights") | |||
| DATASET_PAIRS = [("amazon", ["webcam", "dslr"]), | |||
| ("webcam", ["dslr", "amazon"]), | |||
| ("dslr", ["amazon", "webcam"]) | |||
| ] | |||
| MODEL_NAME = "base" | |||
| model_name = MODEL_NAME | |||
| pass_next= 0 | |||
| pass_trial = 0 | |||
| for trial_number in range(10, 10 + TRIAL_COUNT): | |||
| if pass_trial: | |||
| pass_trial -= 1 | |||
| continue | |||
| base_output_dir = f"{root}/results/vishook/{MODEL_NAME}/{trial_number}" | |||
| os.makedirs(base_output_dir, exist_ok=True) | |||
| d = datetime.now() | |||
| results_file = f"{base_output_dir}/{d.strftime('%Y%m%d-%H:%M:%S')}.txt" | |||
| with open(results_file, "w") as myfile: | |||
| myfile.write("pair, source_acc, target_acc, best_epoch, time\n") | |||
| for source_domain, target_domains in DATASET_PAIRS: | |||
| datasets = get_office31([source_domain], [], folder=data_root) | |||
| dc = DataloaderCreator(batch_size=batch_size, | |||
| num_workers=num_workers, | |||
| ) | |||
| dataloaders = dc(**datasets) | |||
| G = office31G(pretrained=True, model_dir=model_dir) | |||
| C = office31C(domain=source_domain, pretrained=True, model_dir=model_dir) | |||
| optimizers = Optimizers((torch.optim.Adam, {"lr": 0.0005})) | |||
| lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {"gamma": 0.99})) | |||
| if model_name == "base": | |||
| models = Models({"G": G, "C": C}) | |||
| adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers) | |||
| checkpoint_fn = CheckpointFnCreator(dirname="saved_models", require_empty=False) | |||
| val_hooks = [ScoreHistory(AccuracyValidator())] | |||
| trainer = Ignite( | |||
| adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn, | |||
| ) | |||
| early_stopper_kwargs = {"patience": PATIENCE} | |||
| start_time = datetime.now() | |||
| best_score, best_epoch = trainer.run( | |||
| datasets, dataloader_creator=dc, max_epochs=EPOCHS, early_stopper_kwargs=early_stopper_kwargs | |||
| ) | |||
| end_time = datetime.now() | |||
| training_time = end_time - start_time | |||
| for target_domain in target_domains: | |||
| if pass_next: | |||
| pass_next -= 1 | |||
| continue | |||
| pair_name = f"{source_domain[0]}2{target_domain[0]}" | |||
| output_dir = os.path.join(base_output_dir, pair_name) | |||
| os.makedirs(output_dir, exist_ok=True) | |||
| print("output dir:", output_dir) | |||
| # print(f"best_score={best_score}, best_epoch={best_epoch}, training_time={training_time.seconds}") | |||
| plt.plot(val_hooks[0].score_history, label='source') | |||
| plt.title("val accuracy") | |||
| plt.legend() | |||
| plt.savefig(f"{output_dir}/val_accuracy.png") | |||
| plt.close('all') | |||
| datasets = get_office31([source_domain], [target_domain], folder=data_root, return_target_with_labels=True) | |||
| dc = DataloaderCreator(batch_size=64, num_workers=2, all_val=True) | |||
| validator = AccuracyValidator(key_map={"src_val": "src_val"}) | |||
| src_score = trainer.evaluate_best_model(datasets, validator, dc) | |||
| print("Source acc:", src_score) | |||
| validator = AccuracyValidator(key_map={"target_val_with_labels": "src_val"}) | |||
| target_score = trainer.evaluate_best_model(datasets, validator, dc) | |||
| print("Target acc:", target_score) | |||
| with open(results_file, "a") as myfile: | |||
| myfile.write(f"{pair_name}, {src_score}, {target_score}, {best_epoch}, {training_time.seconds}\n") | |||
| del trainer | |||
| del G | |||
| del C | |||
| gc.collect() | |||
| torch.cuda.empty_cache() | |||
| @@ -0,0 +1,112 @@ | |||
| import matplotlib.pyplot as plt | |||
| import torch | |||
| import os | |||
| import gc | |||
| from datetime import datetime | |||
| from pytorch_adapt.datasets import DataloaderCreator, get_office31 | |||
| from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite | |||
| from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory, DiversityValidator, EntropyValidator, MultipleValidators | |||
| from models import get_model | |||
| from utils import DAModels | |||
| from vis_hook import VizHook | |||
| def train(args, model_name, hp, base_output_dir, results_file, source_domain, target_domain): | |||
| if args.source != None and args.source != source_domain: | |||
| return None, None, None | |||
| if args.target != None and args.target != target_domain: | |||
| return None, None, None | |||
| pair_name = f"{source_domain[0]}2{target_domain[0]}" | |||
| output_dir = os.path.join(base_output_dir, pair_name) | |||
| os.makedirs(output_dir, exist_ok=True) | |||
| print("output dir:", output_dir) | |||
| datasets = get_office31([source_domain], [target_domain], | |||
| folder=args.data_root, | |||
| return_target_with_labels=True, | |||
| download=args.download) | |||
| dc = DataloaderCreator(batch_size=args.batch_size, | |||
| num_workers=args.num_workers, | |||
| train_names=["train"], | |||
| val_names=["src_train", "target_train", "src_val", "target_val", | |||
| "target_train_with_labels", "target_val_with_labels"]) | |||
| source_checkpoint_dir = None if not args.source_checkpoint_base_dir else \ | |||
| f"{args.source_checkpoint_base_dir}/{args.source_checkpoint_trial_number}/{pair_name}/saved_models" | |||
| print("source_checkpoint_dir", source_checkpoint_dir) | |||
| adapter = get_model(model_name, hp, source_checkpoint_dir, args.data_root, source_domain) | |||
| checkpoint_fn = CheckpointFnCreator(dirname=f"{output_dir}/saved_models", require_empty=False) | |||
| sourceAccuracyValidator = AccuracyValidator() | |||
| validators = { | |||
| "entropy": EntropyValidator(), | |||
| "diversity": DiversityValidator(), | |||
| } | |||
| validator = ScoreHistory(MultipleValidators(validators)) | |||
| targetAccuracyValidator = AccuracyValidator(key_map={"target_val_with_labels": "src_val"}) | |||
| val_hooks = [ScoreHistory(sourceAccuracyValidator), | |||
| ScoreHistory(targetAccuracyValidator), | |||
| VizHook(output_dir=output_dir, frequency=args.vishook_frequency)] | |||
| trainer = Ignite( | |||
| adapter, validator=validator, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn | |||
| ) | |||
| early_stopper_kwargs = {"patience": args.patience} | |||
| start_time = datetime.now() | |||
| best_score, best_epoch = trainer.run( | |||
| datasets, dataloader_creator=dc, max_epochs=args.max_epochs, early_stopper_kwargs=early_stopper_kwargs | |||
| ) | |||
| end_time = datetime.now() | |||
| training_time = end_time - start_time | |||
| print(f"best_score={best_score}, best_epoch={best_epoch}") | |||
| plt.plot(val_hooks[0].score_history, label='source') | |||
| plt.plot(val_hooks[1].score_history, label='target') | |||
| plt.title("val accuracy") | |||
| plt.legend() | |||
| plt.savefig(f"{output_dir}/val_accuracy.png") | |||
| plt.close('all') | |||
| plt.plot(validator.score_history) | |||
| plt.title("score_history") | |||
| plt.savefig(f"{output_dir}/score_history.png") | |||
| plt.close('all') | |||
| src_score = trainer.evaluate_best_model(datasets, sourceAccuracyValidator, dc) | |||
| print("Source acc:", src_score) | |||
| target_score = trainer.evaluate_best_model(datasets, targetAccuracyValidator, dc) | |||
| print("Target acc:", target_score) | |||
| print("---------") | |||
| if args.hp_tune: | |||
| with open(results_file, "a") as myfile: | |||
| myfile.write(f"{hp.lr}, {hp.gamma}, {pair_name}, {src_score}, {target_score}, {best_epoch}, {best_score}\n") | |||
| else: | |||
| with open(results_file, "a") as myfile: | |||
| myfile.write( | |||
| f"{pair_name}, {src_score}, {target_score}, {best_epoch}, {best_score}, {training_time.seconds}, {hp.lr}, {hp.gamma},") | |||
| del adapter | |||
| gc.collect() | |||
| torch.cuda.empty_cache() | |||
| return src_score, target_score, best_score | |||
| @@ -0,0 +1,121 @@ | |||
| import torch | |||
| import os | |||
| from pytorch_adapt.adapters import DANN, MCD, VADA, CDAN, RTN, ADDA, Aligner, SymNets | |||
| from pytorch_adapt.containers import Models, Optimizers, LRSchedulers | |||
| from pytorch_adapt.models import Discriminator, office31C, office31G | |||
| from pytorch_adapt.containers import Misc | |||
| from pytorch_adapt.layers import RandomizedDotProduct | |||
| from pytorch_adapt.layers import MultipleModels, CORALLoss, MMDLoss | |||
| from pytorch_adapt.utils import common_functions | |||
| from pytorch_adapt.containers import LRSchedulers | |||
| from classifier_adapter import ClassifierAdapter | |||
| from utils import HP, DAModels | |||
| import copy | |||
| import matplotlib.pyplot as plt | |||
| import torch | |||
| import os | |||
| import gc | |||
| from datetime import datetime | |||
| from pytorch_adapt.datasets import DataloaderCreator, get_office31 | |||
| from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite | |||
| from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory, DiversityValidator, EntropyValidator, MultipleValidators | |||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||
| def train_source(args, model_name, hp, base_output_dir, results_file, source_domain, target_domain): | |||
| if args.source != None and args.source != source_domain: | |||
| return None, None, None | |||
| if args.target != None and args.target != target_domain: | |||
| return None, None, None | |||
| pair_name = f"{source_domain[0]}2{target_domain[0]}" | |||
| output_dir = os.path.join(base_output_dir, pair_name) | |||
| os.makedirs(output_dir, exist_ok=True) | |||
| print("output dir:", output_dir) | |||
| datasets = get_office31([source_domain], [], | |||
| folder=args.data_root, | |||
| return_target_with_labels=True, | |||
| download=args.download) | |||
| dc = DataloaderCreator(batch_size=args.batch_size, | |||
| num_workers=args.num_workers, | |||
| ) | |||
| weights_root = os.path.join(args.data_root, "weights") | |||
| G = office31G(pretrained=True, model_dir=weights_root).to(device) | |||
| C = office31C(domain=source_domain, pretrained=True, | |||
| model_dir=weights_root).to(device) | |||
| optimizers = Optimizers((torch.optim.Adam, {"lr": hp.lr})) | |||
| lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {"gamma": hp.gamma})) | |||
| models = Models({"G": G, "C": C}) | |||
| adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers) | |||
| # adapter = get_model(model_name, hp, args.data_root, source_domain) | |||
| print("checkpoint dir:", output_dir) | |||
| checkpoint_fn = CheckpointFnCreator(dirname=f"{output_dir}/saved_models", require_empty=False) | |||
| sourceAccuracyValidator = AccuracyValidator() | |||
| val_hooks = [ScoreHistory(sourceAccuracyValidator)] | |||
| trainer = Ignite( | |||
| adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn | |||
| ) | |||
| early_stopper_kwargs = {"patience": args.patience} | |||
| start_time = datetime.now() | |||
| best_score, best_epoch = trainer.run( | |||
| datasets, dataloader_creator=dc, max_epochs=args.max_epochs, early_stopper_kwargs=early_stopper_kwargs | |||
| ) | |||
| end_time = datetime.now() | |||
| training_time = end_time - start_time | |||
| print(f"best_score={best_score}, best_epoch={best_epoch}") | |||
| plt.plot(val_hooks[0].score_history, label='source') | |||
| plt.title("val accuracy") | |||
| plt.legend() | |||
| plt.savefig(f"{output_dir}/val_accuracy.png") | |||
| plt.close('all') | |||
| datasets = get_office31([source_domain], [target_domain], folder=args.data_root, return_target_with_labels=True) | |||
| dc = DataloaderCreator(batch_size=args.batch_size, num_workers=args.num_workers, all_val=True) | |||
| validator = AccuracyValidator(key_map={"src_val": "src_val"}) | |||
| src_score = trainer.evaluate_best_model(datasets, validator, dc) | |||
| print("Source acc:", src_score) | |||
| validator = AccuracyValidator(key_map={"target_val_with_labels": "src_val"}) | |||
| target_score = trainer.evaluate_best_model(datasets, validator, dc) | |||
| print("Target acc:", target_score) | |||
| print("---------") | |||
| if args.hp_tune: | |||
| with open(results_file, "a") as myfile: | |||
| myfile.write(f"{hp.lr}, {hp.gamma}, {pair_name}, {src_score}, {target_score}, {best_epoch}, {best_score}\n") | |||
| else: | |||
| with open(results_file, "a") as myfile: | |||
| myfile.write( | |||
| f"{pair_name}, {src_score}, {target_score}, {best_epoch}, {best_score}, {training_time.seconds}, ") | |||
| del adapter | |||
| gc.collect() | |||
| torch.cuda.empty_cache() | |||
| return src_score, target_score, best_score | |||
| @@ -0,0 +1,17 @@ | |||
| from dataclasses import dataclass | |||
| from enum import Enum | |||
| @dataclass | |||
| class HP: | |||
| lr:float = 0.0005 | |||
| gamma:float = 0.99 | |||
| class DAModels(Enum): | |||
| DANN = "DANN" | |||
| MCD = "MCD" | |||
| CDAN = "CDAN" | |||
| SOURCE = "SOURCE" | |||
| MMD = "MMD" | |||
| CORAL = "CORAL" | |||
| SYMNET = "SYMNET" | |||
| @@ -0,0 +1,47 @@ | |||
| import matplotlib.pyplot as plt | |||
| import pandas as pd | |||
| import seaborn as sns | |||
| import torch | |||
| import umap | |||
| from datetime import datetime | |||
| from pytorch_adapt.adapters import DANN | |||
| from pytorch_adapt.containers import Models, Optimizers, LRSchedulers | |||
| from pytorch_adapt.datasets import DataloaderCreator, get_office31 | |||
| from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite | |||
| from pytorch_adapt.models import Discriminator, office31C, office31G | |||
| from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory | |||
| class VizHook: | |||
| def __init__(self, **kwargs): | |||
| self.required_data = ["src_val", | |||
| "target_val", "target_val_with_labels"] | |||
| self.kwargs = kwargs | |||
| def __call__(self, epoch, src_val, target_val, target_val_with_labels, **kwargs): | |||
| accuracy_validator = AccuracyValidator() | |||
| accuracy = accuracy_validator.compute_score(src_val=src_val) | |||
| print("src_val accuracy:", accuracy) | |||
| accuracy_validator = AccuracyValidator() | |||
| accuracy = accuracy_validator.compute_score(src_val=target_val_with_labels) | |||
| print("target_val accuracy:", accuracy) | |||
| if epoch >= 2 and epoch % kwargs.get("frequency", 5) != 0: | |||
| return | |||
| features = [src_val["features"], target_val["features"]] | |||
| domain = [src_val["domain"], target_val["domain"]] | |||
| features = torch.cat(features, dim=0).cpu().numpy() | |||
| domain = torch.cat(domain, dim=0).cpu().numpy() | |||
| emb = umap.UMAP().fit_transform(features) | |||
| df = pd.DataFrame(emb).assign(domain=domain) | |||
| df["domain"] = df["domain"].replace({0: "Source", 1: "Target"}) | |||
| sns.set_theme(style="white", rc={"figure.figsize": (8, 6)}) | |||
| sns.scatterplot(data=df, x=0, y=1, hue="domain", s=10) | |||
| plt.savefig(f"{self.kwargs['output_dir']}/val_{epoch}.png") | |||
| plt.close('all') | |||
| @@ -0,0 +1,42 @@ | |||
| import matplotlib.pyplot as plt | |||
| import numpy as np | |||
| import torchvision | |||
| from pytorch_adapt.datasets import get_office31 | |||
| # root="datasets/pytorch-adapt/" | |||
| mean = [0.485, 0.456, 0.406] | |||
| std = [0.229, 0.224, 0.225] | |||
| inv_normalize = torchvision.transforms.Normalize( | |||
| mean=[-m / s for m, s in zip(mean, std)], std=[1 / s for s in std] | |||
| ) | |||
| idx = 0 | |||
| def imshow(img, domain, figsize=(10, 6)): | |||
| img = inv_normalize(img) | |||
| npimg = img.numpy() | |||
| plt.figure(figsize=figsize) | |||
| plt.imshow(np.transpose(npimg, (1, 2, 0))) | |||
| plt.axis('off') | |||
| plt.savefig(f"office31-{idx}") | |||
| plt.show() | |||
| plt.close("all") | |||
| idx += 1 | |||
| def imshow_many(datasets, src, target): | |||
| d = datasets["train"] | |||
| for name in ["src_imgs", "target_imgs"]: | |||
| domains = src if name == "src_imgs" else target | |||
| if len(domains) == 0: | |||
| continue | |||
| print(domains) | |||
| imgs = [d[i][name] for i in np.random.choice(len(d), size=16, replace=False)] | |||
| imshow(torchvision.utils.make_grid(imgs)) | |||
| for src, target in [(["amazon"], ["dslr"]), (["webcam"], [])]: | |||
| datasets = get_office31(src, target,folder=root) | |||
| imshow_many(datasets, src, target) | |||
| @@ -0,0 +1,165 @@ | |||
| import itertools | |||
| import numpy as np | |||
| from pprint import pprint | |||
| domains = ["w", "d", "a"] | |||
| pairs = [[f"{d1}2{d2}" for d1 in domains] for d2 in domains] | |||
| pairs = list(itertools.chain.from_iterable(pairs)) | |||
| pairs.sort() | |||
| # print(pairs) | |||
| ROUND_FACTOR = 3 | |||
| all_accs_list = {} | |||
| files = [ | |||
| "all.txt", | |||
| "all2.txt", | |||
| "all3.txt", | |||
| # "all_step_scheduler.txt", | |||
| # "all_source_trained_1.txt", | |||
| # "all_source_trained_2_specific_hp.txt", | |||
| ] | |||
| for file in files: | |||
| with open(file) as f: | |||
| name = None | |||
| for line in f: | |||
| splitted = line.split(" ") | |||
| if splitted[0] == "##": # e.g. ## DANN | |||
| name = splitted[1].strip() | |||
| splitted = line.split(",") # a2w, acc1, acc2, | |||
| if splitted[0] in pairs: | |||
| pair = splitted[0] | |||
| name = name.lower() | |||
| acc = float(splitted[2].strip()) | |||
| if name not in all_accs_list: | |||
| all_accs_list[name] = {p:[] for p in pairs} | |||
| all_accs_list[name][pair].append(acc) | |||
| # all_accs_list format: {'dann': {'a2w': [acc1, acc2, acc3]}} | |||
| acc_means_by_model_name = {} | |||
| vars_by_model_name = {} | |||
| for name, pair_list in all_accs_list.items(): | |||
| accs = {p:[] for p in pairs} | |||
| vars = {p:[] for p in pairs} | |||
| for pair, acc_list in pair_list.items(): | |||
| if len(acc_list) > 0: | |||
| ## Calculate average and round | |||
| accs[pair] = round(100 * sum(acc_list) / len(acc_list), ROUND_FACTOR) | |||
| vars[pair] = round(np.var(acc_list) * 100, ROUND_FACTOR) | |||
| print(vars[pair], "|||", acc_list) | |||
| acc_means_by_model_name[name] = accs | |||
| vars_by_model_name[name] = vars | |||
| # for name, acc_list in acc_means_by_model_name.items(): | |||
| # pprint(name) | |||
| # pprint(all_accs_list) | |||
| # pprint(acc_means_by_model_name) | |||
| # pprint(vars_by_model_name) | |||
| print() | |||
| latex_table = "" | |||
| header = [pair for pair in pairs if pair[0] != pair[-1]] | |||
| table = [] | |||
| var_table = [] | |||
| for name, acc_list in acc_means_by_model_name.items(): | |||
| if "target" in name: | |||
| continue | |||
| var_list = vars_by_model_name[name] | |||
| valid_accs = [] | |||
| table_row = [] | |||
| var_table_row = [] | |||
| for pair in pairs: | |||
| acc = acc_list[pair] | |||
| var = var_list[pair] | |||
| if pair[0] != pair[-1]: # exclude w2w, ... | |||
| table_row.append(acc) | |||
| var_table_row.append(var) | |||
| if acc != None: | |||
| valid_accs.append(acc) | |||
| acc_average = round(sum(valid_accs) / len(header), ROUND_FACTOR) | |||
| table_row.append(acc_average) | |||
| table.append(table_row) | |||
| var =round(np.var(valid_accs), ROUND_FACTOR) | |||
| print(var, "~~~~~~", valid_accs) | |||
| var_table_row.append(var) | |||
| var_table.append(var_table_row) | |||
| t = np.array(table) | |||
| t[t==None] = np.nan | |||
| # pprint(t) | |||
| col_max = t.max(axis=0) | |||
| pprint(table) | |||
| latex_table = "" | |||
| header = [pair for pair in pairs if pair[0] != pair[-1]] | |||
| name_map = {"base_source": "Source-Only"} | |||
| j = 0 | |||
| for name, acc_list in acc_means_by_model_name.items(): | |||
| if "target" in name: | |||
| continue | |||
| latex_name = name | |||
| if name in name_map: | |||
| latex_name= name_map[name] | |||
| latex_row = f"{latex_name.replace('_','-').upper()} &" | |||
| acc_sum = 0 | |||
| for i, acc in enumerate(table[j]): | |||
| if i == len(table[j]) - 1: | |||
| acc_str = f"${acc}$" | |||
| else: | |||
| acc_str = f"${acc} \pm {var_table[j][i]}$" | |||
| if acc == col_max[i]: | |||
| latex_row += f" \\underline{{{acc_str}}} &" | |||
| else: | |||
| latex_row += f" {acc_str} &" | |||
| latex_row = f"{latex_row[:-1]} \\\\ \hline" | |||
| latex_table += f"{latex_row}\n" | |||
| j += 1 | |||
| print(*header, sep=" & ") | |||
| print(latex_table) | |||
| data = np.array(table) | |||
| legend = [key for key in acc_means_by_model_name.keys()] | |||
| labels = [*header, "avg"] | |||
| data = np.array([[71.75, 75.94 ,67.38, 90.99, 68.91, 96.67, 78.61], [64.0, 66.67, 37.32, 94.97, 45.74, 98.5, 67.87]]) | |||
| legend = ["CDAN", "Source-only"] | |||
| labels = [*header, "avg"] | |||
| import matplotlib.pyplot as plt | |||
| # Assume your matrix is called 'data' | |||
| n, m = data.shape | |||
| # Create an array of x-coordinates for the bars | |||
| x = np.arange(m) | |||
| # Plot the bars for each row side by side | |||
| for i in range(n): | |||
| row = data[i, :] | |||
| plt.bar(x + (i-n/2)*0.3, row, width=0.25, align='center') | |||
| # Set x-axis tick labels and labels | |||
| plt.xticks(x, labels=labels) | |||
| # plt.xlabel("Task") | |||
| plt.ylabel("Accuracy") | |||
| # Add a legend | |||
| plt.legend(legend) | |||
| plt.show() | |||
| @@ -0,0 +1,167 @@ | |||
| import itertools | |||
| import numpy as np | |||
| from pprint import pprint | |||
| domains = ["w", "d", "a"] | |||
| pairs = [[f"{d1}2{d2}" for d1 in domains] for d2 in domains] | |||
| pairs = list(itertools.chain.from_iterable(pairs)) | |||
| pairs.sort() | |||
| # print(pairs) | |||
| ROUND_FACTOR = 2 | |||
| all_accs_list = {} | |||
| files = [ | |||
| # "all.txt", | |||
| # "all2.txt", | |||
| # "all3.txt", | |||
| # "all_step_scheduler.txt", | |||
| # "all_source_trained_1.txt", | |||
| "all_source_trained_2_specific_hp.txt", | |||
| ] | |||
| for file in files: | |||
| with open(file) as f: | |||
| name = None | |||
| for line in f: | |||
| splitted = line.split(" ") | |||
| if splitted[0] == "##": # e.g. ## DANN | |||
| name = splitted[1].strip() | |||
| splitted = line.split(",") # a2w, acc1, acc2, | |||
| if splitted[0] in pairs: | |||
| pair = splitted[0] | |||
| name = name.lower() | |||
| acc = float(splitted[5].strip()) #/ 60 / 60 | |||
| if name not in all_accs_list: | |||
| all_accs_list[name] = {p:[] for p in pairs} | |||
| all_accs_list[name][pair].append(acc) | |||
| # all_accs_list format: {'dann': {'a2w': [acc1, acc2, acc3]}} | |||
| acc_means_by_model_name = {} | |||
| vars_by_model_name = {} | |||
| for name, pair_list in all_accs_list.items(): | |||
| accs = {p:[] for p in pairs} | |||
| vars = {p:[] for p in pairs} | |||
| for pair, acc_list in pair_list.items(): | |||
| if len(acc_list) > 0: | |||
| ## Calculate average and round | |||
| accs[pair] = round(sum(acc_list) / len(acc_list), ROUND_FACTOR) | |||
| vars[pair] = round(np.var(acc_list) * 100, ROUND_FACTOR) | |||
| print(vars[pair], "|||", acc_list) | |||
| acc_means_by_model_name[name] = accs | |||
| vars_by_model_name[name] = vars | |||
| # for name, acc_list in acc_means_by_model_name.items(): | |||
| # pprint(name) | |||
| # pprint(all_accs_list) | |||
| # pprint(acc_means_by_model_name) | |||
| # pprint(vars_by_model_name) | |||
| print() | |||
| latex_table = "" | |||
| header = [pair for pair in pairs if pair[0] != pair[-1]] | |||
| table = [] | |||
| var_table = [] | |||
| for name, acc_list in acc_means_by_model_name.items(): | |||
| if "target" in name or "source" in name: | |||
| continue | |||
| print("~~~~%%%%~~~", name) | |||
| var_list = vars_by_model_name[name] | |||
| valid_accs = [] | |||
| table_row = [] | |||
| var_table_row = [] | |||
| for pair in pairs: | |||
| acc = acc_list[pair] | |||
| var = var_list[pair] | |||
| if pair[0] != pair[-1]: # exclude w2w, ... | |||
| table_row.append(acc) | |||
| var_table_row.append(var) | |||
| if acc != None: | |||
| valid_accs.append(acc) | |||
| acc_average = round(sum(valid_accs) / len(header), ROUND_FACTOR) | |||
| table_row.append(acc_average) | |||
| table.append(table_row) | |||
| var =round(np.var(valid_accs), ROUND_FACTOR) | |||
| print(var, ">>>", valid_accs) | |||
| var_table_row.append(var) | |||
| var_table.append(var_table_row) | |||
| t = np.array(table) | |||
| t[t==None] = np.nan | |||
| # pprint(t) | |||
| col_max = t.min(axis=0) | |||
| pprint(table) | |||
| latex_table = "" | |||
| header = [pair for pair in pairs if pair[0] != pair[-1]] | |||
| name_map = {"base_source": "Source-Only"} | |||
| j = 0 | |||
| for name, acc_list in acc_means_by_model_name.items(): | |||
| if "target" in name or "source" in name: | |||
| continue | |||
| latex_name = name | |||
| if name in name_map: | |||
| latex_name= name_map[name] | |||
| latex_row = f"{latex_name.replace('_','-').upper()} &" | |||
| acc_sum = 0 | |||
| for i, acc in enumerate(table[j]): | |||
| if i == len(table[j]) - 1: | |||
| acc_str = f"${acc}$" | |||
| else: | |||
| acc_str = f"${acc}$" | |||
| if acc == col_max[i]: | |||
| latex_row += f" \\underline{{{acc_str}}} &" | |||
| else: | |||
| latex_row += f" {acc_str} &" | |||
| latex_row = f"{latex_row[:-1]} \\\\ \hline" | |||
| latex_table += f"{latex_row}\n" | |||
| j += 1 | |||
| print(*header, sep=" & ") | |||
| print(latex_table) | |||
| data = np.array(table) | |||
| legend = [key for key in acc_means_by_model_name.keys() if "source" not in key] | |||
| labels = [*header, "avg"] | |||
| # data = np.array([[71.75, 75.94 ,67.38, 90.99, 68.91, 96.67, 78.61], [64.0, 66.67, 37.32, 94.97, 45.74, 98.5, 67.87]]) | |||
| # legend = ["CDAN", "Source-only"] | |||
| # labels = [*header, "avg"] | |||
| import matplotlib.pyplot as plt | |||
| # Assume your matrix is called 'data' | |||
| n, m = data.shape | |||
| # Create an array of x-coordinates for the bars | |||
| x = np.arange(m) | |||
| # Plot the bars for each row side by side | |||
| for i in range(n): | |||
| row = data[i, :] | |||
| plt.bar(x + (i-n/2)*0.1, row, width=0.08, align='center') | |||
| # Set x-axis tick labels and labels | |||
| plt.xticks(x, labels=labels) | |||
| # plt.xlabel("Task") | |||
| plt.ylabel("Time (s)") | |||
| # Add a legend | |||
| plt.legend(legend) | |||
| plt.show() | |||
| @@ -0,0 +1,30 @@ | |||
| from pytorch_adapt.datasets import DataloaderCreator, get_office31 | |||
| from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite | |||
| from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory, DiversityValidator, EntropyValidator, MultipleValidators | |||
| from time import time | |||
| import multiprocessing as mp | |||
| data_root = "../datasets/pytorch-adapt/" | |||
| batch_size = 32 | |||
| for num_workers in range(2, mp.cpu_count(), 2): | |||
| datasets = get_office31(["amazon"], ["webcam"], | |||
| folder=data_root, | |||
| return_target_with_labels=True, | |||
| download=False) | |||
| dc = DataloaderCreator(batch_size=batch_size, | |||
| num_workers=num_workers, | |||
| train_names=["train"], | |||
| val_names=["src_train", "target_train", "src_val", "target_val", | |||
| "target_train_with_labels", "target_val_with_labels"]) | |||
| dataloaders = dc(**datasets) | |||
| train_loader = dataloaders["train"] | |||
| start = time() | |||
| for epoch in range(1, 3): | |||
| for i, data in enumerate(train_loader, 0): | |||
| pass | |||
| end = time() | |||
| print("Finish with:{} second, num_workers={}".format(end - start, num_workers)) | |||
| @@ -0,0 +1,924 @@ | |||
| { | |||
| "cells": [ | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 1, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "data": { | |||
| "text/plain": [ | |||
| "device(type='cuda', index=0)" | |||
| ] | |||
| }, | |||
| "execution_count": 1, | |||
| "metadata": {}, | |||
| "output_type": "execute_result" | |||
| } | |||
| ], | |||
| "source": [ | |||
| "\n", | |||
| "import torch\n", | |||
| "import os\n", | |||
| "\n", | |||
| "from pytorch_adapt.adapters import DANN, MCD, VADA, CDAN, RTN, ADDA, Aligner, SymNets\n", | |||
| "from pytorch_adapt.containers import Models, Optimizers, LRSchedulers\n", | |||
| "from pytorch_adapt.models import Discriminator, office31C, office31G\n", | |||
| "from pytorch_adapt.containers import Misc\n", | |||
| "from pytorch_adapt.layers import RandomizedDotProduct\n", | |||
| "from pytorch_adapt.layers import MultipleModels, CORALLoss, MMDLoss\n", | |||
| "from pytorch_adapt.utils import common_functions\n", | |||
| "from pytorch_adapt.containers import LRSchedulers\n", | |||
| "\n", | |||
| "from classifier_adapter import ClassifierAdapter\n", | |||
| "\n", | |||
| "from utils import HP, DAModels\n", | |||
| "\n", | |||
| "import copy\n", | |||
| "\n", | |||
| "import matplotlib.pyplot as plt\n", | |||
| "import torch\n", | |||
| "import os\n", | |||
| "import gc\n", | |||
| "from datetime import datetime\n", | |||
| "\n", | |||
| "from pytorch_adapt.datasets import DataloaderCreator, get_office31\n", | |||
| "from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite\n", | |||
| "from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory, DiversityValidator, EntropyValidator, MultipleValidators\n", | |||
| "\n", | |||
| "from models import get_model\n", | |||
| "from utils import DAModels\n", | |||
| "\n", | |||
| "from vis_hook import VizHook\n", | |||
| "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | |||
| "device" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 2, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "Namespace(batch_size=64, data_root='./datasets/pytorch-adapt/', download=False, gamma=0.99, hp_tune=False, initial_trial=0, lr=0.0001, max_epochs=1, model_names=['DANN'], num_workers=1, patience=2, results_root='./results/', root='./', source=None, target=None, trials_count=1, vishook_frequency=5)\n" | |||
| ] | |||
| } | |||
| ], | |||
| "source": [ | |||
| "import argparse\n", | |||
| "parser = argparse.ArgumentParser()\n", | |||
| "parser.add_argument('--max_epochs', default=1, type=int)\n", | |||
| "parser.add_argument('--patience', default=2, type=int)\n", | |||
| "parser.add_argument('--batch_size', default=64, type=int)\n", | |||
| "parser.add_argument('--num_workers', default=1, type=int)\n", | |||
| "parser.add_argument('--trials_count', default=1, type=int)\n", | |||
| "parser.add_argument('--initial_trial', default=0, type=int)\n", | |||
| "parser.add_argument('--download', default=False, type=bool)\n", | |||
| "parser.add_argument('--root', default=\"./\")\n", | |||
| "parser.add_argument('--data_root', default=\"./datasets/pytorch-adapt/\")\n", | |||
| "parser.add_argument('--results_root', default=\"./results/\")\n", | |||
| "parser.add_argument('--model_names', default=[\"DANN\"], nargs='+')\n", | |||
| "parser.add_argument('--lr', default=0.0001, type=float)\n", | |||
| "parser.add_argument('--gamma', default=0.99, type=float)\n", | |||
| "parser.add_argument('--hp_tune', default=False, type=bool)\n", | |||
| "parser.add_argument('--source', default=None)\n", | |||
| "parser.add_argument('--target', default=None) \n", | |||
| "parser.add_argument('--vishook_frequency', default=5, type=int)\n", | |||
| " \n", | |||
| "\n", | |||
| "args = parser.parse_args(\"\")\n", | |||
| "print(args)\n" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 45, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "source_domain = 'amazon'\n", | |||
| "target_domain = 'webcam'\n", | |||
| "datasets = get_office31([source_domain], [],\n", | |||
| " folder=args.data_root,\n", | |||
| " return_target_with_labels=True,\n", | |||
| " download=args.download)\n", | |||
| "\n", | |||
| "dc = DataloaderCreator(batch_size=args.batch_size,\n", | |||
| " num_workers=args.num_workers,\n", | |||
| " )\n", | |||
| "\n", | |||
| "weights_root = os.path.join(args.data_root, \"weights\")\n", | |||
| "\n", | |||
| "G = office31G(pretrained=True, model_dir=weights_root).to(device)\n", | |||
| "C = office31C(domain=source_domain, pretrained=True,\n", | |||
| " model_dir=weights_root).to(device)\n", | |||
| "\n", | |||
| "\n", | |||
| "optimizers = Optimizers((torch.optim.Adam, {\"lr\": 1e-4}))\n", | |||
| "lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {\"gamma\": 0.99})) \n", | |||
| "\n", | |||
| "models = Models({\"G\": G, \"C\": C})\n", | |||
| "adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "cuda:0\n" | |||
| ] | |||
| }, | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "f28aaf5a334d4f91a9beb21e714c43a5", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| "[1/35] 3%|2 |it [00:00<?]" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| }, | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "7131d8b9099c4d0a95155595919c55f5", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| "[1/9] 11%|#1 |it [00:00<?]" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| }, | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "best_score=None, best_epoch=None\n" | |||
| ] | |||
| }, | |||
| { | |||
| "ename": "AttributeError", | |||
| "evalue": "'Namespace' object has no attribute 'dataroot'", | |||
| "output_type": "error", | |||
| "traceback": [ | |||
| "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |||
| "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", | |||
| "Cell \u001b[0;32mIn[39], line 31\u001b[0m\n\u001b[1;32m 28\u001b[0m plt\u001b[39m.\u001b[39msavefig(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00moutput_dir\u001b[39m}\u001b[39;00m\u001b[39m/val_accuracy.png\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 29\u001b[0m plt\u001b[39m.\u001b[39mclose(\u001b[39m'\u001b[39m\u001b[39mall\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[0;32m---> 31\u001b[0m datasets \u001b[39m=\u001b[39m get_office31([source_domain], [target_domain], folder\u001b[39m=\u001b[39margs\u001b[39m.\u001b[39;49mdataroot, return_target_with_labels\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m 32\u001b[0m dc \u001b[39m=\u001b[39m DataloaderCreator(batch_size\u001b[39m=\u001b[39margs\u001b[39m.\u001b[39mbatch_size, num_workers\u001b[39m=\u001b[39margs\u001b[39m.\u001b[39mnum_workers, all_val\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m 34\u001b[0m validator \u001b[39m=\u001b[39m AccuracyValidator(key_map\u001b[39m=\u001b[39m{\u001b[39m\"\u001b[39m\u001b[39msrc_val\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m\"\u001b[39m\u001b[39msrc_val\u001b[39m\u001b[39m\"\u001b[39m})\n", | |||
| "\u001b[0;31mAttributeError\u001b[0m: 'Namespace' object has no attribute 'dataroot'" | |||
| ] | |||
| } | |||
| ], | |||
| "source": [ | |||
| "\n", | |||
| "output_dir = \"tmp\"\n", | |||
| "checkpoint_fn = CheckpointFnCreator(dirname=f\"{output_dir}/saved_models\", require_empty=False)\n", | |||
| "\n", | |||
| "sourceAccuracyValidator = AccuracyValidator()\n", | |||
| "val_hooks = [ScoreHistory(sourceAccuracyValidator)]\n", | |||
| "\n", | |||
| "trainer = Ignite(\n", | |||
| " adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn, device=device\n", | |||
| ")\n", | |||
| "print(trainer.device)\n", | |||
| "\n", | |||
| "early_stopper_kwargs = {\"patience\": args.patience}\n", | |||
| "\n", | |||
| "start_time = datetime.now()\n", | |||
| "\n", | |||
| "best_score, best_epoch = trainer.run(\n", | |||
| " datasets, dataloader_creator=dc, max_epochs=args.max_epochs, early_stopper_kwargs=early_stopper_kwargs\n", | |||
| ")\n", | |||
| "\n", | |||
| "end_time = datetime.now()\n", | |||
| "training_time = end_time - start_time\n", | |||
| "\n", | |||
| "print(f\"best_score={best_score}, best_epoch={best_epoch}\")\n", | |||
| "\n", | |||
| "plt.plot(val_hooks[0].score_history, label='source')\n", | |||
| "plt.title(\"val accuracy\")\n", | |||
| "plt.legend()\n", | |||
| "plt.savefig(f\"{output_dir}/val_accuracy.png\")\n", | |||
| "plt.close('all')\n" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "173d54ab994d4abda6e4f0897ad96c49", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| "[1/9] 11%|#1 |it [00:00<?]" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| }, | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "Source acc: 0.868794322013855\n" | |||
| ] | |||
| }, | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "13b4a1ccc3b34456b68c73357d14bc21", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| "[1/3] 33%|###3 |it [00:00<?]" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| }, | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "Target acc: 0.74842768907547\n", | |||
| "---------\n" | |||
| ] | |||
| } | |||
| ], | |||
| "source": [ | |||
| "\n", | |||
| "datasets = get_office31([source_domain], [target_domain], folder=args.data_root, return_target_with_labels=True)\n", | |||
| "dc = DataloaderCreator(batch_size=args.batch_size, num_workers=args.num_workers, all_val=True)\n", | |||
| "\n", | |||
| "validator = AccuracyValidator(key_map={\"src_val\": \"src_val\"})\n", | |||
| "src_score = trainer.evaluate_best_model(datasets, validator, dc)\n", | |||
| "print(\"Source acc:\", src_score)\n", | |||
| "\n", | |||
| "validator = AccuracyValidator(key_map={\"target_val_with_labels\": \"src_val\"})\n", | |||
| "target_score = trainer.evaluate_best_model(datasets, validator, dc)\n", | |||
| "print(\"Target acc:\", target_score)\n", | |||
| "print(\"---------\")" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "C2 = copy.deepcopy(C) " | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 93, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "cuda:0\n" | |||
| ] | |||
| } | |||
| ], | |||
| "source": [ | |||
| "source_domain = 'amazon'\n", | |||
| "target_domain = 'webcam'\n", | |||
| "G = office31G(pretrained=False).to(device)\n", | |||
| "C = office31C(pretrained=False).to(device)\n", | |||
| "\n", | |||
| "\n", | |||
| "optimizers = Optimizers((torch.optim.Adam, {\"lr\": 1e-4}))\n", | |||
| "lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {\"gamma\": 0.99})) \n", | |||
| "\n", | |||
| "models = Models({\"G\": G, \"C\": C})\n", | |||
| "adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)\n", | |||
| "\n", | |||
| "\n", | |||
| "output_dir = \"tmp\"\n", | |||
| "checkpoint_fn = CheckpointFnCreator(dirname=f\"{output_dir}/saved_models\", require_empty=False)\n", | |||
| "\n", | |||
| "sourceAccuracyValidator = AccuracyValidator()\n", | |||
| "val_hooks = [ScoreHistory(sourceAccuracyValidator)]\n", | |||
| "\n", | |||
| "new_trainer = Ignite(\n", | |||
| " adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn, device=device\n", | |||
| ")\n", | |||
| "print(trainer.device)\n", | |||
| "\n", | |||
| "from pytorch_adapt.frameworks.ignite import (\n", | |||
| " CheckpointFnCreator,\n", | |||
| " IgniteValHookWrapper,\n", | |||
| " checkpoint_utils,\n", | |||
| ")\n", | |||
| "\n", | |||
| "objs = [\n", | |||
| " {\n", | |||
| " \"engine\": new_trainer.trainer,\n", | |||
| " \"validator\": new_trainer.validator,\n", | |||
| " \"val_hook0\": val_hooks[0],\n", | |||
| " **checkpoint_utils.adapter_to_dict(new_trainer.adapter),\n", | |||
| " }\n", | |||
| " ]\n", | |||
| " \n", | |||
| "# best_score, best_epoch = trainer.run(\n", | |||
| "# datasets, dataloader_creator=dc, max_epochs=args.max_epochs, early_stopper_kwargs=early_stopper_kwargs\n", | |||
| "# )\n", | |||
| "\n", | |||
| "for to_load in objs:\n", | |||
| " checkpoint_fn.load_best_checkpoint(to_load)\n", | |||
| "\n" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 94, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "32f01ff7ea254739909e4567a133b00a", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| "[1/9] 11%|#1 |it [00:00<?]" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| }, | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "Source acc: 0.868794322013855\n" | |||
| ] | |||
| }, | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "cef345c05e5e46eb9fc0e1cc40b02435", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| "[1/3] 33%|###3 |it [00:00<?]" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| }, | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "Target acc: 0.74842768907547\n", | |||
| "---------\n" | |||
| ] | |||
| } | |||
| ], | |||
| "source": [ | |||
| "\n", | |||
| "datasets = get_office31([source_domain], [target_domain], folder=args.data_root, return_target_with_labels=True)\n", | |||
| "dc = DataloaderCreator(batch_size=args.batch_size, num_workers=args.num_workers, all_val=True)\n", | |||
| "\n", | |||
| "validator = AccuracyValidator(key_map={\"src_val\": \"src_val\"})\n", | |||
| "src_score = new_trainer.evaluate_best_model(datasets, validator, dc)\n", | |||
| "print(\"Source acc:\", src_score)\n", | |||
| "\n", | |||
| "validator = AccuracyValidator(key_map={\"target_val_with_labels\": \"src_val\"})\n", | |||
| "target_score = new_trainer.evaluate_best_model(datasets, validator, dc)\n", | |||
| "print(\"Target acc:\", target_score)\n", | |||
| "print(\"---------\")" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 89, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "\n", | |||
| "datasets = get_office31([source_domain], [target_domain],\n", | |||
| " folder=args.data_root,\n", | |||
| " return_target_with_labels=True,\n", | |||
| " download=args.download)\n", | |||
| " \n", | |||
| "dc = DataloaderCreator(batch_size=args.batch_size,\n", | |||
| " num_workers=args.num_workers,\n", | |||
| " train_names=[\"train\"],\n", | |||
| " val_names=[\"src_train\", \"target_train\", \"src_val\", \"target_val\",\n", | |||
| " \"target_train_with_labels\", \"target_val_with_labels\"])\n", | |||
| "\n", | |||
| "G = new_trainer.adapter.models[\"G\"]\n", | |||
| "C = new_trainer.adapter.models[\"C\"]\n", | |||
| "D = Discriminator(in_size=2048, h=1024).to(device)\n", | |||
| "\n", | |||
| "optimizers = Optimizers((torch.optim.Adam, {\"lr\": 0.001}))\n", | |||
| "lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {\"gamma\": 0.99}))\n", | |||
| "# lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.MultiStepLR, {\"milestones\": [2, 5, 10, 20, 40], \"gamma\": hp.gamma}))\n", | |||
| "\n", | |||
| "models = Models({\"G\": G, \"C\": C, \"D\": D})\n", | |||
| "adapter = DANN(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)\n" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 90, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "cuda:0\n" | |||
| ] | |||
| }, | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "bf490d18567444149070191e100f8c45", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| "[1/3] 33%|###3 |it [00:00<?]" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| }, | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "525920fcd19d4178a4bada48932c8fb1", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| "[1/9] 11%|#1 |it [00:00<?]" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| }, | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "bd1158f548e746cdab88d608b22ab65c", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| "[1/9] 11%|#1 |it [00:00<?]" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| }, | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "196fa120037b48fdb4e9a879e7e7c79b", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| "[1/3] 33%|###3 |it [00:00<?]" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| }, | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "6795edb658a84309b1a03bcea6a24643", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| "[1/9] 11%|#1 |it [00:00<?]" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| } | |||
| ], | |||
| "source": [ | |||
| "\n", | |||
| "output_dir = \"tmp\"\n", | |||
| "checkpoint_fn = CheckpointFnCreator(dirname=f\"{output_dir}/saved_models\", require_empty=False)\n", | |||
| "\n", | |||
| "sourceAccuracyValidator = AccuracyValidator()\n", | |||
| "targetAccuracyValidator = AccuracyValidator(key_map={\"target_val_with_labels\": \"src_val\"})\n", | |||
| "val_hooks = [ScoreHistory(sourceAccuracyValidator), ScoreHistory(targetAccuracyValidator)]\n", | |||
| "\n", | |||
| "trainer = Ignite(\n", | |||
| " adapter, val_hooks=val_hooks, device=device\n", | |||
| ")\n", | |||
| "print(trainer.device)\n", | |||
| "\n", | |||
| "best_score, best_epoch = trainer.run(\n", | |||
| " datasets, dataloader_creator=dc, max_epochs=args.max_epochs, early_stopper_kwargs=early_stopper_kwargs, check_initial_score=True\n", | |||
| ")\n" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 91, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "data": { | |||
| "text/plain": [ | |||
| "ScoreHistory(\n", | |||
| " validator=AccuracyValidator(required_data=['src_val'])\n", | |||
| " latest_score=0.30319148302078247\n", | |||
| " best_score=0.868794322013855\n", | |||
| " best_epoch=0\n", | |||
| ")" | |||
| ] | |||
| }, | |||
| "execution_count": 91, | |||
| "metadata": {}, | |||
| "output_type": "execute_result" | |||
| } | |||
| ], | |||
| "source": [ | |||
| "val_hooks[0]" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 92, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "data": { | |||
| "text/plain": [ | |||
| "ScoreHistory(\n", | |||
| " validator=AccuracyValidator(required_data=['target_val_with_labels'])\n", | |||
| " latest_score=0.2515723407268524\n", | |||
| " best_score=0.74842768907547\n", | |||
| " best_epoch=0\n", | |||
| ")" | |||
| ] | |||
| }, | |||
| "execution_count": 92, | |||
| "metadata": {}, | |||
| "output_type": "execute_result" | |||
| } | |||
| ], | |||
| "source": [ | |||
| "val_hooks[1]" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 86, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "del trainer" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 87, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "data": { | |||
| "text/plain": [ | |||
| "21169" | |||
| ] | |||
| }, | |||
| "execution_count": 87, | |||
| "metadata": {}, | |||
| "output_type": "execute_result" | |||
| } | |||
| ], | |||
| "source": [ | |||
| "import gc\n", | |||
| "gc.collect()" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 88, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "torch.cuda.empty_cache()" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 95, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "args.vishook_frequency = 133" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 96, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "data": { | |||
| "text/plain": [ | |||
| "Namespace(batch_size=64, data_root='./datasets/pytorch-adapt/', download=False, gamma=0.99, hp_tune=False, initial_trial=0, lr=0.0001, max_epochs=1, model_names=['DANN'], num_workers=1, patience=2, results_root='./results/', root='./', source=None, target=None, trials_count=1, vishook_frequency=133)" | |||
| ] | |||
| }, | |||
| "execution_count": 96, | |||
| "metadata": {}, | |||
| "output_type": "execute_result" | |||
| }, | |||
| { | |||
| "ename": "", | |||
| "evalue": "", | |||
| "output_type": "error", | |||
| "traceback": [ | |||
| "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details." | |||
| ] | |||
| } | |||
| ], | |||
| "source": [ | |||
| "args" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "-----" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 3, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "path = \"/media/10TB71/shashemi/Domain-Adaptation/results/DAModels.CDAN/2000/a2d/saved_models\"" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 5, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "source_domain = 'amazon'\n", | |||
| "target_domain = 'dslr'\n", | |||
| "G = office31G(pretrained=False).to(device)\n", | |||
| "C = office31C(pretrained=False).to(device)\n", | |||
| "\n", | |||
| "\n", | |||
| "optimizers = Optimizers((torch.optim.Adam, {\"lr\": 1e-4}))\n", | |||
| "lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {\"gamma\": 0.99})) \n", | |||
| "\n", | |||
| "models = Models({\"G\": G, \"C\": C})\n", | |||
| "adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)\n", | |||
| "\n", | |||
| "\n", | |||
| "output_dir = \"tmp\"\n", | |||
| "checkpoint_fn = CheckpointFnCreator(dirname=f\"{output_dir}/saved_models\", require_empty=False)\n", | |||
| "\n", | |||
| "sourceAccuracyValidator = AccuracyValidator()\n", | |||
| "val_hooks = [ScoreHistory(sourceAccuracyValidator)]\n", | |||
| "\n", | |||
| "new_trainer = Ignite(\n", | |||
| " adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn, device=device\n", | |||
| ")\n", | |||
| "\n", | |||
| "from pytorch_adapt.frameworks.ignite import (\n", | |||
| " CheckpointFnCreator,\n", | |||
| " IgniteValHookWrapper,\n", | |||
| " checkpoint_utils,\n", | |||
| ")\n", | |||
| "\n", | |||
| "objs = [\n", | |||
| " {\n", | |||
| " \"engine\": new_trainer.trainer,\n", | |||
| " \"validator\": new_trainer.validator,\n", | |||
| " \"val_hook0\": val_hooks[0],\n", | |||
| " **checkpoint_utils.adapter_to_dict(new_trainer.adapter),\n", | |||
| " }\n", | |||
| " ]\n", | |||
| " \n", | |||
| "# best_score, best_epoch = trainer.run(\n", | |||
| "# datasets, dataloader_creator=dc, max_epochs=args.max_epochs, early_stopper_kwargs=early_stopper_kwargs\n", | |||
| "# )\n", | |||
| "\n", | |||
| "for to_load in objs:\n", | |||
| " checkpoint_fn.load_best_checkpoint(to_load)\n", | |||
| "\n" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 6, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "926966dd640e4979ade6a45cf0fcdd49", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| "[1/9] 11%|#1 |it [00:00<?]" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| }, | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "Source acc: 0.868794322013855\n" | |||
| ] | |||
| }, | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "64cd5cfb052c4f52af9af1a63a4c0087", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| "[1/2] 50%|##### |it [00:00<?]" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| }, | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "Target acc: 0.7200000286102295\n", | |||
| "---------\n" | |||
| ] | |||
| } | |||
| ], | |||
| "source": [ | |||
| "\n", | |||
| "datasets = get_office31([source_domain], [target_domain], folder=args.data_root, return_target_with_labels=True)\n", | |||
| "dc = DataloaderCreator(batch_size=args.batch_size, num_workers=args.num_workers, all_val=True)\n", | |||
| "\n", | |||
| "validator = AccuracyValidator(key_map={\"src_val\": \"src_val\"})\n", | |||
| "src_score = new_trainer.evaluate_best_model(datasets, validator, dc)\n", | |||
| "print(\"Source acc:\", src_score)\n", | |||
| "\n", | |||
| "validator = AccuracyValidator(key_map={\"target_val_with_labels\": \"src_val\"})\n", | |||
| "target_score = new_trainer.evaluate_best_model(datasets, validator, dc)\n", | |||
| "print(\"Target acc:\", target_score)\n", | |||
| "print(\"---------\")" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 10, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "source_domain = 'amazon'\n", | |||
| "target_domain = 'dslr'\n", | |||
| "G = new_trainer.adapter.models[\"G\"]\n", | |||
| "C = new_trainer.adapter.models[\"C\"]\n", | |||
| "\n", | |||
| "G.fc = C.net[:6]\n", | |||
| "C.net = C.net[6:]\n", | |||
| "\n", | |||
| "\n", | |||
| "optimizers = Optimizers((torch.optim.Adam, {\"lr\": 1e-4}))\n", | |||
| "lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {\"gamma\": 0.99})) \n", | |||
| "\n", | |||
| "models = Models({\"G\": G, \"C\": C})\n", | |||
| "adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)\n", | |||
| "\n", | |||
| "\n", | |||
| "output_dir = \"tmp\"\n", | |||
| "checkpoint_fn = CheckpointFnCreator(dirname=f\"{output_dir}/saved_models\", require_empty=False)\n", | |||
| "\n", | |||
| "sourceAccuracyValidator = AccuracyValidator()\n", | |||
| "val_hooks = [ScoreHistory(sourceAccuracyValidator)]\n", | |||
| "\n", | |||
| "more_new_trainer = Ignite(\n", | |||
| " adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn, device=device\n", | |||
| ")" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 13, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "from pytorch_adapt.hooks import FeaturesAndLogitsHook\n", | |||
| "\n", | |||
| "h1 = FeaturesAndLogitsHook()" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 19, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "ename": "KeyError", | |||
| "evalue": "in FeaturesAndLogitsHook: __call__\nin FeaturesHook: __call__\nFeaturesHook: Getting src\nFeaturesHook: Getting output: ['src_imgs_features']\nFeaturesHook: Using model G with inputs: src_imgs\nG", | |||
| "output_type": "error", | |||
| "traceback": [ | |||
| "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |||
| "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", | |||
| "Cell \u001b[0;32mIn[19], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m h1(datasets)\n", | |||
| "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/base.py:52\u001b[0m, in \u001b[0;36mBaseHook.__call__\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 51\u001b[0m inputs \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mmap_keys(inputs, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkey_map)\n\u001b[0;32m---> 52\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcall(inputs, losses)\n\u001b[1;32m 53\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(x, (\u001b[39mbool\u001b[39m, np\u001b[39m.\u001b[39mbool_)):\n\u001b[1;32m 54\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlogger\u001b[39m.\u001b[39mreset()\n", | |||
| "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/utils.py:109\u001b[0m, in \u001b[0;36mChainHook.call\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 107\u001b[0m all_losses \u001b[39m=\u001b[39m {\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mall_losses, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mprev_losses}\n\u001b[1;32m 108\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconditions[i](all_inputs, all_losses):\n\u001b[0;32m--> 109\u001b[0m x \u001b[39m=\u001b[39m h(all_inputs, all_losses)\n\u001b[1;32m 110\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 111\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39malts[i](all_inputs, all_losses)\n", | |||
| "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/base.py:52\u001b[0m, in \u001b[0;36mBaseHook.__call__\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 51\u001b[0m inputs \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mmap_keys(inputs, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkey_map)\n\u001b[0;32m---> 52\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcall(inputs, losses)\n\u001b[1;32m 53\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(x, (\u001b[39mbool\u001b[39m, np\u001b[39m.\u001b[39mbool_)):\n\u001b[1;32m 54\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlogger\u001b[39m.\u001b[39mreset()\n", | |||
| "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/features.py:80\u001b[0m, in \u001b[0;36mBaseFeaturesHook.call\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 78\u001b[0m func \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmode_detached \u001b[39mif\u001b[39;00m detach \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmode_with_grad\n\u001b[1;32m 79\u001b[0m in_keys \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mfilter(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39min_keys, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m^\u001b[39m\u001b[39m{\u001b[39;00mdomain\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[0;32m---> 80\u001b[0m func(inputs, outputs, domain, in_keys)\n\u001b[1;32m 82\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcheck_outputs_requires_grad(outputs)\n\u001b[1;32m 83\u001b[0m \u001b[39mreturn\u001b[39;00m outputs, {}\n", | |||
| "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/features.py:106\u001b[0m, in \u001b[0;36mBaseFeaturesHook.mode_with_grad\u001b[0;34m(self, inputs, outputs, domain, in_keys)\u001b[0m\n\u001b[1;32m 104\u001b[0m output_keys \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mfilter(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_out_keys(), \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m^\u001b[39m\u001b[39m{\u001b[39;00mdomain\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 105\u001b[0m output_vals \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mget_kwargs(inputs, output_keys)\n\u001b[0;32m--> 106\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49madd_if_new(\n\u001b[1;32m 107\u001b[0m outputs, output_keys, output_vals, inputs, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel_name, in_keys, domain\n\u001b[1;32m 108\u001b[0m )\n\u001b[1;32m 109\u001b[0m \u001b[39mreturn\u001b[39;00m output_keys, output_vals\n", | |||
| "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/features.py:133\u001b[0m, in \u001b[0;36mBaseFeaturesHook.add_if_new\u001b[0;34m(self, outputs, full_key, output_vals, inputs, model_name, in_keys, domain)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39madd_if_new\u001b[39m(\n\u001b[1;32m 131\u001b[0m \u001b[39mself\u001b[39m, outputs, full_key, output_vals, inputs, model_name, in_keys, domain\n\u001b[1;32m 132\u001b[0m ):\n\u001b[0;32m--> 133\u001b[0m c_f\u001b[39m.\u001b[39;49madd_if_new(\n\u001b[1;32m 134\u001b[0m outputs,\n\u001b[1;32m 135\u001b[0m full_key,\n\u001b[1;32m 136\u001b[0m output_vals,\n\u001b[1;32m 137\u001b[0m inputs,\n\u001b[1;32m 138\u001b[0m model_name,\n\u001b[1;32m 139\u001b[0m in_keys,\n\u001b[1;32m 140\u001b[0m logger\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mlogger,\n\u001b[1;32m 141\u001b[0m )\n", | |||
| "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/utils/common_functions.py:96\u001b[0m, in \u001b[0;36madd_if_new\u001b[0;34m(d, key, x, kwargs, model_name, in_keys, other_args, logger)\u001b[0m\n\u001b[1;32m 94\u001b[0m condition \u001b[39m=\u001b[39m is_none\n\u001b[1;32m 95\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39many\u001b[39m(condition(y) \u001b[39mfor\u001b[39;00m y \u001b[39min\u001b[39;00m x):\n\u001b[0;32m---> 96\u001b[0m model \u001b[39m=\u001b[39m kwargs[model_name]\n\u001b[1;32m 97\u001b[0m input_vals \u001b[39m=\u001b[39m [kwargs[k] \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m in_keys] \u001b[39m+\u001b[39m \u001b[39mlist\u001b[39m(other_args\u001b[39m.\u001b[39mvalues())\n\u001b[1;32m 98\u001b[0m new_x \u001b[39m=\u001b[39m try_use_model(model, model_name, input_vals)\n", | |||
| "\u001b[0;31mKeyError\u001b[0m: in FeaturesAndLogitsHook: __call__\nin FeaturesHook: __call__\nFeaturesHook: Getting src\nFeaturesHook: Getting output: ['src_imgs_features']\nFeaturesHook: Using model G with inputs: src_imgs\nG" | |||
| ] | |||
| } | |||
| ], | |||
| "source": [ | |||
| "h1(datasets)" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [] | |||
| } | |||
| ], | |||
| "metadata": { | |||
| "kernelspec": { | |||
| "display_name": "cdtrans", | |||
| "language": "python", | |||
| "name": "python3" | |||
| }, | |||
| "language_info": { | |||
| "codemirror_mode": { | |||
| "name": "ipython", | |||
| "version": 3 | |||
| }, | |||
| "file_extension": ".py", | |||
| "mimetype": "text/x-python", | |||
| "name": "python", | |||
| "nbconvert_exporter": "python", | |||
| "pygments_lexer": "ipython3", | |||
| "version": "3.8.15 (default, Nov 24 2022, 15:19:38) \n[GCC 11.2.0]" | |||
| }, | |||
| "orig_nbformat": 4, | |||
| "vscode": { | |||
| "interpreter": { | |||
| "hash": "959b82c3a41427bdf7d14d4ba7335271e0c50cfcddd70501934b27dcc36968ad" | |||
| } | |||
| } | |||
| }, | |||
| "nbformat": 4, | |||
| "nbformat_minor": 2 | |||
| } | |||
| @@ -0,0 +1,101 @@ | |||
| { | |||
| "amazon": { | |||
| "back_pack": 92, | |||
| "bike": 82, | |||
| "bike_helmet": 72, | |||
| "bookcase": 82, | |||
| "bottle": 36, | |||
| "calculator": 94, | |||
| "desk_chair": 91, | |||
| "desk_lamp": 97, | |||
| "desktop_computer": 97, | |||
| "file_cabinet": 81, | |||
| "headphones": 99, | |||
| "keyboard": 100, | |||
| "laptop_computer": 100, | |||
| "letter_tray": 98, | |||
| "mobile_phone": 100, | |||
| "monitor": 99, | |||
| "mouse": 100, | |||
| "mug": 94, | |||
| "paper_notebook": 96, | |||
| "pen": 95, | |||
| "phone": 93, | |||
| "printer": 100, | |||
| "projector": 98, | |||
| "punchers": 98, | |||
| "ring_binder": 90, | |||
| "ruler": 75, | |||
| "scissors": 100, | |||
| "speaker": 99, | |||
| "stapler": 99, | |||
| "tape_dispenser": 96, | |||
| "trash_can": 64 | |||
| }, | |||
| "dslr": { | |||
| "back_pack": 12, | |||
| "bike": 21, | |||
| "bike_helmet": 24, | |||
| "bookcase": 12, | |||
| "bottle": 16, | |||
| "calculator": 12, | |||
| "desk_chair": 13, | |||
| "desk_lamp": 14, | |||
| "desktop_computer": 15, | |||
| "file_cabinet": 15, | |||
| "headphones": 13, | |||
| "keyboard": 10, | |||
| "laptop_computer": 24, | |||
| "letter_tray": 16, | |||
| "mobile_phone": 31, | |||
| "monitor": 22, | |||
| "mouse": 12, | |||
| "mug": 8, | |||
| "paper_notebook": 10, | |||
| "pen": 10, | |||
| "phone": 13, | |||
| "printer": 15, | |||
| "projector": 23, | |||
| "punchers": 18, | |||
| "ring_binder": 10, | |||
| "ruler": 7, | |||
| "scissors": 18, | |||
| "speaker": 26, | |||
| "stapler": 21, | |||
| "tape_dispenser": 22, | |||
| "trash_can": 15 | |||
| }, | |||
| "webcam": { | |||
| "back_pack": 29, | |||
| "bike": 21, | |||
| "bike_helmet": 28, | |||
| "bookcase": 12, | |||
| "bottle": 16, | |||
| "calculator": 31, | |||
| "desk_chair": 40, | |||
| "desk_lamp": 18, | |||
| "desktop_computer": 21, | |||
| "file_cabinet": 19, | |||
| "headphones": 27, | |||
| "keyboard": 27, | |||
| "laptop_computer": 30, | |||
| "letter_tray": 19, | |||
| "mobile_phone": 30, | |||
| "monitor": 43, | |||
| "mouse": 30, | |||
| "mug": 27, | |||
| "paper_notebook": 28, | |||
| "pen": 32, | |||
| "phone": 16, | |||
| "printer": 20, | |||
| "projector": 30, | |||
| "punchers": 27, | |||
| "ring_binder": 40, | |||
| "ruler": 11, | |||
| "scissors": 25, | |||
| "speaker": 30, | |||
| "stapler": 24, | |||
| "tape_dispenser": 23, | |||
| "trash_can": 21 | |||
| } | |||
| } | |||
| @@ -0,0 +1,5 @@ | |||
| { | |||
| "amazon": 2817, | |||
| "dslr": 498, | |||
| "webcam": 795 | |||
| } | |||