You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test-save-load.ipynb 34KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "metadata": {},
  7. "outputs": [
  8. {
  9. "data": {
  10. "text/plain": [
  11. "device(type='cuda', index=0)"
  12. ]
  13. },
  14. "execution_count": 1,
  15. "metadata": {},
  16. "output_type": "execute_result"
  17. }
  18. ],
  19. "source": [
  20. "\n",
  21. "import torch\n",
  22. "import os\n",
  23. "\n",
  24. "from pytorch_adapt.adapters import DANN, MCD, VADA, CDAN, RTN, ADDA, Aligner, SymNets\n",
  25. "from pytorch_adapt.containers import Models, Optimizers, LRSchedulers\n",
  26. "from pytorch_adapt.models import Discriminator, office31C, office31G\n",
  27. "from pytorch_adapt.containers import Misc\n",
  28. "from pytorch_adapt.layers import RandomizedDotProduct\n",
  29. "from pytorch_adapt.layers import MultipleModels, CORALLoss, MMDLoss\n",
  30. "from pytorch_adapt.utils import common_functions\n",
  31. "from pytorch_adapt.containers import LRSchedulers\n",
  32. "\n",
  33. "from classifier_adapter import ClassifierAdapter\n",
  34. "\n",
  35. "from utils import HP, DAModels\n",
  36. "\n",
  37. "import copy\n",
  38. "\n",
  39. "import matplotlib.pyplot as plt\n",
  40. "import torch\n",
  41. "import os\n",
  42. "import gc\n",
  43. "from datetime import datetime\n",
  44. "\n",
  45. "from pytorch_adapt.datasets import DataloaderCreator, get_office31\n",
  46. "from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite\n",
  47. "from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory, DiversityValidator, EntropyValidator, MultipleValidators\n",
  48. "\n",
  49. "from models import get_model\n",
  50. "from utils import DAModels\n",
  51. "\n",
  52. "from vis_hook import VizHook\n",
  53. "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
  54. "device"
  55. ]
  56. },
  57. {
  58. "cell_type": "code",
  59. "execution_count": 2,
  60. "metadata": {},
  61. "outputs": [
  62. {
  63. "name": "stdout",
  64. "output_type": "stream",
  65. "text": [
  66. "Namespace(batch_size=64, data_root='./datasets/pytorch-adapt/', download=False, gamma=0.99, hp_tune=False, initial_trial=0, lr=0.0001, max_epochs=1, model_names=['DANN'], num_workers=1, patience=2, results_root='./results/', root='./', source=None, target=None, trials_count=1, vishook_frequency=5)\n"
  67. ]
  68. }
  69. ],
  70. "source": [
  71. "import argparse\n",
  72. "parser = argparse.ArgumentParser()\n",
  73. "parser.add_argument('--max_epochs', default=1, type=int)\n",
  74. "parser.add_argument('--patience', default=2, type=int)\n",
  75. "parser.add_argument('--batch_size', default=64, type=int)\n",
  76. "parser.add_argument('--num_workers', default=1, type=int)\n",
  77. "parser.add_argument('--trials_count', default=1, type=int)\n",
  78. "parser.add_argument('--initial_trial', default=0, type=int)\n",
  79. "parser.add_argument('--download', default=False, type=bool)\n",
  80. "parser.add_argument('--root', default=\"./\")\n",
  81. "parser.add_argument('--data_root', default=\"./datasets/pytorch-adapt/\")\n",
  82. "parser.add_argument('--results_root', default=\"./results/\")\n",
  83. "parser.add_argument('--model_names', default=[\"DANN\"], nargs='+')\n",
  84. "parser.add_argument('--lr', default=0.0001, type=float)\n",
  85. "parser.add_argument('--gamma', default=0.99, type=float)\n",
  86. "parser.add_argument('--hp_tune', default=False, type=bool)\n",
  87. "parser.add_argument('--source', default=None)\n",
  88. "parser.add_argument('--target', default=None) \n",
  89. "parser.add_argument('--vishook_frequency', default=5, type=int)\n",
  90. " \n",
  91. "\n",
  92. "args = parser.parse_args(\"\")\n",
  93. "print(args)\n"
  94. ]
  95. },
  96. {
  97. "cell_type": "code",
  98. "execution_count": 45,
  99. "metadata": {},
  100. "outputs": [],
  101. "source": [
  102. "source_domain = 'amazon'\n",
  103. "target_domain = 'webcam'\n",
  104. "datasets = get_office31([source_domain], [],\n",
  105. " folder=args.data_root,\n",
  106. " return_target_with_labels=True,\n",
  107. " download=args.download)\n",
  108. "\n",
  109. "dc = DataloaderCreator(batch_size=args.batch_size,\n",
  110. " num_workers=args.num_workers,\n",
  111. " )\n",
  112. "\n",
  113. "weights_root = os.path.join(args.data_root, \"weights\")\n",
  114. "\n",
  115. "G = office31G(pretrained=True, model_dir=weights_root).to(device)\n",
  116. "C = office31C(domain=source_domain, pretrained=True,\n",
  117. " model_dir=weights_root).to(device)\n",
  118. "\n",
  119. "\n",
  120. "optimizers = Optimizers((torch.optim.Adam, {\"lr\": 1e-4}))\n",
  121. "lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {\"gamma\": 0.99})) \n",
  122. "\n",
  123. "models = Models({\"G\": G, \"C\": C})\n",
  124. "adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)"
  125. ]
  126. },
  127. {
  128. "cell_type": "code",
  129. "execution_count": null,
  130. "metadata": {},
  131. "outputs": [
  132. {
  133. "name": "stdout",
  134. "output_type": "stream",
  135. "text": [
  136. "cuda:0\n"
  137. ]
  138. },
  139. {
  140. "data": {
  141. "application/vnd.jupyter.widget-view+json": {
  142. "model_id": "f28aaf5a334d4f91a9beb21e714c43a5",
  143. "version_major": 2,
  144. "version_minor": 0
  145. },
  146. "text/plain": [
  147. "[1/35] 3%|2 |it [00:00<?]"
  148. ]
  149. },
  150. "metadata": {},
  151. "output_type": "display_data"
  152. },
  153. {
  154. "data": {
  155. "application/vnd.jupyter.widget-view+json": {
  156. "model_id": "7131d8b9099c4d0a95155595919c55f5",
  157. "version_major": 2,
  158. "version_minor": 0
  159. },
  160. "text/plain": [
  161. "[1/9] 11%|#1 |it [00:00<?]"
  162. ]
  163. },
  164. "metadata": {},
  165. "output_type": "display_data"
  166. },
  167. {
  168. "name": "stdout",
  169. "output_type": "stream",
  170. "text": [
  171. "best_score=None, best_epoch=None\n"
  172. ]
  173. },
  174. {
  175. "ename": "AttributeError",
  176. "evalue": "'Namespace' object has no attribute 'dataroot'",
  177. "output_type": "error",
  178. "traceback": [
  179. "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
  180. "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
  181. "Cell \u001b[0;32mIn[39], line 31\u001b[0m\n\u001b[1;32m 28\u001b[0m plt\u001b[39m.\u001b[39msavefig(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00moutput_dir\u001b[39m}\u001b[39;00m\u001b[39m/val_accuracy.png\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 29\u001b[0m plt\u001b[39m.\u001b[39mclose(\u001b[39m'\u001b[39m\u001b[39mall\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[0;32m---> 31\u001b[0m datasets \u001b[39m=\u001b[39m get_office31([source_domain], [target_domain], folder\u001b[39m=\u001b[39margs\u001b[39m.\u001b[39;49mdataroot, return_target_with_labels\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m 32\u001b[0m dc \u001b[39m=\u001b[39m DataloaderCreator(batch_size\u001b[39m=\u001b[39margs\u001b[39m.\u001b[39mbatch_size, num_workers\u001b[39m=\u001b[39margs\u001b[39m.\u001b[39mnum_workers, all_val\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m 34\u001b[0m validator \u001b[39m=\u001b[39m AccuracyValidator(key_map\u001b[39m=\u001b[39m{\u001b[39m\"\u001b[39m\u001b[39msrc_val\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m\"\u001b[39m\u001b[39msrc_val\u001b[39m\u001b[39m\"\u001b[39m})\n",
  182. "\u001b[0;31mAttributeError\u001b[0m: 'Namespace' object has no attribute 'dataroot'"
  183. ]
  184. }
  185. ],
  186. "source": [
  187. "\n",
  188. "output_dir = \"tmp\"\n",
  189. "checkpoint_fn = CheckpointFnCreator(dirname=f\"{output_dir}/saved_models\", require_empty=False)\n",
  190. "\n",
  191. "sourceAccuracyValidator = AccuracyValidator()\n",
  192. "val_hooks = [ScoreHistory(sourceAccuracyValidator)]\n",
  193. "\n",
  194. "trainer = Ignite(\n",
  195. " adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn, device=device\n",
  196. ")\n",
  197. "print(trainer.device)\n",
  198. "\n",
  199. "early_stopper_kwargs = {\"patience\": args.patience}\n",
  200. "\n",
  201. "start_time = datetime.now()\n",
  202. "\n",
  203. "best_score, best_epoch = trainer.run(\n",
  204. " datasets, dataloader_creator=dc, max_epochs=args.max_epochs, early_stopper_kwargs=early_stopper_kwargs\n",
  205. ")\n",
  206. "\n",
  207. "end_time = datetime.now()\n",
  208. "training_time = end_time - start_time\n",
  209. "\n",
  210. "print(f\"best_score={best_score}, best_epoch={best_epoch}\")\n",
  211. "\n",
  212. "plt.plot(val_hooks[0].score_history, label='source')\n",
  213. "plt.title(\"val accuracy\")\n",
  214. "plt.legend()\n",
  215. "plt.savefig(f\"{output_dir}/val_accuracy.png\")\n",
  216. "plt.close('all')\n"
  217. ]
  218. },
  219. {
  220. "cell_type": "code",
  221. "execution_count": null,
  222. "metadata": {},
  223. "outputs": [
  224. {
  225. "data": {
  226. "application/vnd.jupyter.widget-view+json": {
  227. "model_id": "173d54ab994d4abda6e4f0897ad96c49",
  228. "version_major": 2,
  229. "version_minor": 0
  230. },
  231. "text/plain": [
  232. "[1/9] 11%|#1 |it [00:00<?]"
  233. ]
  234. },
  235. "metadata": {},
  236. "output_type": "display_data"
  237. },
  238. {
  239. "name": "stdout",
  240. "output_type": "stream",
  241. "text": [
  242. "Source acc: 0.868794322013855\n"
  243. ]
  244. },
  245. {
  246. "data": {
  247. "application/vnd.jupyter.widget-view+json": {
  248. "model_id": "13b4a1ccc3b34456b68c73357d14bc21",
  249. "version_major": 2,
  250. "version_minor": 0
  251. },
  252. "text/plain": [
  253. "[1/3] 33%|###3 |it [00:00<?]"
  254. ]
  255. },
  256. "metadata": {},
  257. "output_type": "display_data"
  258. },
  259. {
  260. "name": "stdout",
  261. "output_type": "stream",
  262. "text": [
  263. "Target acc: 0.74842768907547\n",
  264. "---------\n"
  265. ]
  266. }
  267. ],
  268. "source": [
  269. "\n",
  270. "datasets = get_office31([source_domain], [target_domain], folder=args.data_root, return_target_with_labels=True)\n",
  271. "dc = DataloaderCreator(batch_size=args.batch_size, num_workers=args.num_workers, all_val=True)\n",
  272. "\n",
  273. "validator = AccuracyValidator(key_map={\"src_val\": \"src_val\"})\n",
  274. "src_score = trainer.evaluate_best_model(datasets, validator, dc)\n",
  275. "print(\"Source acc:\", src_score)\n",
  276. "\n",
  277. "validator = AccuracyValidator(key_map={\"target_val_with_labels\": \"src_val\"})\n",
  278. "target_score = trainer.evaluate_best_model(datasets, validator, dc)\n",
  279. "print(\"Target acc:\", target_score)\n",
  280. "print(\"---------\")"
  281. ]
  282. },
  283. {
  284. "cell_type": "code",
  285. "execution_count": null,
  286. "metadata": {},
  287. "outputs": [],
  288. "source": [
  289. "C2 = copy.deepcopy(C) "
  290. ]
  291. },
  292. {
  293. "cell_type": "code",
  294. "execution_count": 93,
  295. "metadata": {},
  296. "outputs": [
  297. {
  298. "name": "stdout",
  299. "output_type": "stream",
  300. "text": [
  301. "cuda:0\n"
  302. ]
  303. }
  304. ],
  305. "source": [
  306. "source_domain = 'amazon'\n",
  307. "target_domain = 'webcam'\n",
  308. "G = office31G(pretrained=False).to(device)\n",
  309. "C = office31C(pretrained=False).to(device)\n",
  310. "\n",
  311. "\n",
  312. "optimizers = Optimizers((torch.optim.Adam, {\"lr\": 1e-4}))\n",
  313. "lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {\"gamma\": 0.99})) \n",
  314. "\n",
  315. "models = Models({\"G\": G, \"C\": C})\n",
  316. "adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)\n",
  317. "\n",
  318. "\n",
  319. "output_dir = \"tmp\"\n",
  320. "checkpoint_fn = CheckpointFnCreator(dirname=f\"{output_dir}/saved_models\", require_empty=False)\n",
  321. "\n",
  322. "sourceAccuracyValidator = AccuracyValidator()\n",
  323. "val_hooks = [ScoreHistory(sourceAccuracyValidator)]\n",
  324. "\n",
  325. "new_trainer = Ignite(\n",
  326. " adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn, device=device\n",
  327. ")\n",
  328. "print(trainer.device)\n",
  329. "\n",
  330. "from pytorch_adapt.frameworks.ignite import (\n",
  331. " CheckpointFnCreator,\n",
  332. " IgniteValHookWrapper,\n",
  333. " checkpoint_utils,\n",
  334. ")\n",
  335. "\n",
  336. "objs = [\n",
  337. " {\n",
  338. " \"engine\": new_trainer.trainer,\n",
  339. " \"validator\": new_trainer.validator,\n",
  340. " \"val_hook0\": val_hooks[0],\n",
  341. " **checkpoint_utils.adapter_to_dict(new_trainer.adapter),\n",
  342. " }\n",
  343. " ]\n",
  344. " \n",
  345. "# best_score, best_epoch = trainer.run(\n",
  346. "# datasets, dataloader_creator=dc, max_epochs=args.max_epochs, early_stopper_kwargs=early_stopper_kwargs\n",
  347. "# )\n",
  348. "\n",
  349. "for to_load in objs:\n",
  350. " checkpoint_fn.load_best_checkpoint(to_load)\n",
  351. "\n"
  352. ]
  353. },
  354. {
  355. "cell_type": "code",
  356. "execution_count": 94,
  357. "metadata": {},
  358. "outputs": [
  359. {
  360. "data": {
  361. "application/vnd.jupyter.widget-view+json": {
  362. "model_id": "32f01ff7ea254739909e4567a133b00a",
  363. "version_major": 2,
  364. "version_minor": 0
  365. },
  366. "text/plain": [
  367. "[1/9] 11%|#1 |it [00:00<?]"
  368. ]
  369. },
  370. "metadata": {},
  371. "output_type": "display_data"
  372. },
  373. {
  374. "name": "stdout",
  375. "output_type": "stream",
  376. "text": [
  377. "Source acc: 0.868794322013855\n"
  378. ]
  379. },
  380. {
  381. "data": {
  382. "application/vnd.jupyter.widget-view+json": {
  383. "model_id": "cef345c05e5e46eb9fc0e1cc40b02435",
  384. "version_major": 2,
  385. "version_minor": 0
  386. },
  387. "text/plain": [
  388. "[1/3] 33%|###3 |it [00:00<?]"
  389. ]
  390. },
  391. "metadata": {},
  392. "output_type": "display_data"
  393. },
  394. {
  395. "name": "stdout",
  396. "output_type": "stream",
  397. "text": [
  398. "Target acc: 0.74842768907547\n",
  399. "---------\n"
  400. ]
  401. }
  402. ],
  403. "source": [
  404. "\n",
  405. "datasets = get_office31([source_domain], [target_domain], folder=args.data_root, return_target_with_labels=True)\n",
  406. "dc = DataloaderCreator(batch_size=args.batch_size, num_workers=args.num_workers, all_val=True)\n",
  407. "\n",
  408. "validator = AccuracyValidator(key_map={\"src_val\": \"src_val\"})\n",
  409. "src_score = new_trainer.evaluate_best_model(datasets, validator, dc)\n",
  410. "print(\"Source acc:\", src_score)\n",
  411. "\n",
  412. "validator = AccuracyValidator(key_map={\"target_val_with_labels\": \"src_val\"})\n",
  413. "target_score = new_trainer.evaluate_best_model(datasets, validator, dc)\n",
  414. "print(\"Target acc:\", target_score)\n",
  415. "print(\"---------\")"
  416. ]
  417. },
  418. {
  419. "cell_type": "code",
  420. "execution_count": 89,
  421. "metadata": {},
  422. "outputs": [],
  423. "source": [
  424. "\n",
  425. "datasets = get_office31([source_domain], [target_domain],\n",
  426. " folder=args.data_root,\n",
  427. " return_target_with_labels=True,\n",
  428. " download=args.download)\n",
  429. " \n",
  430. "dc = DataloaderCreator(batch_size=args.batch_size,\n",
  431. " num_workers=args.num_workers,\n",
  432. " train_names=[\"train\"],\n",
  433. " val_names=[\"src_train\", \"target_train\", \"src_val\", \"target_val\",\n",
  434. " \"target_train_with_labels\", \"target_val_with_labels\"])\n",
  435. "\n",
  436. "G = new_trainer.adapter.models[\"G\"]\n",
  437. "C = new_trainer.adapter.models[\"C\"]\n",
  438. "D = Discriminator(in_size=2048, h=1024).to(device)\n",
  439. "\n",
  440. "optimizers = Optimizers((torch.optim.Adam, {\"lr\": 0.001}))\n",
  441. "lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {\"gamma\": 0.99}))\n",
  442. "# lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.MultiStepLR, {\"milestones\": [2, 5, 10, 20, 40], \"gamma\": hp.gamma}))\n",
  443. "\n",
  444. "models = Models({\"G\": G, \"C\": C, \"D\": D})\n",
  445. "adapter = DANN(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)\n"
  446. ]
  447. },
  448. {
  449. "cell_type": "code",
  450. "execution_count": 90,
  451. "metadata": {},
  452. "outputs": [
  453. {
  454. "name": "stdout",
  455. "output_type": "stream",
  456. "text": [
  457. "cuda:0\n"
  458. ]
  459. },
  460. {
  461. "data": {
  462. "application/vnd.jupyter.widget-view+json": {
  463. "model_id": "bf490d18567444149070191e100f8c45",
  464. "version_major": 2,
  465. "version_minor": 0
  466. },
  467. "text/plain": [
  468. "[1/3] 33%|###3 |it [00:00<?]"
  469. ]
  470. },
  471. "metadata": {},
  472. "output_type": "display_data"
  473. },
  474. {
  475. "data": {
  476. "application/vnd.jupyter.widget-view+json": {
  477. "model_id": "525920fcd19d4178a4bada48932c8fb1",
  478. "version_major": 2,
  479. "version_minor": 0
  480. },
  481. "text/plain": [
  482. "[1/9] 11%|#1 |it [00:00<?]"
  483. ]
  484. },
  485. "metadata": {},
  486. "output_type": "display_data"
  487. },
  488. {
  489. "data": {
  490. "application/vnd.jupyter.widget-view+json": {
  491. "model_id": "bd1158f548e746cdab88d608b22ab65c",
  492. "version_major": 2,
  493. "version_minor": 0
  494. },
  495. "text/plain": [
  496. "[1/9] 11%|#1 |it [00:00<?]"
  497. ]
  498. },
  499. "metadata": {},
  500. "output_type": "display_data"
  501. },
  502. {
  503. "data": {
  504. "application/vnd.jupyter.widget-view+json": {
  505. "model_id": "196fa120037b48fdb4e9a879e7e7c79b",
  506. "version_major": 2,
  507. "version_minor": 0
  508. },
  509. "text/plain": [
  510. "[1/3] 33%|###3 |it [00:00<?]"
  511. ]
  512. },
  513. "metadata": {},
  514. "output_type": "display_data"
  515. },
  516. {
  517. "data": {
  518. "application/vnd.jupyter.widget-view+json": {
  519. "model_id": "6795edb658a84309b1a03bcea6a24643",
  520. "version_major": 2,
  521. "version_minor": 0
  522. },
  523. "text/plain": [
  524. "[1/9] 11%|#1 |it [00:00<?]"
  525. ]
  526. },
  527. "metadata": {},
  528. "output_type": "display_data"
  529. }
  530. ],
  531. "source": [
  532. "\n",
  533. "output_dir = \"tmp\"\n",
  534. "checkpoint_fn = CheckpointFnCreator(dirname=f\"{output_dir}/saved_models\", require_empty=False)\n",
  535. "\n",
  536. "sourceAccuracyValidator = AccuracyValidator()\n",
  537. "targetAccuracyValidator = AccuracyValidator(key_map={\"target_val_with_labels\": \"src_val\"})\n",
  538. "val_hooks = [ScoreHistory(sourceAccuracyValidator), ScoreHistory(targetAccuracyValidator)]\n",
  539. "\n",
  540. "trainer = Ignite(\n",
  541. " adapter, val_hooks=val_hooks, device=device\n",
  542. ")\n",
  543. "print(trainer.device)\n",
  544. "\n",
  545. "best_score, best_epoch = trainer.run(\n",
  546. " datasets, dataloader_creator=dc, max_epochs=args.max_epochs, early_stopper_kwargs=early_stopper_kwargs, check_initial_score=True\n",
  547. ")\n"
  548. ]
  549. },
  550. {
  551. "cell_type": "code",
  552. "execution_count": 91,
  553. "metadata": {},
  554. "outputs": [
  555. {
  556. "data": {
  557. "text/plain": [
  558. "ScoreHistory(\n",
  559. " validator=AccuracyValidator(required_data=['src_val'])\n",
  560. " latest_score=0.30319148302078247\n",
  561. " best_score=0.868794322013855\n",
  562. " best_epoch=0\n",
  563. ")"
  564. ]
  565. },
  566. "execution_count": 91,
  567. "metadata": {},
  568. "output_type": "execute_result"
  569. }
  570. ],
  571. "source": [
  572. "val_hooks[0]"
  573. ]
  574. },
  575. {
  576. "cell_type": "code",
  577. "execution_count": 92,
  578. "metadata": {},
  579. "outputs": [
  580. {
  581. "data": {
  582. "text/plain": [
  583. "ScoreHistory(\n",
  584. " validator=AccuracyValidator(required_data=['target_val_with_labels'])\n",
  585. " latest_score=0.2515723407268524\n",
  586. " best_score=0.74842768907547\n",
  587. " best_epoch=0\n",
  588. ")"
  589. ]
  590. },
  591. "execution_count": 92,
  592. "metadata": {},
  593. "output_type": "execute_result"
  594. }
  595. ],
  596. "source": [
  597. "val_hooks[1]"
  598. ]
  599. },
  600. {
  601. "cell_type": "code",
  602. "execution_count": 86,
  603. "metadata": {},
  604. "outputs": [],
  605. "source": [
  606. "del trainer"
  607. ]
  608. },
  609. {
  610. "cell_type": "code",
  611. "execution_count": 87,
  612. "metadata": {},
  613. "outputs": [
  614. {
  615. "data": {
  616. "text/plain": [
  617. "21169"
  618. ]
  619. },
  620. "execution_count": 87,
  621. "metadata": {},
  622. "output_type": "execute_result"
  623. }
  624. ],
  625. "source": [
  626. "import gc\n",
  627. "gc.collect()"
  628. ]
  629. },
  630. {
  631. "cell_type": "code",
  632. "execution_count": 88,
  633. "metadata": {},
  634. "outputs": [],
  635. "source": [
  636. "torch.cuda.empty_cache()"
  637. ]
  638. },
  639. {
  640. "cell_type": "code",
  641. "execution_count": 95,
  642. "metadata": {},
  643. "outputs": [],
  644. "source": [
  645. "args.vishook_frequency = 133"
  646. ]
  647. },
  648. {
  649. "cell_type": "code",
  650. "execution_count": 96,
  651. "metadata": {},
  652. "outputs": [
  653. {
  654. "data": {
  655. "text/plain": [
  656. "Namespace(batch_size=64, data_root='./datasets/pytorch-adapt/', download=False, gamma=0.99, hp_tune=False, initial_trial=0, lr=0.0001, max_epochs=1, model_names=['DANN'], num_workers=1, patience=2, results_root='./results/', root='./', source=None, target=None, trials_count=1, vishook_frequency=133)"
  657. ]
  658. },
  659. "execution_count": 96,
  660. "metadata": {},
  661. "output_type": "execute_result"
  662. },
  663. {
  664. "ename": "",
  665. "evalue": "",
  666. "output_type": "error",
  667. "traceback": [
  668. "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
  669. ]
  670. }
  671. ],
  672. "source": [
  673. "args"
  674. ]
  675. },
  676. {
  677. "cell_type": "code",
  678. "execution_count": null,
  679. "metadata": {},
  680. "outputs": [],
  681. "source": []
  682. },
  683. {
  684. "attachments": {},
  685. "cell_type": "markdown",
  686. "metadata": {},
  687. "source": [
  688. "-----"
  689. ]
  690. },
  691. {
  692. "cell_type": "code",
  693. "execution_count": 3,
  694. "metadata": {},
  695. "outputs": [],
  696. "source": [
  697. "path = \"/media/10TB71/shashemi/Domain-Adaptation/results/DAModels.CDAN/2000/a2d/saved_models\""
  698. ]
  699. },
  700. {
  701. "cell_type": "code",
  702. "execution_count": 5,
  703. "metadata": {},
  704. "outputs": [],
  705. "source": [
  706. "source_domain = 'amazon'\n",
  707. "target_domain = 'dslr'\n",
  708. "G = office31G(pretrained=False).to(device)\n",
  709. "C = office31C(pretrained=False).to(device)\n",
  710. "\n",
  711. "\n",
  712. "optimizers = Optimizers((torch.optim.Adam, {\"lr\": 1e-4}))\n",
  713. "lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {\"gamma\": 0.99})) \n",
  714. "\n",
  715. "models = Models({\"G\": G, \"C\": C})\n",
  716. "adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)\n",
  717. "\n",
  718. "\n",
  719. "output_dir = \"tmp\"\n",
  720. "checkpoint_fn = CheckpointFnCreator(dirname=f\"{output_dir}/saved_models\", require_empty=False)\n",
  721. "\n",
  722. "sourceAccuracyValidator = AccuracyValidator()\n",
  723. "val_hooks = [ScoreHistory(sourceAccuracyValidator)]\n",
  724. "\n",
  725. "new_trainer = Ignite(\n",
  726. " adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn, device=device\n",
  727. ")\n",
  728. "\n",
  729. "from pytorch_adapt.frameworks.ignite import (\n",
  730. " CheckpointFnCreator,\n",
  731. " IgniteValHookWrapper,\n",
  732. " checkpoint_utils,\n",
  733. ")\n",
  734. "\n",
  735. "objs = [\n",
  736. " {\n",
  737. " \"engine\": new_trainer.trainer,\n",
  738. " \"validator\": new_trainer.validator,\n",
  739. " \"val_hook0\": val_hooks[0],\n",
  740. " **checkpoint_utils.adapter_to_dict(new_trainer.adapter),\n",
  741. " }\n",
  742. " ]\n",
  743. " \n",
  744. "# best_score, best_epoch = trainer.run(\n",
  745. "# datasets, dataloader_creator=dc, max_epochs=args.max_epochs, early_stopper_kwargs=early_stopper_kwargs\n",
  746. "# )\n",
  747. "\n",
  748. "for to_load in objs:\n",
  749. " checkpoint_fn.load_best_checkpoint(to_load)\n",
  750. "\n"
  751. ]
  752. },
  753. {
  754. "cell_type": "code",
  755. "execution_count": 6,
  756. "metadata": {},
  757. "outputs": [
  758. {
  759. "data": {
  760. "application/vnd.jupyter.widget-view+json": {
  761. "model_id": "926966dd640e4979ade6a45cf0fcdd49",
  762. "version_major": 2,
  763. "version_minor": 0
  764. },
  765. "text/plain": [
  766. "[1/9] 11%|#1 |it [00:00<?]"
  767. ]
  768. },
  769. "metadata": {},
  770. "output_type": "display_data"
  771. },
  772. {
  773. "name": "stdout",
  774. "output_type": "stream",
  775. "text": [
  776. "Source acc: 0.868794322013855\n"
  777. ]
  778. },
  779. {
  780. "data": {
  781. "application/vnd.jupyter.widget-view+json": {
  782. "model_id": "64cd5cfb052c4f52af9af1a63a4c0087",
  783. "version_major": 2,
  784. "version_minor": 0
  785. },
  786. "text/plain": [
  787. "[1/2] 50%|##### |it [00:00<?]"
  788. ]
  789. },
  790. "metadata": {},
  791. "output_type": "display_data"
  792. },
  793. {
  794. "name": "stdout",
  795. "output_type": "stream",
  796. "text": [
  797. "Target acc: 0.7200000286102295\n",
  798. "---------\n"
  799. ]
  800. }
  801. ],
  802. "source": [
  803. "\n",
  804. "datasets = get_office31([source_domain], [target_domain], folder=args.data_root, return_target_with_labels=True)\n",
  805. "dc = DataloaderCreator(batch_size=args.batch_size, num_workers=args.num_workers, all_val=True)\n",
  806. "\n",
  807. "validator = AccuracyValidator(key_map={\"src_val\": \"src_val\"})\n",
  808. "src_score = new_trainer.evaluate_best_model(datasets, validator, dc)\n",
  809. "print(\"Source acc:\", src_score)\n",
  810. "\n",
  811. "validator = AccuracyValidator(key_map={\"target_val_with_labels\": \"src_val\"})\n",
  812. "target_score = new_trainer.evaluate_best_model(datasets, validator, dc)\n",
  813. "print(\"Target acc:\", target_score)\n",
  814. "print(\"---------\")"
  815. ]
  816. },
  817. {
  818. "cell_type": "code",
  819. "execution_count": 10,
  820. "metadata": {},
  821. "outputs": [],
  822. "source": [
  823. "source_domain = 'amazon'\n",
  824. "target_domain = 'dslr'\n",
  825. "G = new_trainer.adapter.models[\"G\"]\n",
  826. "C = new_trainer.adapter.models[\"C\"]\n",
  827. "\n",
  828. "G.fc = C.net[:6]\n",
  829. "C.net = C.net[6:]\n",
  830. "\n",
  831. "\n",
  832. "optimizers = Optimizers((torch.optim.Adam, {\"lr\": 1e-4}))\n",
  833. "lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {\"gamma\": 0.99})) \n",
  834. "\n",
  835. "models = Models({\"G\": G, \"C\": C})\n",
  836. "adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)\n",
  837. "\n",
  838. "\n",
  839. "output_dir = \"tmp\"\n",
  840. "checkpoint_fn = CheckpointFnCreator(dirname=f\"{output_dir}/saved_models\", require_empty=False)\n",
  841. "\n",
  842. "sourceAccuracyValidator = AccuracyValidator()\n",
  843. "val_hooks = [ScoreHistory(sourceAccuracyValidator)]\n",
  844. "\n",
  845. "more_new_trainer = Ignite(\n",
  846. " adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn, device=device\n",
  847. ")"
  848. ]
  849. },
  850. {
  851. "cell_type": "code",
  852. "execution_count": 13,
  853. "metadata": {},
  854. "outputs": [],
  855. "source": [
  856. "from pytorch_adapt.hooks import FeaturesAndLogitsHook\n",
  857. "\n",
  858. "h1 = FeaturesAndLogitsHook()"
  859. ]
  860. },
  861. {
  862. "cell_type": "code",
  863. "execution_count": 19,
  864. "metadata": {},
  865. "outputs": [
  866. {
  867. "ename": "KeyError",
  868. "evalue": "in FeaturesAndLogitsHook: __call__\nin FeaturesHook: __call__\nFeaturesHook: Getting src\nFeaturesHook: Getting output: ['src_imgs_features']\nFeaturesHook: Using model G with inputs: src_imgs\nG",
  869. "output_type": "error",
  870. "traceback": [
  871. "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
  872. "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
  873. "Cell \u001b[0;32mIn[19], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m h1(datasets)\n",
  874. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/base.py:52\u001b[0m, in \u001b[0;36mBaseHook.__call__\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 51\u001b[0m inputs \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mmap_keys(inputs, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkey_map)\n\u001b[0;32m---> 52\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcall(inputs, losses)\n\u001b[1;32m 53\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(x, (\u001b[39mbool\u001b[39m, np\u001b[39m.\u001b[39mbool_)):\n\u001b[1;32m 54\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlogger\u001b[39m.\u001b[39mreset()\n",
  875. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/utils.py:109\u001b[0m, in \u001b[0;36mChainHook.call\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 107\u001b[0m all_losses \u001b[39m=\u001b[39m {\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mall_losses, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mprev_losses}\n\u001b[1;32m 108\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconditions[i](all_inputs, all_losses):\n\u001b[0;32m--> 109\u001b[0m x \u001b[39m=\u001b[39m h(all_inputs, all_losses)\n\u001b[1;32m 110\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 111\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39malts[i](all_inputs, all_losses)\n",
  876. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/base.py:52\u001b[0m, in \u001b[0;36mBaseHook.__call__\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 51\u001b[0m inputs \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mmap_keys(inputs, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkey_map)\n\u001b[0;32m---> 52\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcall(inputs, losses)\n\u001b[1;32m 53\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(x, (\u001b[39mbool\u001b[39m, np\u001b[39m.\u001b[39mbool_)):\n\u001b[1;32m 54\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlogger\u001b[39m.\u001b[39mreset()\n",
  877. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/features.py:80\u001b[0m, in \u001b[0;36mBaseFeaturesHook.call\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 78\u001b[0m func \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmode_detached \u001b[39mif\u001b[39;00m detach \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmode_with_grad\n\u001b[1;32m 79\u001b[0m in_keys \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mfilter(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39min_keys, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m^\u001b[39m\u001b[39m{\u001b[39;00mdomain\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[0;32m---> 80\u001b[0m func(inputs, outputs, domain, in_keys)\n\u001b[1;32m 82\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcheck_outputs_requires_grad(outputs)\n\u001b[1;32m 83\u001b[0m \u001b[39mreturn\u001b[39;00m outputs, {}\n",
  878. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/features.py:106\u001b[0m, in \u001b[0;36mBaseFeaturesHook.mode_with_grad\u001b[0;34m(self, inputs, outputs, domain, in_keys)\u001b[0m\n\u001b[1;32m 104\u001b[0m output_keys \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mfilter(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_out_keys(), \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m^\u001b[39m\u001b[39m{\u001b[39;00mdomain\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 105\u001b[0m output_vals \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mget_kwargs(inputs, output_keys)\n\u001b[0;32m--> 106\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49madd_if_new(\n\u001b[1;32m 107\u001b[0m outputs, output_keys, output_vals, inputs, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel_name, in_keys, domain\n\u001b[1;32m 108\u001b[0m )\n\u001b[1;32m 109\u001b[0m \u001b[39mreturn\u001b[39;00m output_keys, output_vals\n",
  879. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/features.py:133\u001b[0m, in \u001b[0;36mBaseFeaturesHook.add_if_new\u001b[0;34m(self, outputs, full_key, output_vals, inputs, model_name, in_keys, domain)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39madd_if_new\u001b[39m(\n\u001b[1;32m 131\u001b[0m \u001b[39mself\u001b[39m, outputs, full_key, output_vals, inputs, model_name, in_keys, domain\n\u001b[1;32m 132\u001b[0m ):\n\u001b[0;32m--> 133\u001b[0m c_f\u001b[39m.\u001b[39;49madd_if_new(\n\u001b[1;32m 134\u001b[0m outputs,\n\u001b[1;32m 135\u001b[0m full_key,\n\u001b[1;32m 136\u001b[0m output_vals,\n\u001b[1;32m 137\u001b[0m inputs,\n\u001b[1;32m 138\u001b[0m model_name,\n\u001b[1;32m 139\u001b[0m in_keys,\n\u001b[1;32m 140\u001b[0m logger\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mlogger,\n\u001b[1;32m 141\u001b[0m )\n",
  880. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/utils/common_functions.py:96\u001b[0m, in \u001b[0;36madd_if_new\u001b[0;34m(d, key, x, kwargs, model_name, in_keys, other_args, logger)\u001b[0m\n\u001b[1;32m 94\u001b[0m condition \u001b[39m=\u001b[39m is_none\n\u001b[1;32m 95\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39many\u001b[39m(condition(y) \u001b[39mfor\u001b[39;00m y \u001b[39min\u001b[39;00m x):\n\u001b[0;32m---> 96\u001b[0m model \u001b[39m=\u001b[39m kwargs[model_name]\n\u001b[1;32m 97\u001b[0m input_vals \u001b[39m=\u001b[39m [kwargs[k] \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m in_keys] \u001b[39m+\u001b[39m \u001b[39mlist\u001b[39m(other_args\u001b[39m.\u001b[39mvalues())\n\u001b[1;32m 98\u001b[0m new_x \u001b[39m=\u001b[39m try_use_model(model, model_name, input_vals)\n",
  881. "\u001b[0;31mKeyError\u001b[0m: in FeaturesAndLogitsHook: __call__\nin FeaturesHook: __call__\nFeaturesHook: Getting src\nFeaturesHook: Getting output: ['src_imgs_features']\nFeaturesHook: Using model G with inputs: src_imgs\nG"
  882. ]
  883. }
  884. ],
  885. "source": [
  886. "h1(datasets)"
  887. ]
  888. },
  889. {
  890. "cell_type": "code",
  891. "execution_count": null,
  892. "metadata": {},
  893. "outputs": [],
  894. "source": []
  895. }
  896. ],
  897. "metadata": {
  898. "kernelspec": {
  899. "display_name": "cdtrans",
  900. "language": "python",
  901. "name": "python3"
  902. },
  903. "language_info": {
  904. "codemirror_mode": {
  905. "name": "ipython",
  906. "version": 3
  907. },
  908. "file_extension": ".py",
  909. "mimetype": "text/x-python",
  910. "name": "python",
  911. "nbconvert_exporter": "python",
  912. "pygments_lexer": "ipython3",
  913. "version": "3.8.15 (default, Nov 24 2022, 15:19:38) \n[GCC 11.2.0]"
  914. },
  915. "orig_nbformat": 4,
  916. "vscode": {
  917. "interpreter": {
  918. "hash": "959b82c3a41427bdf7d14d4ba7335271e0c50cfcddd70501934b27dcc36968ad"
  919. }
  920. }
  921. },
  922. "nbformat": 4,
  923. "nbformat_minor": 2
  924. }