{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "device(type='cuda', index=0)" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "import torch\n", "import os\n", "\n", "from pytorch_adapt.adapters import DANN, MCD, VADA, CDAN, RTN, ADDA, Aligner, SymNets\n", "from pytorch_adapt.containers import Models, Optimizers, LRSchedulers\n", "from pytorch_adapt.models import Discriminator, office31C, office31G\n", "from pytorch_adapt.containers import Misc\n", "from pytorch_adapt.layers import RandomizedDotProduct\n", "from pytorch_adapt.layers import MultipleModels, CORALLoss, MMDLoss\n", "from pytorch_adapt.utils import common_functions\n", "from pytorch_adapt.containers import LRSchedulers\n", "\n", "from classifier_adapter import ClassifierAdapter\n", "\n", "from utils import HP, DAModels\n", "\n", "import copy\n", "\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import os\n", "import gc\n", "from datetime import datetime\n", "\n", "from pytorch_adapt.datasets import DataloaderCreator, get_office31\n", "from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite\n", "from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory, DiversityValidator, EntropyValidator, MultipleValidators\n", "\n", "from models import get_model\n", "from utils import DAModels\n", "\n", "from vis_hook import VizHook\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "device" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Namespace(batch_size=64, data_root='./datasets/pytorch-adapt/', download=False, gamma=0.99, hp_tune=False, initial_trial=0, lr=0.0001, max_epochs=1, model_names=['DANN'], num_workers=1, patience=2, results_root='./results/', root='./', source=None, target=None, trials_count=1, vishook_frequency=5)\n" ] } ], "source": [ "import argparse\n", "parser = argparse.ArgumentParser()\n", "parser.add_argument('--max_epochs', default=1, type=int)\n", "parser.add_argument('--patience', default=2, type=int)\n", "parser.add_argument('--batch_size', default=64, type=int)\n", "parser.add_argument('--num_workers', default=1, type=int)\n", "parser.add_argument('--trials_count', default=1, type=int)\n", "parser.add_argument('--initial_trial', default=0, type=int)\n", "parser.add_argument('--download', default=False, type=bool)\n", "parser.add_argument('--root', default=\"./\")\n", "parser.add_argument('--data_root', default=\"./datasets/pytorch-adapt/\")\n", "parser.add_argument('--results_root', default=\"./results/\")\n", "parser.add_argument('--model_names', default=[\"DANN\"], nargs='+')\n", "parser.add_argument('--lr', default=0.0001, type=float)\n", "parser.add_argument('--gamma', default=0.99, type=float)\n", "parser.add_argument('--hp_tune', default=False, type=bool)\n", "parser.add_argument('--source', default=None)\n", "parser.add_argument('--target', default=None) \n", "parser.add_argument('--vishook_frequency', default=5, type=int)\n", " \n", "\n", "args = parser.parse_args(\"\")\n", "print(args)\n" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "source_domain = 'amazon'\n", "target_domain = 'webcam'\n", "datasets = get_office31([source_domain], [],\n", " folder=args.data_root,\n", " return_target_with_labels=True,\n", " download=args.download)\n", "\n", "dc = DataloaderCreator(batch_size=args.batch_size,\n", " num_workers=args.num_workers,\n", " )\n", "\n", "weights_root = os.path.join(args.data_root, \"weights\")\n", "\n", "G = office31G(pretrained=True, model_dir=weights_root).to(device)\n", "C = office31C(domain=source_domain, pretrained=True,\n", " model_dir=weights_root).to(device)\n", "\n", "\n", "optimizers = Optimizers((torch.optim.Adam, {\"lr\": 1e-4}))\n", "lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {\"gamma\": 0.99})) \n", "\n", "models = Models({\"G\": G, \"C\": C})\n", "adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda:0\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f28aaf5a334d4f91a9beb21e714c43a5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "[1/35] 3%|2 |it [00:00 31\u001b[0m datasets \u001b[39m=\u001b[39m get_office31([source_domain], [target_domain], folder\u001b[39m=\u001b[39margs\u001b[39m.\u001b[39;49mdataroot, return_target_with_labels\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m 32\u001b[0m dc \u001b[39m=\u001b[39m DataloaderCreator(batch_size\u001b[39m=\u001b[39margs\u001b[39m.\u001b[39mbatch_size, num_workers\u001b[39m=\u001b[39margs\u001b[39m.\u001b[39mnum_workers, all_val\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m 34\u001b[0m validator \u001b[39m=\u001b[39m AccuracyValidator(key_map\u001b[39m=\u001b[39m{\u001b[39m\"\u001b[39m\u001b[39msrc_val\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m\"\u001b[39m\u001b[39msrc_val\u001b[39m\u001b[39m\"\u001b[39m})\n", "\u001b[0;31mAttributeError\u001b[0m: 'Namespace' object has no attribute 'dataroot'" ] } ], "source": [ "\n", "output_dir = \"tmp\"\n", "checkpoint_fn = CheckpointFnCreator(dirname=f\"{output_dir}/saved_models\", require_empty=False)\n", "\n", "sourceAccuracyValidator = AccuracyValidator()\n", "val_hooks = [ScoreHistory(sourceAccuracyValidator)]\n", "\n", "trainer = Ignite(\n", " adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn, device=device\n", ")\n", "print(trainer.device)\n", "\n", "early_stopper_kwargs = {\"patience\": args.patience}\n", "\n", "start_time = datetime.now()\n", "\n", "best_score, best_epoch = trainer.run(\n", " datasets, dataloader_creator=dc, max_epochs=args.max_epochs, early_stopper_kwargs=early_stopper_kwargs\n", ")\n", "\n", "end_time = datetime.now()\n", "training_time = end_time - start_time\n", "\n", "print(f\"best_score={best_score}, best_epoch={best_epoch}\")\n", "\n", "plt.plot(val_hooks[0].score_history, label='source')\n", "plt.title(\"val accuracy\")\n", "plt.legend()\n", "plt.savefig(f\"{output_dir}/val_accuracy.png\")\n", "plt.close('all')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "173d54ab994d4abda6e4f0897ad96c49", "version_major": 2, "version_minor": 0 }, "text/plain": [ "[1/9] 11%|#1 |it [00:00here for more info. View Jupyter log for further details." ] } ], "source": [ "args" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "-----" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "path = \"/media/10TB71/shashemi/Domain-Adaptation/results/DAModels.CDAN/2000/a2d/saved_models\"" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "source_domain = 'amazon'\n", "target_domain = 'dslr'\n", "G = office31G(pretrained=False).to(device)\n", "C = office31C(pretrained=False).to(device)\n", "\n", "\n", "optimizers = Optimizers((torch.optim.Adam, {\"lr\": 1e-4}))\n", "lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {\"gamma\": 0.99})) \n", "\n", "models = Models({\"G\": G, \"C\": C})\n", "adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)\n", "\n", "\n", "output_dir = \"tmp\"\n", "checkpoint_fn = CheckpointFnCreator(dirname=f\"{output_dir}/saved_models\", require_empty=False)\n", "\n", "sourceAccuracyValidator = AccuracyValidator()\n", "val_hooks = [ScoreHistory(sourceAccuracyValidator)]\n", "\n", "new_trainer = Ignite(\n", " adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn, device=device\n", ")\n", "\n", "from pytorch_adapt.frameworks.ignite import (\n", " CheckpointFnCreator,\n", " IgniteValHookWrapper,\n", " checkpoint_utils,\n", ")\n", "\n", "objs = [\n", " {\n", " \"engine\": new_trainer.trainer,\n", " \"validator\": new_trainer.validator,\n", " \"val_hook0\": val_hooks[0],\n", " **checkpoint_utils.adapter_to_dict(new_trainer.adapter),\n", " }\n", " ]\n", " \n", "# best_score, best_epoch = trainer.run(\n", "# datasets, dataloader_creator=dc, max_epochs=args.max_epochs, early_stopper_kwargs=early_stopper_kwargs\n", "# )\n", "\n", "for to_load in objs:\n", " checkpoint_fn.load_best_checkpoint(to_load)\n", "\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "926966dd640e4979ade6a45cf0fcdd49", "version_major": 2, "version_minor": 0 }, "text/plain": [ "[1/9] 11%|#1 |it [00:00 1\u001b[0m h1(datasets)\n", "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/base.py:52\u001b[0m, in \u001b[0;36mBaseHook.__call__\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 51\u001b[0m inputs \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mmap_keys(inputs, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkey_map)\n\u001b[0;32m---> 52\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcall(inputs, losses)\n\u001b[1;32m 53\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(x, (\u001b[39mbool\u001b[39m, np\u001b[39m.\u001b[39mbool_)):\n\u001b[1;32m 54\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlogger\u001b[39m.\u001b[39mreset()\n", "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/utils.py:109\u001b[0m, in \u001b[0;36mChainHook.call\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 107\u001b[0m all_losses \u001b[39m=\u001b[39m {\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mall_losses, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mprev_losses}\n\u001b[1;32m 108\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconditions[i](all_inputs, all_losses):\n\u001b[0;32m--> 109\u001b[0m x \u001b[39m=\u001b[39m h(all_inputs, all_losses)\n\u001b[1;32m 110\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 111\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39malts[i](all_inputs, all_losses)\n", "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/base.py:52\u001b[0m, in \u001b[0;36mBaseHook.__call__\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 51\u001b[0m inputs \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mmap_keys(inputs, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkey_map)\n\u001b[0;32m---> 52\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcall(inputs, losses)\n\u001b[1;32m 53\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(x, (\u001b[39mbool\u001b[39m, np\u001b[39m.\u001b[39mbool_)):\n\u001b[1;32m 54\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlogger\u001b[39m.\u001b[39mreset()\n", "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/features.py:80\u001b[0m, in \u001b[0;36mBaseFeaturesHook.call\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 78\u001b[0m func \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmode_detached \u001b[39mif\u001b[39;00m detach \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmode_with_grad\n\u001b[1;32m 79\u001b[0m in_keys \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mfilter(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39min_keys, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m^\u001b[39m\u001b[39m{\u001b[39;00mdomain\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[0;32m---> 80\u001b[0m func(inputs, outputs, domain, in_keys)\n\u001b[1;32m 82\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcheck_outputs_requires_grad(outputs)\n\u001b[1;32m 83\u001b[0m \u001b[39mreturn\u001b[39;00m outputs, {}\n", "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/features.py:106\u001b[0m, in \u001b[0;36mBaseFeaturesHook.mode_with_grad\u001b[0;34m(self, inputs, outputs, domain, in_keys)\u001b[0m\n\u001b[1;32m 104\u001b[0m output_keys \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mfilter(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_out_keys(), \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m^\u001b[39m\u001b[39m{\u001b[39;00mdomain\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 105\u001b[0m output_vals \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mget_kwargs(inputs, output_keys)\n\u001b[0;32m--> 106\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49madd_if_new(\n\u001b[1;32m 107\u001b[0m outputs, output_keys, output_vals, inputs, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel_name, in_keys, domain\n\u001b[1;32m 108\u001b[0m )\n\u001b[1;32m 109\u001b[0m \u001b[39mreturn\u001b[39;00m output_keys, output_vals\n", "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/features.py:133\u001b[0m, in \u001b[0;36mBaseFeaturesHook.add_if_new\u001b[0;34m(self, outputs, full_key, output_vals, inputs, model_name, in_keys, domain)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39madd_if_new\u001b[39m(\n\u001b[1;32m 131\u001b[0m \u001b[39mself\u001b[39m, outputs, full_key, output_vals, inputs, model_name, in_keys, domain\n\u001b[1;32m 132\u001b[0m ):\n\u001b[0;32m--> 133\u001b[0m c_f\u001b[39m.\u001b[39;49madd_if_new(\n\u001b[1;32m 134\u001b[0m outputs,\n\u001b[1;32m 135\u001b[0m full_key,\n\u001b[1;32m 136\u001b[0m output_vals,\n\u001b[1;32m 137\u001b[0m inputs,\n\u001b[1;32m 138\u001b[0m model_name,\n\u001b[1;32m 139\u001b[0m in_keys,\n\u001b[1;32m 140\u001b[0m logger\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mlogger,\n\u001b[1;32m 141\u001b[0m )\n", "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/utils/common_functions.py:96\u001b[0m, in \u001b[0;36madd_if_new\u001b[0;34m(d, key, x, kwargs, model_name, in_keys, other_args, logger)\u001b[0m\n\u001b[1;32m 94\u001b[0m condition \u001b[39m=\u001b[39m is_none\n\u001b[1;32m 95\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39many\u001b[39m(condition(y) \u001b[39mfor\u001b[39;00m y \u001b[39min\u001b[39;00m x):\n\u001b[0;32m---> 96\u001b[0m model \u001b[39m=\u001b[39m kwargs[model_name]\n\u001b[1;32m 97\u001b[0m input_vals \u001b[39m=\u001b[39m [kwargs[k] \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m in_keys] \u001b[39m+\u001b[39m \u001b[39mlist\u001b[39m(other_args\u001b[39m.\u001b[39mvalues())\n\u001b[1;32m 98\u001b[0m new_x \u001b[39m=\u001b[39m try_use_model(model, model_name, input_vals)\n", "\u001b[0;31mKeyError\u001b[0m: in FeaturesAndLogitsHook: __call__\nin FeaturesHook: __call__\nFeaturesHook: Getting src\nFeaturesHook: Getting output: ['src_imgs_features']\nFeaturesHook: Using model G with inputs: src_imgs\nG" ] } ], "source": [ "h1(datasets)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "cdtrans", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.15 (default, Nov 24 2022, 15:19:38) \n[GCC 11.2.0]" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "959b82c3a41427bdf7d14d4ba7335271e0c50cfcddd70501934b27dcc36968ad" } } }, "nbformat": 4, "nbformat_minor": 2 }