datasets | |||||
results | |||||
model-weights | |||||
other | |||||
saved_models | |||||
__pycache__ | |||||
.vscode | |||||
venv |
# Domain-Adaptation | |||||
## Prepare Environment | |||||
Install the requirements using conda from `requirements-conda.txt` (or using pip from `requirements.txt`). | |||||
*My setting:* conda 22.11.1 with Python 3.8 | |||||
If you have trouble with installing torch check [this link](https://pytorch.org/get-started/previous-versions/) | |||||
to find the currect version for your device. | |||||
## Run | |||||
1. Go to `src` directory. | |||||
2. Run `main.py` with appropriate arguments. | |||||
Examples: | |||||
- Perform MMD and CORAL adaptation on all 6 domain adaptations tasks from office31 dataset: | |||||
```bash | |||||
python main.py --model_names MMD CORAL --batch_size 32 | |||||
``` | |||||
- Tuning parameters of DANN model | |||||
```bash | |||||
python main.py --max_epochs 10 --patience 3 --trials_count 1 --model_names DANN --num_workers 2 --batch_size 32 --source amazon --target webcam --hp_tune True | |||||
``` | |||||
Check "Available arguments" in "Code Structure" section for all the available arguments. | |||||
## Code Structure | |||||
The main code for project is located in the `src/` directory. | |||||
- `main.py`: The entry to program | |||||
- **Available arguments:** | |||||
- `max_epochs`: Maximum number of epochs to run the training | |||||
- `patience`: Maximum number of epochs to continue if no improvement is seen (Early stopping parameter) | |||||
- `batch_size` | |||||
- `num_workers` | |||||
- `trials_count`: Number of trials to run each of the tasks | |||||
- `initial_trial`: The number to start indexing the trials from | |||||
- `download`: Whether to download the dataset or not | |||||
- `root`: Path to the root of project | |||||
- `data_root`: Path to the data root | |||||
- `results_root`: Path to the directory to store the results | |||||
- `model_names`: Names of models to run separated by space - available options: DANN, CDAN, MMD, MCD, CORAL, SOURCE | |||||
- `lr`: learning rate** | |||||
- `gamma`: Gamma value for ExponentialLR** | |||||
- `hp_tune`: Set true of you want to run for different hyperparameters, used for hyperparameter tuning | |||||
- `source`: The source domain to run the training for, training will run for all the available domains if not specified - available options: amazon, dslr, webcam | |||||
- `target`: The target domain to run the training for, training will run for all the available domains if not specified - available options: amazon, dslr, webcam | |||||
- `vishook_frequency`: Number of epochs to wait before save a visualization | |||||
- `source_checkpoint_base_dir`: Path to source-only trained model directory to use as base, set `None` to not use source-trained model*** | |||||
- `source_checkpoint_trial_number`: Trail number of source-only trained model to use | |||||
- `models.py`: Contains models for adaptation | |||||
- `train.py`: Contains base training iteration | |||||
- `classifier_adapter.py`: Contains ClassifierAdapter class which is used for training a source-only model without adaptation | |||||
- `load_source.py`: Load source-only trained model to use as base model for adaptation | |||||
- `source.py`: Contains source model | |||||
- `train_source.py`: Contains base source-only training iteration | |||||
- `utils.py`: Contains utility classes | |||||
- `vis_hook.py`: Contains VizHook class which is used for visualization | |||||
** Use can also set different lr and gammas for different models and tasks by changing `hp_map` in `main.py` directly. | |||||
*** For perfoming domain adaptation on source-trained model, one must should train the model for source using option `--model_name SOURCE` first | |||||
## Acknowledgements | |||||
[Pytorch Adapt](https://github.com/KevinMusgrave/pytorch-adapt/tree/0b0fb63b04c9bd7e2cc6cf45314c7ee9d6e391c0) |
# This file may be used to create an environment using: | |||||
# $ conda create --name <env> --file <this file> | |||||
# platform: linux-64 | |||||
_libgcc_mutex=0.1=main | |||||
_openmp_mutex=5.1=1_gnu | |||||
_pytorch_select=0.1=cpu_0 | |||||
anyio=3.5.0=py38h06a4308_0 | |||||
argon2-cffi=21.3.0=pyhd3eb1b0_0 | |||||
argon2-cffi-bindings=21.2.0=py38h7f8727e_0 | |||||
asttokens=2.0.5=pyhd3eb1b0_0 | |||||
attrs=22.1.0=py38h06a4308_0 | |||||
babel=2.11.0=py38h06a4308_0 | |||||
backcall=0.2.0=pyhd3eb1b0_0 | |||||
beautifulsoup4=4.11.1=py38h06a4308_0 | |||||
blas=1.0=mkl | |||||
bleach=4.1.0=pyhd3eb1b0_0 | |||||
brotli=1.0.9=h5eee18b_7 | |||||
brotli-bin=1.0.9=h5eee18b_7 | |||||
brotlipy=0.7.0=py38h27cfd23_1003 | |||||
ca-certificates=2022.10.11=h06a4308_0 | |||||
certifi=2022.12.7=py38h06a4308_0 | |||||
cffi=1.15.1=py38h5eee18b_3 | |||||
charset-normalizer=2.0.4=pyhd3eb1b0_0 | |||||
comm=0.1.2=py38h06a4308_0 | |||||
contourpy=1.0.5=py38hdb19cb5_0 | |||||
cryptography=38.0.1=py38h9ce1e76_0 | |||||
cudatoolkit=10.1.243=h8cb64d8_10 | |||||
cudnn=7.6.5.32=hc0a50b0_1 | |||||
cycler=0.11.0=pyhd3eb1b0_0 | |||||
dbus=1.13.18=hb2f20db_0 | |||||
debugpy=1.5.1=py38h295c915_0 | |||||
decorator=5.1.1=pyhd3eb1b0_0 | |||||
defusedxml=0.7.1=pyhd3eb1b0_0 | |||||
entrypoints=0.4=py38h06a4308_0 | |||||
executing=0.8.3=pyhd3eb1b0_0 | |||||
expat=2.4.9=h6a678d5_0 | |||||
filelock=3.9.0=pypi_0 | |||||
flit-core=3.6.0=pyhd3eb1b0_0 | |||||
fontconfig=2.14.1=h52c9d5c_1 | |||||
fonttools=4.25.0=pyhd3eb1b0_0 | |||||
freetype=2.12.1=h4a9f257_0 | |||||
gdown=4.6.0=pypi_0 | |||||
giflib=5.2.1=h7b6447c_0 | |||||
glib=2.69.1=he621ea3_2 | |||||
gst-plugins-base=1.14.0=h8213a91_2 | |||||
gstreamer=1.14.0=h28cd5cc_2 | |||||
h5py=3.2.1=pypi_0 | |||||
icu=58.2=he6710b0_3 | |||||
idna=3.4=py38h06a4308_0 | |||||
importlib-metadata=4.11.3=py38h06a4308_0 | |||||
importlib_resources=5.2.0=pyhd3eb1b0_1 | |||||
intel-openmp=2021.4.0=h06a4308_3561 | |||||
ipykernel=6.19.2=py38hb070fc8_0 | |||||
ipython=8.7.0=py38h06a4308_0 | |||||
ipython_genutils=0.2.0=pyhd3eb1b0_1 | |||||
ipywidgets=7.6.5=pyhd3eb1b0_1 | |||||
jedi=0.18.1=py38h06a4308_1 | |||||
jinja2=3.1.2=py38h06a4308_0 | |||||
joblib=1.2.0=pypi_0 | |||||
jpeg=9e=h7f8727e_0 | |||||
json5=0.9.6=pyhd3eb1b0_0 | |||||
jsonschema=4.16.0=py38h06a4308_0 | |||||
jupyter=1.0.0=py38h06a4308_8 | |||||
jupyter_client=7.4.8=py38h06a4308_0 | |||||
jupyter_console=6.4.4=py38h06a4308_0 | |||||
jupyter_core=5.1.1=py38h06a4308_0 | |||||
jupyter_server=1.23.4=py38h06a4308_0 | |||||
jupyterlab=3.5.2=py38h06a4308_0 | |||||
jupyterlab_pygments=0.1.2=py_0 | |||||
jupyterlab_server=2.16.5=py38h06a4308_0 | |||||
jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 | |||||
kiwisolver=1.4.4=py38h6a678d5_0 | |||||
krb5=1.19.2=hac12032_0 | |||||
lcms2=2.12=h3be6417_0 | |||||
ld_impl_linux-64=2.38=h1181459_1 | |||||
lerc=3.0=h295c915_0 | |||||
libbrotlicommon=1.0.9=h5eee18b_7 | |||||
libbrotlidec=1.0.9=h5eee18b_7 | |||||
libbrotlienc=1.0.9=h5eee18b_7 | |||||
libclang=10.0.1=default_hb85057a_2 | |||||
libdeflate=1.8=h7f8727e_5 | |||||
libedit=3.1.20221030=h5eee18b_0 | |||||
libevent=2.1.12=h8f2d780_0 | |||||
libffi=3.4.2=h6a678d5_6 | |||||
libgcc-ng=11.2.0=h1234567_1 | |||||
libgomp=11.2.0=h1234567_1 | |||||
libllvm10=10.0.1=hbcb73fb_5 | |||||
libpng=1.6.37=hbc83047_0 | |||||
libpq=12.9=h16c4e8d_3 | |||||
libsodium=1.0.18=h7b6447c_0 | |||||
libstdcxx-ng=11.2.0=h1234567_1 | |||||
libtiff=4.4.0=hecacb30_2 | |||||
libuuid=1.41.5=h5eee18b_0 | |||||
libuv=1.40.0=h7b6447c_0 | |||||
libwebp=1.2.4=h11a3e52_0 | |||||
libwebp-base=1.2.4=h5eee18b_0 | |||||
libxcb=1.15=h7f8727e_0 | |||||
libxkbcommon=1.0.1=hfa300c1_0 | |||||
libxml2=2.9.14=h74e7548_0 | |||||
libxslt=1.1.35=h4e12654_0 | |||||
llvmlite=0.39.1=pypi_0 | |||||
lxml=4.9.1=py38h1edc446_0 | |||||
lz4-c=1.9.4=h6a678d5_0 | |||||
markupsafe=2.1.1=py38h7f8727e_0 | |||||
matplotlib=3.4.2=pypi_0 | |||||
matplotlib-inline=0.1.6=py38h06a4308_0 | |||||
mistune=0.8.4=py38h7b6447c_1000 | |||||
mkl=2021.4.0=h06a4308_640 | |||||
mkl-service=2.4.0=py38h7f8727e_0 | |||||
mkl_fft=1.3.1=py38hd3c417c_0 | |||||
mkl_random=1.2.2=py38h51133e4_0 | |||||
munkres=1.1.4=py_0 | |||||
nbclassic=0.4.8=py38h06a4308_0 | |||||
nbclient=0.5.13=py38h06a4308_0 | |||||
nbconvert=6.5.4=py38h06a4308_0 | |||||
nbformat=5.7.0=py38h06a4308_0 | |||||
ncurses=6.3=h5eee18b_3 | |||||
nest-asyncio=1.5.6=py38h06a4308_0 | |||||
ninja=1.10.2=h06a4308_5 | |||||
ninja-base=1.10.2=hd09550d_5 | |||||
notebook=6.5.2=py38h06a4308_0 | |||||
notebook-shim=0.2.2=py38h06a4308_0 | |||||
nspr=4.33=h295c915_0 | |||||
nss=3.74=h0370c37_0 | |||||
numba=0.56.4=pypi_0 | |||||
numpy=1.22.4=pypi_0 | |||||
opencv-python=4.5.2.54=pypi_0 | |||||
openssl=1.1.1s=h7f8727e_0 | |||||
packaging=22.0=py38h06a4308_0 | |||||
pandas=1.5.2=pypi_0 | |||||
pandocfilters=1.5.0=pyhd3eb1b0_0 | |||||
parso=0.8.3=pyhd3eb1b0_0 | |||||
pcre=8.45=h295c915_0 | |||||
pexpect=4.8.0=pyhd3eb1b0_3 | |||||
pickleshare=0.7.5=pyhd3eb1b0_1003 | |||||
pillow=8.4.0=pypi_0 | |||||
pip=22.3.1=py38h06a4308_0 | |||||
pkgutil-resolve-name=1.3.10=py38h06a4308_0 | |||||
platformdirs=2.5.2=py38h06a4308_0 | |||||
ply=3.11=py38_0 | |||||
prometheus_client=0.14.1=py38h06a4308_0 | |||||
prompt-toolkit=3.0.36=py38h06a4308_0 | |||||
prompt_toolkit=3.0.36=hd3eb1b0_0 | |||||
psutil=5.9.0=py38h5eee18b_0 | |||||
ptyprocess=0.7.0=pyhd3eb1b0_2 | |||||
pure_eval=0.2.2=pyhd3eb1b0_0 | |||||
pycparser=2.21=pyhd3eb1b0_0 | |||||
pygments=2.11.2=pyhd3eb1b0_0 | |||||
pynndescent=0.5.8=pypi_0 | |||||
pyopenssl=22.0.0=pyhd3eb1b0_0 | |||||
pyparsing=3.0.9=py38h06a4308_0 | |||||
pyqt=5.15.7=py38h6a678d5_1 | |||||
pyqt5-sip=12.11.0=py38h6a678d5_1 | |||||
pyrsistent=0.18.0=py38heee7806_0 | |||||
pysocks=1.7.1=py38h06a4308_0 | |||||
python=3.8.15=h7a1cb2a_2 | |||||
python-dateutil=2.8.2=pyhd3eb1b0_0 | |||||
python-fastjsonschema=2.16.2=py38h06a4308_0 | |||||
pytorch-adapt=0.0.82=pypi_0 | |||||
pytorch-ignite=0.4.9=pypi_0 | |||||
pytorch-metric-learning=1.6.3=pypi_0 | |||||
pytz=2022.7=py38h06a4308_0 | |||||
pyyaml=6.0=pypi_0 | |||||
pyzmq=23.2.0=py38h6a678d5_0 | |||||
qt-main=5.15.2=h327a75a_7 | |||||
qt-webengine=5.15.9=hd2b0992_4 | |||||
qtconsole=5.3.2=py38h06a4308_0 | |||||
qtpy=2.2.0=py38h06a4308_0 | |||||
qtwebkit=5.212=h4eab89a_4 | |||||
readline=8.2=h5eee18b_0 | |||||
requests=2.28.1=py38h06a4308_0 | |||||
scikit-learn=1.0=pypi_0 | |||||
scipy=1.6.3=pypi_0 | |||||
seaborn=0.12.2=pypi_0 | |||||
send2trash=1.8.0=pyhd3eb1b0_1 | |||||
setuptools=65.5.0=py38h06a4308_0 | |||||
sip=6.6.2=py38h6a678d5_0 | |||||
six=1.16.0=pyhd3eb1b0_1 | |||||
sniffio=1.2.0=py38h06a4308_1 | |||||
soupsieve=2.3.2.post1=py38h06a4308_0 | |||||
sqlite=3.40.0=h5082296_0 | |||||
stack_data=0.2.0=pyhd3eb1b0_0 | |||||
terminado=0.17.1=py38h06a4308_0 | |||||
threadpoolctl=3.1.0=pypi_0 | |||||
timm=0.4.9=pypi_0 | |||||
tinycss2=1.2.1=py38h06a4308_0 | |||||
tk=8.6.12=h1ccaba5_0 | |||||
toml=0.10.2=pyhd3eb1b0_0 | |||||
tomli=2.0.1=py38h06a4308_0 | |||||
torch=1.8.1+cu101=pypi_0 | |||||
torchaudio=0.8.1=pypi_0 | |||||
torchmetrics=0.9.3=pypi_0 | |||||
torchvision=0.9.1+cu101=pypi_0 | |||||
tornado=6.2=py38h5eee18b_0 | |||||
tqdm=4.61.2=pypi_0 | |||||
traitlets=5.7.1=py38h06a4308_0 | |||||
typing-extensions=4.4.0=py38h06a4308_0 | |||||
typing_extensions=4.4.0=py38h06a4308_0 | |||||
umap-learn=0.5.3=pypi_0 | |||||
urllib3=1.26.13=py38h06a4308_0 | |||||
wcwidth=0.2.5=pyhd3eb1b0_0 | |||||
webencodings=0.5.1=py38_1 | |||||
websocket-client=0.58.0=py38h06a4308_4 | |||||
wheel=0.37.1=pyhd3eb1b0_0 | |||||
widgetsnbextension=3.5.2=py38h06a4308_0 | |||||
xz=5.2.8=h5eee18b_0 | |||||
yacs=0.1.8=pypi_0 | |||||
zeromq=4.3.4=h2531618_0 | |||||
zipp=3.11.0=py38h06a4308_0 | |||||
zlib=1.2.13=h5eee18b_0 | |||||
zstd=1.5.2=ha4553b6_0 |
anyio @ file:///tmp/build/80754af9/anyio_1644481698350/work/dist | |||||
argon2-cffi @ file:///opt/conda/conda-bld/argon2-cffi_1645000214183/work | |||||
argon2-cffi-bindings @ file:///tmp/build/80754af9/argon2-cffi-bindings_1644569684262/work | |||||
asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work | |||||
attrs @ file:///croot/attrs_1668696182826/work | |||||
Babel @ file:///croot/babel_1671781930836/work | |||||
backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work | |||||
beautifulsoup4 @ file:///opt/conda/conda-bld/beautifulsoup4_1650462163268/work | |||||
bleach @ file:///opt/conda/conda-bld/bleach_1641577558959/work | |||||
brotlipy==0.7.0 | |||||
certifi @ file:///croot/certifi_1671487769961/work/certifi | |||||
cffi @ file:///croot/cffi_1670423208954/work | |||||
charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work | |||||
comm @ file:///croot/comm_1671231121260/work | |||||
contourpy @ file:///opt/conda/conda-bld/contourpy_1663827406301/work | |||||
cryptography @ file:///croot/cryptography_1665612644927/work | |||||
cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work | |||||
debugpy @ file:///tmp/build/80754af9/debugpy_1637091796427/work | |||||
decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work | |||||
defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work | |||||
entrypoints @ file:///tmp/build/80754af9/entrypoints_1649926445639/work | |||||
executing @ file:///opt/conda/conda-bld/executing_1646925071911/work | |||||
fastjsonschema @ file:///opt/conda/conda-bld/python-fastjsonschema_1661371079312/work | |||||
filelock==3.9.0 | |||||
flit_core @ file:///opt/conda/conda-bld/flit-core_1644941570762/work/source/flit_core | |||||
fonttools==4.25.0 | |||||
gdown==4.6.0 | |||||
h5py==3.2.1 | |||||
idna @ file:///croot/idna_1666125576474/work | |||||
importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1648562408398/work | |||||
importlib-resources @ file:///tmp/build/80754af9/importlib_resources_1625135880749/work | |||||
ipykernel @ file:///croot/ipykernel_1671488378391/work | |||||
ipython @ file:///croot/ipython_1670919316550/work | |||||
ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work | |||||
ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1634143127070/work | |||||
jedi @ file:///tmp/build/80754af9/jedi_1644315233700/work | |||||
Jinja2 @ file:///croot/jinja2_1666908132255/work | |||||
joblib==1.2.0 | |||||
json5 @ file:///tmp/build/80754af9/json5_1624432770122/work | |||||
jsonschema @ file:///opt/conda/conda-bld/jsonschema_1663375472438/work | |||||
jupyter @ file:///tmp/abs_33h4eoipez/croots/recipe/jupyter_1659349046347/work | |||||
jupyter-console @ file:///croot/jupyter_console_1671541909316/work | |||||
jupyter-server @ file:///croot/jupyter_server_1671707632269/work | |||||
jupyter_client @ file:///croot/jupyter_client_1671703053786/work | |||||
jupyter_core @ file:///croot/jupyter_core_1672332224593/work | |||||
jupyterlab @ file:///croot/jupyterlab_1672132689850/work | |||||
jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work | |||||
jupyterlab-widgets @ file:///tmp/build/80754af9/jupyterlab_widgets_1609884341231/work | |||||
jupyterlab_server @ file:///croot/jupyterlab_server_1672127357747/work | |||||
kiwisolver @ file:///croot/kiwisolver_1672387140495/work | |||||
llvmlite==0.39.1 | |||||
lxml @ file:///opt/conda/conda-bld/lxml_1657545139709/work | |||||
MarkupSafe @ file:///opt/conda/conda-bld/markupsafe_1654597864307/work | |||||
matplotlib==3.4.2 | |||||
matplotlib-inline @ file:///opt/conda/conda-bld/matplotlib-inline_1662014470464/work | |||||
mistune==0.8.4 | |||||
mkl-fft==1.3.1 | |||||
mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186064646/work | |||||
mkl-service==2.4.0 | |||||
munkres==1.1.4 | |||||
nbclassic @ file:///croot/nbclassic_1668174957779/work | |||||
nbclient @ file:///tmp/build/80754af9/nbclient_1650308366712/work | |||||
nbconvert @ file:///croot/nbconvert_1668450669124/work | |||||
nbformat @ file:///croot/nbformat_1670352325207/work | |||||
nest-asyncio @ file:///croot/nest-asyncio_1672387112409/work | |||||
notebook @ file:///croot/notebook_1668179881751/work | |||||
notebook_shim @ file:///croot/notebook-shim_1668160579331/work | |||||
numba==0.56.4 | |||||
numpy==1.22.4 | |||||
opencv-python==4.5.2.54 | |||||
packaging @ file:///croot/packaging_1671697413597/work | |||||
pandas==1.5.2 | |||||
pandocfilters @ file:///opt/conda/conda-bld/pandocfilters_1643405455980/work | |||||
parso @ file:///opt/conda/conda-bld/parso_1641458642106/work | |||||
pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work | |||||
pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work | |||||
Pillow==8.4.0 | |||||
pkgutil_resolve_name @ file:///opt/conda/conda-bld/pkgutil-resolve-name_1661463321198/work | |||||
platformdirs @ file:///opt/conda/conda-bld/platformdirs_1662711380096/work | |||||
ply==3.11 | |||||
prometheus-client @ file:///tmp/abs_d3zeliano1/croots/recipe/prometheus_client_1659455100375/work | |||||
prompt-toolkit @ file:///croot/prompt-toolkit_1672387306916/work | |||||
psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work | |||||
ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl | |||||
pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work | |||||
pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work | |||||
Pygments @ file:///opt/conda/conda-bld/pygments_1644249106324/work | |||||
pynndescent==0.5.8 | |||||
pyOpenSSL @ file:///opt/conda/conda-bld/pyopenssl_1643788558760/work | |||||
pyparsing @ file:///opt/conda/conda-bld/pyparsing_1661452539315/work | |||||
PyQt5-sip==12.11.0 | |||||
pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1636110947380/work | |||||
PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work | |||||
python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work | |||||
pytorch-adapt==0.0.82 | |||||
pytorch-ignite==0.4.9 | |||||
pytorch-metric-learning==1.6.3 | |||||
pytz @ file:///croot/pytz_1671697431263/work | |||||
PyYAML==6.0 | |||||
pyzmq @ file:///opt/conda/conda-bld/pyzmq_1657724186960/work | |||||
qtconsole @ file:///opt/conda/conda-bld/qtconsole_1662018252641/work | |||||
QtPy @ file:///opt/conda/conda-bld/qtpy_1662014892439/work | |||||
requests @ file:///opt/conda/conda-bld/requests_1657734628632/work | |||||
scikit-learn==1.0 | |||||
scipy==1.6.3 | |||||
seaborn==0.12.2 | |||||
Send2Trash @ file:///tmp/build/80754af9/send2trash_1632406701022/work | |||||
sip @ file:///tmp/abs_44cd77b_pu/croots/recipe/sip_1659012365470/work | |||||
six @ file:///tmp/build/80754af9/six_1644875935023/work | |||||
sniffio @ file:///tmp/build/80754af9/sniffio_1614030475067/work | |||||
soupsieve @ file:///croot/soupsieve_1666296392845/work | |||||
stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work | |||||
terminado @ file:///croot/terminado_1671751832461/work | |||||
threadpoolctl==3.1.0 | |||||
timm==0.4.9 | |||||
tinycss2 @ file:///croot/tinycss2_1668168815555/work | |||||
toml @ file:///tmp/build/80754af9/toml_1616166611790/work | |||||
tomli @ file:///opt/conda/conda-bld/tomli_1657175507142/work | |||||
torch==1.8.1+cu101 | |||||
torchaudio==0.8.1 | |||||
torchmetrics==0.9.3 | |||||
torchvision==0.9.1+cu101 | |||||
tornado @ file:///opt/conda/conda-bld/tornado_1662061693373/work | |||||
tqdm==4.61.2 | |||||
traitlets @ file:///croot/traitlets_1671143879854/work | |||||
typing_extensions @ file:///croot/typing_extensions_1669924550328/work | |||||
umap-learn==0.5.3 | |||||
urllib3 @ file:///croot/urllib3_1670526988650/work | |||||
wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work | |||||
webencodings==0.5.1 | |||||
websocket-client @ file:///tmp/build/80754af9/websocket-client_1614804261064/work | |||||
widgetsnbextension @ file:///tmp/build/80754af9/widgetsnbextension_1645009353553/work | |||||
yacs==0.1.8 | |||||
zipp @ file:///croot/zipp_1672387121353/work |
from pytorch_adapt.adapters.base_adapter import BaseGCAdapter | |||||
from pytorch_adapt.adapters.utils import with_opt | |||||
from pytorch_adapt.hooks import ClassifierHook | |||||
class ClassifierAdapter(BaseGCAdapter): | |||||
""" | |||||
Wraps [AlignerPlusCHook][pytorch_adapt.hooks.AlignerPlusCHook]. | |||||
|Container|Required keys| | |||||
|---|---| | |||||
|models|```["G", "C"]```| | |||||
|optimizers|```["G", "C"]```| | |||||
""" | |||||
def init_hook(self, hook_kwargs): | |||||
opts = with_opt(list(self.optimizers.keys())) | |||||
self.hook = self.hook_cls(opts, **hook_kwargs) | |||||
@property | |||||
def hook_cls(self): | |||||
return ClassifierHook |
import torch | |||||
import os | |||||
from pytorch_adapt.adapters import DANN, MCD, VADA, CDAN, RTN, ADDA, Aligner, SymNets | |||||
from pytorch_adapt.containers import Models, Optimizers, LRSchedulers | |||||
from pytorch_adapt.models import Discriminator, office31C, office31G | |||||
from pytorch_adapt.containers import Misc | |||||
from pytorch_adapt.containers import LRSchedulers | |||||
from classifier_adapter import ClassifierAdapter | |||||
from utils import HP, DAModels | |||||
import copy | |||||
import matplotlib.pyplot as plt | |||||
import torch | |||||
import os | |||||
import gc | |||||
from datetime import datetime | |||||
from pytorch_adapt.datasets import DataloaderCreator, get_office31 | |||||
from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite | |||||
from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory | |||||
from pytorch_adapt.frameworks.ignite import ( | |||||
CheckpointFnCreator, | |||||
IgniteValHookWrapper, | |||||
checkpoint_utils, | |||||
) | |||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||||
def get_source_trainer(checkpoint_dir): | |||||
G = office31G(pretrained=False).to(device) | |||||
C = office31C(pretrained=False).to(device) | |||||
optimizers = Optimizers((torch.optim.Adam, {"lr": 1e-4})) | |||||
lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {"gamma": 0.99})) | |||||
models = Models({"G": G, "C": C}) | |||||
adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers) | |||||
checkpoint_fn = CheckpointFnCreator(dirname=checkpoint_dir, require_empty=False) | |||||
sourceAccuracyValidator = AccuracyValidator() | |||||
val_hooks = [ScoreHistory(sourceAccuracyValidator)] | |||||
new_trainer = Ignite( | |||||
adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn, device=device | |||||
) | |||||
objs = [ | |||||
{ | |||||
"engine": new_trainer.trainer, | |||||
**checkpoint_utils.adapter_to_dict(new_trainer.adapter), | |||||
} | |||||
] | |||||
for to_load in objs: | |||||
checkpoint_fn.load_best_checkpoint(to_load) | |||||
return new_trainer |
import argparse | |||||
import logging | |||||
import os | |||||
from datetime import datetime | |||||
from train import train | |||||
from train_source import train_source | |||||
from utils import HP, DAModels | |||||
import tracemalloc | |||||
logging.basicConfig() | |||||
logging.getLogger("pytorch-adapt").setLevel(logging.WARNING) | |||||
DATASET_PAIRS = [("amazon", "webcam"), ("amazon", "dslr"), | |||||
("webcam", "amazon"), ("webcam", "dslr"), | |||||
("dslr", "amazon"), ("dslr", "webcam")] | |||||
def run_experiment_on_model(args, model_name, trial_number): | |||||
train_fn = train | |||||
if model_name == DAModels.SOURCE: | |||||
train_fn = train_source | |||||
base_output_dir = f"{args.results_root}/{model_name}/{trial_number}" | |||||
os.makedirs(base_output_dir, exist_ok=True) | |||||
hp_map = { | |||||
'DANN': { | |||||
'a2d': (5e-05, 0.99), | |||||
'a2w': (5e-05, 0.99), | |||||
'd2a': (0.0001, 0.9), | |||||
'd2w': (5e-05, 0.99), | |||||
'w2a': (0.0001, 0.99), | |||||
'w2d': (0.0001, 0.99), | |||||
}, | |||||
'CDAN': { | |||||
'a2d': (1e-05, 1), | |||||
'a2w': (1e-05, 1), | |||||
'd2a': (1e-05, 0.99), | |||||
'd2w': (1e-05, 0.99), | |||||
'w2a': (0.0001, 0.99), | |||||
'w2d': (5e-05, 0.99), | |||||
}, | |||||
'MMD': { | |||||
'a2d': (5e-05, 1), | |||||
'a2w': (5e-05, 0.99), | |||||
'd2a': (0.0001, 0.99), | |||||
'd2w': (5e-05, 0.9), | |||||
'w2a': (0.0001, 0.99), | |||||
'w2d': (1e-05, 0.99), | |||||
}, | |||||
'MCD': { | |||||
'a2d': (1e-05, 0.9), | |||||
'a2w': (0.0001, 1), | |||||
'd2a': (1e-05, 0.9), | |||||
'd2w': (1e-05, 0.99), | |||||
'w2a': (1e-05, 0.9), | |||||
'w2d': (5*1e-6, 0.99), | |||||
}, | |||||
'CORAL': { | |||||
'a2d': (1e-05, 0.99), | |||||
'a2w': (1e-05, 1), | |||||
'd2a': (5*1e-6, 0.99), | |||||
'd2w': (0.0001, 0.99), | |||||
'w2a': (1e-5, 0.99), | |||||
'w2d': (0.0001, 0.99), | |||||
}, | |||||
} | |||||
if not args.hp_tune: | |||||
d = datetime.now() | |||||
results_file = f"{base_output_dir}/e{args.max_epochs}_p{args.patience}_{d.strftime('%Y%m%d-%H:%M:%S')}.txt" | |||||
with open(results_file, "w") as myfile: | |||||
myfile.write("pair, source_acc, target_acc, best_epoch, best_score, time, cur, peak, lr, gamma\n") | |||||
for source_domain, target_domain in DATASET_PAIRS: | |||||
pair_name = f"{source_domain[0]}2{target_domain[0]}" | |||||
hp_parmas = hp_map[DAModels.CDAN.name][pair_name] | |||||
lr = args.lr if args.lr else hp_parmas[0] | |||||
gamma = args.gamma if args.gamma else hp_parmas[1] | |||||
hp = HP(lr=lr, gamma=gamma) | |||||
tracemalloc.start() | |||||
train_fn(args, model_name, hp, base_output_dir, results_file, source_domain, target_domain) | |||||
cur, peak = tracemalloc.get_traced_memory() | |||||
tracemalloc.stop() | |||||
with open(results_file, "a") as myfile: | |||||
myfile.write(f"{cur}, {peak}\n") | |||||
else: | |||||
# gamma_list = [1, 0.99, 0.9, 0.8] | |||||
# lr_list = [1e-5, 3*1e-5, 1e-4, 3*1e-4] | |||||
hp_values = { | |||||
1e-4: [0.99, 0.9], | |||||
1e-5: [0.99, 0.9], | |||||
5*1e-5: [0.99, 0.9], | |||||
5*1e-6: [0.99] | |||||
# 5*1e-4: [1, 0.99], | |||||
# 1e-3: [0.99, 0.9, 0.8], | |||||
} | |||||
d = datetime.now() | |||||
hp_file = f"{base_output_dir}/hp_e{args.max_epochs}_p{args.patience}_{d.strftime('%Y%m%d-%H:%M:%S')}.txt" | |||||
with open(hp_file, "w") as myfile: | |||||
myfile.write("lr, gamma, pair, source_acc, target_acc, best_epoch, best_score\n") | |||||
hp_best = None | |||||
hp_best_score = None | |||||
for lr, gamma_list in hp_values.items(): | |||||
# for lr in lr_list: | |||||
for gamma in gamma_list: | |||||
hp = HP(lr=lr, gamma=gamma) | |||||
print("HP:", hp) | |||||
for source_domain, target_domain in DATASET_PAIRS: | |||||
_, _, best_score = \ | |||||
train_fn(args, model_name, hp, base_output_dir, hp_file, source_domain, target_domain) | |||||
if best_score is not None and (hp_best_score is None or hp_best_score < best_score): | |||||
hp_best = hp | |||||
hp_best_score = best_score | |||||
with open(hp_file, "a") as myfile: | |||||
myfile.write(f"\nbest: {hp_best.lr}, {hp_best.gamma}\n") | |||||
if __name__ == "__main__": | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('--max_epochs', default=60, type=int) | |||||
parser.add_argument('--patience', default=10, type=int) | |||||
parser.add_argument('--batch_size', default=32, type=int) | |||||
parser.add_argument('--num_workers', default=2, type=int) | |||||
parser.add_argument('--trials_count', default=3, type=int) | |||||
parser.add_argument('--initial_trial', default=0, type=int) | |||||
parser.add_argument('--download', default=False, type=bool) | |||||
parser.add_argument('--root', default="../") | |||||
parser.add_argument('--data_root', default="../datasets/pytorch-adapt/") | |||||
parser.add_argument('--results_root', default="../results/") | |||||
parser.add_argument('--model_names', default=["DANN"], nargs='+') | |||||
parser.add_argument('--lr', default=None, type=float) | |||||
parser.add_argument('--gamma', default=None, type=float) | |||||
parser.add_argument('--hp_tune', default=False, type=bool) | |||||
parser.add_argument('--source', default=None) | |||||
parser.add_argument('--target', default=None) | |||||
parser.add_argument('--vishook_frequency', default=5, type=int) | |||||
parser.add_argument('--source_checkpoint_base_dir', default=None) # default='../results/DAModels.SOURCE/' | |||||
parser.add_argument('--source_checkpoint_trial_number', default=-1, type=int) | |||||
args = parser.parse_args() | |||||
print(args) | |||||
for trial_number in range(args.initial_trial, args.initial_trial + args.trials_count): | |||||
if args.source_checkpoint_trial_number == -1: | |||||
args.source_checkpoint_trial_number = trial_number | |||||
for model_name in args.model_names: | |||||
try: | |||||
model_enum = DAModels(model_name) | |||||
except ValueError: | |||||
logging.warning(f"Model {model_name} not found. skipping...") | |||||
continue | |||||
run_experiment_on_model(args, model_enum, trial_number) |
import torch | |||||
import os | |||||
from pytorch_adapt.adapters import DANN, MCD, VADA, CDAN, RTN, ADDA, Aligner, SymNets | |||||
from pytorch_adapt.containers import Models, Optimizers, LRSchedulers | |||||
from pytorch_adapt.models import Discriminator, office31C, office31G | |||||
from pytorch_adapt.containers import Misc | |||||
from pytorch_adapt.layers import RandomizedDotProduct | |||||
from pytorch_adapt.layers import MultipleModels, CORALLoss, MMDLoss | |||||
from pytorch_adapt.utils import common_functions | |||||
from pytorch_adapt.containers import LRSchedulers | |||||
from classifier_adapter import ClassifierAdapter | |||||
from load_source import get_source_trainer | |||||
from utils import HP, DAModels | |||||
import copy | |||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||||
def get_model(model_name, hp: HP, source_checkpoint_dir, data_root, source_domain): | |||||
if source_checkpoint_dir: | |||||
source_trainer = get_source_trainer(source_checkpoint_dir) | |||||
G = copy.deepcopy(source_trainer.adapter.models["G"]) | |||||
C = copy.deepcopy(source_trainer.adapter.models["C"]) | |||||
D = Discriminator(in_size=2048, h=1024).to(device) | |||||
else: | |||||
weights_root = os.path.join(data_root, "weights") | |||||
G = office31G(pretrained=True, model_dir=weights_root).to(device) | |||||
C = office31C(domain=source_domain, pretrained=True, | |||||
model_dir=weights_root).to(device) | |||||
D = Discriminator(in_size=2048, h=1024).to(device) | |||||
optimizers = Optimizers((torch.optim.Adam, {"lr": hp.lr})) | |||||
lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {"gamma": hp.gamma})) | |||||
if model_name == DAModels.DANN: | |||||
models = Models({"G": G, "C": C, "D": D}) | |||||
adapter = DANN(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers) | |||||
elif model_name == DAModels.CDAN: | |||||
models = Models({"G": G, "C": C, "D": D}) | |||||
misc = Misc({"feature_combiner": RandomizedDotProduct([2048, 31], 2048)}) | |||||
adapter = CDAN(models=models, misc=misc, optimizers=optimizers, lr_schedulers=lr_schedulers) | |||||
elif model_name == DAModels.MCD: | |||||
C1 = common_functions.reinit(copy.deepcopy(C)) | |||||
C_combined = MultipleModels(C, C1) | |||||
models = Models({"G": G, "C": C_combined}) | |||||
adapter= MCD(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers) | |||||
elif model_name == DAModels.SYMNET: | |||||
C1 = common_functions.reinit(copy.deepcopy(C)) | |||||
C_combined = MultipleModels(C, C1) | |||||
models = Models({"G": G, "C": C_combined}) | |||||
adapter= SymNets(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers) | |||||
elif model_name == DAModels.MMD: | |||||
models = Models({"G": G, "C": C}) | |||||
hook_kwargs = {"loss_fn": MMDLoss()} | |||||
adapter= Aligner(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers, hook_kwargs=hook_kwargs) | |||||
elif model_name == DAModels.CORAL: | |||||
models = Models({"G": G, "C": C}) | |||||
hook_kwargs = {"loss_fn": CORALLoss()} | |||||
adapter= Aligner(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers, hook_kwargs=hook_kwargs) | |||||
elif model_name == DAModels.SOURCE: | |||||
models = Models({"G": G, "C": C}) | |||||
adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers) | |||||
return adapter | |||||
import logging | |||||
import matplotlib.pyplot as plt | |||||
import pandas as pd | |||||
import seaborn as sns | |||||
import torch | |||||
import umap | |||||
from tqdm import tqdm | |||||
import os | |||||
import gc | |||||
from datetime import datetime | |||||
from pytorch_adapt.adapters import DANN, MCD, VADA, CDAN, RTN, ADDA, Aligner | |||||
from pytorch_adapt.adapters.base_adapter import BaseGCAdapter | |||||
from pytorch_adapt.adapters.utils import with_opt | |||||
from pytorch_adapt.hooks import ClassifierHook | |||||
from pytorch_adapt.containers import Models, Optimizers, LRSchedulers | |||||
from pytorch_adapt.datasets import DataloaderCreator, get_mnist_mnistm, get_office31 | |||||
from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite | |||||
from pytorch_adapt.models import Discriminator, mnistC, mnistG, office31C, office31G | |||||
from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory | |||||
from pytorch_adapt.containers import Misc | |||||
from pytorch_adapt.layers import RandomizedDotProduct | |||||
from pytorch_adapt.layers import MultipleModels | |||||
from pytorch_adapt.utils import common_functions | |||||
from pytorch_adapt.containers import LRSchedulers | |||||
import copy | |||||
from pprint import pprint | |||||
PATIENCE = 5 | |||||
EPOCHS = 50 | |||||
BATCH_SIZE = 32 | |||||
NUM_WORKERS = 2 | |||||
TRIAL_COUNT = 5 | |||||
logging.basicConfig() | |||||
logging.getLogger("pytorch-adapt").setLevel(logging.WARNING) | |||||
class ClassifierAdapter(BaseGCAdapter): | |||||
""" | |||||
Wraps [AlignerPlusCHook][pytorch_adapt.hooks.AlignerPlusCHook]. | |||||
|Container|Required keys| | |||||
|---|---| | |||||
|models|```["G", "C"]```| | |||||
|optimizers|```["G", "C"]```| | |||||
""" | |||||
def init_hook(self, hook_kwargs): | |||||
opts = with_opt(list(self.optimizers.keys())) | |||||
self.hook = self.hook_cls(opts, **hook_kwargs) | |||||
@property | |||||
def hook_cls(self): | |||||
return ClassifierHook | |||||
root='/content/drive/MyDrive/Shared with Sabas/Bsc/' | |||||
# root="datasets/pytorch-adapt/" | |||||
data_root = os.path.join(root,'data') | |||||
batch_size=BATCH_SIZE | |||||
num_workers=NUM_WORKERS | |||||
device = torch.device("cuda") | |||||
model_dir = os.path.join(data_root, "weights") | |||||
DATASET_PAIRS = [("amazon", ["webcam", "dslr"]), | |||||
("webcam", ["dslr", "amazon"]), | |||||
("dslr", ["amazon", "webcam"]) | |||||
] | |||||
MODEL_NAME = "base" | |||||
model_name = MODEL_NAME | |||||
pass_next= 0 | |||||
pass_trial = 0 | |||||
for trial_number in range(10, 10 + TRIAL_COUNT): | |||||
if pass_trial: | |||||
pass_trial -= 1 | |||||
continue | |||||
base_output_dir = f"{root}/results/vishook/{MODEL_NAME}/{trial_number}" | |||||
os.makedirs(base_output_dir, exist_ok=True) | |||||
d = datetime.now() | |||||
results_file = f"{base_output_dir}/{d.strftime('%Y%m%d-%H:%M:%S')}.txt" | |||||
with open(results_file, "w") as myfile: | |||||
myfile.write("pair, source_acc, target_acc, best_epoch, time\n") | |||||
for source_domain, target_domains in DATASET_PAIRS: | |||||
datasets = get_office31([source_domain], [], folder=data_root) | |||||
dc = DataloaderCreator(batch_size=batch_size, | |||||
num_workers=num_workers, | |||||
) | |||||
dataloaders = dc(**datasets) | |||||
G = office31G(pretrained=True, model_dir=model_dir) | |||||
C = office31C(domain=source_domain, pretrained=True, model_dir=model_dir) | |||||
optimizers = Optimizers((torch.optim.Adam, {"lr": 0.0005})) | |||||
lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {"gamma": 0.99})) | |||||
if model_name == "base": | |||||
models = Models({"G": G, "C": C}) | |||||
adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers) | |||||
checkpoint_fn = CheckpointFnCreator(dirname="saved_models", require_empty=False) | |||||
val_hooks = [ScoreHistory(AccuracyValidator())] | |||||
trainer = Ignite( | |||||
adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn, | |||||
) | |||||
early_stopper_kwargs = {"patience": PATIENCE} | |||||
start_time = datetime.now() | |||||
best_score, best_epoch = trainer.run( | |||||
datasets, dataloader_creator=dc, max_epochs=EPOCHS, early_stopper_kwargs=early_stopper_kwargs | |||||
) | |||||
end_time = datetime.now() | |||||
training_time = end_time - start_time | |||||
for target_domain in target_domains: | |||||
if pass_next: | |||||
pass_next -= 1 | |||||
continue | |||||
pair_name = f"{source_domain[0]}2{target_domain[0]}" | |||||
output_dir = os.path.join(base_output_dir, pair_name) | |||||
os.makedirs(output_dir, exist_ok=True) | |||||
print("output dir:", output_dir) | |||||
# print(f"best_score={best_score}, best_epoch={best_epoch}, training_time={training_time.seconds}") | |||||
plt.plot(val_hooks[0].score_history, label='source') | |||||
plt.title("val accuracy") | |||||
plt.legend() | |||||
plt.savefig(f"{output_dir}/val_accuracy.png") | |||||
plt.close('all') | |||||
datasets = get_office31([source_domain], [target_domain], folder=data_root, return_target_with_labels=True) | |||||
dc = DataloaderCreator(batch_size=64, num_workers=2, all_val=True) | |||||
validator = AccuracyValidator(key_map={"src_val": "src_val"}) | |||||
src_score = trainer.evaluate_best_model(datasets, validator, dc) | |||||
print("Source acc:", src_score) | |||||
validator = AccuracyValidator(key_map={"target_val_with_labels": "src_val"}) | |||||
target_score = trainer.evaluate_best_model(datasets, validator, dc) | |||||
print("Target acc:", target_score) | |||||
with open(results_file, "a") as myfile: | |||||
myfile.write(f"{pair_name}, {src_score}, {target_score}, {best_epoch}, {training_time.seconds}\n") | |||||
del trainer | |||||
del G | |||||
del C | |||||
gc.collect() | |||||
torch.cuda.empty_cache() |
import matplotlib.pyplot as plt | |||||
import torch | |||||
import os | |||||
import gc | |||||
from datetime import datetime | |||||
from pytorch_adapt.datasets import DataloaderCreator, get_office31 | |||||
from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite | |||||
from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory, DiversityValidator, EntropyValidator, MultipleValidators | |||||
from models import get_model | |||||
from utils import DAModels | |||||
from vis_hook import VizHook | |||||
def train(args, model_name, hp, base_output_dir, results_file, source_domain, target_domain): | |||||
if args.source != None and args.source != source_domain: | |||||
return None, None, None | |||||
if args.target != None and args.target != target_domain: | |||||
return None, None, None | |||||
pair_name = f"{source_domain[0]}2{target_domain[0]}" | |||||
output_dir = os.path.join(base_output_dir, pair_name) | |||||
os.makedirs(output_dir, exist_ok=True) | |||||
print("output dir:", output_dir) | |||||
datasets = get_office31([source_domain], [target_domain], | |||||
folder=args.data_root, | |||||
return_target_with_labels=True, | |||||
download=args.download) | |||||
dc = DataloaderCreator(batch_size=args.batch_size, | |||||
num_workers=args.num_workers, | |||||
train_names=["train"], | |||||
val_names=["src_train", "target_train", "src_val", "target_val", | |||||
"target_train_with_labels", "target_val_with_labels"]) | |||||
source_checkpoint_dir = None if not args.source_checkpoint_base_dir else \ | |||||
f"{args.source_checkpoint_base_dir}/{args.source_checkpoint_trial_number}/{pair_name}/saved_models" | |||||
print("source_checkpoint_dir", source_checkpoint_dir) | |||||
adapter = get_model(model_name, hp, source_checkpoint_dir, args.data_root, source_domain) | |||||
checkpoint_fn = CheckpointFnCreator(dirname=f"{output_dir}/saved_models", require_empty=False) | |||||
sourceAccuracyValidator = AccuracyValidator() | |||||
validators = { | |||||
"entropy": EntropyValidator(), | |||||
"diversity": DiversityValidator(), | |||||
} | |||||
validator = ScoreHistory(MultipleValidators(validators)) | |||||
targetAccuracyValidator = AccuracyValidator(key_map={"target_val_with_labels": "src_val"}) | |||||
val_hooks = [ScoreHistory(sourceAccuracyValidator), | |||||
ScoreHistory(targetAccuracyValidator), | |||||
VizHook(output_dir=output_dir, frequency=args.vishook_frequency)] | |||||
trainer = Ignite( | |||||
adapter, validator=validator, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn | |||||
) | |||||
early_stopper_kwargs = {"patience": args.patience} | |||||
start_time = datetime.now() | |||||
best_score, best_epoch = trainer.run( | |||||
datasets, dataloader_creator=dc, max_epochs=args.max_epochs, early_stopper_kwargs=early_stopper_kwargs | |||||
) | |||||
end_time = datetime.now() | |||||
training_time = end_time - start_time | |||||
print(f"best_score={best_score}, best_epoch={best_epoch}") | |||||
plt.plot(val_hooks[0].score_history, label='source') | |||||
plt.plot(val_hooks[1].score_history, label='target') | |||||
plt.title("val accuracy") | |||||
plt.legend() | |||||
plt.savefig(f"{output_dir}/val_accuracy.png") | |||||
plt.close('all') | |||||
plt.plot(validator.score_history) | |||||
plt.title("score_history") | |||||
plt.savefig(f"{output_dir}/score_history.png") | |||||
plt.close('all') | |||||
src_score = trainer.evaluate_best_model(datasets, sourceAccuracyValidator, dc) | |||||
print("Source acc:", src_score) | |||||
target_score = trainer.evaluate_best_model(datasets, targetAccuracyValidator, dc) | |||||
print("Target acc:", target_score) | |||||
print("---------") | |||||
if args.hp_tune: | |||||
with open(results_file, "a") as myfile: | |||||
myfile.write(f"{hp.lr}, {hp.gamma}, {pair_name}, {src_score}, {target_score}, {best_epoch}, {best_score}\n") | |||||
else: | |||||
with open(results_file, "a") as myfile: | |||||
myfile.write( | |||||
f"{pair_name}, {src_score}, {target_score}, {best_epoch}, {best_score}, {training_time.seconds}, {hp.lr}, {hp.gamma},") | |||||
del adapter | |||||
gc.collect() | |||||
torch.cuda.empty_cache() | |||||
return src_score, target_score, best_score |
import torch | |||||
import os | |||||
from pytorch_adapt.adapters import DANN, MCD, VADA, CDAN, RTN, ADDA, Aligner, SymNets | |||||
from pytorch_adapt.containers import Models, Optimizers, LRSchedulers | |||||
from pytorch_adapt.models import Discriminator, office31C, office31G | |||||
from pytorch_adapt.containers import Misc | |||||
from pytorch_adapt.layers import RandomizedDotProduct | |||||
from pytorch_adapt.layers import MultipleModels, CORALLoss, MMDLoss | |||||
from pytorch_adapt.utils import common_functions | |||||
from pytorch_adapt.containers import LRSchedulers | |||||
from classifier_adapter import ClassifierAdapter | |||||
from utils import HP, DAModels | |||||
import copy | |||||
import matplotlib.pyplot as plt | |||||
import torch | |||||
import os | |||||
import gc | |||||
from datetime import datetime | |||||
from pytorch_adapt.datasets import DataloaderCreator, get_office31 | |||||
from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite | |||||
from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory, DiversityValidator, EntropyValidator, MultipleValidators | |||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||||
def train_source(args, model_name, hp, base_output_dir, results_file, source_domain, target_domain): | |||||
if args.source != None and args.source != source_domain: | |||||
return None, None, None | |||||
if args.target != None and args.target != target_domain: | |||||
return None, None, None | |||||
pair_name = f"{source_domain[0]}2{target_domain[0]}" | |||||
output_dir = os.path.join(base_output_dir, pair_name) | |||||
os.makedirs(output_dir, exist_ok=True) | |||||
print("output dir:", output_dir) | |||||
datasets = get_office31([source_domain], [], | |||||
folder=args.data_root, | |||||
return_target_with_labels=True, | |||||
download=args.download) | |||||
dc = DataloaderCreator(batch_size=args.batch_size, | |||||
num_workers=args.num_workers, | |||||
) | |||||
weights_root = os.path.join(args.data_root, "weights") | |||||
G = office31G(pretrained=True, model_dir=weights_root).to(device) | |||||
C = office31C(domain=source_domain, pretrained=True, | |||||
model_dir=weights_root).to(device) | |||||
optimizers = Optimizers((torch.optim.Adam, {"lr": hp.lr})) | |||||
lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {"gamma": hp.gamma})) | |||||
models = Models({"G": G, "C": C}) | |||||
adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers) | |||||
# adapter = get_model(model_name, hp, args.data_root, source_domain) | |||||
print("checkpoint dir:", output_dir) | |||||
checkpoint_fn = CheckpointFnCreator(dirname=f"{output_dir}/saved_models", require_empty=False) | |||||
sourceAccuracyValidator = AccuracyValidator() | |||||
val_hooks = [ScoreHistory(sourceAccuracyValidator)] | |||||
trainer = Ignite( | |||||
adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn | |||||
) | |||||
early_stopper_kwargs = {"patience": args.patience} | |||||
start_time = datetime.now() | |||||
best_score, best_epoch = trainer.run( | |||||
datasets, dataloader_creator=dc, max_epochs=args.max_epochs, early_stopper_kwargs=early_stopper_kwargs | |||||
) | |||||
end_time = datetime.now() | |||||
training_time = end_time - start_time | |||||
print(f"best_score={best_score}, best_epoch={best_epoch}") | |||||
plt.plot(val_hooks[0].score_history, label='source') | |||||
plt.title("val accuracy") | |||||
plt.legend() | |||||
plt.savefig(f"{output_dir}/val_accuracy.png") | |||||
plt.close('all') | |||||
datasets = get_office31([source_domain], [target_domain], folder=args.data_root, return_target_with_labels=True) | |||||
dc = DataloaderCreator(batch_size=args.batch_size, num_workers=args.num_workers, all_val=True) | |||||
validator = AccuracyValidator(key_map={"src_val": "src_val"}) | |||||
src_score = trainer.evaluate_best_model(datasets, validator, dc) | |||||
print("Source acc:", src_score) | |||||
validator = AccuracyValidator(key_map={"target_val_with_labels": "src_val"}) | |||||
target_score = trainer.evaluate_best_model(datasets, validator, dc) | |||||
print("Target acc:", target_score) | |||||
print("---------") | |||||
if args.hp_tune: | |||||
with open(results_file, "a") as myfile: | |||||
myfile.write(f"{hp.lr}, {hp.gamma}, {pair_name}, {src_score}, {target_score}, {best_epoch}, {best_score}\n") | |||||
else: | |||||
with open(results_file, "a") as myfile: | |||||
myfile.write( | |||||
f"{pair_name}, {src_score}, {target_score}, {best_epoch}, {best_score}, {training_time.seconds}, ") | |||||
del adapter | |||||
gc.collect() | |||||
torch.cuda.empty_cache() | |||||
return src_score, target_score, best_score |
from dataclasses import dataclass | |||||
from enum import Enum | |||||
@dataclass | |||||
class HP: | |||||
lr:float = 0.0005 | |||||
gamma:float = 0.99 | |||||
class DAModels(Enum): | |||||
DANN = "DANN" | |||||
MCD = "MCD" | |||||
CDAN = "CDAN" | |||||
SOURCE = "SOURCE" | |||||
MMD = "MMD" | |||||
CORAL = "CORAL" | |||||
SYMNET = "SYMNET" |
import matplotlib.pyplot as plt | |||||
import pandas as pd | |||||
import seaborn as sns | |||||
import torch | |||||
import umap | |||||
from datetime import datetime | |||||
from pytorch_adapt.adapters import DANN | |||||
from pytorch_adapt.containers import Models, Optimizers, LRSchedulers | |||||
from pytorch_adapt.datasets import DataloaderCreator, get_office31 | |||||
from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite | |||||
from pytorch_adapt.models import Discriminator, office31C, office31G | |||||
from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory | |||||
class VizHook: | |||||
def __init__(self, **kwargs): | |||||
self.required_data = ["src_val", | |||||
"target_val", "target_val_with_labels"] | |||||
self.kwargs = kwargs | |||||
def __call__(self, epoch, src_val, target_val, target_val_with_labels, **kwargs): | |||||
accuracy_validator = AccuracyValidator() | |||||
accuracy = accuracy_validator.compute_score(src_val=src_val) | |||||
print("src_val accuracy:", accuracy) | |||||
accuracy_validator = AccuracyValidator() | |||||
accuracy = accuracy_validator.compute_score(src_val=target_val_with_labels) | |||||
print("target_val accuracy:", accuracy) | |||||
if epoch >= 2 and epoch % kwargs.get("frequency", 5) != 0: | |||||
return | |||||
features = [src_val["features"], target_val["features"]] | |||||
domain = [src_val["domain"], target_val["domain"]] | |||||
features = torch.cat(features, dim=0).cpu().numpy() | |||||
domain = torch.cat(domain, dim=0).cpu().numpy() | |||||
emb = umap.UMAP().fit_transform(features) | |||||
df = pd.DataFrame(emb).assign(domain=domain) | |||||
df["domain"] = df["domain"].replace({0: "Source", 1: "Target"}) | |||||
sns.set_theme(style="white", rc={"figure.figsize": (8, 6)}) | |||||
sns.scatterplot(data=df, x=0, y=1, hue="domain", s=10) | |||||
plt.savefig(f"{self.kwargs['output_dir']}/val_{epoch}.png") | |||||
plt.close('all') |
import matplotlib.pyplot as plt | |||||
import numpy as np | |||||
import torchvision | |||||
from pytorch_adapt.datasets import get_office31 | |||||
# root="datasets/pytorch-adapt/" | |||||
mean = [0.485, 0.456, 0.406] | |||||
std = [0.229, 0.224, 0.225] | |||||
inv_normalize = torchvision.transforms.Normalize( | |||||
mean=[-m / s for m, s in zip(mean, std)], std=[1 / s for s in std] | |||||
) | |||||
idx = 0 | |||||
def imshow(img, domain, figsize=(10, 6)): | |||||
img = inv_normalize(img) | |||||
npimg = img.numpy() | |||||
plt.figure(figsize=figsize) | |||||
plt.imshow(np.transpose(npimg, (1, 2, 0))) | |||||
plt.axis('off') | |||||
plt.savefig(f"office31-{idx}") | |||||
plt.show() | |||||
plt.close("all") | |||||
idx += 1 | |||||
def imshow_many(datasets, src, target): | |||||
d = datasets["train"] | |||||
for name in ["src_imgs", "target_imgs"]: | |||||
domains = src if name == "src_imgs" else target | |||||
if len(domains) == 0: | |||||
continue | |||||
print(domains) | |||||
imgs = [d[i][name] for i in np.random.choice(len(d), size=16, replace=False)] | |||||
imshow(torchvision.utils.make_grid(imgs)) | |||||
for src, target in [(["amazon"], ["dslr"]), (["webcam"], [])]: | |||||
datasets = get_office31(src, target,folder=root) | |||||
imshow_many(datasets, src, target) |
import itertools | |||||
import numpy as np | |||||
from pprint import pprint | |||||
domains = ["w", "d", "a"] | |||||
pairs = [[f"{d1}2{d2}" for d1 in domains] for d2 in domains] | |||||
pairs = list(itertools.chain.from_iterable(pairs)) | |||||
pairs.sort() | |||||
# print(pairs) | |||||
ROUND_FACTOR = 3 | |||||
all_accs_list = {} | |||||
files = [ | |||||
"all.txt", | |||||
"all2.txt", | |||||
"all3.txt", | |||||
# "all_step_scheduler.txt", | |||||
# "all_source_trained_1.txt", | |||||
# "all_source_trained_2_specific_hp.txt", | |||||
] | |||||
for file in files: | |||||
with open(file) as f: | |||||
name = None | |||||
for line in f: | |||||
splitted = line.split(" ") | |||||
if splitted[0] == "##": # e.g. ## DANN | |||||
name = splitted[1].strip() | |||||
splitted = line.split(",") # a2w, acc1, acc2, | |||||
if splitted[0] in pairs: | |||||
pair = splitted[0] | |||||
name = name.lower() | |||||
acc = float(splitted[2].strip()) | |||||
if name not in all_accs_list: | |||||
all_accs_list[name] = {p:[] for p in pairs} | |||||
all_accs_list[name][pair].append(acc) | |||||
# all_accs_list format: {'dann': {'a2w': [acc1, acc2, acc3]}} | |||||
acc_means_by_model_name = {} | |||||
vars_by_model_name = {} | |||||
for name, pair_list in all_accs_list.items(): | |||||
accs = {p:[] for p in pairs} | |||||
vars = {p:[] for p in pairs} | |||||
for pair, acc_list in pair_list.items(): | |||||
if len(acc_list) > 0: | |||||
## Calculate average and round | |||||
accs[pair] = round(100 * sum(acc_list) / len(acc_list), ROUND_FACTOR) | |||||
vars[pair] = round(np.var(acc_list) * 100, ROUND_FACTOR) | |||||
print(vars[pair], "|||", acc_list) | |||||
acc_means_by_model_name[name] = accs | |||||
vars_by_model_name[name] = vars | |||||
# for name, acc_list in acc_means_by_model_name.items(): | |||||
# pprint(name) | |||||
# pprint(all_accs_list) | |||||
# pprint(acc_means_by_model_name) | |||||
# pprint(vars_by_model_name) | |||||
print() | |||||
latex_table = "" | |||||
header = [pair for pair in pairs if pair[0] != pair[-1]] | |||||
table = [] | |||||
var_table = [] | |||||
for name, acc_list in acc_means_by_model_name.items(): | |||||
if "target" in name: | |||||
continue | |||||
var_list = vars_by_model_name[name] | |||||
valid_accs = [] | |||||
table_row = [] | |||||
var_table_row = [] | |||||
for pair in pairs: | |||||
acc = acc_list[pair] | |||||
var = var_list[pair] | |||||
if pair[0] != pair[-1]: # exclude w2w, ... | |||||
table_row.append(acc) | |||||
var_table_row.append(var) | |||||
if acc != None: | |||||
valid_accs.append(acc) | |||||
acc_average = round(sum(valid_accs) / len(header), ROUND_FACTOR) | |||||
table_row.append(acc_average) | |||||
table.append(table_row) | |||||
var =round(np.var(valid_accs), ROUND_FACTOR) | |||||
print(var, "~~~~~~", valid_accs) | |||||
var_table_row.append(var) | |||||
var_table.append(var_table_row) | |||||
t = np.array(table) | |||||
t[t==None] = np.nan | |||||
# pprint(t) | |||||
col_max = t.max(axis=0) | |||||
pprint(table) | |||||
latex_table = "" | |||||
header = [pair for pair in pairs if pair[0] != pair[-1]] | |||||
name_map = {"base_source": "Source-Only"} | |||||
j = 0 | |||||
for name, acc_list in acc_means_by_model_name.items(): | |||||
if "target" in name: | |||||
continue | |||||
latex_name = name | |||||
if name in name_map: | |||||
latex_name= name_map[name] | |||||
latex_row = f"{latex_name.replace('_','-').upper()} &" | |||||
acc_sum = 0 | |||||
for i, acc in enumerate(table[j]): | |||||
if i == len(table[j]) - 1: | |||||
acc_str = f"${acc}$" | |||||
else: | |||||
acc_str = f"${acc} \pm {var_table[j][i]}$" | |||||
if acc == col_max[i]: | |||||
latex_row += f" \\underline{{{acc_str}}} &" | |||||
else: | |||||
latex_row += f" {acc_str} &" | |||||
latex_row = f"{latex_row[:-1]} \\\\ \hline" | |||||
latex_table += f"{latex_row}\n" | |||||
j += 1 | |||||
print(*header, sep=" & ") | |||||
print(latex_table) | |||||
data = np.array(table) | |||||
legend = [key for key in acc_means_by_model_name.keys()] | |||||
labels = [*header, "avg"] | |||||
data = np.array([[71.75, 75.94 ,67.38, 90.99, 68.91, 96.67, 78.61], [64.0, 66.67, 37.32, 94.97, 45.74, 98.5, 67.87]]) | |||||
legend = ["CDAN", "Source-only"] | |||||
labels = [*header, "avg"] | |||||
import matplotlib.pyplot as plt | |||||
# Assume your matrix is called 'data' | |||||
n, m = data.shape | |||||
# Create an array of x-coordinates for the bars | |||||
x = np.arange(m) | |||||
# Plot the bars for each row side by side | |||||
for i in range(n): | |||||
row = data[i, :] | |||||
plt.bar(x + (i-n/2)*0.3, row, width=0.25, align='center') | |||||
# Set x-axis tick labels and labels | |||||
plt.xticks(x, labels=labels) | |||||
# plt.xlabel("Task") | |||||
plt.ylabel("Accuracy") | |||||
# Add a legend | |||||
plt.legend(legend) | |||||
plt.show() |
import itertools | |||||
import numpy as np | |||||
from pprint import pprint | |||||
domains = ["w", "d", "a"] | |||||
pairs = [[f"{d1}2{d2}" for d1 in domains] for d2 in domains] | |||||
pairs = list(itertools.chain.from_iterable(pairs)) | |||||
pairs.sort() | |||||
# print(pairs) | |||||
ROUND_FACTOR = 2 | |||||
all_accs_list = {} | |||||
files = [ | |||||
# "all.txt", | |||||
# "all2.txt", | |||||
# "all3.txt", | |||||
# "all_step_scheduler.txt", | |||||
# "all_source_trained_1.txt", | |||||
"all_source_trained_2_specific_hp.txt", | |||||
] | |||||
for file in files: | |||||
with open(file) as f: | |||||
name = None | |||||
for line in f: | |||||
splitted = line.split(" ") | |||||
if splitted[0] == "##": # e.g. ## DANN | |||||
name = splitted[1].strip() | |||||
splitted = line.split(",") # a2w, acc1, acc2, | |||||
if splitted[0] in pairs: | |||||
pair = splitted[0] | |||||
name = name.lower() | |||||
acc = float(splitted[5].strip()) #/ 60 / 60 | |||||
if name not in all_accs_list: | |||||
all_accs_list[name] = {p:[] for p in pairs} | |||||
all_accs_list[name][pair].append(acc) | |||||
# all_accs_list format: {'dann': {'a2w': [acc1, acc2, acc3]}} | |||||
acc_means_by_model_name = {} | |||||
vars_by_model_name = {} | |||||
for name, pair_list in all_accs_list.items(): | |||||
accs = {p:[] for p in pairs} | |||||
vars = {p:[] for p in pairs} | |||||
for pair, acc_list in pair_list.items(): | |||||
if len(acc_list) > 0: | |||||
## Calculate average and round | |||||
accs[pair] = round(sum(acc_list) / len(acc_list), ROUND_FACTOR) | |||||
vars[pair] = round(np.var(acc_list) * 100, ROUND_FACTOR) | |||||
print(vars[pair], "|||", acc_list) | |||||
acc_means_by_model_name[name] = accs | |||||
vars_by_model_name[name] = vars | |||||
# for name, acc_list in acc_means_by_model_name.items(): | |||||
# pprint(name) | |||||
# pprint(all_accs_list) | |||||
# pprint(acc_means_by_model_name) | |||||
# pprint(vars_by_model_name) | |||||
print() | |||||
latex_table = "" | |||||
header = [pair for pair in pairs if pair[0] != pair[-1]] | |||||
table = [] | |||||
var_table = [] | |||||
for name, acc_list in acc_means_by_model_name.items(): | |||||
if "target" in name or "source" in name: | |||||
continue | |||||
print("~~~~%%%%~~~", name) | |||||
var_list = vars_by_model_name[name] | |||||
valid_accs = [] | |||||
table_row = [] | |||||
var_table_row = [] | |||||
for pair in pairs: | |||||
acc = acc_list[pair] | |||||
var = var_list[pair] | |||||
if pair[0] != pair[-1]: # exclude w2w, ... | |||||
table_row.append(acc) | |||||
var_table_row.append(var) | |||||
if acc != None: | |||||
valid_accs.append(acc) | |||||
acc_average = round(sum(valid_accs) / len(header), ROUND_FACTOR) | |||||
table_row.append(acc_average) | |||||
table.append(table_row) | |||||
var =round(np.var(valid_accs), ROUND_FACTOR) | |||||
print(var, ">>>", valid_accs) | |||||
var_table_row.append(var) | |||||
var_table.append(var_table_row) | |||||
t = np.array(table) | |||||
t[t==None] = np.nan | |||||
# pprint(t) | |||||
col_max = t.min(axis=0) | |||||
pprint(table) | |||||
latex_table = "" | |||||
header = [pair for pair in pairs if pair[0] != pair[-1]] | |||||
name_map = {"base_source": "Source-Only"} | |||||
j = 0 | |||||
for name, acc_list in acc_means_by_model_name.items(): | |||||
if "target" in name or "source" in name: | |||||
continue | |||||
latex_name = name | |||||
if name in name_map: | |||||
latex_name= name_map[name] | |||||
latex_row = f"{latex_name.replace('_','-').upper()} &" | |||||
acc_sum = 0 | |||||
for i, acc in enumerate(table[j]): | |||||
if i == len(table[j]) - 1: | |||||
acc_str = f"${acc}$" | |||||
else: | |||||
acc_str = f"${acc}$" | |||||
if acc == col_max[i]: | |||||
latex_row += f" \\underline{{{acc_str}}} &" | |||||
else: | |||||
latex_row += f" {acc_str} &" | |||||
latex_row = f"{latex_row[:-1]} \\\\ \hline" | |||||
latex_table += f"{latex_row}\n" | |||||
j += 1 | |||||
print(*header, sep=" & ") | |||||
print(latex_table) | |||||
data = np.array(table) | |||||
legend = [key for key in acc_means_by_model_name.keys() if "source" not in key] | |||||
labels = [*header, "avg"] | |||||
# data = np.array([[71.75, 75.94 ,67.38, 90.99, 68.91, 96.67, 78.61], [64.0, 66.67, 37.32, 94.97, 45.74, 98.5, 67.87]]) | |||||
# legend = ["CDAN", "Source-only"] | |||||
# labels = [*header, "avg"] | |||||
import matplotlib.pyplot as plt | |||||
# Assume your matrix is called 'data' | |||||
n, m = data.shape | |||||
# Create an array of x-coordinates for the bars | |||||
x = np.arange(m) | |||||
# Plot the bars for each row side by side | |||||
for i in range(n): | |||||
row = data[i, :] | |||||
plt.bar(x + (i-n/2)*0.1, row, width=0.08, align='center') | |||||
# Set x-axis tick labels and labels | |||||
plt.xticks(x, labels=labels) | |||||
# plt.xlabel("Task") | |||||
plt.ylabel("Time (s)") | |||||
# Add a legend | |||||
plt.legend(legend) | |||||
plt.show() |
from pytorch_adapt.datasets import DataloaderCreator, get_office31 | |||||
from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite | |||||
from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory, DiversityValidator, EntropyValidator, MultipleValidators | |||||
from time import time | |||||
import multiprocessing as mp | |||||
data_root = "../datasets/pytorch-adapt/" | |||||
batch_size = 32 | |||||
for num_workers in range(2, mp.cpu_count(), 2): | |||||
datasets = get_office31(["amazon"], ["webcam"], | |||||
folder=data_root, | |||||
return_target_with_labels=True, | |||||
download=False) | |||||
dc = DataloaderCreator(batch_size=batch_size, | |||||
num_workers=num_workers, | |||||
train_names=["train"], | |||||
val_names=["src_train", "target_train", "src_val", "target_val", | |||||
"target_train_with_labels", "target_val_with_labels"]) | |||||
dataloaders = dc(**datasets) | |||||
train_loader = dataloaders["train"] | |||||
start = time() | |||||
for epoch in range(1, 3): | |||||
for i, data in enumerate(train_loader, 0): | |||||
pass | |||||
end = time() | |||||
print("Finish with:{} second, num_workers={}".format(end - start, num_workers)) |
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 1, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"device(type='cuda', index=0)" | |||||
] | |||||
}, | |||||
"execution_count": 1, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"\n", | |||||
"import torch\n", | |||||
"import os\n", | |||||
"\n", | |||||
"from pytorch_adapt.adapters import DANN, MCD, VADA, CDAN, RTN, ADDA, Aligner, SymNets\n", | |||||
"from pytorch_adapt.containers import Models, Optimizers, LRSchedulers\n", | |||||
"from pytorch_adapt.models import Discriminator, office31C, office31G\n", | |||||
"from pytorch_adapt.containers import Misc\n", | |||||
"from pytorch_adapt.layers import RandomizedDotProduct\n", | |||||
"from pytorch_adapt.layers import MultipleModels, CORALLoss, MMDLoss\n", | |||||
"from pytorch_adapt.utils import common_functions\n", | |||||
"from pytorch_adapt.containers import LRSchedulers\n", | |||||
"\n", | |||||
"from classifier_adapter import ClassifierAdapter\n", | |||||
"\n", | |||||
"from utils import HP, DAModels\n", | |||||
"\n", | |||||
"import copy\n", | |||||
"\n", | |||||
"import matplotlib.pyplot as plt\n", | |||||
"import torch\n", | |||||
"import os\n", | |||||
"import gc\n", | |||||
"from datetime import datetime\n", | |||||
"\n", | |||||
"from pytorch_adapt.datasets import DataloaderCreator, get_office31\n", | |||||
"from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite\n", | |||||
"from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory, DiversityValidator, EntropyValidator, MultipleValidators\n", | |||||
"\n", | |||||
"from models import get_model\n", | |||||
"from utils import DAModels\n", | |||||
"\n", | |||||
"from vis_hook import VizHook\n", | |||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | |||||
"device" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 2, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Namespace(batch_size=64, data_root='./datasets/pytorch-adapt/', download=False, gamma=0.99, hp_tune=False, initial_trial=0, lr=0.0001, max_epochs=1, model_names=['DANN'], num_workers=1, patience=2, results_root='./results/', root='./', source=None, target=None, trials_count=1, vishook_frequency=5)\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"import argparse\n", | |||||
"parser = argparse.ArgumentParser()\n", | |||||
"parser.add_argument('--max_epochs', default=1, type=int)\n", | |||||
"parser.add_argument('--patience', default=2, type=int)\n", | |||||
"parser.add_argument('--batch_size', default=64, type=int)\n", | |||||
"parser.add_argument('--num_workers', default=1, type=int)\n", | |||||
"parser.add_argument('--trials_count', default=1, type=int)\n", | |||||
"parser.add_argument('--initial_trial', default=0, type=int)\n", | |||||
"parser.add_argument('--download', default=False, type=bool)\n", | |||||
"parser.add_argument('--root', default=\"./\")\n", | |||||
"parser.add_argument('--data_root', default=\"./datasets/pytorch-adapt/\")\n", | |||||
"parser.add_argument('--results_root', default=\"./results/\")\n", | |||||
"parser.add_argument('--model_names', default=[\"DANN\"], nargs='+')\n", | |||||
"parser.add_argument('--lr', default=0.0001, type=float)\n", | |||||
"parser.add_argument('--gamma', default=0.99, type=float)\n", | |||||
"parser.add_argument('--hp_tune', default=False, type=bool)\n", | |||||
"parser.add_argument('--source', default=None)\n", | |||||
"parser.add_argument('--target', default=None) \n", | |||||
"parser.add_argument('--vishook_frequency', default=5, type=int)\n", | |||||
" \n", | |||||
"\n", | |||||
"args = parser.parse_args(\"\")\n", | |||||
"print(args)\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 45, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"source_domain = 'amazon'\n", | |||||
"target_domain = 'webcam'\n", | |||||
"datasets = get_office31([source_domain], [],\n", | |||||
" folder=args.data_root,\n", | |||||
" return_target_with_labels=True,\n", | |||||
" download=args.download)\n", | |||||
"\n", | |||||
"dc = DataloaderCreator(batch_size=args.batch_size,\n", | |||||
" num_workers=args.num_workers,\n", | |||||
" )\n", | |||||
"\n", | |||||
"weights_root = os.path.join(args.data_root, \"weights\")\n", | |||||
"\n", | |||||
"G = office31G(pretrained=True, model_dir=weights_root).to(device)\n", | |||||
"C = office31C(domain=source_domain, pretrained=True,\n", | |||||
" model_dir=weights_root).to(device)\n", | |||||
"\n", | |||||
"\n", | |||||
"optimizers = Optimizers((torch.optim.Adam, {\"lr\": 1e-4}))\n", | |||||
"lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {\"gamma\": 0.99})) \n", | |||||
"\n", | |||||
"models = Models({\"G\": G, \"C\": C})\n", | |||||
"adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"cuda:0\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "f28aaf5a334d4f91a9beb21e714c43a5", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"[1/35] 3%|2 |it [00:00<?]" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "7131d8b9099c4d0a95155595919c55f5", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"[1/9] 11%|#1 |it [00:00<?]" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"best_score=None, best_epoch=None\n" | |||||
] | |||||
}, | |||||
{ | |||||
"ename": "AttributeError", | |||||
"evalue": "'Namespace' object has no attribute 'dataroot'", | |||||
"output_type": "error", | |||||
"traceback": [ | |||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |||||
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", | |||||
"Cell \u001b[0;32mIn[39], line 31\u001b[0m\n\u001b[1;32m 28\u001b[0m plt\u001b[39m.\u001b[39msavefig(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00moutput_dir\u001b[39m}\u001b[39;00m\u001b[39m/val_accuracy.png\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 29\u001b[0m plt\u001b[39m.\u001b[39mclose(\u001b[39m'\u001b[39m\u001b[39mall\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[0;32m---> 31\u001b[0m datasets \u001b[39m=\u001b[39m get_office31([source_domain], [target_domain], folder\u001b[39m=\u001b[39margs\u001b[39m.\u001b[39;49mdataroot, return_target_with_labels\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m 32\u001b[0m dc \u001b[39m=\u001b[39m DataloaderCreator(batch_size\u001b[39m=\u001b[39margs\u001b[39m.\u001b[39mbatch_size, num_workers\u001b[39m=\u001b[39margs\u001b[39m.\u001b[39mnum_workers, all_val\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m 34\u001b[0m validator \u001b[39m=\u001b[39m AccuracyValidator(key_map\u001b[39m=\u001b[39m{\u001b[39m\"\u001b[39m\u001b[39msrc_val\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m\"\u001b[39m\u001b[39msrc_val\u001b[39m\u001b[39m\"\u001b[39m})\n", | |||||
"\u001b[0;31mAttributeError\u001b[0m: 'Namespace' object has no attribute 'dataroot'" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"\n", | |||||
"output_dir = \"tmp\"\n", | |||||
"checkpoint_fn = CheckpointFnCreator(dirname=f\"{output_dir}/saved_models\", require_empty=False)\n", | |||||
"\n", | |||||
"sourceAccuracyValidator = AccuracyValidator()\n", | |||||
"val_hooks = [ScoreHistory(sourceAccuracyValidator)]\n", | |||||
"\n", | |||||
"trainer = Ignite(\n", | |||||
" adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn, device=device\n", | |||||
")\n", | |||||
"print(trainer.device)\n", | |||||
"\n", | |||||
"early_stopper_kwargs = {\"patience\": args.patience}\n", | |||||
"\n", | |||||
"start_time = datetime.now()\n", | |||||
"\n", | |||||
"best_score, best_epoch = trainer.run(\n", | |||||
" datasets, dataloader_creator=dc, max_epochs=args.max_epochs, early_stopper_kwargs=early_stopper_kwargs\n", | |||||
")\n", | |||||
"\n", | |||||
"end_time = datetime.now()\n", | |||||
"training_time = end_time - start_time\n", | |||||
"\n", | |||||
"print(f\"best_score={best_score}, best_epoch={best_epoch}\")\n", | |||||
"\n", | |||||
"plt.plot(val_hooks[0].score_history, label='source')\n", | |||||
"plt.title(\"val accuracy\")\n", | |||||
"plt.legend()\n", | |||||
"plt.savefig(f\"{output_dir}/val_accuracy.png\")\n", | |||||
"plt.close('all')\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "173d54ab994d4abda6e4f0897ad96c49", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"[1/9] 11%|#1 |it [00:00<?]" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Source acc: 0.868794322013855\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "13b4a1ccc3b34456b68c73357d14bc21", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"[1/3] 33%|###3 |it [00:00<?]" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Target acc: 0.74842768907547\n", | |||||
"---------\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"\n", | |||||
"datasets = get_office31([source_domain], [target_domain], folder=args.data_root, return_target_with_labels=True)\n", | |||||
"dc = DataloaderCreator(batch_size=args.batch_size, num_workers=args.num_workers, all_val=True)\n", | |||||
"\n", | |||||
"validator = AccuracyValidator(key_map={\"src_val\": \"src_val\"})\n", | |||||
"src_score = trainer.evaluate_best_model(datasets, validator, dc)\n", | |||||
"print(\"Source acc:\", src_score)\n", | |||||
"\n", | |||||
"validator = AccuracyValidator(key_map={\"target_val_with_labels\": \"src_val\"})\n", | |||||
"target_score = trainer.evaluate_best_model(datasets, validator, dc)\n", | |||||
"print(\"Target acc:\", target_score)\n", | |||||
"print(\"---------\")" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"C2 = copy.deepcopy(C) " | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 93, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"cuda:0\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"source_domain = 'amazon'\n", | |||||
"target_domain = 'webcam'\n", | |||||
"G = office31G(pretrained=False).to(device)\n", | |||||
"C = office31C(pretrained=False).to(device)\n", | |||||
"\n", | |||||
"\n", | |||||
"optimizers = Optimizers((torch.optim.Adam, {\"lr\": 1e-4}))\n", | |||||
"lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {\"gamma\": 0.99})) \n", | |||||
"\n", | |||||
"models = Models({\"G\": G, \"C\": C})\n", | |||||
"adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)\n", | |||||
"\n", | |||||
"\n", | |||||
"output_dir = \"tmp\"\n", | |||||
"checkpoint_fn = CheckpointFnCreator(dirname=f\"{output_dir}/saved_models\", require_empty=False)\n", | |||||
"\n", | |||||
"sourceAccuracyValidator = AccuracyValidator()\n", | |||||
"val_hooks = [ScoreHistory(sourceAccuracyValidator)]\n", | |||||
"\n", | |||||
"new_trainer = Ignite(\n", | |||||
" adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn, device=device\n", | |||||
")\n", | |||||
"print(trainer.device)\n", | |||||
"\n", | |||||
"from pytorch_adapt.frameworks.ignite import (\n", | |||||
" CheckpointFnCreator,\n", | |||||
" IgniteValHookWrapper,\n", | |||||
" checkpoint_utils,\n", | |||||
")\n", | |||||
"\n", | |||||
"objs = [\n", | |||||
" {\n", | |||||
" \"engine\": new_trainer.trainer,\n", | |||||
" \"validator\": new_trainer.validator,\n", | |||||
" \"val_hook0\": val_hooks[0],\n", | |||||
" **checkpoint_utils.adapter_to_dict(new_trainer.adapter),\n", | |||||
" }\n", | |||||
" ]\n", | |||||
" \n", | |||||
"# best_score, best_epoch = trainer.run(\n", | |||||
"# datasets, dataloader_creator=dc, max_epochs=args.max_epochs, early_stopper_kwargs=early_stopper_kwargs\n", | |||||
"# )\n", | |||||
"\n", | |||||
"for to_load in objs:\n", | |||||
" checkpoint_fn.load_best_checkpoint(to_load)\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 94, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "32f01ff7ea254739909e4567a133b00a", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"[1/9] 11%|#1 |it [00:00<?]" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Source acc: 0.868794322013855\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "cef345c05e5e46eb9fc0e1cc40b02435", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"[1/3] 33%|###3 |it [00:00<?]" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Target acc: 0.74842768907547\n", | |||||
"---------\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"\n", | |||||
"datasets = get_office31([source_domain], [target_domain], folder=args.data_root, return_target_with_labels=True)\n", | |||||
"dc = DataloaderCreator(batch_size=args.batch_size, num_workers=args.num_workers, all_val=True)\n", | |||||
"\n", | |||||
"validator = AccuracyValidator(key_map={\"src_val\": \"src_val\"})\n", | |||||
"src_score = new_trainer.evaluate_best_model(datasets, validator, dc)\n", | |||||
"print(\"Source acc:\", src_score)\n", | |||||
"\n", | |||||
"validator = AccuracyValidator(key_map={\"target_val_with_labels\": \"src_val\"})\n", | |||||
"target_score = new_trainer.evaluate_best_model(datasets, validator, dc)\n", | |||||
"print(\"Target acc:\", target_score)\n", | |||||
"print(\"---------\")" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 89, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"\n", | |||||
"datasets = get_office31([source_domain], [target_domain],\n", | |||||
" folder=args.data_root,\n", | |||||
" return_target_with_labels=True,\n", | |||||
" download=args.download)\n", | |||||
" \n", | |||||
"dc = DataloaderCreator(batch_size=args.batch_size,\n", | |||||
" num_workers=args.num_workers,\n", | |||||
" train_names=[\"train\"],\n", | |||||
" val_names=[\"src_train\", \"target_train\", \"src_val\", \"target_val\",\n", | |||||
" \"target_train_with_labels\", \"target_val_with_labels\"])\n", | |||||
"\n", | |||||
"G = new_trainer.adapter.models[\"G\"]\n", | |||||
"C = new_trainer.adapter.models[\"C\"]\n", | |||||
"D = Discriminator(in_size=2048, h=1024).to(device)\n", | |||||
"\n", | |||||
"optimizers = Optimizers((torch.optim.Adam, {\"lr\": 0.001}))\n", | |||||
"lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {\"gamma\": 0.99}))\n", | |||||
"# lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.MultiStepLR, {\"milestones\": [2, 5, 10, 20, 40], \"gamma\": hp.gamma}))\n", | |||||
"\n", | |||||
"models = Models({\"G\": G, \"C\": C, \"D\": D})\n", | |||||
"adapter = DANN(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 90, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"cuda:0\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "bf490d18567444149070191e100f8c45", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"[1/3] 33%|###3 |it [00:00<?]" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "525920fcd19d4178a4bada48932c8fb1", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"[1/9] 11%|#1 |it [00:00<?]" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "bd1158f548e746cdab88d608b22ab65c", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"[1/9] 11%|#1 |it [00:00<?]" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "196fa120037b48fdb4e9a879e7e7c79b", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"[1/3] 33%|###3 |it [00:00<?]" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "6795edb658a84309b1a03bcea6a24643", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"[1/9] 11%|#1 |it [00:00<?]" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
} | |||||
], | |||||
"source": [ | |||||
"\n", | |||||
"output_dir = \"tmp\"\n", | |||||
"checkpoint_fn = CheckpointFnCreator(dirname=f\"{output_dir}/saved_models\", require_empty=False)\n", | |||||
"\n", | |||||
"sourceAccuracyValidator = AccuracyValidator()\n", | |||||
"targetAccuracyValidator = AccuracyValidator(key_map={\"target_val_with_labels\": \"src_val\"})\n", | |||||
"val_hooks = [ScoreHistory(sourceAccuracyValidator), ScoreHistory(targetAccuracyValidator)]\n", | |||||
"\n", | |||||
"trainer = Ignite(\n", | |||||
" adapter, val_hooks=val_hooks, device=device\n", | |||||
")\n", | |||||
"print(trainer.device)\n", | |||||
"\n", | |||||
"best_score, best_epoch = trainer.run(\n", | |||||
" datasets, dataloader_creator=dc, max_epochs=args.max_epochs, early_stopper_kwargs=early_stopper_kwargs, check_initial_score=True\n", | |||||
")\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 91, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"ScoreHistory(\n", | |||||
" validator=AccuracyValidator(required_data=['src_val'])\n", | |||||
" latest_score=0.30319148302078247\n", | |||||
" best_score=0.868794322013855\n", | |||||
" best_epoch=0\n", | |||||
")" | |||||
] | |||||
}, | |||||
"execution_count": 91, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"val_hooks[0]" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 92, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"ScoreHistory(\n", | |||||
" validator=AccuracyValidator(required_data=['target_val_with_labels'])\n", | |||||
" latest_score=0.2515723407268524\n", | |||||
" best_score=0.74842768907547\n", | |||||
" best_epoch=0\n", | |||||
")" | |||||
] | |||||
}, | |||||
"execution_count": 92, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"val_hooks[1]" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 86, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"del trainer" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 87, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"21169" | |||||
] | |||||
}, | |||||
"execution_count": 87, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"import gc\n", | |||||
"gc.collect()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 88, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"torch.cuda.empty_cache()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 95, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"args.vishook_frequency = 133" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 96, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"Namespace(batch_size=64, data_root='./datasets/pytorch-adapt/', download=False, gamma=0.99, hp_tune=False, initial_trial=0, lr=0.0001, max_epochs=1, model_names=['DANN'], num_workers=1, patience=2, results_root='./results/', root='./', source=None, target=None, trials_count=1, vishook_frequency=133)" | |||||
] | |||||
}, | |||||
"execution_count": 96, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
}, | |||||
{ | |||||
"ename": "", | |||||
"evalue": "", | |||||
"output_type": "error", | |||||
"traceback": [ | |||||
"\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details." | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"args" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
}, | |||||
{ | |||||
"attachments": {}, | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"-----" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 3, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"path = \"/media/10TB71/shashemi/Domain-Adaptation/results/DAModels.CDAN/2000/a2d/saved_models\"" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 5, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"source_domain = 'amazon'\n", | |||||
"target_domain = 'dslr'\n", | |||||
"G = office31G(pretrained=False).to(device)\n", | |||||
"C = office31C(pretrained=False).to(device)\n", | |||||
"\n", | |||||
"\n", | |||||
"optimizers = Optimizers((torch.optim.Adam, {\"lr\": 1e-4}))\n", | |||||
"lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {\"gamma\": 0.99})) \n", | |||||
"\n", | |||||
"models = Models({\"G\": G, \"C\": C})\n", | |||||
"adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)\n", | |||||
"\n", | |||||
"\n", | |||||
"output_dir = \"tmp\"\n", | |||||
"checkpoint_fn = CheckpointFnCreator(dirname=f\"{output_dir}/saved_models\", require_empty=False)\n", | |||||
"\n", | |||||
"sourceAccuracyValidator = AccuracyValidator()\n", | |||||
"val_hooks = [ScoreHistory(sourceAccuracyValidator)]\n", | |||||
"\n", | |||||
"new_trainer = Ignite(\n", | |||||
" adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn, device=device\n", | |||||
")\n", | |||||
"\n", | |||||
"from pytorch_adapt.frameworks.ignite import (\n", | |||||
" CheckpointFnCreator,\n", | |||||
" IgniteValHookWrapper,\n", | |||||
" checkpoint_utils,\n", | |||||
")\n", | |||||
"\n", | |||||
"objs = [\n", | |||||
" {\n", | |||||
" \"engine\": new_trainer.trainer,\n", | |||||
" \"validator\": new_trainer.validator,\n", | |||||
" \"val_hook0\": val_hooks[0],\n", | |||||
" **checkpoint_utils.adapter_to_dict(new_trainer.adapter),\n", | |||||
" }\n", | |||||
" ]\n", | |||||
" \n", | |||||
"# best_score, best_epoch = trainer.run(\n", | |||||
"# datasets, dataloader_creator=dc, max_epochs=args.max_epochs, early_stopper_kwargs=early_stopper_kwargs\n", | |||||
"# )\n", | |||||
"\n", | |||||
"for to_load in objs:\n", | |||||
" checkpoint_fn.load_best_checkpoint(to_load)\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 6, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "926966dd640e4979ade6a45cf0fcdd49", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"[1/9] 11%|#1 |it [00:00<?]" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Source acc: 0.868794322013855\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "64cd5cfb052c4f52af9af1a63a4c0087", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"[1/2] 50%|##### |it [00:00<?]" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Target acc: 0.7200000286102295\n", | |||||
"---------\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"\n", | |||||
"datasets = get_office31([source_domain], [target_domain], folder=args.data_root, return_target_with_labels=True)\n", | |||||
"dc = DataloaderCreator(batch_size=args.batch_size, num_workers=args.num_workers, all_val=True)\n", | |||||
"\n", | |||||
"validator = AccuracyValidator(key_map={\"src_val\": \"src_val\"})\n", | |||||
"src_score = new_trainer.evaluate_best_model(datasets, validator, dc)\n", | |||||
"print(\"Source acc:\", src_score)\n", | |||||
"\n", | |||||
"validator = AccuracyValidator(key_map={\"target_val_with_labels\": \"src_val\"})\n", | |||||
"target_score = new_trainer.evaluate_best_model(datasets, validator, dc)\n", | |||||
"print(\"Target acc:\", target_score)\n", | |||||
"print(\"---------\")" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 10, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"source_domain = 'amazon'\n", | |||||
"target_domain = 'dslr'\n", | |||||
"G = new_trainer.adapter.models[\"G\"]\n", | |||||
"C = new_trainer.adapter.models[\"C\"]\n", | |||||
"\n", | |||||
"G.fc = C.net[:6]\n", | |||||
"C.net = C.net[6:]\n", | |||||
"\n", | |||||
"\n", | |||||
"optimizers = Optimizers((torch.optim.Adam, {\"lr\": 1e-4}))\n", | |||||
"lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {\"gamma\": 0.99})) \n", | |||||
"\n", | |||||
"models = Models({\"G\": G, \"C\": C})\n", | |||||
"adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)\n", | |||||
"\n", | |||||
"\n", | |||||
"output_dir = \"tmp\"\n", | |||||
"checkpoint_fn = CheckpointFnCreator(dirname=f\"{output_dir}/saved_models\", require_empty=False)\n", | |||||
"\n", | |||||
"sourceAccuracyValidator = AccuracyValidator()\n", | |||||
"val_hooks = [ScoreHistory(sourceAccuracyValidator)]\n", | |||||
"\n", | |||||
"more_new_trainer = Ignite(\n", | |||||
" adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn, device=device\n", | |||||
")" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 13, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from pytorch_adapt.hooks import FeaturesAndLogitsHook\n", | |||||
"\n", | |||||
"h1 = FeaturesAndLogitsHook()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 19, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"ename": "KeyError", | |||||
"evalue": "in FeaturesAndLogitsHook: __call__\nin FeaturesHook: __call__\nFeaturesHook: Getting src\nFeaturesHook: Getting output: ['src_imgs_features']\nFeaturesHook: Using model G with inputs: src_imgs\nG", | |||||
"output_type": "error", | |||||
"traceback": [ | |||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |||||
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", | |||||
"Cell \u001b[0;32mIn[19], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m h1(datasets)\n", | |||||
"File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/base.py:52\u001b[0m, in \u001b[0;36mBaseHook.__call__\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 51\u001b[0m inputs \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mmap_keys(inputs, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkey_map)\n\u001b[0;32m---> 52\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcall(inputs, losses)\n\u001b[1;32m 53\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(x, (\u001b[39mbool\u001b[39m, np\u001b[39m.\u001b[39mbool_)):\n\u001b[1;32m 54\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlogger\u001b[39m.\u001b[39mreset()\n", | |||||
"File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/utils.py:109\u001b[0m, in \u001b[0;36mChainHook.call\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 107\u001b[0m all_losses \u001b[39m=\u001b[39m {\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mall_losses, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mprev_losses}\n\u001b[1;32m 108\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconditions[i](all_inputs, all_losses):\n\u001b[0;32m--> 109\u001b[0m x \u001b[39m=\u001b[39m h(all_inputs, all_losses)\n\u001b[1;32m 110\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 111\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39malts[i](all_inputs, all_losses)\n", | |||||
"File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/base.py:52\u001b[0m, in \u001b[0;36mBaseHook.__call__\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 51\u001b[0m inputs \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mmap_keys(inputs, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkey_map)\n\u001b[0;32m---> 52\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcall(inputs, losses)\n\u001b[1;32m 53\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(x, (\u001b[39mbool\u001b[39m, np\u001b[39m.\u001b[39mbool_)):\n\u001b[1;32m 54\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlogger\u001b[39m.\u001b[39mreset()\n", | |||||
"File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/features.py:80\u001b[0m, in \u001b[0;36mBaseFeaturesHook.call\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 78\u001b[0m func \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmode_detached \u001b[39mif\u001b[39;00m detach \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmode_with_grad\n\u001b[1;32m 79\u001b[0m in_keys \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mfilter(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39min_keys, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m^\u001b[39m\u001b[39m{\u001b[39;00mdomain\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[0;32m---> 80\u001b[0m func(inputs, outputs, domain, in_keys)\n\u001b[1;32m 82\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcheck_outputs_requires_grad(outputs)\n\u001b[1;32m 83\u001b[0m \u001b[39mreturn\u001b[39;00m outputs, {}\n", | |||||
"File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/features.py:106\u001b[0m, in \u001b[0;36mBaseFeaturesHook.mode_with_grad\u001b[0;34m(self, inputs, outputs, domain, in_keys)\u001b[0m\n\u001b[1;32m 104\u001b[0m output_keys \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mfilter(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_out_keys(), \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m^\u001b[39m\u001b[39m{\u001b[39;00mdomain\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 105\u001b[0m output_vals \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mget_kwargs(inputs, output_keys)\n\u001b[0;32m--> 106\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49madd_if_new(\n\u001b[1;32m 107\u001b[0m outputs, output_keys, output_vals, inputs, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel_name, in_keys, domain\n\u001b[1;32m 108\u001b[0m )\n\u001b[1;32m 109\u001b[0m \u001b[39mreturn\u001b[39;00m output_keys, output_vals\n", | |||||
"File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/features.py:133\u001b[0m, in \u001b[0;36mBaseFeaturesHook.add_if_new\u001b[0;34m(self, outputs, full_key, output_vals, inputs, model_name, in_keys, domain)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39madd_if_new\u001b[39m(\n\u001b[1;32m 131\u001b[0m \u001b[39mself\u001b[39m, outputs, full_key, output_vals, inputs, model_name, in_keys, domain\n\u001b[1;32m 132\u001b[0m ):\n\u001b[0;32m--> 133\u001b[0m c_f\u001b[39m.\u001b[39;49madd_if_new(\n\u001b[1;32m 134\u001b[0m outputs,\n\u001b[1;32m 135\u001b[0m full_key,\n\u001b[1;32m 136\u001b[0m output_vals,\n\u001b[1;32m 137\u001b[0m inputs,\n\u001b[1;32m 138\u001b[0m model_name,\n\u001b[1;32m 139\u001b[0m in_keys,\n\u001b[1;32m 140\u001b[0m logger\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mlogger,\n\u001b[1;32m 141\u001b[0m )\n", | |||||
"File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/utils/common_functions.py:96\u001b[0m, in \u001b[0;36madd_if_new\u001b[0;34m(d, key, x, kwargs, model_name, in_keys, other_args, logger)\u001b[0m\n\u001b[1;32m 94\u001b[0m condition \u001b[39m=\u001b[39m is_none\n\u001b[1;32m 95\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39many\u001b[39m(condition(y) \u001b[39mfor\u001b[39;00m y \u001b[39min\u001b[39;00m x):\n\u001b[0;32m---> 96\u001b[0m model \u001b[39m=\u001b[39m kwargs[model_name]\n\u001b[1;32m 97\u001b[0m input_vals \u001b[39m=\u001b[39m [kwargs[k] \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m in_keys] \u001b[39m+\u001b[39m \u001b[39mlist\u001b[39m(other_args\u001b[39m.\u001b[39mvalues())\n\u001b[1;32m 98\u001b[0m new_x \u001b[39m=\u001b[39m try_use_model(model, model_name, input_vals)\n", | |||||
"\u001b[0;31mKeyError\u001b[0m: in FeaturesAndLogitsHook: __call__\nin FeaturesHook: __call__\nFeaturesHook: Getting src\nFeaturesHook: Getting output: ['src_imgs_features']\nFeaturesHook: Using model G with inputs: src_imgs\nG" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"h1(datasets)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "cdtrans", | |||||
"language": "python", | |||||
"name": "python3" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.8.15 (default, Nov 24 2022, 15:19:38) \n[GCC 11.2.0]" | |||||
}, | |||||
"orig_nbformat": 4, | |||||
"vscode": { | |||||
"interpreter": { | |||||
"hash": "959b82c3a41427bdf7d14d4ba7335271e0c50cfcddd70501934b27dcc36968ad" | |||||
} | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |
{ | |||||
"amazon": { | |||||
"back_pack": 92, | |||||
"bike": 82, | |||||
"bike_helmet": 72, | |||||
"bookcase": 82, | |||||
"bottle": 36, | |||||
"calculator": 94, | |||||
"desk_chair": 91, | |||||
"desk_lamp": 97, | |||||
"desktop_computer": 97, | |||||
"file_cabinet": 81, | |||||
"headphones": 99, | |||||
"keyboard": 100, | |||||
"laptop_computer": 100, | |||||
"letter_tray": 98, | |||||
"mobile_phone": 100, | |||||
"monitor": 99, | |||||
"mouse": 100, | |||||
"mug": 94, | |||||
"paper_notebook": 96, | |||||
"pen": 95, | |||||
"phone": 93, | |||||
"printer": 100, | |||||
"projector": 98, | |||||
"punchers": 98, | |||||
"ring_binder": 90, | |||||
"ruler": 75, | |||||
"scissors": 100, | |||||
"speaker": 99, | |||||
"stapler": 99, | |||||
"tape_dispenser": 96, | |||||
"trash_can": 64 | |||||
}, | |||||
"dslr": { | |||||
"back_pack": 12, | |||||
"bike": 21, | |||||
"bike_helmet": 24, | |||||
"bookcase": 12, | |||||
"bottle": 16, | |||||
"calculator": 12, | |||||
"desk_chair": 13, | |||||
"desk_lamp": 14, | |||||
"desktop_computer": 15, | |||||
"file_cabinet": 15, | |||||
"headphones": 13, | |||||
"keyboard": 10, | |||||
"laptop_computer": 24, | |||||
"letter_tray": 16, | |||||
"mobile_phone": 31, | |||||
"monitor": 22, | |||||
"mouse": 12, | |||||
"mug": 8, | |||||
"paper_notebook": 10, | |||||
"pen": 10, | |||||
"phone": 13, | |||||
"printer": 15, | |||||
"projector": 23, | |||||
"punchers": 18, | |||||
"ring_binder": 10, | |||||
"ruler": 7, | |||||
"scissors": 18, | |||||
"speaker": 26, | |||||
"stapler": 21, | |||||
"tape_dispenser": 22, | |||||
"trash_can": 15 | |||||
}, | |||||
"webcam": { | |||||
"back_pack": 29, | |||||
"bike": 21, | |||||
"bike_helmet": 28, | |||||
"bookcase": 12, | |||||
"bottle": 16, | |||||
"calculator": 31, | |||||
"desk_chair": 40, | |||||
"desk_lamp": 18, | |||||
"desktop_computer": 21, | |||||
"file_cabinet": 19, | |||||
"headphones": 27, | |||||
"keyboard": 27, | |||||
"laptop_computer": 30, | |||||
"letter_tray": 19, | |||||
"mobile_phone": 30, | |||||
"monitor": 43, | |||||
"mouse": 30, | |||||
"mug": 27, | |||||
"paper_notebook": 28, | |||||
"pen": 32, | |||||
"phone": 16, | |||||
"printer": 20, | |||||
"projector": 30, | |||||
"punchers": 27, | |||||
"ring_binder": 40, | |||||
"ruler": 11, | |||||
"scissors": 25, | |||||
"speaker": 30, | |||||
"stapler": 24, | |||||
"tape_dispenser": 23, | |||||
"trash_can": 21 | |||||
} | |||||
} |
{ | |||||
"amazon": 2817, | |||||
"dslr": 498, | |||||
"webcam": 795 | |||||
} |