You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

experiment-saba-pytorch.ipynb 861KB


  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 2,
  6. "metadata": {},
  7. "outputs": [],
  8. "source": [
  9. "import torch\n",
  10. "from tqdm import tqdm\n",
  11. "import os\n",
  12. "from pytorch_adapt.containers import Models, Optimizers\n",
  13. "from pytorch_adapt.datasets import DataloaderCreator, get_mnist_mnistm, get_office31\n",
  14. "from pytorch_adapt.datasets.office31 import Office31\n",
  15. "from pytorch_adapt.hooks import DANNHook\n",
  16. "from pytorch_adapt.adapters import DANN\n",
  17. "\n",
  18. "from pytorch_adapt.models import Discriminator, mnistC, mnistG, office31C, office31G\n",
  19. "from pytorch_adapt.utils.common_functions import batch_to_device\n",
  20. "from pytorch_adapt.validators import IMValidator, AccuracyValidator\n",
  21. "\n",
  22. "from pprint import pprint\n"
  23. ]
  24. },
  25. {
  26. "cell_type": "code",
  27. "execution_count": 18,
  28. "metadata": {},
  29. "outputs": [],
  30. "source": [
  31. "root=\"datasets/pytorch-adapt/\"\n",
  32. "batch_size=32\n",
  33. "num_workers=2\n",
  34. "\n",
  35. "datasets = get_office31([\"amazon\"], [\"webcam\"], folder=root, return_target_with_labels=True)\n",
  36. "dc = DataloaderCreator(batch_size=batch_size, \n",
  37. " num_workers=num_workers, \n",
  38. " train_names=[\"train\"],\n",
  39. " val_names=[\"src_train\", \"target_train\", \"src_val\", \"target_val\", \"target_train_with_labels\", \"target_val_with_labels\"])\n",
  40. "dataloaders = dc(**datasets)"
  41. ]
  42. },
  43. {
  44. "cell_type": "code",
  45. "execution_count": 3,
  46. "metadata": {},
  47. "outputs": [
  48. {
  49. "name": "stdout",
  50. "output_type": "stream",
  51. "text": [
  52. "dict_keys(['src_train', 'src_val', 'target_train', 'target_val', 'target_train_with_labels', 'target_val_with_labels', 'train'])\n"
  53. ]
  54. }
  55. ],
  56. "source": [
  57. "print(dataloaders.keys())\n",
  58. "# 'train' dataset type is `CombinedSourceAndTargetDataset` and contains both source and unlabeled target data\n",
  59. "# 'src_train' and 'src_val' datasets types are `SourceDataset`\n",
  60. "# Other datasets are typed `TargetDataset`"
  61. ]
  62. },
  63. {
  64. "cell_type": "code",
  65. "execution_count": 4,
  66. "metadata": {},
  67. "outputs": [
  68. {
  69. "name": "stdout",
  70. "output_type": "stream",
  71. "text": [
  72. "src_train 71 2253\n",
  73. "dict_keys(['src_imgs', 'src_domain', 'src_labels', 'src_sample_idx'])\n",
  74. "src_val 18 564\n",
  75. "dict_keys(['src_imgs', 'src_domain', 'src_labels', 'src_sample_idx'])\n",
  76. "target_train 20 636\n",
  77. "dict_keys(['target_imgs', 'target_domain', 'target_sample_idx'])\n",
  78. "target_val 5 159\n",
  79. "dict_keys(['target_imgs', 'target_domain', 'target_sample_idx'])\n",
  80. "target_train_with_labels 20 636\n",
  81. "dict_keys(['target_imgs', 'target_domain', 'target_sample_idx', 'target_labels'])\n",
  82. "target_val_with_labels 5 159\n",
  83. "dict_keys(['target_imgs', 'target_domain', 'target_sample_idx', 'target_labels'])\n",
  84. "train 19 636\n",
  85. "dict_keys(['src_imgs', 'src_domain', 'src_labels', 'src_sample_idx', 'target_imgs', 'target_domain', 'target_sample_idx'])\n"
  86. ]
  87. }
  88. ],
  89. "source": [
  90. "for k in dataloaders.keys():\n",
  91. " print(k, len(dataloaders[k]), len(dataloaders[k].dataset))\n",
  92. " for a in dataloaders[k]:\n",
  93. " print(a.keys())\n",
  94. " break"
  95. ]
  96. },
  97. {
  98. "cell_type": "code",
  99. "execution_count": 3,
  100. "metadata": {},
  101. "outputs": [],
  102. "source": [
  103. "\n",
  104. "device = torch.device(\"cuda\")\n",
  105. "weights_root = os.path.join(root, \"weights\")\n",
  106. "trained_domain = \"amazon\"\n",
  107. "\n",
  108. "G = office31G(pretrained=True, model_dir=weights_root).to(device)\n",
  109. "C = office31C(domain=trained_domain, pretrained=True, model_dir=weights_root).to(device)\n",
  110. "D = Discriminator(in_size=2048, h=256).to(device)\n",
  111. "models = Models({\"G\": G, \"C\": C, \"D\": D})\n",
  112. "optimizers = Optimizers((torch.optim.Adam, {\"lr\": 0.001}))\n",
  113. "optimizers.create_with(models)\n",
  114. "optimizers = list(optimizers.values())\n",
  115. "\n",
  116. "hook = DANNHook(optimizers)\n",
  117. "\n",
  118. "from pytorch_adapt.validators import AccuracyValidator, BaseValidator\n",
  119. "from pytorch_adapt.layers import BNMLoss\n",
  120. "\n",
  121. "class CustomTargetValidator(BaseValidator):\n",
  122. " def compute_score(self, target_train):\n",
  123. " return BNMLoss()(target_train[\"preds\"])\n",
  124. "\n",
  125. "src_train_validator = AccuracyValidator(key_map={\"src_train\": \"src_val\"})\n",
  126. "src_val_validator = AccuracyValidator(key_map={\"src_val\": \"src_val\"})\n",
  127. "target_validator = CustomTargetValidator()\n",
  128. "target_train_oracle_validator = AccuracyValidator(key_map={\"target_train\": \"src_val\"})\n",
  129. "target_val_oracle_validator = AccuracyValidator(key_map={\"target_val\": \"src_val\"})\n",
  130. "targen_im_validator = IMValidator()\n"
  131. ]
  132. },
  133. {
  134. "cell_type": "code",
  135. "execution_count": 4,
  136. "metadata": {},
  137. "outputs": [
  138. {
  139. "name": "stderr",
  140. "output_type": "stream",
  141. "text": [
  142. "100%|██████████| 19/19 [00:49<00:00, 2.63s/it]\n"
  143. ]
  144. },
  145. {
  146. "name": "stdout",
  147. "output_type": "stream",
  148. "text": [
  149. "{'total_loss': {'src_c_loss': 3.168257236480713,\n",
  150. " 'src_domain_loss': 0.7317013740539551,\n",
  151. " 'target_domain_loss': 0.6301205158233643,\n",
  152. " 'total': 1.5100263357162476}}\n"
  153. ]
  154. },
  155. {
  156. "name": "stderr",
  157. "output_type": "stream",
  158. "text": [
  159. "100%|██████████| 18/18 [00:08<00:00, 2.03it/s]\n",
  160. "100%|██████████| 5/5 [00:02<00:00, 1.69it/s]\n",
  161. "100%|██████████| 20/20 [00:09<00:00, 2.16it/s]\n"
  162. ]
  163. },
  164. {
  165. "name": "stdout",
  166. "output_type": "stream",
  167. "text": [
  168. "Target Evaluation:\n",
  169. "src_val accuracy \t= 0.12411347776651382\n",
  170. "target_train accuracy \t= 0.08805031329393387\n",
  171. "target_train score (IM) \t= 0.4806191921234131\n"
  172. ]
  173. },
  174. {
  175. "name": "stderr",
  176. "output_type": "stream",
  177. "text": [
  178. "100%|██████████| 19/19 [00:51<00:00, 2.69s/it]\n"
  179. ]
  180. },
  181. {
  182. "name": "stdout",
  183. "output_type": "stream",
  184. "text": [
  185. "{'total_loss': {'src_c_loss': 2.804849624633789,\n",
  186. " 'src_domain_loss': 0.6947869062423706,\n",
  187. " 'target_domain_loss': 0.6891149282455444,\n",
  188. " 'total': 1.3962504863739014}}\n"
  189. ]
  190. },
  191. {
  192. "name": "stderr",
  193. "output_type": "stream",
  194. "text": [
  195. "100%|██████████| 18/18 [00:08<00:00, 2.02it/s]\n",
  196. "100%|██████████| 5/5 [00:02<00:00, 1.68it/s]\n",
  197. "100%|██████████| 20/20 [00:09<00:00, 2.15it/s]\n"
  198. ]
  199. },
  200. {
  201. "name": "stdout",
  202. "output_type": "stream",
  203. "text": [
  204. "Target Evaluation:\n",
  205. "src_val accuracy \t= 0.13652482628822327\n",
  206. "target_train accuracy \t= 0.16352201998233795\n",
  207. "target_train score (IM) \t= 0.6105575561523438\n"
  208. ]
  209. },
  210. {
  211. "name": "stderr",
  212. "output_type": "stream",
  213. "text": [
  214. "100%|██████████| 19/19 [00:51<00:00, 2.71s/it]\n"
  215. ]
  216. },
  217. {
  218. "name": "stdout",
  219. "output_type": "stream",
  220. "text": [
  221. "{'total_loss': {'src_c_loss': 3.0151782035827637,\n",
  222. " 'src_domain_loss': 0.7150613069534302,\n",
  223. " 'target_domain_loss': 0.8466196060180664,\n",
  224. " 'total': 1.5256197452545166}}\n"
  225. ]
  226. },
  227. {
  228. "name": "stderr",
  229. "output_type": "stream",
  230. "text": [
  231. "100%|██████████| 18/18 [00:08<00:00, 2.03it/s]\n",
  232. "100%|██████████| 5/5 [00:02<00:00, 1.68it/s]\n",
  233. "100%|██████████| 20/20 [00:09<00:00, 2.14it/s]\n"
  234. ]
  235. },
  236. {
  237. "name": "stdout",
  238. "output_type": "stream",
  239. "text": [
  240. "Target Evaluation:\n",
  241. "src_val accuracy \t= 0.14007091522216797\n",
  242. "target_train accuracy \t= 0.1949685513973236\n",
  243. "target_train score (IM) \t= 1.0187971591949463\n"
  244. ]
  245. },
  246. {
  247. "name": "stderr",
  248. "output_type": "stream",
  249. "text": [
  250. "100%|██████████| 19/19 [00:52<00:00, 2.75s/it]\n"
  251. ]
  252. },
  253. {
  254. "name": "stdout",
  255. "output_type": "stream",
  256. "text": [
  257. "{'total_loss': {'src_c_loss': 2.8740577697753906,\n",
  258. " 'src_domain_loss': 0.6584925651550293,\n",
  259. " 'target_domain_loss': 0.7099870443344116,\n",
  260. " 'total': 1.4141792058944702}}\n"
  261. ]
  262. },
  263. {
  264. "name": "stderr",
  265. "output_type": "stream",
  266. "text": [
  267. "100%|██████████| 18/18 [00:08<00:00, 2.17it/s]\n",
  268. "100%|██████████| 5/5 [00:02<00:00, 1.73it/s]\n",
  269. "100%|██████████| 20/20 [00:09<00:00, 2.15it/s]\n"
  270. ]
  271. },
  272. {
  273. "name": "stdout",
  274. "output_type": "stream",
  275. "text": [
  276. "Target Evaluation:\n",
  277. "src_val accuracy \t= 0.13297872245311737\n",
  278. "target_train accuracy \t= 0.23270440101623535\n",
  279. "target_train score (IM) \t= 0.8761472702026367\n"
  280. ]
  281. },
  282. {
  283. "name": "stderr",
  284. "output_type": "stream",
  285. "text": [
  286. " 89%|████████▉ | 17/19 [00:48<00:05, 2.82s/it]\n"
  287. ]
  288. },
  289. {
  290. "ename": "KeyboardInterrupt",
  291. "evalue": "",
  292. "output_type": "error",
  293. "traceback": [
  294. "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
  295. "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
  296. "Cell \u001b[0;32mIn[4], line 29\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[39mfor\u001b[39;00m data \u001b[39min\u001b[39;00m tqdm(dataloaders[\u001b[39m\"\u001b[39m\u001b[39mtrain\u001b[39m\u001b[39m\"\u001b[39m]):\n\u001b[1;32m 28\u001b[0m data \u001b[39m=\u001b[39m batch_to_device(data, device)\n\u001b[0;32m---> 29\u001b[0m _, loss \u001b[39m=\u001b[39m hook({\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mmodels, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mdata})\n\u001b[1;32m 30\u001b[0m pprint(loss)\n\u001b[1;32m 32\u001b[0m \u001b[39m# eval loop\u001b[39;00m\n",
  297. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/base.py:52\u001b[0m, in \u001b[0;36mBaseHook.__call__\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 51\u001b[0m inputs \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mmap_keys(inputs, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkey_map)\n\u001b[0;32m---> 52\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcall(inputs, losses)\n\u001b[1;32m 53\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(x, (\u001b[39mbool\u001b[39m, np\u001b[39m.\u001b[39mbool_)):\n\u001b[1;32m 54\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlogger\u001b[39m.\u001b[39mreset()\n",
  298. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/base.py:194\u001b[0m, in \u001b[0;36mBaseWrapperHook.call\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mcall\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 193\u001b[0m \u001b[39m\"\"\"\"\"\"\u001b[39;00m\n\u001b[0;32m--> 194\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mhook(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
  299. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/base.py:52\u001b[0m, in \u001b[0;36mBaseHook.__call__\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 51\u001b[0m inputs \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mmap_keys(inputs, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkey_map)\n\u001b[0;32m---> 52\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcall(inputs, losses)\n\u001b[1;32m 53\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(x, (\u001b[39mbool\u001b[39m, np\u001b[39m.\u001b[39mbool_)):\n\u001b[1;32m 54\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlogger\u001b[39m.\u001b[39mreset()\n",
  300. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/utils.py:109\u001b[0m, in \u001b[0;36mChainHook.call\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 107\u001b[0m all_losses \u001b[39m=\u001b[39m {\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mall_losses, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mprev_losses}\n\u001b[1;32m 108\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconditions[i](all_inputs, all_losses):\n\u001b[0;32m--> 109\u001b[0m x \u001b[39m=\u001b[39m h(all_inputs, all_losses)\n\u001b[1;32m 110\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 111\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39malts[i](all_inputs, all_losses)\n",
  301. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/base.py:52\u001b[0m, in \u001b[0;36mBaseHook.__call__\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 51\u001b[0m inputs \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mmap_keys(inputs, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkey_map)\n\u001b[0;32m---> 52\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcall(inputs, losses)\n\u001b[1;32m 53\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(x, (\u001b[39mbool\u001b[39m, np\u001b[39m.\u001b[39mbool_)):\n\u001b[1;32m 54\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlogger\u001b[39m.\u001b[39mreset()\n",
  302. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/optimizer.py:51\u001b[0m, in \u001b[0;36mOptimizerHook.call\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mcall\u001b[39m(\u001b[39mself\u001b[39m, inputs, losses):\n\u001b[1;32m 50\u001b[0m \u001b[39m\"\"\"\"\"\"\u001b[39;00m\n\u001b[0;32m---> 51\u001b[0m outputs, losses \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mhook(inputs, losses)\n\u001b[1;32m 52\u001b[0m combined \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39massert_dicts_are_disjoint(inputs, outputs)\n\u001b[1;32m 53\u001b[0m new_outputs, losses \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreducer(combined, losses)\n",
  303. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/base.py:52\u001b[0m, in \u001b[0;36mBaseHook.__call__\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 51\u001b[0m inputs \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mmap_keys(inputs, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkey_map)\n\u001b[0;32m---> 52\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcall(inputs, losses)\n\u001b[1;32m 53\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(x, (\u001b[39mbool\u001b[39m, np\u001b[39m.\u001b[39mbool_)):\n\u001b[1;32m 54\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlogger\u001b[39m.\u001b[39mreset()\n",
  304. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/utils.py:109\u001b[0m, in \u001b[0;36mChainHook.call\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 107\u001b[0m all_losses \u001b[39m=\u001b[39m {\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mall_losses, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mprev_losses}\n\u001b[1;32m 108\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconditions[i](all_inputs, all_losses):\n\u001b[0;32m--> 109\u001b[0m x \u001b[39m=\u001b[39m h(all_inputs, all_losses)\n\u001b[1;32m 110\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 111\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39malts[i](all_inputs, all_losses)\n",
  305. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/base.py:52\u001b[0m, in \u001b[0;36mBaseHook.__call__\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 51\u001b[0m inputs \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mmap_keys(inputs, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkey_map)\n\u001b[0;32m---> 52\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcall(inputs, losses)\n\u001b[1;32m 53\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(x, (\u001b[39mbool\u001b[39m, np\u001b[39m.\u001b[39mbool_)):\n\u001b[1;32m 54\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlogger\u001b[39m.\u001b[39mreset()\n",
  306. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/utils.py:318\u001b[0m, in \u001b[0;36mAssertHook.call\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mcall\u001b[39m(\u001b[39mself\u001b[39m, inputs, losses):\n\u001b[1;32m 317\u001b[0m \u001b[39m\"\"\"\"\"\"\u001b[39;00m\n\u001b[0;32m--> 318\u001b[0m outputs, losses \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mhook(inputs, losses)\n\u001b[1;32m 319\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39massert_fn(outputs)\n\u001b[1;32m 320\u001b[0m \u001b[39mreturn\u001b[39;00m outputs, losses\n",
  307. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/base.py:52\u001b[0m, in \u001b[0;36mBaseHook.__call__\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 51\u001b[0m inputs \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mmap_keys(inputs, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkey_map)\n\u001b[0;32m---> 52\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcall(inputs, losses)\n\u001b[1;32m 53\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(x, (\u001b[39mbool\u001b[39m, np\u001b[39m.\u001b[39mbool_)):\n\u001b[1;32m 54\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlogger\u001b[39m.\u001b[39mreset()\n",
  308. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/utils.py:220\u001b[0m, in \u001b[0;36mOnlyNewOutputsHook.call\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 218\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mcall\u001b[39m(\u001b[39mself\u001b[39m, inputs, losses):\n\u001b[1;32m 219\u001b[0m \u001b[39m\"\"\"\"\"\"\u001b[39;00m\n\u001b[0;32m--> 220\u001b[0m outputs, losses \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mhook(inputs, losses)\n\u001b[1;32m 221\u001b[0m outputs \u001b[39m=\u001b[39m {k: outputs[k] \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m (outputs\u001b[39m.\u001b[39mkeys() \u001b[39m-\u001b[39m inputs\u001b[39m.\u001b[39mkeys())}\n\u001b[1;32m 222\u001b[0m c_f\u001b[39m.\u001b[39massert_dicts_are_disjoint(inputs, outputs)\n",
  309. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/base.py:52\u001b[0m, in \u001b[0;36mBaseHook.__call__\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 51\u001b[0m inputs \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mmap_keys(inputs, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkey_map)\n\u001b[0;32m---> 52\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcall(inputs, losses)\n\u001b[1;32m 53\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(x, (\u001b[39mbool\u001b[39m, np\u001b[39m.\u001b[39mbool_)):\n\u001b[1;32m 54\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlogger\u001b[39m.\u001b[39mreset()\n",
  310. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/utils.py:109\u001b[0m, in \u001b[0;36mChainHook.call\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 107\u001b[0m all_losses \u001b[39m=\u001b[39m {\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mall_losses, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mprev_losses}\n\u001b[1;32m 108\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconditions[i](all_inputs, all_losses):\n\u001b[0;32m--> 109\u001b[0m x \u001b[39m=\u001b[39m h(all_inputs, all_losses)\n\u001b[1;32m 110\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 111\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39malts[i](all_inputs, all_losses)\n",
  311. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/base.py:52\u001b[0m, in \u001b[0;36mBaseHook.__call__\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 51\u001b[0m inputs \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mmap_keys(inputs, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkey_map)\n\u001b[0;32m---> 52\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcall(inputs, losses)\n\u001b[1;32m 53\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(x, (\u001b[39mbool\u001b[39m, np\u001b[39m.\u001b[39mbool_)):\n\u001b[1;32m 54\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlogger\u001b[39m.\u001b[39mreset()\n",
  312. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/utils.py:249\u001b[0m, in \u001b[0;36mApplyFnHook.call\u001b[0;34m(self, inputs, losses)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[39m\"\"\"\"\"\"\u001b[39;00m\n\u001b[1;32m 248\u001b[0m x \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mextract(inputs, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mapply_to)\n\u001b[0;32m--> 249\u001b[0m outputs \u001b[39m=\u001b[39m {k: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfn(v) \u001b[39mfor\u001b[39;00m k, v \u001b[39min\u001b[39;00m \u001b[39mzip\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mapply_to, x)}\n\u001b[1;32m 250\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_loss:\n\u001b[1;32m 251\u001b[0m \u001b[39mreturn\u001b[39;00m outputs, {}\n",
  313. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/hooks/utils.py:249\u001b[0m, in \u001b[0;36m<dictcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[39m\"\"\"\"\"\"\u001b[39;00m\n\u001b[1;32m 248\u001b[0m x \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mextract(inputs, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mapply_to)\n\u001b[0;32m--> 249\u001b[0m outputs \u001b[39m=\u001b[39m {k: \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mfn(v) \u001b[39mfor\u001b[39;00m k, v \u001b[39min\u001b[39;00m \u001b[39mzip\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mapply_to, x)}\n\u001b[1;32m 250\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_loss:\n\u001b[1;32m 251\u001b[0m \u001b[39mreturn\u001b[39;00m outputs, {}\n",
  314. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/torch/nn/modules/module.py:889\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_slow_forward(\u001b[39m*\u001b[39m\u001b[39minput\u001b[39m, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m 888\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 889\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mforward(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 890\u001b[0m \u001b[39mfor\u001b[39;00m hook \u001b[39min\u001b[39;00m itertools\u001b[39m.\u001b[39mchain(\n\u001b[1;32m 891\u001b[0m _global_forward_hooks\u001b[39m.\u001b[39mvalues(),\n\u001b[1;32m 892\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks\u001b[39m.\u001b[39mvalues()):\n\u001b[1;32m 893\u001b[0m hook_result \u001b[39m=\u001b[39m hook(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m, result)\n",
  315. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/layers/gradient_reversal.py:31\u001b[0m, in \u001b[0;36mGradientReversal.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, x):\n\u001b[1;32m 30\u001b[0m \u001b[39m\"\"\"\"\"\"\u001b[39;00m\n\u001b[0;32m---> 31\u001b[0m \u001b[39mreturn\u001b[39;00m _GradientReversal\u001b[39m.\u001b[39mapply(x, pml_cf\u001b[39m.\u001b[39;49mto_device(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mweight, x))\n",
  316. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_metric_learning/utils/common_functions.py:492\u001b[0m, in \u001b[0;36mto_device\u001b[0;34m(x, tensor, device, dtype)\u001b[0m\n\u001b[1;32m 490\u001b[0m dv \u001b[39m=\u001b[39m device \u001b[39mif\u001b[39;00m device \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m tensor\u001b[39m.\u001b[39mdevice\n\u001b[1;32m 491\u001b[0m \u001b[39mif\u001b[39;00m x\u001b[39m.\u001b[39mdevice \u001b[39m!=\u001b[39m dv:\n\u001b[0;32m--> 492\u001b[0m x \u001b[39m=\u001b[39m x\u001b[39m.\u001b[39;49mto(dv)\n\u001b[1;32m 493\u001b[0m \u001b[39mif\u001b[39;00m dtype \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 494\u001b[0m x \u001b[39m=\u001b[39m to_dtype(x, dtype\u001b[39m=\u001b[39mdtype)\n",
  317. "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
  318. ]
  319. },
  320. {
  321. "ename": "",
  322. "evalue": "",
  323. "output_type": "error",
  324. "traceback": [
  325. "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
  326. ]
  327. }
  328. ],
  329. "source": [
  330. "torch.cuda.empty_cache()\n",
  331. "\n",
  332. "\n",
  333. "from tqdm import tqdm\n",
  334. "from pytorch_adapt.utils.common_functions import batch_to_device\n",
  335. "from collections import defaultdict\n",
  336. "\n",
  337. "\n",
  338. "def get_preds_and_labels(split, full_key):\n",
  339. " output = defaultdict(list)\n",
  340. " with torch.no_grad():\n",
  341. " for data in tqdm(dataloaders[full_key]):\n",
  342. " data = batch_to_device(data, device)\n",
  343. " logits = C(G(data[f\"{split}_imgs\"]))\n",
  344. " output[\"logits\"].append(logits)\n",
  345. " curr_preds = torch.softmax(logits, dim=1)\n",
  346. " output[\"preds\"].append(curr_preds)\n",
  347. " if full_key in [\"src_train\", \"src_val\", \"target_train_with_labels\", \"target_val_with_labels\"]:\n",
  348. " output[\"labels\"].append(data[f\"{split}_labels\"])\n",
  349. " for k,v in output.items():\n",
  350. " output[k] = torch.cat(v, dim=0)\n",
  351. " return output\n",
  352. " \n",
  353. "for epoch in range(100):\n",
  354. " # train loop\n",
  355. " models.train()\n",
  356. " for data in tqdm(dataloaders[\"train\"]):\n",
  357. " data = batch_to_device(data, device)\n",
  358. " _, loss = hook({**models, **data})\n",
  359. " pprint(loss)\n",
  360. "\n",
  361. " # eval loop\n",
  362. " scores = []\n",
  363. " # src_outputs = get_preds_and_labels(\"src\", \"src_train\")\n",
  364. " # score = src_train_validator(src_train=src_outputs)\n",
  365. " # scores.append(f\"src_train accuracy \\t= {score}\")\n",
  366. "\n",
  367. " src_outputs = get_preds_and_labels(\"src\", \"src_val\")\n",
  368. " score = src_val_validator(src_val=src_outputs)\n",
  369. " scores.append(f\"src_val accuracy \\t= {score}\")\n",
  370. "\n",
  371. " # target_outputs = get_preds_and_labels(\"target\", \"target_train_with_labels\")\n",
  372. " # score = target_train_oracle_validator(target_train=target_outputs)\n",
  373. " # scores.append(f\"target_train accuracy \\t= {score}\")\n",
  374. "\n",
  375. " target_outputs = get_preds_and_labels(\"target\", \"target_val_with_labels\")\n",
  376. " score = target_val_oracle_validator(target_val=target_outputs)\n",
  377. " scores.append(f\"target_train accuracy \\t= {score}\")\n",
  378. "\n",
  379. " target_outputs = get_preds_and_labels(\"target\", \"target_train\")\n",
  380. " score = targen_im_validator(target_train={\"logits\": target_outputs['logits']})\n",
  381. " scores.append(f\"target_train score (IM)\\t= {score}\")\n",
  382. "\n",
  383. " # score = target_validator(target_train=target_outputs)\n",
  384. " # scores.append(f\"target_train score (BNM) \\t= {score}\")\n",
  385. "\n",
  386. " print(f\"Target Evaluation:\")\n",
  387. " print(*scores, sep=\"\\n\")"
  388. ]
  389. },
  390. {
  391. "cell_type": "code",
  392. "execution_count": null,
  393. "metadata": {},
  394. "outputs": [],
  395. "source": [
  396. "torch.cuda.empty_cache()"
  397. ]
  398. },
  399. {
  400. "attachments": {},
  401. "cell_type": "markdown",
  402. "metadata": {},
  403. "source": [
  404. "----"
  405. ]
  406. },
  407. {
  408. "attachments": {},
  409. "cell_type": "markdown",
  410. "metadata": {},
  411. "source": [
  412. "## Dataset Visualization"
  413. ]
  414. },
  415. {
  416. "cell_type": "code",
  417. "execution_count": null,
  418. "metadata": {},
  419. "outputs": [
  420. {
  421. "name": "stdout",
  422. "output_type": "stream",
  423. "text": [
  424. "['amazon']\n"
  425. ]
  426. },
  427. {
  428. "data": {
  429. "image/png": "",
  430. "text/plain": [
  431. "<Figure size 800x400 with 1 Axes>"
  432. ]
  433. },
  434. "metadata": {},
  435. "output_type": "display_data"
  436. },
  437. {
  438. "name": "stdout",
  439. "output_type": "stream",
  440. "text": [
  441. "['dslr']\n"
  442. ]
  443. },
  444. {
  445. "data": {
  446. "image/png": "",
  447. "text/plain": [
  448. "<Figure size 800x400 with 1 Axes>"
  449. ]
  450. },
  451. "metadata": {},
  452. "output_type": "display_data"
  453. },
  454. {
  455. "name": "stdout",
  456. "output_type": "stream",
  457. "text": [
  458. "['webcam']\n"
  459. ]
  460. },
  461. {
  462. "data": {
  463. "image/png": "",
  464. "text/plain": [
  465. "<Figure size 800x400 with 1 Axes>"
  466. ]
  467. },
  468. "metadata": {},
  469. "output_type": "display_data"
  470. }
  471. ],
  472. "source": [
  473. "import matplotlib.pyplot as plt\n",
  474. "import numpy as np\n",
  475. "import torchvision\n",
  476. "\n",
  477. "mean = [0.485, 0.456, 0.406]\n",
  478. "std = [0.229, 0.224, 0.225]\n",
  479. "\n",
  480. "inv_normalize = torchvision.transforms.Normalize(\n",
  481. " mean=[-m / s for m, s in zip(mean, std)], std=[1 / s for s in std]\n",
  482. ")\n",
  483. "\n",
  484. "def imshow(img, figsize=(8, 4)):\n",
  485. " img = inv_normalize(img)\n",
  486. " npimg = img.numpy()\n",
  487. " plt.figure(figsize=figsize)\n",
  488. " plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
  489. " plt.show()\n",
  490. "\n",
  491. "def imshow_many(datasets, src, target):\n",
  492. " d = datasets[\"train\"]\n",
  493. " for name in [\"src_imgs\", \"target_imgs\"]:\n",
  494. " domains = src if name == \"src_imgs\" else target\n",
  495. " if len(domains) == 0:\n",
  496. " continue\n",
  497. " print(domains)\n",
  498. " imgs = [d[i][name] for i in np.random.choice(len(d), size=16, replace=False)]\n",
  499. " imshow(torchvision.utils.make_grid(imgs))\n",
  500. "\n",
  501. "for src, target in [([\"amazon\"], [\"dslr\"]), ([\"webcam\"], [])]:\n",
  502. " datasets = get_office31(src, target,folder=root)\n",
  503. " imshow_many(datasets, src, target)"
  504. ]
  505. },
  506. {
  507. "attachments": {},
  508. "cell_type": "markdown",
  509. "metadata": {},
  510. "source": [
  511. "---"
  512. ]
  513. },
  514. {
  515. "attachments": {},
  516. "cell_type": "markdown",
  517. "metadata": {},
  518. "source": [
  519. "## Ignite"
  520. ]
  521. },
  522. {
  523. "cell_type": "code",
  524. "execution_count": 4,
  525. "metadata": {},
  526. "outputs": [],
  527. "source": [
  528. "import logging\n",
  529. "\n",
  530. "import matplotlib.pyplot as plt\n",
  531. "import pandas as pd\n",
  532. "import seaborn as sns\n",
  533. "import torch\n",
  534. "import umap\n",
  535. "from tqdm import tqdm\n",
  536. "import os\n",
  537. "\n",
  538. "from pytorch_adapt.adapters import DANN, MCD\n",
  539. "from pytorch_adapt.containers import Models, Optimizers, LRSchedulers\n",
  540. "from pytorch_adapt.datasets import DataloaderCreator, get_mnist_mnistm, get_office31\n",
  541. "from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite\n",
  542. "from pytorch_adapt.models import Discriminator, mnistC, mnistG, office31C, office31G\n",
  543. "from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory\n",
  544. "\n",
  545. "from pprint import pprint\n",
  546. "\n",
  547. "\n",
  548. "logging.basicConfig()\n",
  549. "logging.getLogger(\"pytorch-adapt\").setLevel(logging.INFO)"
  550. ]
  551. },
  552. {
  553. "cell_type": "code",
  554. "execution_count": 5,
  555. "metadata": {},
  556. "outputs": [],
  557. "source": [
  558. "class VizHook:\n",
  559. " def __init__(self):\n",
  560. " self.required_data = [\"src_val\", \"target_val\", \"target_val_with_labels\"]\n",
  561. "\n",
  562. " def __call__(self, epoch, src_val, target_val, target_val_with_labels, **kwargs):\n",
  563. "\n",
  564. " accuracy_validator = AccuracyValidator()\n",
  565. " accuracy = accuracy_validator.compute_score(src_val=src_val)\n",
  566. " print(\"src_val accuracy:\", accuracy)\n",
  567. " accuracy_validator = AccuracyValidator()\n",
  568. " accuracy = accuracy_validator.compute_score(src_val=target_val_with_labels)\n",
  569. " print(\"target_val accuracy:\", accuracy)\n",
  570. "\n",
  571. " if epoch % 1 != 0:\n",
  572. " return\n",
  573. "\n",
  574. " features = [src_val[\"features\"], target_val[\"features\"]]\n",
  575. " domain = [src_val[\"domain\"], target_val[\"domain\"]]\n",
  576. " features = torch.cat(features, dim=0).cpu().numpy()\n",
  577. " domain = torch.cat(domain, dim=0).cpu().numpy()\n",
  578. " emb = umap.UMAP().fit_transform(features)\n",
  579. "\n",
  580. " df = pd.DataFrame(emb).assign(domain=domain)\n",
  581. " df[\"domain\"] = df[\"domain\"].replace({0: \"Source\", 1: \"Target\"})\n",
  582. " sns.set_theme(style=\"white\", rc={\"figure.figsize\": (6, 4)})\n",
  583. " sns.scatterplot(data=df, x=0, y=1, hue=\"domain\", s=15)\n",
  584. " plt.savefig(f\"results/vishook/dann/val_{epoch}.png\") \n",
  585. " plt.show()\n",
  586. " plt.close('all')"
  587. ]
  588. },
  589. {
  590. "cell_type": "code",
  591. "execution_count": 6,
  592. "metadata": {},
  593. "outputs": [],
  594. "source": [
  595. "root=\"datasets/pytorch-adapt/\"\n",
  596. "batch_size=32\n",
  597. "num_workers=2\n",
  598. "\n",
  599. "datasets = get_office31([\"amazon\"], [\"webcam\"], folder=root, return_target_with_labels=True)\n",
  600. "dc = DataloaderCreator(batch_size=batch_size, \n",
  601. " num_workers=num_workers, \n",
  602. " train_names=[\"train\"],\n",
  603. " val_names=[\"src_train\", \"target_train\", \"src_val\", \"target_val\", \"target_train_with_labels\", \"target_val_with_labels\"])\n",
  604. "dataloaders = dc(**datasets)"
  605. ]
  606. },
  607. {
  608. "cell_type": "code",
  609. "execution_count": 4,
  610. "metadata": {},
  611. "outputs": [
  612. {
  613. "name": "stderr",
  614. "output_type": "stream",
  615. "text": [
  616. "2021-05-21 13:49:42,518 ignite.distributed.auto.auto_model INFO: Apply torch DataParallel on model\n",
  617. "2021-05-21 13:49:42,519 ignite.distributed.auto.auto_model INFO: Apply torch DataParallel on model\n",
  618. "2021-05-21 13:49:42,520 ignite.distributed.auto.auto_model INFO: Apply torch DataParallel on model\n"
  619. ]
  620. }
  621. ],
  622. "source": [
  623. "device = torch.device(\"cuda\")\n",
  624. "weights_root = os.path.join(root, \"weights\")\n",
  625. "trained_domain = \"amazon\"\n",
  626. "\n",
  627. "G = office31G(pretrained=True, model_dir=weights_root).to(device)\n",
  628. "C = office31C(domain=trained_domain, pretrained=True, model_dir=weights_root).to(device)\n",
  629. "D = Discriminator(in_size=2048, h=1024).to(device)\n",
  630. "\n",
  631. "models = Models({\"G\": G, \"C\": C, \"D\": D})\n",
  632. "\n",
  633. "optimizers = Optimizers((torch.optim.Adam, {\"lr\": 0.0005}))\n",
  634. "lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {\"gamma\": 0.99}))\n",
  635. "\n",
  636. "adapter = DANN(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)\n",
  637. "checkpoint_fn = CheckpointFnCreator(dirname=\"saved_models\", require_empty=False)\n",
  638. "validator = ScoreHistory(IMValidator())\n",
  639. "tarAccuracyValidator = AccuracyValidator(key_map={\"target_val_with_labels\":\"src_val\"})\n",
  640. "val_hooks = [ScoreHistory(AccuracyValidator()), ScoreHistory(tarAccuracyValidator), VizHook()]\n",
  641. "trainer = Ignite(\n",
  642. " adapter, validator=validator, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn\n",
  643. ")"
  644. ]
  645. },
  646. {
  647. "cell_type": "code",
  648. "execution_count": 7,
  649. "metadata": {},
  650. "outputs": [
  651. {
  652. "name": "stderr",
  653. "output_type": "stream",
  654. "text": [
  655. "2021-05-21 14:16:07,148 ignite.distributed.auto.auto_model INFO: Apply torch DataParallel on model\n",
  656. "2021-05-21 14:16:07,152 ignite.distributed.auto.auto_model INFO: Apply torch DataParallel on model\n"
  657. ]
  658. }
  659. ],
  660. "source": [
  661. "from pytorch_adapt.layers import MultipleModels\n",
  662. "from pytorch_adapt.utils import common_functions \n",
  663. "import copy\n",
  664. "\n",
  665. "device = torch.device(\"cuda\")\n",
  666. "weights_root = os.path.join(root, \"weights\")\n",
  667. "trained_domain = \"amazon\"\n",
  668. "\n",
  669. "G = office31G(pretrained=True, model_dir=weights_root).to(device)\n",
  670. "C0 = office31C(domain=trained_domain, pretrained=True, model_dir=weights_root).to(device)\n",
  671. "C1 = common_functions.reinit(copy.deepcopy(C0))\n",
  672. "C = MultipleModels(C0, C1)\n",
  673. "\n",
  674. "models = Models({\"G\": G, \"C\": C})\n",
  675. "\n",
  676. "optimizers = Optimizers((torch.optim.Adam, {\"lr\": 0.0005}))\n",
  677. "lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {\"gamma\": 0.99}))\n",
  678. "\n",
  679. "adapter= MCD(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)\n",
  680. "# adapter = DANN(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)\n",
  681. "checkpoint_fn = CheckpointFnCreator(dirname=\"saved_models\", require_empty=False)\n",
  682. "validator = ScoreHistory(IMValidator())\n",
  683. "tarAccuracyValidator = AccuracyValidator(key_map={\"target_val_with_labels\":\"src_val\"})\n",
  684. "val_hooks = [ScoreHistory(AccuracyValidator()), ScoreHistory(tarAccuracyValidator), VizHook()]\n",
  685. "trainer = Ignite(\n",
  686. " adapter, validator=validator, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn\n",
  687. ")"
  688. ]
  689. },
  690. {
  691. "cell_type": "code",
  692. "execution_count": 8,
  693. "metadata": {},
  694. "outputs": [
  695. {
  696. "name": "stderr",
  697. "output_type": "stream",
  698. "text": [
  699. "WARNING:pytorch-adapt:val_hook has no state_dict or load_state_dict so it will not be saved or loaded\n"
  700. ]
  701. },
  702. {
  703. "ename": "RuntimeError",
  704. "evalue": "CUDA error: out of memory",
  705. "output_type": "error",
  706. "traceback": [
  707. "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
  708. "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
  709. "Cell \u001b[0;32mIn[8], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m logging\u001b[39m.\u001b[39mgetLogger(\u001b[39m\"\u001b[39m\u001b[39mpytorch-adapt\u001b[39m\u001b[39m\"\u001b[39m)\u001b[39m.\u001b[39msetLevel(logging\u001b[39m.\u001b[39mWARNING)\n\u001b[0;32m----> 3\u001b[0m best_score, best_epoch \u001b[39m=\u001b[39m trainer\u001b[39m.\u001b[39;49mrun(\n\u001b[1;32m 4\u001b[0m datasets, dataloader_creator\u001b[39m=\u001b[39;49mdc, max_epochs\u001b[39m=\u001b[39;49m\u001b[39m5\u001b[39;49m, check_initial_score\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m\n\u001b[1;32m 5\u001b[0m )\n\u001b[1;32m 6\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mbest_score=\u001b[39m\u001b[39m{\u001b[39;00mbest_score\u001b[39m}\u001b[39;00m\u001b[39m, best_epoch=\u001b[39m\u001b[39m{\u001b[39;00mbest_epoch\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n",
  710. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/frameworks/ignite/ignite.py:233\u001b[0m, in \u001b[0;36mIgnite.run\u001b[0;34m(self, datasets, dataloader_creator, dataloaders, val_interval, early_stopper_kwargs, resume, check_initial_score, **trainer_kwargs)\u001b[0m\n\u001b[1;32m 230\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mload_checkpoint(resume)\n\u001b[1;32m 232\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m i_g\u001b[39m.\u001b[39mis_done(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtrainer, max_epochs):\n\u001b[0;32m--> 233\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtrainer\u001b[39m.\u001b[39;49mrun(dataloaders[\u001b[39m\"\u001b[39;49m\u001b[39mtrain\u001b[39;49m\u001b[39m\"\u001b[39;49m], \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mtrainer_kwargs)\n\u001b[1;32m 235\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mremove_temp_events()\n\u001b[1;32m 237\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mvalidator:\n",
  711. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/ignite/engine/engine.py:704\u001b[0m, in \u001b[0;36mEngine.run\u001b[0;34m(self, data, max_epochs, epoch_length, seed)\u001b[0m\n\u001b[1;32m 701\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mepoch_length should be provided if data is None\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 703\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mdataloader \u001b[39m=\u001b[39m data\n\u001b[0;32m--> 704\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_internal_run()\n",
  712. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/ignite/engine/engine.py:783\u001b[0m, in \u001b[0;36mEngine._internal_run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 781\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataloader_iter \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 782\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlogger\u001b[39m.\u001b[39merror(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mEngine run is terminating due to exception: \u001b[39m\u001b[39m{\u001b[39;00me\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 783\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_handle_exception(e)\n\u001b[1;32m 785\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataloader_iter \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 786\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\n",
  713. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/ignite/engine/engine.py:466\u001b[0m, in \u001b[0;36mEngine._handle_exception\u001b[0;34m(self, e)\u001b[0m\n\u001b[1;32m 464\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_fire_event(Events\u001b[39m.\u001b[39mEXCEPTION_RAISED, e)\n\u001b[1;32m 465\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 466\u001b[0m \u001b[39mraise\u001b[39;00m e\n",
  714. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/ignite/engine/engine.py:745\u001b[0m, in \u001b[0;36mEngine._internal_run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 743\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 744\u001b[0m start_time \u001b[39m=\u001b[39m time\u001b[39m.\u001b[39mtime()\n\u001b[0;32m--> 745\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_fire_event(Events\u001b[39m.\u001b[39;49mSTARTED)\n\u001b[1;32m 746\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mepoch \u001b[39m<\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mmax_epochs \u001b[39mand\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mshould_terminate: \u001b[39m# type: ignore[operator]\u001b[39;00m\n\u001b[1;32m 747\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mepoch \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n",
  715. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/ignite/engine/engine.py:421\u001b[0m, in \u001b[0;36mEngine._fire_event\u001b[0;34m(self, event_name, *event_args, **event_kwargs)\u001b[0m\n\u001b[1;32m 419\u001b[0m kwargs\u001b[39m.\u001b[39mupdate(event_kwargs)\n\u001b[1;32m 420\u001b[0m first, others \u001b[39m=\u001b[39m ((args[\u001b[39m0\u001b[39m],), args[\u001b[39m1\u001b[39m:]) \u001b[39mif\u001b[39;00m (args \u001b[39mand\u001b[39;00m args[\u001b[39m0\u001b[39m] \u001b[39m==\u001b[39m \u001b[39mself\u001b[39m) \u001b[39melse\u001b[39;00m ((), args)\n\u001b[0;32m--> 421\u001b[0m func(\u001b[39m*\u001b[39;49mfirst, \u001b[39m*\u001b[39;49m(event_args \u001b[39m+\u001b[39;49m others), \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
  716. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/frameworks/ignite/utils.py:30\u001b[0m, in \u001b[0;36mget_validation_runner.<locals>.run_validation\u001b[0;34m(engine)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mrun_validation\u001b[39m(engine):\n\u001b[1;32m 29\u001b[0m epoch \u001b[39m=\u001b[39m engine\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mepoch\n\u001b[0;32m---> 30\u001b[0m collected_data \u001b[39m=\u001b[39m collect_from_dataloaders(collector, dataloaders, required_data)\n\u001b[1;32m 31\u001b[0m score \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 32\u001b[0m \u001b[39mif\u001b[39;00m validator:\n",
  717. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/frameworks/ignite/utils.py:53\u001b[0m, in \u001b[0;36mcollect_from_dataloaders\u001b[0;34m(collector, dataloaders, required_data)\u001b[0m\n\u001b[1;32m 51\u001b[0m curr_dataset \u001b[39m=\u001b[39m curr_dataloader\u001b[39m.\u001b[39mdataset\n\u001b[1;32m 52\u001b[0m iterable \u001b[39m=\u001b[39m curr_dataloader\u001b[39m.\u001b[39m\u001b[39m__iter__\u001b[39m()\n\u001b[0;32m---> 53\u001b[0m curr_collected \u001b[39m=\u001b[39m accumulate_collector_output(collector, iterable, k)\n\u001b[1;32m 54\u001b[0m c_f\u001b[39m.\u001b[39mval_collected_data_checks(curr_collected, curr_dataset)\n\u001b[1;32m 55\u001b[0m collected_data[k] \u001b[39m=\u001b[39m curr_collected\n",
  718. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/frameworks/ignite/utils.py:98\u001b[0m, in \u001b[0;36maccumulate_collector_output\u001b[0;34m(collector, iterable, output_name)\u001b[0m\n\u001b[1;32m 96\u001b[0m accumulator \u001b[39m=\u001b[39m DictionaryAccumulator()\n\u001b[1;32m 97\u001b[0m accumulator\u001b[39m.\u001b[39mattach(collector, output_name)\n\u001b[0;32m---> 98\u001b[0m collector\u001b[39m.\u001b[39;49mrun(iterable)\n\u001b[1;32m 99\u001b[0m accumulator\u001b[39m.\u001b[39mdetach(collector)\n\u001b[1;32m 100\u001b[0m output \u001b[39m=\u001b[39m collector\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mmetrics[output_name]\n",
  719. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/ignite/engine/engine.py:704\u001b[0m, in \u001b[0;36mEngine.run\u001b[0;34m(self, data, max_epochs, epoch_length, seed)\u001b[0m\n\u001b[1;32m 701\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mepoch_length should be provided if data is None\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 703\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mdataloader \u001b[39m=\u001b[39m data\n\u001b[0;32m--> 704\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_internal_run()\n",
  720. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/ignite/engine/engine.py:783\u001b[0m, in \u001b[0;36mEngine._internal_run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 781\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataloader_iter \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 782\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlogger\u001b[39m.\u001b[39merror(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mEngine run is terminating due to exception: \u001b[39m\u001b[39m{\u001b[39;00me\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 783\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_handle_exception(e)\n\u001b[1;32m 785\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataloader_iter \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 786\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\n",
  721. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/ignite/engine/engine.py:466\u001b[0m, in \u001b[0;36mEngine._handle_exception\u001b[0;34m(self, e)\u001b[0m\n\u001b[1;32m 464\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_fire_event(Events\u001b[39m.\u001b[39mEXCEPTION_RAISED, e)\n\u001b[1;32m 465\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 466\u001b[0m \u001b[39mraise\u001b[39;00m e\n",
  722. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/ignite/engine/engine.py:753\u001b[0m, in \u001b[0;36mEngine._internal_run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 750\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataloader_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 751\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_setup_engine()\n\u001b[0;32m--> 753\u001b[0m time_taken \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_run_once_on_dataset()\n\u001b[1;32m 754\u001b[0m \u001b[39m# time is available for handlers but must be update after fire\u001b[39;00m\n\u001b[1;32m 755\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mtimes[Events\u001b[39m.\u001b[39mEPOCH_COMPLETED\u001b[39m.\u001b[39mname] \u001b[39m=\u001b[39m time_taken\n",
  723. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/ignite/engine/engine.py:854\u001b[0m, in \u001b[0;36mEngine._run_once_on_dataset\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 852\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e:\n\u001b[1;32m 853\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlogger\u001b[39m.\u001b[39merror(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mCurrent run is terminating due to exception: \u001b[39m\u001b[39m{\u001b[39;00me\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 854\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_handle_exception(e)\n\u001b[1;32m 856\u001b[0m \u001b[39mreturn\u001b[39;00m time\u001b[39m.\u001b[39mtime() \u001b[39m-\u001b[39m start_time\n",
  724. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/ignite/engine/engine.py:466\u001b[0m, in \u001b[0;36mEngine._handle_exception\u001b[0;34m(self, e)\u001b[0m\n\u001b[1;32m 464\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_fire_event(Events\u001b[39m.\u001b[39mEXCEPTION_RAISED, e)\n\u001b[1;32m 465\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 466\u001b[0m \u001b[39mraise\u001b[39;00m e\n",
  725. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/ignite/engine/engine.py:840\u001b[0m, in \u001b[0;36mEngine._run_once_on_dataset\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 838\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39miteration \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[1;32m 839\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_fire_event(Events\u001b[39m.\u001b[39mITERATION_STARTED)\n\u001b[0;32m--> 840\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39moutput \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_process_function(\u001b[39mself\u001b[39;49m, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstate\u001b[39m.\u001b[39;49mbatch)\n\u001b[1;32m 841\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_fire_event(Events\u001b[39m.\u001b[39mITERATION_COMPLETED)\n\u001b[1;32m 843\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mshould_terminate \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mshould_terminate_single_epoch:\n",
  726. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/frameworks/ignite/ignite.py:332\u001b[0m, in \u001b[0;36mIgnite.get_collector_step.<locals>.collector_step\u001b[0;34m(engine, batch)\u001b[0m\n\u001b[1;32m 330\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mcollector_step\u001b[39m(engine, batch):\n\u001b[1;32m 331\u001b[0m batch \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mbatch_to_device(batch, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdevice)\n\u001b[0;32m--> 332\u001b[0m \u001b[39mreturn\u001b[39;00m f_utils\u001b[39m.\u001b[39;49mcollector_step(inference, batch, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mval_output_dict_fn)\n",
  727. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/frameworks/utils.py:33\u001b[0m, in \u001b[0;36mcollector_step\u001b[0;34m(inference, batch, output_dict_fn)\u001b[0m\n\u001b[1;32m 31\u001b[0m data \u001b[39m=\u001b[39m extract_data(batch)\n\u001b[1;32m 32\u001b[0m \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mno_grad():\n\u001b[0;32m---> 33\u001b[0m f_dict \u001b[39m=\u001b[39m inference(data[\u001b[39m\"\u001b[39;49m\u001b[39mimgs\u001b[39;49m\u001b[39m\"\u001b[39;49m], domain\u001b[39m=\u001b[39;49mdata[\u001b[39m\"\u001b[39;49m\u001b[39mdomain\u001b[39;49m\u001b[39m\"\u001b[39;49m])\n\u001b[1;32m 34\u001b[0m data\u001b[39m.\u001b[39mpop(\u001b[39m\"\u001b[39m\u001b[39mimgs\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39m# we don't want to collect imgs\u001b[39;00m\n\u001b[1;32m 35\u001b[0m f_dict \u001b[39m=\u001b[39m output_dict_fn(f_dict)\n",
  728. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/adapters/base_adapter.py:120\u001b[0m, in \u001b[0;36mBaseAdapter.inference\u001b[0;34m(self, x, domain)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39minference\u001b[39m(\n\u001b[1;32m 110\u001b[0m \u001b[39mself\u001b[39m, x: torch\u001b[39m.\u001b[39mTensor, domain: \u001b[39mint\u001b[39m \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 111\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Tuple[torch\u001b[39m.\u001b[39mTensor, torch\u001b[39m.\u001b[39mTensor]:\n\u001b[1;32m 112\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 113\u001b[0m \u001b[39m Arguments:\u001b[39;00m\n\u001b[1;32m 114\u001b[0m \u001b[39m x: The input to the model\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[39m Features and logits\u001b[39;00m\n\u001b[1;32m 119\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 120\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minference_fn(\n\u001b[1;32m 121\u001b[0m x\u001b[39m=\u001b[39;49mx,\n\u001b[1;32m 122\u001b[0m domain\u001b[39m=\u001b[39;49mdomain,\n\u001b[1;32m 123\u001b[0m models\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodels,\n\u001b[1;32m 124\u001b[0m misc\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmisc,\n\u001b[1;32m 125\u001b[0m )\n",
  729. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/inference/inference.py:145\u001b[0m, in \u001b[0;36mmcd_fn\u001b[0;34m(x, models, get_all, **kwargs)\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mmcd_fn\u001b[39m(x, models, get_all\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 141\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 142\u001b[0m \u001b[39m Returns:\u001b[39;00m\n\u001b[1;32m 143\u001b[0m \u001b[39m Features and logits, where ```logits = sum(C(features))```.\u001b[39;00m\n\u001b[1;32m 144\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 145\u001b[0m features \u001b[39m=\u001b[39m models[\u001b[39m\"\u001b[39;49m\u001b[39mG\u001b[39;49m\u001b[39m\"\u001b[39;49m](x)\n\u001b[1;32m 146\u001b[0m logits_list \u001b[39m=\u001b[39m models[\u001b[39m\"\u001b[39m\u001b[39mC\u001b[39m\u001b[39m\"\u001b[39m](features)\n\u001b[1;32m 147\u001b[0m logits \u001b[39m=\u001b[39m \u001b[39msum\u001b[39m(logits_list)\n",
  730. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/torch/nn/modules/module.py:889\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_slow_forward(\u001b[39m*\u001b[39m\u001b[39minput\u001b[39m, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m 888\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 889\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mforward(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 890\u001b[0m \u001b[39mfor\u001b[39;00m hook \u001b[39min\u001b[39;00m itertools\u001b[39m.\u001b[39mchain(\n\u001b[1;32m 891\u001b[0m _global_forward_hooks\u001b[39m.\u001b[39mvalues(),\n\u001b[1;32m 892\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks\u001b[39m.\u001b[39mvalues()):\n\u001b[1;32m 893\u001b[0m hook_result \u001b[39m=\u001b[39m hook(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m, result)\n",
  731. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py:157\u001b[0m, in \u001b[0;36mDataParallel.forward\u001b[0;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[39mif\u001b[39;00m t\u001b[39m.\u001b[39mdevice \u001b[39m!=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msrc_device_obj:\n\u001b[1;32m 153\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mmodule must have its parameters and buffers \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 154\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mon device \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m (device_ids[0]) but found one of \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 155\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mthem on device: \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39msrc_device_obj, t\u001b[39m.\u001b[39mdevice))\n\u001b[0;32m--> 157\u001b[0m inputs, kwargs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mscatter(inputs, kwargs, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdevice_ids)\n\u001b[1;32m 158\u001b[0m \u001b[39m# for forward function without any inputs, empty list and dict will be created\u001b[39;00m\n\u001b[1;32m 159\u001b[0m \u001b[39m# so the module can be executed on one device which is the first one in device_ids\u001b[39;00m\n\u001b[1;32m 160\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m inputs \u001b[39mand\u001b[39;00m \u001b[39mnot\u001b[39;00m kwargs:\n",
  732. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py:174\u001b[0m, in \u001b[0;36mDataParallel.scatter\u001b[0;34m(self, inputs, kwargs, device_ids)\u001b[0m\n\u001b[1;32m 173\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mscatter\u001b[39m(\u001b[39mself\u001b[39m, inputs, kwargs, device_ids):\n\u001b[0;32m--> 174\u001b[0m \u001b[39mreturn\u001b[39;00m scatter_kwargs(inputs, kwargs, device_ids, dim\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdim)\n",
  733. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py:44\u001b[0m, in \u001b[0;36mscatter_kwargs\u001b[0;34m(inputs, kwargs, target_gpus, dim)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mscatter_kwargs\u001b[39m(inputs, kwargs, target_gpus, dim\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m):\n\u001b[1;32m 43\u001b[0m \u001b[39mr\u001b[39m\u001b[39m\"\"\"Scatter with support for kwargs dictionary\"\"\"\u001b[39;00m\n\u001b[0;32m---> 44\u001b[0m inputs \u001b[39m=\u001b[39m scatter(inputs, target_gpus, dim) \u001b[39mif\u001b[39;00m inputs \u001b[39melse\u001b[39;00m []\n\u001b[1;32m 45\u001b[0m kwargs \u001b[39m=\u001b[39m scatter(kwargs, target_gpus, dim) \u001b[39mif\u001b[39;00m kwargs \u001b[39melse\u001b[39;00m []\n\u001b[1;32m 46\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(inputs) \u001b[39m<\u001b[39m \u001b[39mlen\u001b[39m(kwargs):\n",
  734. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py:36\u001b[0m, in \u001b[0;36mscatter\u001b[0;34m(inputs, target_gpus, dim)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[39m# After scatter_map is called, a scatter_map cell will exist. This cell\u001b[39;00m\n\u001b[1;32m 31\u001b[0m \u001b[39m# has a reference to the actual function scatter_map, which has references\u001b[39;00m\n\u001b[1;32m 32\u001b[0m \u001b[39m# to a closure that has a reference to the scatter_map cell (because the\u001b[39;00m\n\u001b[1;32m 33\u001b[0m \u001b[39m# fn is recursive). To avoid this reference cycle, we set the function to\u001b[39;00m\n\u001b[1;32m 34\u001b[0m \u001b[39m# None, clearing the cell\u001b[39;00m\n\u001b[1;32m 35\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m---> 36\u001b[0m res \u001b[39m=\u001b[39m scatter_map(inputs)\n\u001b[1;32m 37\u001b[0m \u001b[39mfinally\u001b[39;00m:\n\u001b[1;32m 38\u001b[0m scatter_map \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n",
  735. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py:23\u001b[0m, in \u001b[0;36mscatter.<locals>.scatter_map\u001b[0;34m(obj)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[39mreturn\u001b[39;00m [\u001b[39mtype\u001b[39m(obj)(\u001b[39m*\u001b[39margs) \u001b[39mfor\u001b[39;00m args \u001b[39min\u001b[39;00m \u001b[39mzip\u001b[39m(\u001b[39m*\u001b[39m\u001b[39mmap\u001b[39m(scatter_map, obj))]\n\u001b[1;32m 22\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(obj, \u001b[39mtuple\u001b[39m) \u001b[39mand\u001b[39;00m \u001b[39mlen\u001b[39m(obj) \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[0;32m---> 23\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mlist\u001b[39m(\u001b[39mzip\u001b[39;49m(\u001b[39m*\u001b[39;49m\u001b[39mmap\u001b[39;49m(scatter_map, obj)))\n\u001b[1;32m 24\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(obj, \u001b[39mlist\u001b[39m) \u001b[39mand\u001b[39;00m \u001b[39mlen\u001b[39m(obj) \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[1;32m 25\u001b[0m \u001b[39mreturn\u001b[39;00m [\u001b[39mlist\u001b[39m(i) \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mzip\u001b[39m(\u001b[39m*\u001b[39m\u001b[39mmap\u001b[39m(scatter_map, obj))]\n",
  736. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py:19\u001b[0m, in \u001b[0;36mscatter.<locals>.scatter_map\u001b[0;34m(obj)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mscatter_map\u001b[39m(obj):\n\u001b[1;32m 18\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(obj, torch\u001b[39m.\u001b[39mTensor):\n\u001b[0;32m---> 19\u001b[0m \u001b[39mreturn\u001b[39;00m Scatter\u001b[39m.\u001b[39;49mapply(target_gpus, \u001b[39mNone\u001b[39;49;00m, dim, obj)\n\u001b[1;32m 20\u001b[0m \u001b[39mif\u001b[39;00m is_namedtuple(obj):\n\u001b[1;32m 21\u001b[0m \u001b[39mreturn\u001b[39;00m [\u001b[39mtype\u001b[39m(obj)(\u001b[39m*\u001b[39margs) \u001b[39mfor\u001b[39;00m args \u001b[39min\u001b[39;00m \u001b[39mzip\u001b[39m(\u001b[39m*\u001b[39m\u001b[39mmap\u001b[39m(scatter_map, obj))]\n",
  737. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:93\u001b[0m, in \u001b[0;36mScatter.forward\u001b[0;34m(ctx, target_gpus, chunk_sizes, dim, input)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[39mif\u001b[39;00m torch\u001b[39m.\u001b[39mcuda\u001b[39m.\u001b[39mis_available() \u001b[39mand\u001b[39;00m ctx\u001b[39m.\u001b[39minput_device \u001b[39m==\u001b[39m \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m:\n\u001b[1;32m 91\u001b[0m \u001b[39m# Perform CPU to GPU copies in a background stream\u001b[39;00m\n\u001b[1;32m 92\u001b[0m streams \u001b[39m=\u001b[39m [_get_stream(device) \u001b[39mfor\u001b[39;00m device \u001b[39min\u001b[39;00m target_gpus]\n\u001b[0;32m---> 93\u001b[0m outputs \u001b[39m=\u001b[39m comm\u001b[39m.\u001b[39;49mscatter(\u001b[39minput\u001b[39;49m, target_gpus, chunk_sizes, ctx\u001b[39m.\u001b[39;49mdim, streams)\n\u001b[1;32m 94\u001b[0m \u001b[39m# Synchronize with the copy stream\u001b[39;00m\n\u001b[1;32m 95\u001b[0m \u001b[39mif\u001b[39;00m streams \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n",
  738. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/torch/nn/parallel/comm.py:189\u001b[0m, in \u001b[0;36mscatter\u001b[0;34m(tensor, devices, chunk_sizes, dim, streams, out)\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[39mif\u001b[39;00m out \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 188\u001b[0m devices \u001b[39m=\u001b[39m [_get_device_index(d) \u001b[39mfor\u001b[39;00m d \u001b[39min\u001b[39;00m devices]\n\u001b[0;32m--> 189\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mtuple\u001b[39m(torch\u001b[39m.\u001b[39;49m_C\u001b[39m.\u001b[39;49m_scatter(tensor, devices, chunk_sizes, dim, streams))\n\u001b[1;32m 190\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 191\u001b[0m \u001b[39mif\u001b[39;00m devices \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n",
  739. "\u001b[0;31mRuntimeError\u001b[0m: CUDA error: out of memory"
  740. ]
  741. }
  742. ],
  743. "source": [
  744. "logging.getLogger(\"pytorch-adapt\").setLevel(logging.WARNING)\n",
  745. "\n",
  746. "early_stopper_kwargs = {\"patience\":2}\n",
  747. "\n",
  748. "\n",
  749. "best_score, best_epoch = trainer.run(\n",
  750. " datasets, dataloader_creator=dc, max_epochs=5, check_initial_score=True, early_stopper_kwargs=early_stopper_kwargs\n",
  751. ")\n",
  752. "print(f\"best_score={best_score}, best_epoch={best_epoch}\")"
  753. ]
  754. },
  755. {
  756. "cell_type": "code",
  757. "execution_count": 10,
  758. "metadata": {},
  759. "outputs": [
  760. {
  761. "data": {
  762. "image/png": "",
  763. "text/plain": [
  764. "<Figure size 600x400 with 1 Axes>"
  765. ]
  766. },
  767. "metadata": {},
  768. "output_type": "display_data"
  769. },
  770. {
  771. "data": {
  772. "image/png": "",
  773. "text/plain": [
  774. "<Figure size 600x400 with 1 Axes>"
  775. ]
  776. },
  777. "metadata": {},
  778. "output_type": "display_data"
  779. },
  780. {
  781. "ename": "",
  782. "evalue": "",
  783. "output_type": "error",
  784. "traceback": [
  785. "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
  786. ]
  787. }
  788. ],
  789. "source": [
  790. "plt.plot(validator.score_history[1:])\n",
  791. "plt.title(\"score_history\")\n",
  792. "plt.show()\n",
  793. "\n",
  794. "plt.plot(val_hooks[0].score_history[1:], label='source')\n",
  795. "plt.plot(val_hooks[1].score_history[1:], label='target')\n",
  796. "plt.legend()\n",
  797. "plt.title(\"src, target val accuracy\")\n",
  798. "plt.show()"
  799. ]
  800. },
  801. {
  802. "cell_type": "code",
  803. "execution_count": 6,
  804. "metadata": {},
  805. "outputs": [
  806. {
  807. "name": "stdout",
  808. "output_type": "stream",
  809. "text": [
  810. "validator score history [2.85106766 1.07307923 1.06300926 1.43146741 1.88468337 1.96975768\n",
  811. " 2.14979064 2.62666798 2.61236888 2.4836719 2.59688711 2.05844843\n",
  812. " 2.02083945 2.18156433 2.30989116 2.43664771 2.50495136 2.57028472\n",
  813. " 2.55985087 2.59315497 2.65624613]\n",
  814. "src_val accuracy history [0.90425533 0.20921986 0.40602836 0.52482271 0.49822694 0.5673759\n",
  815. " 0.57624114 0.59042555 0.59751773 0.57978725 0.63297874 0.57446808\n",
  816. " 0.65602839 0.66312057 0.66489363 0.66666669 0.65957445 0.65957445\n",
  817. " 0.67730498 0.67375886 0.66666669]\n"
  818. ]
  819. }
  820. ],
  821. "source": []
  822. },
  823. {
  824. "cell_type": "code",
  825. "execution_count": 7,
  826. "metadata": {},
  827. "outputs": [
  828. {
  829. "name": "stderr",
  830. "output_type": "stream",
  831. "text": [
  832. "INFO:pytorch-adapt:***EVALUATING BEST MODEL***\n",
  833. "INFO:pytorch-adapt:Collecting target_val_with_labels\n",
  834. "INFO:pytorch-adapt:Setting models to eval() mode\n"
  835. ]
  836. },
  837. {
  838. "data": {
  839. "application/vnd.jupyter.widget-view+json": {
  840. "model_id": "4980a70229194e1eb8fee18c5af651c1",
  841. "version_major": 2,
  842. "version_minor": 0
  843. },
  844. "text/plain": [
  845. "[1/5] 20%|## |it [00:00<?]"
  846. ]
  847. },
  848. "metadata": {},
  849. "output_type": "display_data"
  850. },
  851. {
  852. "name": "stdout",
  853. "output_type": "stream",
  854. "text": [
  855. "0.48427674174308777\n"
  856. ]
  857. }
  858. ],
  859. "source": [
  860. "validator = AccuracyValidator(key_map={\"target_val_with_labels\": \"src_val\"})\n",
  861. "score = trainer.evaluate_best_model(datasets, validator, dc)\n",
  862. "print(score)"
  863. ]
  864. },
  865. {
  866. "cell_type": "code",
  867. "execution_count": 8,
  868. "metadata": {},
  869. "outputs": [
  870. {
  871. "name": "stderr",
  872. "output_type": "stream",
  873. "text": [
  874. "INFO:pytorch-adapt:***EVALUATING BEST MODEL***\n",
  875. "INFO:pytorch-adapt:Collecting src_val\n",
  876. "INFO:pytorch-adapt:Setting models to eval() mode\n"
  877. ]
  878. },
  879. {
  880. "data": {
  881. "application/vnd.jupyter.widget-view+json": {
  882. "model_id": "c372abf93f1e4808b286822e23a49261",
  883. "version_major": 2,
  884. "version_minor": 0
  885. },
  886. "text/plain": [
  887. "[1/18] 6%|5 |it [00:00<?]"
  888. ]
  889. },
  890. "metadata": {},
  891. "output_type": "display_data"
  892. },
  893. {
  894. "name": "stdout",
  895. "output_type": "stream",
  896. "text": [
  897. "0.6666666865348816\n"
  898. ]
  899. }
  900. ],
  901. "source": [
  902. "validator = AccuracyValidator(key_map={\"src_val\": \"src_val\"})\n",
  903. "score = trainer.evaluate_best_model(datasets, validator, dc)\n",
  904. "print(score)"
  905. ]
  906. },
  907. {
  908. "cell_type": "code",
  909. "execution_count": null,
  910. "metadata": {},
  911. "outputs": [],
  912. "source": []
  913. },
  914. {
  915. "attachments": {},
  916. "cell_type": "markdown",
  917. "metadata": {},
  918. "source": [
  919. "### LOAD"
  920. ]
  921. },
  922. {
  923. "cell_type": "code",
  924. "execution_count": 6,
  925. "metadata": {},
  926. "outputs": [],
  927. "source": [
  928. "path = \"/media/10TB71/shashemi/Domain-Adaptation/results/DAModels.CORAL/0/w2d/saved_models/checkpointer_49.pt\"\n",
  929. "base_path = \"/media/10TB71/shashemi/Domain-Adaptation/results/DAModels.CORAL/0/w2d/saved_models/\"\n",
  930. "checkpoint = torch.load(path)"
  931. ]
  932. },
  933. {
  934. "cell_type": "code",
  935. "execution_count": 22,
  936. "metadata": {},
  937. "outputs": [],
  938. "source": [
  939. "base_path = \"/media/10TB71/shashemi/Domain-Adaptation/results/DAModels.DANN/0/w2a/saved_models/\""
  940. ]
  941. },
  942. {
  943. "cell_type": "code",
  944. "execution_count": 23,
  945. "metadata": {},
  946. "outputs": [
  947. {
  948. "name": "stderr",
  949. "output_type": "stream",
  950. "text": [
  951. "2021-05-25 17:32:04,308 ignite.distributed.auto.auto_model INFO: Apply torch DataParallel on model\n",
  952. "2021-05-25 17:32:04,309 ignite.distributed.auto.auto_model INFO: Apply torch DataParallel on model\n",
  953. "2021-05-25 17:32:04,311 ignite.distributed.auto.auto_model INFO: Apply torch DataParallel on model\n"
  954. ]
  955. }
  956. ],
  957. "source": [
  958. "\n",
  959. "import matplotlib.pyplot as plt\n",
  960. "import torch\n",
  961. "import os\n",
  962. "import gc\n",
  963. "from datetime import datetime\n",
  964. "\n",
  965. "from pytorch_adapt.datasets import DataloaderCreator, get_office31\n",
  966. "from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite\n",
  967. "from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory, DiversityValidator, EntropyValidator, MultipleValidators\n",
  968. "\n",
  969. "from models import get_model\n",
  970. "from utils import DAModels\n",
  971. "\n",
  972. "from pytorch_adapt.frameworks.ignite import (\n",
  973. " CheckpointFnCreator,\n",
  974. " IgniteValHookWrapper,\n",
  975. " checkpoint_utils,\n",
  976. ")\n",
  977. "\n",
  978. "checkpoint_fn = CheckpointFnCreator(dirname=base_path, require_empty=False, n_saved=None)\n",
  979. " \n",
  980. " \n",
  981. "sourceAccuracyValidator = AccuracyValidator()\n",
  982. "validators = {\n",
  983. " \"entropy\": EntropyValidator(),\n",
  984. " \"diversity\": DiversityValidator(),\n",
  985. " # \"accuracy\": sourceAccuracyValidator,\n",
  986. "}\n",
  987. "validator = ScoreHistory(MultipleValidators(validators))\n",
  988. "\n",
  989. "\n",
  990. "targetAccuracyValidator = AccuracyValidator(key_map={\"target_val_with_labels\": \"src_val\"})\n",
  991. "\n",
  992. "\n",
  993. "val_hooks = [ScoreHistory(sourceAccuracyValidator), \n",
  994. " ScoreHistory(targetAccuracyValidator)]\n",
  995. "\n",
  996. "source_domain =\"webcam\"\n",
  997. "target_domain = \"amazon\"\n",
  998. "\n",
  999. "G = office31G(pretrained=True, model_dir=weights_root).to(device)\n",
  1000. "C = office31C(domain=source_domain, pretrained=True,\n",
  1001. " model_dir=weights_root).to(device)\n",
  1002. "D = Discriminator(in_size=2048, h=1024).to(device)\n",
  1003. "\n",
  1004. "optimizers = Optimizers((torch.optim.Adam, {\"lr\": 1e-4}))\n",
  1005. "lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {\"gamma\": 0.99}))\n",
  1006. "\n",
  1007. "\n",
  1008. "models = Models({\"G\": G, \"C\": C, \"D\": D})\n",
  1009. "adapter = DANN(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)\n",
  1010. "\n",
  1011. "\n",
  1012. "trainer = Ignite(\n",
  1013. " adapter, validator=validator, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn\n",
  1014. ")\n"
  1015. ]
  1016. },
  1017. {
  1018. "cell_type": "code",
  1019. "execution_count": 37,
  1020. "metadata": {},
  1021. "outputs": [
  1022. {
  1023. "ename": "RuntimeError",
  1024. "evalue": "Error(s) in loading state_dict for DataParallel:\n\tMissing key(s) in state_dict: \"module.conv1.weight\", \"module.bn1.weight\", \"module.bn1.bias\", \"module.bn1.running_mean\", \"module.bn1.running_var\", \"module.layer1.0.conv1.weight\", \"module.layer1.0.bn1.weight\", \"module.layer1.0.bn1.bias\", \"module.layer1.0.bn1.running_mean\", \"module.layer1.0.bn1.running_var\", \"module.layer1.0.conv2.weight\", \"module.layer1.0.bn2.weight\", \"module.layer1.0.bn2.bias\", \"module.layer1.0.bn2.running_mean\", \"module.layer1.0.bn2.running_var\", \"module.layer1.0.conv3.weight\", \"module.layer1.0.bn3.weight\", \"module.layer1.0.bn3.bias\", \"module.layer1.0.bn3.running_mean\", \"module.layer1.0.bn3.running_var\", \"module.layer1.0.downsample.0.weight\", \"module.layer1.0.downsample.1.weight\", \"module.layer1.0.downsample.1.bias\", \"module.layer1.0.downsample.1.running_mean\", \"module.layer1.0.downsample.1.running_var\", \"module.layer1.1.conv1.weight\", \"module.layer1.1.bn1.weight\", \"module.layer1.1.bn1.bias\", \"module.layer1.1.bn1.running_mean\", \"module.layer1.1.bn1.running_var\", \"module.layer1.1.conv2.weight\", \"module.layer1.1.bn2.weight\", \"module.layer1.1.bn2.bias\", \"module.layer1.1.bn2.running_mean\", \"module.layer1.1.bn2.running_var\", \"module.layer1.1.conv3.weight\", \"module.layer1.1.bn3.weight\", \"module.layer1.1.bn3.bias\", \"module.layer1.1.bn3.running_mean\", \"module.layer1.1.bn3.running_var\", \"module.layer1.2.conv1.weight\", \"module.layer1.2.bn1.weight\", \"module.layer1.2.bn1.bias\", \"module.layer1.2.bn1.running_mean\", \"module.layer1.2.bn1.running_var\", \"module.layer1.2.conv2.weight\", \"module.layer1.2.bn2.weight\", \"module.layer1.2.bn2.bias\", \"module.layer1.2.bn2.running_mean\", \"module.layer1.2.bn2.running_var\", \"module.layer1.2.conv3.weight\", \"module.layer1.2.bn3.weight\", \"module.layer1.2.bn3.bias\", \"module.layer1.2.bn3.running_mean\", \"module.layer1.2.bn3.running_var\", \"module.layer2.0.conv1.weight\", \"module.layer2.0.bn1.weight\", \"module.layer2.0.bn1.bias\", \"module.layer2.0.bn1.running_mean\", \"module.layer2.0.bn1.running_var\", \"module.layer2.0.conv2.weight\", \"module.layer2.0.bn2.weight\", \"module.layer2.0.bn2.bias\", \"module.layer2.0.bn2.running_mean\", \"module.layer2.0.bn2.running_var\", \"module.layer2.0.conv3.weight\", \"module.layer2.0.bn3.weight\", \"module.layer2.0.bn3.bias\", \"module.layer2.0.bn3.running_mean\", \"module.layer2.0.bn3.running_var\", \"module.layer2.0.downsample.0.weight\", \"module.layer2.0.downsample.1.weight\", \"module.layer2.0.downsample.1.bias\", \"module.layer2.0.downsample.1.running_mean\", \"module.layer2.0.downsample.1.running_var\", \"module.layer2.1.conv1.weight\", \"module.layer2.1.bn1.weight\", \"module.layer2.1.bn1.bias\", \"module.layer2.1.bn1.running_mean\", \"module.layer2.1.bn1.running_var\", \"module.layer2.1.conv2.weight\", \"module.layer2.1.bn2.weight\", \"module.layer2.1.bn2.bias\", \"module.layer2.1.bn2.running_mean\", \"module.layer2.1.bn2.running_var\", \"module.layer2.1.conv3.weight\", \"module.layer2.1.bn3.weight\", \"module.layer2.1.bn3.bias\", \"module.layer2.1.bn3.running_mean\", \"module.layer2.1.bn3.running_var\", \"module.layer2.2.conv1.weight\", \"module.layer2.2.bn1.weight\", \"module.layer2.2.bn1.bias\", \"module.layer2.2.bn1.running_mean\", \"module.layer2.2.bn1.running_var\", \"module.layer2.2.conv2.weight\", \"module.layer2.2.bn2.weight\", \"module.layer2.2.bn2.bias\", \"module.layer2.2.bn2.running_mean\", \"module.layer2.2.bn2.running_var\", \"module.layer2.2.conv3.weight\", \"module.layer2.2.bn3.weight\", \"module.layer2.2.bn3.bias\", \"module.layer2.2.bn3.running_mean\", \"module.layer2.2.bn3.running_var\", \"module.layer2.3.conv1.weight\", \"module.layer2.3.bn1.weight\", \"module.layer2.3.bn1.bias\", \"module.layer2.3.bn1.running_mean\", \"module.layer2.3.bn1.running_var\", \"module.layer2.3.conv2.weight\", \"module.layer2.3.bn2.weight\", \"module.layer2.3.bn2.bias\", \"module.layer2.3.bn2.running_mean\", \"module.layer2.3.bn2.running_var\", \"module.layer2.3.conv3.weight\", \"module.layer2.3.bn3.weight\", \"module.layer2.3.bn3.bias\", \"module.layer2.3.bn3.running_mean\", \"module.layer2.3.bn3.running_var\", \"module.layer3.0.conv1.weight\", \"module.layer3.0.bn1.weight\", \"module.layer3.0.bn1.bias\", \"module.layer3.0.bn1.running_mean\", \"module.layer3.0.bn1.running_var\", \"module.layer3.0.conv2.weight\", \"module.layer3.0.bn2.weight\", \"module.layer3.0.bn2.bias\", \"module.layer3.0.bn2.running_mean\", \"module.layer3.0.bn2.running_var\", \"module.layer3.0.conv3.weight\", \"module.layer3.0.bn3.weight\", \"module.layer3.0.bn3.bias\", \"module.layer3.0.bn3.running_mean\", \"module.layer3.0.bn3.running_var\", \"module.layer3.0.downsample.0.weight\", \"module.layer3.0.downsample.1.weight\", \"module.layer3.0.downsample.1.bias\", \"module.layer3.0.downsample.1.running_mean\", \"module.layer3.0.downsample.1.running_var\", \"module.layer3.1.conv1.weight\", \"module.layer3.1.bn1.weight\", \"module.layer3.1.bn1.bias\", \"module.layer3.1.bn1.running_mean\", \"module.layer3.1.bn1.running_var\", \"module.layer3.1.conv2.weight\", \"module.layer3.1.bn2.weight\", \"module.layer3.1.bn2.bias\", \"module.layer3.1.bn2.running_mean\", \"module.layer3.1.bn2.running_var\", \"module.layer3.1.conv3.weight\", \"module.layer3.1.bn3.weight\", \"module.layer3.1.bn3.bias\", \"module.layer3.1.bn3.running_mean\", \"module.layer3.1.bn3.running_var\", \"module.layer3.2.conv1.weight\", \"module.layer3.2.bn1.weight\", \"module.layer3.2.bn1.bias\", \"module.layer3.2.bn1.running_mean\", \"module.layer3.2.bn1.running_var\", \"module.layer3.2.conv2.weight\", \"module.layer3.2.bn2.weight\", \"module.layer3.2.bn2.bias\", \"module.layer3.2.bn2.running_mean\", \"module.layer3.2.bn2.running_var\", \"module.layer3.2.conv3.weight\", \"module.layer3.2.bn3.weight\", \"module.layer3.2.bn3.bias\", \"module.layer3.2.bn3.running_mean\", \"module.layer3.2.bn3.running_var\", \"module.layer3.3.conv1.weight\", \"module.layer3.3.bn1.weight\", \"module.layer3.3.bn1.bias\", \"module.layer3.3.bn1.running_mean\", \"module.layer3.3.bn1.running_var\", \"module.layer3.3.conv2.weight\", \"module.layer3.3.bn2.weight\", \"module.layer3.3.bn2.bias\", \"module.layer3.3.bn2.running_mean\", \"module.layer3.3.bn2.running_var\", \"module.layer3.3.conv3.weight\", \"module.layer3.3.bn3.weight\", \"module.layer3.3.bn3.bias\", \"module.layer3.3.bn3.running_mean\", \"module.layer3.3.bn3.running_var\", \"module.layer3.4.conv1.weight\", \"module.layer3.4.bn1.weight\", \"module.layer3.4.bn1.bias\", \"module.layer3.4.bn1.running_mean\", \"module.layer3.4.bn1.running_var\", \"module.layer3.4.conv2.weight\", \"module.layer3.4.bn2.weight\", \"module.layer3.4.bn2.bias\", \"module.layer3.4.bn2.running_mean\", \"module.layer3.4.bn2.running_var\", \"module.layer3.4.conv3.weight\", \"module.layer3.4.bn3.weight\", \"module.layer3.4.bn3.bias\", \"module.layer3.4.bn3.running_mean\", \"module.layer3.4.bn3.running_var\", \"module.layer3.5.conv1.weight\", \"module.layer3.5.bn1.weight\", \"module.layer3.5.bn1.bias\", \"module.layer3.5.bn1.running_mean\", \"module.layer3.5.bn1.running_var\", \"module.layer3.5.conv2.weight\", \"module.layer3.5.bn2.weight\", \"module.layer3.5.bn2.bias\", \"module.layer3.5.bn2.running_mean\", \"module.layer3.5.bn2.running_var\", \"module.layer3.5.conv3.weight\", \"module.layer3.5.bn3.weight\", \"module.layer3.5.bn3.bias\", \"module.layer3.5.bn3.running_mean\", \"module.layer3.5.bn3.running_var\", \"module.layer4.0.conv1.weight\", \"module.layer4.0.bn1.weight\", \"module.layer4.0.bn1.bias\", \"module.layer4.0.bn1.running_mean\", \"module.layer4.0.bn1.running_var\", \"module.layer4.0.conv2.weight\", \"module.layer4.0.bn2.weight\", \"module.layer4.0.bn2.bias\", \"module.layer4.0.bn2.running_mean\", \"module.layer4.0.bn2.running_var\", \"module.layer4.0.conv3.weight\", \"module.layer4.0.bn3.weight\", \"module.layer4.0.bn3.bias\", \"module.layer4.0.bn3.running_mean\", \"module.layer4.0.bn3.running_var\", \"module.layer4.0.downsample.0.weight\", \"module.layer4.0.downsample.1.weight\", \"module.layer4.0.downsample.1.bias\", \"module.layer4.0.downsample.1.running_mean\", \"module.layer4.0.downsample.1.running_var\", \"module.layer4.1.conv1.weight\", \"module.layer4.1.bn1.weight\", \"module.layer4.1.bn1.bias\", \"module.layer4.1.bn1.running_mean\", \"module.layer4.1.bn1.running_var\", \"module.layer4.1.conv2.weight\", \"module.layer4.1.bn2.weight\", \"module.layer4.1.bn2.bias\", \"module.layer4.1.bn2.running_mean\", \"module.layer4.1.bn2.running_var\", \"module.layer4.1.conv3.weight\", \"module.layer4.1.bn3.weight\", \"module.layer4.1.bn3.bias\", \"module.layer4.1.bn3.running_mean\", \"module.layer4.1.bn3.running_var\", \"module.layer4.2.conv1.weight\", \"module.layer4.2.bn1.weight\", \"module.layer4.2.bn1.bias\", \"module.layer4.2.bn1.running_mean\", \"module.layer4.2.bn1.running_var\", \"module.layer4.2.conv2.weight\", \"module.layer4.2.bn2.weight\", \"module.layer4.2.bn2.bias\", \"module.layer4.2.bn2.running_mean\", \"module.layer4.2.bn2.running_var\", \"module.layer4.2.conv3.weight\", \"module.layer4.2.bn3.weight\", \"module.layer4.2.bn3.bias\", \"module.layer4.2.bn3.running_mean\", \"module.layer4.2.bn3.running_var\". \n\tUnexpected key(s) in state_dict: \"conv1.weight\", \"bn1.weight\", \"bn1.bias\", \"bn1.running_mean\", \"bn1.running_var\", \"bn1.num_batches_tracked\", \"layer1.0.conv1.weight\", \"layer1.0.bn1.weight\", \"layer1.0.bn1.bias\", \"layer1.0.bn1.running_mean\", \"layer1.0.bn1.running_var\", \"layer1.0.bn1.num_batches_tracked\", \"layer1.0.conv2.weight\", \"layer1.0.bn2.weight\", \"layer1.0.bn2.bias\", \"layer1.0.bn2.running_mean\", \"layer1.0.bn2.running_var\", \"layer1.0.bn2.num_batches_tracked\", \"layer1.0.conv3.weight\", \"layer1.0.bn3.weight\", \"layer1.0.bn3.bias\", \"layer1.0.bn3.running_mean\", \"layer1.0.bn3.running_var\", \"layer1.0.bn3.num_batches_tracked\", \"layer1.0.downsample.0.weight\", \"layer1.0.downsample.1.weight\", \"layer1.0.downsample.1.bias\", \"layer1.0.downsample.1.running_mean\", \"layer1.0.downsample.1.running_var\", \"layer1.0.downsample.1.num_batches_tracked\", \"layer1.1.conv1.weight\", \"layer1.1.bn1.weight\", \"layer1.1.bn1.bias\", \"layer1.1.bn1.running_mean\", \"layer1.1.bn1.running_var\", \"layer1.1.bn1.num_batches_tracked\", \"layer1.1.conv2.weight\", \"layer1.1.bn2.weight\", \"layer1.1.bn2.bias\", \"layer1.1.bn2.running_mean\", \"layer1.1.bn2.running_var\", \"layer1.1.bn2.num_batches_tracked\", \"layer1.1.conv3.weight\", \"layer1.1.bn3.weight\", \"layer1.1.bn3.bias\", \"layer1.1.bn3.running_mean\", \"layer1.1.bn3.running_var\", \"layer1.1.bn3.num_batches_tracked\", \"layer1.2.conv1.weight\", \"layer1.2.bn1.weight\", \"layer1.2.bn1.bias\", \"layer1.2.bn1.running_mean\", \"layer1.2.bn1.running_var\", \"layer1.2.bn1.num_batches_tracked\", \"layer1.2.conv2.weight\", \"layer1.2.bn2.weight\", \"layer1.2.bn2.bias\", \"layer1.2.bn2.running_mean\", \"layer1.2.bn2.running_var\", \"layer1.2.bn2.num_batches_tracked\", \"layer1.2.conv3.weight\", \"layer1.2.bn3.weight\", \"layer1.2.bn3.bias\", \"layer1.2.bn3.running_mean\", \"layer1.2.bn3.running_var\", \"layer1.2.bn3.num_batches_tracked\", \"layer2.0.conv1.weight\", \"layer2.0.bn1.weight\", \"layer2.0.bn1.bias\", \"layer2.0.bn1.running_mean\", \"layer2.0.bn1.running_var\", \"layer2.0.bn1.num_batches_tracked\", \"layer2.0.conv2.weight\", \"layer2.0.bn2.weight\", \"layer2.0.bn2.bias\", \"layer2.0.bn2.running_mean\", \"layer2.0.bn2.running_var\", \"layer2.0.bn2.num_batches_tracked\", \"layer2.0.conv3.weight\", \"layer2.0.bn3.weight\", \"layer2.0.bn3.bias\", \"layer2.0.bn3.running_mean\", \"layer2.0.bn3.running_var\", \"layer2.0.bn3.num_batches_tracked\", \"layer2.0.downsample.0.weight\", \"layer2.0.downsample.1.weight\", \"layer2.0.downsample.1.bias\", \"layer2.0.downsample.1.running_mean\", \"layer2.0.downsample.1.running_var\", \"layer2.0.downsample.1.num_batches_tracked\", \"layer2.1.conv1.weight\", \"layer2.1.bn1.weight\", \"layer2.1.bn1.bias\", \"layer2.1.bn1.running_mean\", \"layer2.1.bn1.running_var\", \"layer2.1.bn1.num_batches_tracked\", \"layer2.1.conv2.weight\", \"layer2.1.bn2.weight\", \"layer2.1.bn2.bias\", \"layer2.1.bn2.running_mean\", \"layer2.1.bn2.running_var\", \"layer2.1.bn2.num_batches_tracked\", \"layer2.1.conv3.weight\", \"layer2.1.bn3.weight\", \"layer2.1.bn3.bias\", \"layer2.1.bn3.running_mean\", \"layer2.1.bn3.running_var\", \"layer2.1.bn3.num_batches_tracked\", \"layer2.2.conv1.weight\", \"layer2.2.bn1.weight\", \"layer2.2.bn1.bias\", \"layer2.2.bn1.running_mean\", \"layer2.2.bn1.running_var\", \"layer2.2.bn1.num_batches_tracked\", \"layer2.2.conv2.weight\", \"layer2.2.bn2.weight\", \"layer2.2.bn2.bias\", \"layer2.2.bn2.running_mean\", \"layer2.2.bn2.running_var\", \"layer2.2.bn2.num_batches_tracked\", \"layer2.2.conv3.weight\", \"layer2.2.bn3.weight\", \"layer2.2.bn3.bias\", \"layer2.2.bn3.running_mean\", \"layer2.2.bn3.running_var\", \"layer2.2.bn3.num_batches_tracked\", \"layer2.3.conv1.weight\", \"layer2.3.bn1.weight\", \"layer2.3.bn1.bias\", \"layer2.3.bn1.running_mean\", \"layer2.3.bn1.running_var\", \"layer2.3.bn1.num_batches_tracked\", \"layer2.3.conv2.weight\", \"layer2.3.bn2.weight\", \"layer2.3.bn2.bias\", \"layer2.3.bn2.running_mean\", \"layer2.3.bn2.running_var\", \"layer2.3.bn2.num_batches_tracked\", \"layer2.3.conv3.weight\", \"layer2.3.bn3.weight\", \"layer2.3.bn3.bias\", \"layer2.3.bn3.running_mean\", \"layer2.3.bn3.running_var\", \"layer2.3.bn3.num_batches_tracked\", \"layer3.0.conv1.weight\", \"layer3.0.bn1.weight\", \"layer3.0.bn1.bias\", \"layer3.0.bn1.running_mean\", \"layer3.0.bn1.running_var\", \"layer3.0.bn1.num_batches_tracked\", \"layer3.0.conv2.weight\", \"layer3.0.bn2.weight\", \"layer3.0.bn2.bias\", \"layer3.0.bn2.running_mean\", \"layer3.0.bn2.running_var\", \"layer3.0.bn2.num_batches_tracked\", \"layer3.0.conv3.weight\", \"layer3.0.bn3.weight\", \"layer3.0.bn3.bias\", \"layer3.0.bn3.running_mean\", \"layer3.0.bn3.running_var\", \"layer3.0.bn3.num_batches_tracked\", \"layer3.0.downsample.0.weight\", \"layer3.0.downsample.1.weight\", \"layer3.0.downsample.1.bias\", \"layer3.0.downsample.1.running_mean\", \"layer3.0.downsample.1.running_var\", \"layer3.0.downsample.1.num_batches_tracked\", \"layer3.1.conv1.weight\", \"layer3.1.bn1.weight\", \"layer3.1.bn1.bias\", \"layer3.1.bn1.running_mean\", \"layer3.1.bn1.running_var\", \"layer3.1.bn1.num_batches_tracked\", \"layer3.1.conv2.weight\", \"layer3.1.bn2.weight\", \"layer3.1.bn2.bias\", \"layer3.1.bn2.running_mean\", \"layer3.1.bn2.running_var\", \"layer3.1.bn2.num_batches_tracked\", \"layer3.1.conv3.weight\", \"layer3.1.bn3.weight\", \"layer3.1.bn3.bias\", \"layer3.1.bn3.running_mean\", \"layer3.1.bn3.running_var\", \"layer3.1.bn3.num_batches_tracked\", \"layer3.2.conv1.weight\", \"layer3.2.bn1.weight\", \"layer3.2.bn1.bias\", \"layer3.2.bn1.running_mean\", \"layer3.2.bn1.running_var\", \"layer3.2.bn1.num_batches_tracked\", \"layer3.2.conv2.weight\", \"layer3.2.bn2.weight\", \"layer3.2.bn2.bias\", \"layer3.2.bn2.running_mean\", \"layer3.2.bn2.running_var\", \"layer3.2.bn2.num_batches_tracked\", \"layer3.2.conv3.weight\", \"layer3.2.bn3.weight\", \"layer3.2.bn3.bias\", \"layer3.2.bn3.running_mean\", \"layer3.2.bn3.running_var\", \"layer3.2.bn3.num_batches_tracked\", \"layer3.3.conv1.weight\", \"layer3.3.bn1.weight\", \"layer3.3.bn1.bias\", \"layer3.3.bn1.running_mean\", \"layer3.3.bn1.running_var\", \"layer3.3.bn1.num_batches_tracked\", \"layer3.3.conv2.weight\", \"layer3.3.bn2.weight\", \"layer3.3.bn2.bias\", \"layer3.3.bn2.running_mean\", \"layer3.3.bn2.running_var\", \"layer3.3.bn2.num_batches_tracked\", \"layer3.3.conv3.weight\", \"layer3.3.bn3.weight\", \"layer3.3.bn3.bias\", \"layer3.3.bn3.running_mean\", \"layer3.3.bn3.running_var\", \"layer3.3.bn3.num_batches_tracked\", \"layer3.4.conv1.weight\", \"layer3.4.bn1.weight\", \"layer3.4.bn1.bias\", \"layer3.4.bn1.running_mean\", \"layer3.4.bn1.running_var\", \"layer3.4.bn1.num_batches_tracked\", \"layer3.4.conv2.weight\", \"layer3.4.bn2.weight\", \"layer3.4.bn2.bias\", \"layer3.4.bn2.running_mean\", \"layer3.4.bn2.running_var\", \"layer3.4.bn2.num_batches_tracked\", \"layer3.4.conv3.weight\", \"layer3.4.bn3.weight\", \"layer3.4.bn3.bias\", \"layer3.4.bn3.running_mean\", \"layer3.4.bn3.running_var\", \"layer3.4.bn3.num_batches_tracked\", \"layer3.5.conv1.weight\", \"layer3.5.bn1.weight\", \"layer3.5.bn1.bias\", \"layer3.5.bn1.running_mean\", \"layer3.5.bn1.running_var\", \"layer3.5.bn1.num_batches_tracked\", \"layer3.5.conv2.weight\", \"layer3.5.bn2.weight\", \"layer3.5.bn2.bias\", \"layer3.5.bn2.running_mean\", \"layer3.5.bn2.running_var\", \"layer3.5.bn2.num_batches_tracked\", \"layer3.5.conv3.weight\", \"layer3.5.bn3.weight\", \"layer3.5.bn3.bias\", \"layer3.5.bn3.running_mean\", \"layer3.5.bn3.running_var\", \"layer3.5.bn3.num_batches_tracked\", \"layer4.0.conv1.weight\", \"layer4.0.bn1.weight\", \"layer4.0.bn1.bias\", \"layer4.0.bn1.running_mean\", \"layer4.0.bn1.running_var\", \"layer4.0.bn1.num_batches_tracked\", \"layer4.0.conv2.weight\", \"layer4.0.bn2.weight\", \"layer4.0.bn2.bias\", \"layer4.0.bn2.running_mean\", \"layer4.0.bn2.running_var\", \"layer4.0.bn2.num_batches_tracked\", \"layer4.0.conv3.weight\", \"layer4.0.bn3.weight\", \"layer4.0.bn3.bias\", \"layer4.0.bn3.running_mean\", \"layer4.0.bn3.running_var\", \"layer4.0.bn3.num_batches_tracked\", \"layer4.0.downsample.0.weight\", \"layer4.0.downsample.1.weight\", \"layer4.0.downsample.1.bias\", \"layer4.0.downsample.1.running_mean\", \"layer4.0.downsample.1.running_var\", \"layer4.0.downsample.1.num_batches_tracked\", \"layer4.1.conv1.weight\", \"layer4.1.bn1.weight\", \"layer4.1.bn1.bias\", \"layer4.1.bn1.running_mean\", \"layer4.1.bn1.running_var\", \"layer4.1.bn1.num_batches_tracked\", \"layer4.1.conv2.weight\", \"layer4.1.bn2.weight\", \"layer4.1.bn2.bias\", \"layer4.1.bn2.running_mean\", \"layer4.1.bn2.running_var\", \"layer4.1.bn2.num_batches_tracked\", \"layer4.1.conv3.weight\", \"layer4.1.bn3.weight\", \"layer4.1.bn3.bias\", \"layer4.1.bn3.running_mean\", \"layer4.1.bn3.running_var\", \"layer4.1.bn3.num_batches_tracked\", \"layer4.2.conv1.weight\", \"layer4.2.bn1.weight\", \"layer4.2.bn1.bias\", \"layer4.2.bn1.running_mean\", \"layer4.2.bn1.running_var\", \"layer4.2.bn1.num_batches_tracked\", \"layer4.2.conv2.weight\", \"layer4.2.bn2.weight\", \"layer4.2.bn2.bias\", \"layer4.2.bn2.running_mean\", \"layer4.2.bn2.running_var\", \"layer4.2.bn2.num_batches_tracked\", \"layer4.2.conv3.weight\", \"layer4.2.bn3.weight\", \"layer4.2.bn3.bias\", \"layer4.2.bn3.running_mean\", \"layer4.2.bn3.running_var\", \"layer4.2.bn3.num_batches_tracked\". ",
  1025. "output_type": "error",
  1026. "traceback": [
  1027. "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
  1028. "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
  1029. "Cell \u001b[0;32mIn[37], line 23\u001b[0m\n\u001b[1;32m 12\u001b[0m objs \u001b[39m=\u001b[39m [\n\u001b[1;32m 13\u001b[0m {\n\u001b[1;32m 14\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mengine\u001b[39m\u001b[39m\"\u001b[39m: trainer\u001b[39m.\u001b[39mtrainer,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 19\u001b[0m }\n\u001b[1;32m 20\u001b[0m ]\n\u001b[1;32m 22\u001b[0m \u001b[39mfor\u001b[39;00m to_load \u001b[39min\u001b[39;00m objs:\n\u001b[0;32m---> 23\u001b[0m checkpoint_fn\u001b[39m.\u001b[39;49mload_best_checkpoint(to_load)\n",
  1030. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/frameworks/ignite/checkpoint_utils.py:97\u001b[0m, in \u001b[0;36mCheckpointFnCreator.load_best_checkpoint\u001b[0;34m(self, to_load)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mload_best_checkpoint\u001b[39m(\u001b[39mself\u001b[39m, to_load):\n\u001b[1;32m 96\u001b[0m last_checkpoint \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mget_best_checkpoint()\n\u001b[0;32m---> 97\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mload_objects(to_load, last_checkpoint)\n",
  1031. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/frameworks/ignite/checkpoint_utils.py:93\u001b[0m, in \u001b[0;36mCheckpointFnCreator.load_objects\u001b[0;34m(self, to_load, checkpoint, global_step)\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mobjs\u001b[39m.\u001b[39mreload_objects(\n\u001b[1;32m 90\u001b[0m to_load, name\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mcheckpoint\u001b[39m\u001b[39m\"\u001b[39m, global_step\u001b[39m=\u001b[39mglobal_step\n\u001b[1;32m 91\u001b[0m )\n\u001b[1;32m 92\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 93\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mobjs\u001b[39m.\u001b[39;49mload_objects(to_load, \u001b[39mstr\u001b[39;49m(checkpoint))\n",
  1032. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/ignite/handlers/checkpoint.py:618\u001b[0m, in \u001b[0;36mCheckpoint.load_objects\u001b[0;34m(to_load, checkpoint, **kwargs)\u001b[0m\n\u001b[1;32m 616\u001b[0m \u001b[39mif\u001b[39;00m k \u001b[39mnot\u001b[39;00m \u001b[39min\u001b[39;00m checkpoint_obj:\n\u001b[1;32m 617\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mObject labeled by \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mk\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m from `to_load` is not found in the checkpoint\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 618\u001b[0m _load_object(obj, checkpoint_obj[k])\n",
  1033. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/ignite/handlers/checkpoint.py:605\u001b[0m, in \u001b[0;36mCheckpoint.load_objects.<locals>._load_object\u001b[0;34m(obj, chkpt_obj)\u001b[0m\n\u001b[1;32m 603\u001b[0m obj\u001b[39m.\u001b[39mload_state_dict(chkpt_obj, strict\u001b[39m=\u001b[39mis_state_dict_strict)\n\u001b[1;32m 604\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 605\u001b[0m obj\u001b[39m.\u001b[39;49mload_state_dict(chkpt_obj)\n",
  1034. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/containers/base_container.py:158\u001b[0m, in \u001b[0;36mBaseContainer.load_state_dict\u001b[0;34m(self, state_dict)\u001b[0m\n\u001b[1;32m 156\u001b[0m c_f\u001b[39m.\u001b[39massert_state_dict_keys(state_dict, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkeys())\n\u001b[1;32m 157\u001b[0m \u001b[39mfor\u001b[39;00m k, v \u001b[39min\u001b[39;00m state_dict\u001b[39m.\u001b[39mitems():\n\u001b[0;32m--> 158\u001b[0m \u001b[39mself\u001b[39;49m[k]\u001b[39m.\u001b[39;49mload_state_dict(v)\n",
  1035. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/torch/nn/modules/module.py:1223\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict)\u001b[0m\n\u001b[1;32m 1218\u001b[0m error_msgs\u001b[39m.\u001b[39minsert(\n\u001b[1;32m 1219\u001b[0m \u001b[39m0\u001b[39m, \u001b[39m'\u001b[39m\u001b[39mMissing key(s) in state_dict: \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m. \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 1220\u001b[0m \u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mjoin(\u001b[39m'\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(k) \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m missing_keys)))\n\u001b[1;32m 1222\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(error_msgs) \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[0;32m-> 1223\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m'\u001b[39m\u001b[39mError(s) in loading state_dict for \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 1224\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mjoin(error_msgs)))\n\u001b[1;32m 1225\u001b[0m \u001b[39mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n",
  1036. "\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for DataParallel:\n\tMissing key(s) in state_dict: \"module.conv1.weight\", \"module.bn1.weight\", \"module.bn1.bias\", \"module.bn1.running_mean\", \"module.bn1.running_var\", \"module.layer1.0.conv1.weight\", \"module.layer1.0.bn1.weight\", \"module.layer1.0.bn1.bias\", \"module.layer1.0.bn1.running_mean\", \"module.layer1.0.bn1.running_var\", \"module.layer1.0.conv2.weight\", \"module.layer1.0.bn2.weight\", \"module.layer1.0.bn2.bias\", \"module.layer1.0.bn2.running_mean\", \"module.layer1.0.bn2.running_var\", \"module.layer1.0.conv3.weight\", \"module.layer1.0.bn3.weight\", \"module.layer1.0.bn3.bias\", \"module.layer1.0.bn3.running_mean\", \"module.layer1.0.bn3.running_var\", \"module.layer1.0.downsample.0.weight\", \"module.layer1.0.downsample.1.weight\", \"module.layer1.0.downsample.1.bias\", \"module.layer1.0.downsample.1.running_mean\", \"module.layer1.0.downsample.1.running_var\", \"module.layer1.1.conv1.weight\", \"module.layer1.1.bn1.weight\", \"module.layer1.1.bn1.bias\", \"module.layer1.1.bn1.running_mean\", \"module.layer1.1.bn1.running_var\", \"module.layer1.1.conv2.weight\", \"module.layer1.1.bn2.weight\", \"module.layer1.1.bn2.bias\", \"module.layer1.1.bn2.running_mean\", \"module.layer1.1.bn2.running_var\", \"module.layer1.1.conv3.weight\", \"module.layer1.1.bn3.weight\", \"module.layer1.1.bn3.bias\", \"module.layer1.1.bn3.running_mean\", \"module.layer1.1.bn3.running_var\", \"module.layer1.2.conv1.weight\", \"module.layer1.2.bn1.weight\", \"module.layer1.2.bn1.bias\", \"module.layer1.2.bn1.running_mean\", \"module.layer1.2.bn1.running_var\", \"module.layer1.2.conv2.weight\", \"module.layer1.2.bn2.weight\", \"module.layer1.2.bn2.bias\", \"module.layer1.2.bn2.running_mean\", \"module.layer1.2.bn2.running_var\", \"module.layer1.2.conv3.weight\", \"module.layer1.2.bn3.weight\", \"module.layer1.2.bn3.bias\", \"module.layer1.2.bn3.running_mean\", \"module.layer1.2.bn3.running_var\", \"module.layer2.0.conv1.weight\", \"module.layer2.0.bn1.weight\", \"module.layer2.0.bn1.bias\", \"module.layer2.0.bn1.running_mean\", \"module.layer2.0.bn1.running_var\", \"module.layer2.0.conv2.weight\", \"module.layer2.0.bn2.weight\", \"module.layer2.0.bn2.bias\", \"module.layer2.0.bn2.running_mean\", \"module.layer2.0.bn2.running_var\", \"module.layer2.0.conv3.weight\", \"module.layer2.0.bn3.weight\", \"module.layer2.0.bn3.bias\", \"module.layer2.0.bn3.running_mean\", \"module.layer2.0.bn3.running_var\", \"module.layer2.0.downsample.0.weight\", \"module.layer2.0.downsample.1.weight\", \"module.layer2.0.downsample.1.bias\", \"module.layer2.0.downsample.1.running_mean\", \"module.layer2.0.downsample.1.running_var\", \"module.layer2.1.conv1.weight\", \"module.layer2.1.bn1.weight\", \"module.layer2.1.bn1.bias\", \"module.layer2.1.bn1.running_mean\", \"module.layer2.1.bn1.running_var\", \"module.layer2.1.conv2.weight\", \"module.layer2.1.bn2.weight\", \"module.layer2.1.bn2.bias\", \"module.layer2.1.bn2.running_mean\", \"module.layer2.1.bn2.running_var\", \"module.layer2.1.conv3.weight\", \"module.layer2.1.bn3.weight\", \"module.layer2.1.bn3.bias\", \"module.layer2.1.bn3.running_mean\", \"module.layer2.1.bn3.running_var\", \"module.layer2.2.conv1.weight\", \"module.layer2.2.bn1.weight\", \"module.layer2.2.bn1.bias\", \"module.layer2.2.bn1.running_mean\", \"module.layer2.2.bn1.running_var\", \"module.layer2.2.conv2.weight\", \"module.layer2.2.bn2.weight\", \"module.layer2.2.bn2.bias\", \"module.layer2.2.bn2.running_mean\", \"module.layer2.2.bn2.running_var\", \"module.layer2.2.conv3.weight\", \"module.layer2.2.bn3.weight\", \"module.layer2.2.bn3.bias\", \"module.layer2.2.bn3.running_mean\", \"module.layer2.2.bn3.running_var\", \"module.layer2.3.conv1.weight\", \"module.layer2.3.bn1.weight\", \"module.layer2.3.bn1.bias\", \"module.layer2.3.bn1.running_mean\", \"module.layer2.3.bn1.running_var\", \"module.layer2.3.conv2.weight\", \"module.layer2.3.bn2.weight\", \"module.layer2.3.bn2.bias\", \"module.layer2.3.bn2.running_mean\", \"module.layer2.3.bn2.running_var\", \"module.layer2.3.conv3.weight\", \"module.layer2.3.bn3.weight\", \"module.layer2.3.bn3.bias\", \"module.layer2.3.bn3.running_mean\", \"module.layer2.3.bn3.running_var\", \"module.layer3.0.conv1.weight\", \"module.layer3.0.bn1.weight\", \"module.layer3.0.bn1.bias\", \"module.layer3.0.bn1.running_mean\", \"module.layer3.0.bn1.running_var\", \"module.layer3.0.conv2.weight\", \"module.layer3.0.bn2.weight\", \"module.layer3.0.bn2.bias\", \"module.layer3.0.bn2.running_mean\", \"module.layer3.0.bn2.running_var\", \"module.layer3.0.conv3.weight\", \"module.layer3.0.bn3.weight\", \"module.layer3.0.bn3.bias\", \"module.layer3.0.bn3.running_mean\", \"module.layer3.0.bn3.running_var\", \"module.layer3.0.downsample.0.weight\", \"module.layer3.0.downsample.1.weight\", \"module.layer3.0.downsample.1.bias\", \"module.layer3.0.downsample.1.running_mean\", \"module.layer3.0.downsample.1.running_var\", \"module.layer3.1.conv1.weight\", \"module.layer3.1.bn1.weight\", \"module.layer3.1.bn1.bias\", \"module.layer3.1.bn1.running_mean\", \"module.layer3.1.bn1.running_var\", \"module.layer3.1.conv2.weight\", \"module.layer3.1.bn2.weight\", \"module.layer3.1.bn2.bias\", \"module.layer3.1.bn2.running_mean\", \"module.layer3.1.bn2.running_var\", \"module.layer3.1.conv3.weight\", \"module.layer3.1.bn3.weight\", \"module.layer3.1.bn3.bias\", \"module.layer3.1.bn3.running_mean\", \"module.layer3.1.bn3.running_var\", \"module.layer3.2.conv1.weight\", \"module.layer3.2.bn1.weight\", \"module.layer3.2.bn1.bias\", \"module.layer3.2.bn1.running_mean\", \"module.layer3.2.bn1.running_var\", \"module.layer3.2.conv2.weight\", \"module.layer3.2.bn2.weight\", \"module.layer3.2.bn2.bias\", \"module.layer3.2.bn2.running_mean\", \"module.layer3.2.bn2.running_var\", \"module.layer3.2.conv3.weight\", \"module.layer3.2.bn3.weight\", \"module.layer3.2.bn3.bias\", \"module.layer3.2.bn3.running_mean\", \"module.layer3.2.bn3.running_var\", \"module.layer3.3.conv1.weight\", \"module.layer3.3.bn1.weight\", \"module.layer3.3.bn1.bias\", \"module.layer3.3.bn1.running_mean\", \"module.layer3.3.bn1.running_var\", \"module.layer3.3.conv2.weight\", \"module.layer3.3.bn2.weight\", \"module.layer3.3.bn2.bias\", \"module.layer3.3.bn2.running_mean\", \"module.layer3.3.bn2.running_var\", \"module.layer3.3.conv3.weight\", \"module.layer3.3.bn3.weight\", \"module.layer3.3.bn3.bias\", \"module.layer3.3.bn3.running_mean\", \"module.layer3.3.bn3.running_var\", \"module.layer3.4.conv1.weight\", \"module.layer3.4.bn1.weight\", \"module.layer3.4.bn1.bias\", \"module.layer3.4.bn1.running_mean\", \"module.layer3.4.bn1.running_var\", \"module.layer3.4.conv2.weight\", \"module.layer3.4.bn2.weight\", \"module.layer3.4.bn2.bias\", \"module.layer3.4.bn2.running_mean\", \"module.layer3.4.bn2.running_var\", \"module.layer3.4.conv3.weight\", \"module.layer3.4.bn3.weight\", \"module.layer3.4.bn3.bias\", \"module.layer3.4.bn3.running_mean\", \"module.layer3.4.bn3.running_var\", \"module.layer3.5.conv1.weight\", \"module.layer3.5.bn1.weight\", \"module.layer3.5.bn1.bias\", \"module.layer3.5.bn1.running_mean\", \"module.layer3.5.bn1.running_var\", \"module.layer3.5.conv2.weight\", \"module.layer3.5.bn2.weight\", \"module.layer3.5.bn2.bias\", \"module.layer3.5.bn2.running_mean\", \"module.layer3.5.bn2.running_var\", \"module.layer3.5.conv3.weight\", \"module.layer3.5.bn3.weight\", \"module.layer3.5.bn3.bias\", \"module.layer3.5.bn3.running_mean\", \"module.layer3.5.bn3.running_var\", \"module.layer4.0.conv1.weight\", \"module.layer4.0.bn1.weight\", \"module.layer4.0.bn1.bias\", \"module.layer4.0.bn1.running_mean\", \"module.layer4.0.bn1.running_var\", \"module.layer4.0.conv2.weight\", \"module.layer4.0.bn2.weight\", \"module.layer4.0.bn2.bias\", \"module.layer4.0.bn2.running_mean\", \"module.layer4.0.bn2.running_var\", \"module.layer4.0.conv3.weight\", \"module.layer4.0.bn3.weight\", \"module.layer4.0.bn3.bias\", \"module.layer4.0.bn3.running_mean\", \"module.layer4.0.bn3.running_var\", \"module.layer4.0.downsample.0.weight\", \"module.layer4.0.downsample.1.weight\", \"module.layer4.0.downsample.1.bias\", \"module.layer4.0.downsample.1.running_mean\", \"module.layer4.0.downsample.1.running_var\", \"module.layer4.1.conv1.weight\", \"module.layer4.1.bn1.weight\", \"module.layer4.1.bn1.bias\", \"module.layer4.1.bn1.running_mean\", \"module.layer4.1.bn1.running_var\", \"module.layer4.1.conv2.weight\", \"module.layer4.1.bn2.weight\", \"module.layer4.1.bn2.bias\", \"module.layer4.1.bn2.running_mean\", \"module.layer4.1.bn2.running_var\", \"module.layer4.1.conv3.weight\", \"module.layer4.1.bn3.weight\", \"module.layer4.1.bn3.bias\", \"module.layer4.1.bn3.running_mean\", \"module.layer4.1.bn3.running_var\", \"module.layer4.2.conv1.weight\", \"module.layer4.2.bn1.weight\", \"module.layer4.2.bn1.bias\", \"module.layer4.2.bn1.running_mean\", \"module.layer4.2.bn1.running_var\", \"module.layer4.2.conv2.weight\", \"module.layer4.2.bn2.weight\", \"module.layer4.2.bn2.bias\", \"module.layer4.2.bn2.running_mean\", \"module.layer4.2.bn2.running_var\", \"module.layer4.2.conv3.weight\", \"module.layer4.2.bn3.weight\", \"module.layer4.2.bn3.bias\", \"module.layer4.2.bn3.running_mean\", \"module.layer4.2.bn3.running_var\". \n\tUnexpected key(s) in state_dict: \"conv1.weight\", \"bn1.weight\", \"bn1.bias\", \"bn1.running_mean\", \"bn1.running_var\", \"bn1.num_batches_tracked\", \"layer1.0.conv1.weight\", \"layer1.0.bn1.weight\", \"layer1.0.bn1.bias\", \"layer1.0.bn1.running_mean\", \"layer1.0.bn1.running_var\", \"layer1.0.bn1.num_batches_tracked\", \"layer1.0.conv2.weight\", \"layer1.0.bn2.weight\", \"layer1.0.bn2.bias\", \"layer1.0.bn2.running_mean\", \"layer1.0.bn2.running_var\", \"layer1.0.bn2.num_batches_tracked\", \"layer1.0.conv3.weight\", \"layer1.0.bn3.weight\", \"layer1.0.bn3.bias\", \"layer1.0.bn3.running_mean\", \"layer1.0.bn3.running_var\", \"layer1.0.bn3.num_batches_tracked\", \"layer1.0.downsample.0.weight\", \"layer1.0.downsample.1.weight\", \"layer1.0.downsample.1.bias\", \"layer1.0.downsample.1.running_mean\", \"layer1.0.downsample.1.running_var\", \"layer1.0.downsample.1.num_batches_tracked\", \"layer1.1.conv1.weight\", \"layer1.1.bn1.weight\", \"layer1.1.bn1.bias\", \"layer1.1.bn1.running_mean\", \"layer1.1.bn1.running_var\", \"layer1.1.bn1.num_batches_tracked\", \"layer1.1.conv2.weight\", \"layer1.1.bn2.weight\", \"layer1.1.bn2.bias\", \"layer1.1.bn2.running_mean\", \"layer1.1.bn2.running_var\", \"layer1.1.bn2.num_batches_tracked\", \"layer1.1.conv3.weight\", \"layer1.1.bn3.weight\", \"layer1.1.bn3.bias\", \"layer1.1.bn3.running_mean\", \"layer1.1.bn3.running_var\", \"layer1.1.bn3.num_batches_tracked\", \"layer1.2.conv1.weight\", \"layer1.2.bn1.weight\", \"layer1.2.bn1.bias\", \"layer1.2.bn1.running_mean\", \"layer1.2.bn1.running_var\", \"layer1.2.bn1.num_batches_tracked\", \"layer1.2.conv2.weight\", \"layer1.2.bn2.weight\", \"layer1.2.bn2.bias\", \"layer1.2.bn2.running_mean\", \"layer1.2.bn2.running_var\", \"layer1.2.bn2.num_batches_tracked\", \"layer1.2.conv3.weight\", \"layer1.2.bn3.weight\", \"layer1.2.bn3.bias\", \"layer1.2.bn3.running_mean\", \"layer1.2.bn3.running_var\", \"layer1.2.bn3.num_batches_tracked\", \"layer2.0.conv1.weight\", \"layer2.0.bn1.weight\", \"layer2.0.bn1.bias\", \"layer2.0.bn1.running_mean\", \"layer2.0.bn1.running_var\", \"layer2.0.bn1.num_batches_tracked\", \"layer2.0.conv2.weight\", \"layer2.0.bn2.weight\", \"layer2.0.bn2.bias\", \"layer2.0.bn2.running_mean\", \"layer2.0.bn2.running_var\", \"layer2.0.bn2.num_batches_tracked\", \"layer2.0.conv3.weight\", \"layer2.0.bn3.weight\", \"layer2.0.bn3.bias\", \"layer2.0.bn3.running_mean\", \"layer2.0.bn3.running_var\", \"layer2.0.bn3.num_batches_tracked\", \"layer2.0.downsample.0.weight\", \"layer2.0.downsample.1.weight\", \"layer2.0.downsample.1.bias\", \"layer2.0.downsample.1.running_mean\", \"layer2.0.downsample.1.running_var\", \"layer2.0.downsample.1.num_batches_tracked\", \"layer2.1.conv1.weight\", \"layer2.1.bn1.weight\", \"layer2.1.bn1.bias\", \"layer2.1.bn1.running_mean\", \"layer2.1.bn1.running_var\", \"layer2.1.bn1.num_batches_tracked\", \"layer2.1.conv2.weight\", \"layer2.1.bn2.weight\", \"layer2.1.bn2.bias\", \"layer2.1.bn2.running_mean\", \"layer2.1.bn2.running_var\", \"layer2.1.bn2.num_batches_tracked\", \"layer2.1.conv3.weight\", \"layer2.1.bn3.weight\", \"layer2.1.bn3.bias\", \"layer2.1.bn3.running_mean\", \"layer2.1.bn3.running_var\", \"layer2.1.bn3.num_batches_tracked\", \"layer2.2.conv1.weight\", \"layer2.2.bn1.weight\", \"layer2.2.bn1.bias\", \"layer2.2.bn1.running_mean\", \"layer2.2.bn1.running_var\", \"layer2.2.bn1.num_batches_tracked\", \"layer2.2.conv2.weight\", \"layer2.2.bn2.weight\", \"layer2.2.bn2.bias\", \"layer2.2.bn2.running_mean\", \"layer2.2.bn2.running_var\", \"layer2.2.bn2.num_batches_tracked\", \"layer2.2.conv3.weight\", \"layer2.2.bn3.weight\", \"layer2.2.bn3.bias\", \"layer2.2.bn3.running_mean\", \"layer2.2.bn3.running_var\", \"layer2.2.bn3.num_batches_tracked\", \"layer2.3.conv1.weight\", \"layer2.3.bn1.weight\", \"layer2.3.bn1.bias\", \"layer2.3.bn1.running_mean\", \"layer2.3.bn1.running_var\", \"layer2.3.bn1.num_batches_tracked\", \"layer2.3.conv2.weight\", \"layer2.3.bn2.weight\", \"layer2.3.bn2.bias\", \"layer2.3.bn2.running_mean\", \"layer2.3.bn2.running_var\", \"layer2.3.bn2.num_batches_tracked\", \"layer2.3.conv3.weight\", \"layer2.3.bn3.weight\", \"layer2.3.bn3.bias\", \"layer2.3.bn3.running_mean\", \"layer2.3.bn3.running_var\", \"layer2.3.bn3.num_batches_tracked\", \"layer3.0.conv1.weight\", \"layer3.0.bn1.weight\", \"layer3.0.bn1.bias\", \"layer3.0.bn1.running_mean\", \"layer3.0.bn1.running_var\", \"layer3.0.bn1.num_batches_tracked\", \"layer3.0.conv2.weight\", \"layer3.0.bn2.weight\", \"layer3.0.bn2.bias\", \"layer3.0.bn2.running_mean\", \"layer3.0.bn2.running_var\", \"layer3.0.bn2.num_batches_tracked\", \"layer3.0.conv3.weight\", \"layer3.0.bn3.weight\", \"layer3.0.bn3.bias\", \"layer3.0.bn3.running_mean\", \"layer3.0.bn3.running_var\", \"layer3.0.bn3.num_batches_tracked\", \"layer3.0.downsample.0.weight\", \"layer3.0.downsample.1.weight\", \"layer3.0.downsample.1.bias\", \"layer3.0.downsample.1.running_mean\", \"layer3.0.downsample.1.running_var\", \"layer3.0.downsample.1.num_batches_tracked\", \"layer3.1.conv1.weight\", \"layer3.1.bn1.weight\", \"layer3.1.bn1.bias\", \"layer3.1.bn1.running_mean\", \"layer3.1.bn1.running_var\", \"layer3.1.bn1.num_batches_tracked\", \"layer3.1.conv2.weight\", \"layer3.1.bn2.weight\", \"layer3.1.bn2.bias\", \"layer3.1.bn2.running_mean\", \"layer3.1.bn2.running_var\", \"layer3.1.bn2.num_batches_tracked\", \"layer3.1.conv3.weight\", \"layer3.1.bn3.weight\", \"layer3.1.bn3.bias\", \"layer3.1.bn3.running_mean\", \"layer3.1.bn3.running_var\", \"layer3.1.bn3.num_batches_tracked\", \"layer3.2.conv1.weight\", \"layer3.2.bn1.weight\", \"layer3.2.bn1.bias\", \"layer3.2.bn1.running_mean\", \"layer3.2.bn1.running_var\", \"layer3.2.bn1.num_batches_tracked\", \"layer3.2.conv2.weight\", \"layer3.2.bn2.weight\", \"layer3.2.bn2.bias\", \"layer3.2.bn2.running_mean\", \"layer3.2.bn2.running_var\", \"layer3.2.bn2.num_batches_tracked\", \"layer3.2.conv3.weight\", \"layer3.2.bn3.weight\", \"layer3.2.bn3.bias\", \"layer3.2.bn3.running_mean\", \"layer3.2.bn3.running_var\", \"layer3.2.bn3.num_batches_tracked\", \"layer3.3.conv1.weight\", \"layer3.3.bn1.weight\", \"layer3.3.bn1.bias\", \"layer3.3.bn1.running_mean\", \"layer3.3.bn1.running_var\", \"layer3.3.bn1.num_batches_tracked\", \"layer3.3.conv2.weight\", \"layer3.3.bn2.weight\", \"layer3.3.bn2.bias\", \"layer3.3.bn2.running_mean\", \"layer3.3.bn2.running_var\", \"layer3.3.bn2.num_batches_tracked\", \"layer3.3.conv3.weight\", \"layer3.3.bn3.weight\", \"layer3.3.bn3.bias\", \"layer3.3.bn3.running_mean\", \"layer3.3.bn3.running_var\", \"layer3.3.bn3.num_batches_tracked\", \"layer3.4.conv1.weight\", \"layer3.4.bn1.weight\", \"layer3.4.bn1.bias\", \"layer3.4.bn1.running_mean\", \"layer3.4.bn1.running_var\", \"layer3.4.bn1.num_batches_tracked\", \"layer3.4.conv2.weight\", \"layer3.4.bn2.weight\", \"layer3.4.bn2.bias\", \"layer3.4.bn2.running_mean\", \"layer3.4.bn2.running_var\", \"layer3.4.bn2.num_batches_tracked\", \"layer3.4.conv3.weight\", \"layer3.4.bn3.weight\", \"layer3.4.bn3.bias\", \"layer3.4.bn3.running_mean\", \"layer3.4.bn3.running_var\", \"layer3.4.bn3.num_batches_tracked\", \"layer3.5.conv1.weight\", \"layer3.5.bn1.weight\", \"layer3.5.bn1.bias\", \"layer3.5.bn1.running_mean\", \"layer3.5.bn1.running_var\", \"layer3.5.bn1.num_batches_tracked\", \"layer3.5.conv2.weight\", \"layer3.5.bn2.weight\", \"layer3.5.bn2.bias\", \"layer3.5.bn2.running_mean\", \"layer3.5.bn2.running_var\", \"layer3.5.bn2.num_batches_tracked\", \"layer3.5.conv3.weight\", \"layer3.5.bn3.weight\", \"layer3.5.bn3.bias\", \"layer3.5.bn3.running_mean\", \"layer3.5.bn3.running_var\", \"layer3.5.bn3.num_batches_tracked\", \"layer4.0.conv1.weight\", \"layer4.0.bn1.weight\", \"layer4.0.bn1.bias\", \"layer4.0.bn1.running_mean\", \"layer4.0.bn1.running_var\", \"layer4.0.bn1.num_batches_tracked\", \"layer4.0.conv2.weight\", \"layer4.0.bn2.weight\", \"layer4.0.bn2.bias\", \"layer4.0.bn2.running_mean\", \"layer4.0.bn2.running_var\", \"layer4.0.bn2.num_batches_tracked\", \"layer4.0.conv3.weight\", \"layer4.0.bn3.weight\", \"layer4.0.bn3.bias\", \"layer4.0.bn3.running_mean\", \"layer4.0.bn3.running_var\", \"layer4.0.bn3.num_batches_tracked\", \"layer4.0.downsample.0.weight\", \"layer4.0.downsample.1.weight\", \"layer4.0.downsample.1.bias\", \"layer4.0.downsample.1.running_mean\", \"layer4.0.downsample.1.running_var\", \"layer4.0.downsample.1.num_batches_tracked\", \"layer4.1.conv1.weight\", \"layer4.1.bn1.weight\", \"layer4.1.bn1.bias\", \"layer4.1.bn1.running_mean\", \"layer4.1.bn1.running_var\", \"layer4.1.bn1.num_batches_tracked\", \"layer4.1.conv2.weight\", \"layer4.1.bn2.weight\", \"layer4.1.bn2.bias\", \"layer4.1.bn2.running_mean\", \"layer4.1.bn2.running_var\", \"layer4.1.bn2.num_batches_tracked\", \"layer4.1.conv3.weight\", \"layer4.1.bn3.weight\", \"layer4.1.bn3.bias\", \"layer4.1.bn3.running_mean\", \"layer4.1.bn3.running_var\", \"layer4.1.bn3.num_batches_tracked\", \"layer4.2.conv1.weight\", \"layer4.2.bn1.weight\", \"layer4.2.bn1.bias\", \"layer4.2.bn1.running_mean\", \"layer4.2.bn1.running_var\", \"layer4.2.bn1.num_batches_tracked\", \"layer4.2.conv2.weight\", \"layer4.2.bn2.weight\", \"layer4.2.bn2.bias\", \"layer4.2.bn2.running_mean\", \"layer4.2.bn2.running_var\", \"layer4.2.bn2.num_batches_tracked\", \"layer4.2.conv3.weight\", \"layer4.2.bn3.weight\", \"layer4.2.bn3.bias\", \"layer4.2.bn3.running_mean\", \"layer4.2.bn3.running_var\", \"layer4.2.bn3.num_batches_tracked\". "
  1037. ]
  1038. }
  1039. ],
  1040. "source": [
  1041. "\n",
  1042. "load_partial = False\n",
  1043. "\n",
  1044. "if load_partial:\n",
  1045. " objs = [\n",
  1046. " {\n",
  1047. " \"engine\": trainer.trainer,\n",
  1048. " **checkpoint_utils.adapter_to_dict(trainer.adapter),\n",
  1049. " },\n",
  1050. " {\"validator\": validator},\n",
  1051. " ]\n",
  1052. "else:\n",
  1053. " objs = [\n",
  1054. " {\n",
  1055. " \"engine\": trainer.trainer,\n",
  1056. " \"validator\": trainer.validator,\n",
  1057. " \"models\": trainer.adapter.models\n",
  1058. " # \"adapter\": trainer.adapter,\n",
  1059. " # **checkpoint_utils.adapter_to_dict(trainer.adapter),\n",
  1060. " }\n",
  1061. " ]\n",
  1062. "\n",
  1063. "for to_load in objs:\n",
  1064. " checkpoint_fn.load_best_checkpoint(to_load)\n",
  1065. "\n",
  1066. "# trainer.adapter.models\n"
  1067. ]
  1068. },
  1069. {
  1070. "cell_type": "code",
  1071. "execution_count": 21,
  1072. "metadata": {},
  1073. "outputs": [
  1074. {
  1075. "data": {
  1076. "text/plain": [
  1077. "{'models': G: DataParallel(\n",
  1078. " (module): ResNet(\n",
  1079. " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
  1080. " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1081. " (act1): ReLU(inplace=True)\n",
  1082. " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
  1083. " (layer1): Sequential(\n",
  1084. " (0): Bottleneck(\n",
  1085. " (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1086. " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1087. " (act1): ReLU(inplace=True)\n",
  1088. " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
  1089. " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1090. " (act2): ReLU(inplace=True)\n",
  1091. " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1092. " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1093. " (act3): ReLU(inplace=True)\n",
  1094. " (downsample): Sequential(\n",
  1095. " (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1096. " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1097. " )\n",
  1098. " )\n",
  1099. " (1): Bottleneck(\n",
  1100. " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1101. " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1102. " (act1): ReLU(inplace=True)\n",
  1103. " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
  1104. " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1105. " (act2): ReLU(inplace=True)\n",
  1106. " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1107. " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1108. " (act3): ReLU(inplace=True)\n",
  1109. " )\n",
  1110. " (2): Bottleneck(\n",
  1111. " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1112. " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1113. " (act1): ReLU(inplace=True)\n",
  1114. " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
  1115. " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1116. " (act2): ReLU(inplace=True)\n",
  1117. " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1118. " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1119. " (act3): ReLU(inplace=True)\n",
  1120. " )\n",
  1121. " )\n",
  1122. " (layer2): Sequential(\n",
  1123. " (0): Bottleneck(\n",
  1124. " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1125. " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1126. " (act1): ReLU(inplace=True)\n",
  1127. " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
  1128. " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1129. " (act2): ReLU(inplace=True)\n",
  1130. " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1131. " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1132. " (act3): ReLU(inplace=True)\n",
  1133. " (downsample): Sequential(\n",
  1134. " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
  1135. " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1136. " )\n",
  1137. " )\n",
  1138. " (1): Bottleneck(\n",
  1139. " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1140. " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1141. " (act1): ReLU(inplace=True)\n",
  1142. " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
  1143. " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1144. " (act2): ReLU(inplace=True)\n",
  1145. " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1146. " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1147. " (act3): ReLU(inplace=True)\n",
  1148. " )\n",
  1149. " (2): Bottleneck(\n",
  1150. " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1151. " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1152. " (act1): ReLU(inplace=True)\n",
  1153. " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
  1154. " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1155. " (act2): ReLU(inplace=True)\n",
  1156. " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1157. " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1158. " (act3): ReLU(inplace=True)\n",
  1159. " )\n",
  1160. " (3): Bottleneck(\n",
  1161. " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1162. " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1163. " (act1): ReLU(inplace=True)\n",
  1164. " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
  1165. " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1166. " (act2): ReLU(inplace=True)\n",
  1167. " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1168. " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1169. " (act3): ReLU(inplace=True)\n",
  1170. " )\n",
  1171. " )\n",
  1172. " (layer3): Sequential(\n",
  1173. " (0): Bottleneck(\n",
  1174. " (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1175. " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1176. " (act1): ReLU(inplace=True)\n",
  1177. " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
  1178. " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1179. " (act2): ReLU(inplace=True)\n",
  1180. " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1181. " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1182. " (act3): ReLU(inplace=True)\n",
  1183. " (downsample): Sequential(\n",
  1184. " (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
  1185. " (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1186. " )\n",
  1187. " )\n",
  1188. " (1): Bottleneck(\n",
  1189. " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1190. " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1191. " (act1): ReLU(inplace=True)\n",
  1192. " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
  1193. " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1194. " (act2): ReLU(inplace=True)\n",
  1195. " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1196. " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1197. " (act3): ReLU(inplace=True)\n",
  1198. " )\n",
  1199. " (2): Bottleneck(\n",
  1200. " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1201. " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1202. " (act1): ReLU(inplace=True)\n",
  1203. " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
  1204. " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1205. " (act2): ReLU(inplace=True)\n",
  1206. " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1207. " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1208. " (act3): ReLU(inplace=True)\n",
  1209. " )\n",
  1210. " (3): Bottleneck(\n",
  1211. " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1212. " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1213. " (act1): ReLU(inplace=True)\n",
  1214. " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
  1215. " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1216. " (act2): ReLU(inplace=True)\n",
  1217. " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1218. " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1219. " (act3): ReLU(inplace=True)\n",
  1220. " )\n",
  1221. " (4): Bottleneck(\n",
  1222. " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1223. " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1224. " (act1): ReLU(inplace=True)\n",
  1225. " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
  1226. " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1227. " (act2): ReLU(inplace=True)\n",
  1228. " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1229. " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1230. " (act3): ReLU(inplace=True)\n",
  1231. " )\n",
  1232. " (5): Bottleneck(\n",
  1233. " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1234. " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1235. " (act1): ReLU(inplace=True)\n",
  1236. " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
  1237. " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1238. " (act2): ReLU(inplace=True)\n",
  1239. " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1240. " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1241. " (act3): ReLU(inplace=True)\n",
  1242. " )\n",
  1243. " )\n",
  1244. " (layer4): Sequential(\n",
  1245. " (0): Bottleneck(\n",
  1246. " (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1247. " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1248. " (act1): ReLU(inplace=True)\n",
  1249. " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
  1250. " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1251. " (act2): ReLU(inplace=True)\n",
  1252. " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1253. " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1254. " (act3): ReLU(inplace=True)\n",
  1255. " (downsample): Sequential(\n",
  1256. " (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
  1257. " (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1258. " )\n",
  1259. " )\n",
  1260. " (1): Bottleneck(\n",
  1261. " (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1262. " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1263. " (act1): ReLU(inplace=True)\n",
  1264. " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
  1265. " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1266. " (act2): ReLU(inplace=True)\n",
  1267. " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1268. " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1269. " (act3): ReLU(inplace=True)\n",
  1270. " )\n",
  1271. " (2): Bottleneck(\n",
  1272. " (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1273. " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1274. " (act1): ReLU(inplace=True)\n",
  1275. " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
  1276. " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1277. " (act2): ReLU(inplace=True)\n",
  1278. " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
  1279. " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  1280. " (act3): ReLU(inplace=True)\n",
  1281. " )\n",
  1282. " )\n",
  1283. " (global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=True)\n",
  1284. " (fc): Identity()\n",
  1285. " )\n",
  1286. " )\n",
  1287. " C: DataParallel(\n",
  1288. " (module): Classifier(\n",
  1289. " (net): Sequential(\n",
  1290. " (0): Linear(in_features=2048, out_features=256, bias=True)\n",
  1291. " (1): ReLU()\n",
  1292. " (2): Dropout(p=0.5, inplace=False)\n",
  1293. " (3): Linear(in_features=256, out_features=128, bias=True)\n",
  1294. " (4): ReLU()\n",
  1295. " (5): Dropout(p=0.5, inplace=False)\n",
  1296. " (6): Linear(in_features=128, out_features=31, bias=True)\n",
  1297. " )\n",
  1298. " )\n",
  1299. " ),\n",
  1300. " 'optimizers': G: Adam (\n",
  1301. " Parameter Group 0\n",
  1302. " amsgrad: False\n",
  1303. " betas: (0.9, 0.999)\n",
  1304. " eps: 1e-08\n",
  1305. " initial_lr: 0.0001\n",
  1306. " lr: 0.0001\n",
  1307. " weight_decay: 0\n",
  1308. " )\n",
  1309. " C: Adam (\n",
  1310. " Parameter Group 0\n",
  1311. " amsgrad: False\n",
  1312. " betas: (0.9, 0.999)\n",
  1313. " eps: 1e-08\n",
  1314. " initial_lr: 0.0001\n",
  1315. " lr: 0.0001\n",
  1316. " weight_decay: 0\n",
  1317. " ),\n",
  1318. " 'lr_schedulers': G: <torch.optim.lr_scheduler.ExponentialLR object at 0x7f5960bcdd90>\n",
  1319. " C: <torch.optim.lr_scheduler.ExponentialLR object at 0x7f5960bcdee0>,\n",
  1320. " 'misc': }"
  1321. ]
  1322. },
  1323. "execution_count": 21,
  1324. "metadata": {},
  1325. "output_type": "execute_result"
  1326. }
  1327. ],
  1328. "source": [
  1329. "checkpoint_utils.adapter_to_dict(trainer.adapter)"
  1330. ]
  1331. },
  1332. {
  1333. "cell_type": "code",
  1334. "execution_count": 19,
  1335. "metadata": {},
  1336. "outputs": [
  1337. {
  1338. "ename": "RuntimeError",
  1339. "evalue": "Error(s) in loading state_dict for DataParallel:\n\tMissing key(s) in state_dict: \"module.conv1.weight\", \"module.bn1.weight\", \"module.bn1.bias\", \"module.bn1.running_mean\", \"module.bn1.running_var\", \"module.layer1.0.conv1.weight\", \"module.layer1.0.bn1.weight\", \"module.layer1.0.bn1.bias\", \"module.layer1.0.bn1.running_mean\", \"module.layer1.0.bn1.running_var\", \"module.layer1.0.conv2.weight\", \"module.layer1.0.bn2.weight\", \"module.layer1.0.bn2.bias\", \"module.layer1.0.bn2.running_mean\", \"module.layer1.0.bn2.running_var\", \"module.layer1.0.conv3.weight\", \"module.layer1.0.bn3.weight\", \"module.layer1.0.bn3.bias\", \"module.layer1.0.bn3.running_mean\", \"module.layer1.0.bn3.running_var\", \"module.layer1.0.downsample.0.weight\", \"module.layer1.0.downsample.1.weight\", \"module.layer1.0.downsample.1.bias\", \"module.layer1.0.downsample.1.running_mean\", \"module.layer1.0.downsample.1.running_var\", \"module.layer1.1.conv1.weight\", \"module.layer1.1.bn1.weight\", \"module.layer1.1.bn1.bias\", \"module.layer1.1.bn1.running_mean\", \"module.layer1.1.bn1.running_var\", \"module.layer1.1.conv2.weight\", \"module.layer1.1.bn2.weight\", \"module.layer1.1.bn2.bias\", \"module.layer1.1.bn2.running_mean\", \"module.layer1.1.bn2.running_var\", \"module.layer1.1.conv3.weight\", \"module.layer1.1.bn3.weight\", \"module.layer1.1.bn3.bias\", \"module.layer1.1.bn3.running_mean\", \"module.layer1.1.bn3.running_var\", \"module.layer1.2.conv1.weight\", \"module.layer1.2.bn1.weight\", \"module.layer1.2.bn1.bias\", \"module.layer1.2.bn1.running_mean\", \"module.layer1.2.bn1.running_var\", \"module.layer1.2.conv2.weight\", \"module.layer1.2.bn2.weight\", \"module.layer1.2.bn2.bias\", \"module.layer1.2.bn2.running_mean\", \"module.layer1.2.bn2.running_var\", \"module.layer1.2.conv3.weight\", \"module.layer1.2.bn3.weight\", \"module.layer1.2.bn3.bias\", \"module.layer1.2.bn3.running_mean\", \"module.layer1.2.bn3.running_var\", \"module.layer2.0.conv1.weight\", \"module.layer2.0.bn1.weight\", \"module.layer2.0.bn1.bias\", \"module.layer2.0.bn1.running_mean\", \"module.layer2.0.bn1.running_var\", \"module.layer2.0.conv2.weight\", \"module.layer2.0.bn2.weight\", \"module.layer2.0.bn2.bias\", \"module.layer2.0.bn2.running_mean\", \"module.layer2.0.bn2.running_var\", \"module.layer2.0.conv3.weight\", \"module.layer2.0.bn3.weight\", \"module.layer2.0.bn3.bias\", \"module.layer2.0.bn3.running_mean\", \"module.layer2.0.bn3.running_var\", \"module.layer2.0.downsample.0.weight\", \"module.layer2.0.downsample.1.weight\", \"module.layer2.0.downsample.1.bias\", \"module.layer2.0.downsample.1.running_mean\", \"module.layer2.0.downsample.1.running_var\", \"module.layer2.1.conv1.weight\", \"module.layer2.1.bn1.weight\", \"module.layer2.1.bn1.bias\", \"module.layer2.1.bn1.running_mean\", \"module.layer2.1.bn1.running_var\", \"module.layer2.1.conv2.weight\", \"module.layer2.1.bn2.weight\", \"module.layer2.1.bn2.bias\", \"module.layer2.1.bn2.running_mean\", \"module.layer2.1.bn2.running_var\", \"module.layer2.1.conv3.weight\", \"module.layer2.1.bn3.weight\", \"module.layer2.1.bn3.bias\", \"module.layer2.1.bn3.running_mean\", \"module.layer2.1.bn3.running_var\", \"module.layer2.2.conv1.weight\", \"module.layer2.2.bn1.weight\", \"module.layer2.2.bn1.bias\", \"module.layer2.2.bn1.running_mean\", \"module.layer2.2.bn1.running_var\", \"module.layer2.2.conv2.weight\", \"module.layer2.2.bn2.weight\", \"module.layer2.2.bn2.bias\", \"module.layer2.2.bn2.running_mean\", \"module.layer2.2.bn2.running_var\", \"module.layer2.2.conv3.weight\", \"module.layer2.2.bn3.weight\", \"module.layer2.2.bn3.bias\", \"module.layer2.2.bn3.running_mean\", \"module.layer2.2.bn3.running_var\", \"module.layer2.3.conv1.weight\", \"module.layer2.3.bn1.weight\", \"module.layer2.3.bn1.bias\", \"module.layer2.3.bn1.running_mean\", \"module.layer2.3.bn1.running_var\", \"module.layer2.3.conv2.weight\", \"module.layer2.3.bn2.weight\", \"module.layer2.3.bn2.bias\", \"module.layer2.3.bn2.running_mean\", \"module.layer2.3.bn2.running_var\", \"module.layer2.3.conv3.weight\", \"module.layer2.3.bn3.weight\", \"module.layer2.3.bn3.bias\", \"module.layer2.3.bn3.running_mean\", \"module.layer2.3.bn3.running_var\", \"module.layer3.0.conv1.weight\", \"module.layer3.0.bn1.weight\", \"module.layer3.0.bn1.bias\", \"module.layer3.0.bn1.running_mean\", \"module.layer3.0.bn1.running_var\", \"module.layer3.0.conv2.weight\", \"module.layer3.0.bn2.weight\", \"module.layer3.0.bn2.bias\", \"module.layer3.0.bn2.running_mean\", \"module.layer3.0.bn2.running_var\", \"module.layer3.0.conv3.weight\", \"module.layer3.0.bn3.weight\", \"module.layer3.0.bn3.bias\", \"module.layer3.0.bn3.running_mean\", \"module.layer3.0.bn3.running_var\", \"module.layer3.0.downsample.0.weight\", \"module.layer3.0.downsample.1.weight\", \"module.layer3.0.downsample.1.bias\", \"module.layer3.0.downsample.1.running_mean\", \"module.layer3.0.downsample.1.running_var\", \"module.layer3.1.conv1.weight\", \"module.layer3.1.bn1.weight\", \"module.layer3.1.bn1.bias\", \"module.layer3.1.bn1.running_mean\", \"module.layer3.1.bn1.running_var\", \"module.layer3.1.conv2.weight\", \"module.layer3.1.bn2.weight\", \"module.layer3.1.bn2.bias\", \"module.layer3.1.bn2.running_mean\", \"module.layer3.1.bn2.running_var\", \"module.layer3.1.conv3.weight\", \"module.layer3.1.bn3.weight\", \"module.layer3.1.bn3.bias\", \"module.layer3.1.bn3.running_mean\", \"module.layer3.1.bn3.running_var\", \"module.layer3.2.conv1.weight\", \"module.layer3.2.bn1.weight\", \"module.layer3.2.bn1.bias\", \"module.layer3.2.bn1.running_mean\", \"module.layer3.2.bn1.running_var\", \"module.layer3.2.conv2.weight\", \"module.layer3.2.bn2.weight\", \"module.layer3.2.bn2.bias\", \"module.layer3.2.bn2.running_mean\", \"module.layer3.2.bn2.running_var\", \"module.layer3.2.conv3.weight\", \"module.layer3.2.bn3.weight\", \"module.layer3.2.bn3.bias\", \"module.layer3.2.bn3.running_mean\", \"module.layer3.2.bn3.running_var\", \"module.layer3.3.conv1.weight\", \"module.layer3.3.bn1.weight\", \"module.layer3.3.bn1.bias\", \"module.layer3.3.bn1.running_mean\", \"module.layer3.3.bn1.running_var\", \"module.layer3.3.conv2.weight\", \"module.layer3.3.bn2.weight\", \"module.layer3.3.bn2.bias\", \"module.layer3.3.bn2.running_mean\", \"module.layer3.3.bn2.running_var\", \"module.layer3.3.conv3.weight\", \"module.layer3.3.bn3.weight\", \"module.layer3.3.bn3.bias\", \"module.layer3.3.bn3.running_mean\", \"module.layer3.3.bn3.running_var\", \"module.layer3.4.conv1.weight\", \"module.layer3.4.bn1.weight\", \"module.layer3.4.bn1.bias\", \"module.layer3.4.bn1.running_mean\", \"module.layer3.4.bn1.running_var\", \"module.layer3.4.conv2.weight\", \"module.layer3.4.bn2.weight\", \"module.layer3.4.bn2.bias\", \"module.layer3.4.bn2.running_mean\", \"module.layer3.4.bn2.running_var\", \"module.layer3.4.conv3.weight\", \"module.layer3.4.bn3.weight\", \"module.layer3.4.bn3.bias\", \"module.layer3.4.bn3.running_mean\", \"module.layer3.4.bn3.running_var\", \"module.layer3.5.conv1.weight\", \"module.layer3.5.bn1.weight\", \"module.layer3.5.bn1.bias\", \"module.layer3.5.bn1.running_mean\", \"module.layer3.5.bn1.running_var\", \"module.layer3.5.conv2.weight\", \"module.layer3.5.bn2.weight\", \"module.layer3.5.bn2.bias\", \"module.layer3.5.bn2.running_mean\", \"module.layer3.5.bn2.running_var\", \"module.layer3.5.conv3.weight\", \"module.layer3.5.bn3.weight\", \"module.layer3.5.bn3.bias\", \"module.layer3.5.bn3.running_mean\", \"module.layer3.5.bn3.running_var\", \"module.layer4.0.conv1.weight\", \"module.layer4.0.bn1.weight\", \"module.layer4.0.bn1.bias\", \"module.layer4.0.bn1.running_mean\", \"module.layer4.0.bn1.running_var\", \"module.layer4.0.conv2.weight\", \"module.layer4.0.bn2.weight\", \"module.layer4.0.bn2.bias\", \"module.layer4.0.bn2.running_mean\", \"module.layer4.0.bn2.running_var\", \"module.layer4.0.conv3.weight\", \"module.layer4.0.bn3.weight\", \"module.layer4.0.bn3.bias\", \"module.layer4.0.bn3.running_mean\", \"module.layer4.0.bn3.running_var\", \"module.layer4.0.downsample.0.weight\", \"module.layer4.0.downsample.1.weight\", \"module.layer4.0.downsample.1.bias\", \"module.layer4.0.downsample.1.running_mean\", \"module.layer4.0.downsample.1.running_var\", \"module.layer4.1.conv1.weight\", \"module.layer4.1.bn1.weight\", \"module.layer4.1.bn1.bias\", \"module.layer4.1.bn1.running_mean\", \"module.layer4.1.bn1.running_var\", \"module.layer4.1.conv2.weight\", \"module.layer4.1.bn2.weight\", \"module.layer4.1.bn2.bias\", \"module.layer4.1.bn2.running_mean\", \"module.layer4.1.bn2.running_var\", \"module.layer4.1.conv3.weight\", \"module.layer4.1.bn3.weight\", \"module.layer4.1.bn3.bias\", \"module.layer4.1.bn3.running_mean\", \"module.layer4.1.bn3.running_var\", \"module.layer4.2.conv1.weight\", \"module.layer4.2.bn1.weight\", \"module.layer4.2.bn1.bias\", \"module.layer4.2.bn1.running_mean\", \"module.layer4.2.bn1.running_var\", \"module.layer4.2.conv2.weight\", \"module.layer4.2.bn2.weight\", \"module.layer4.2.bn2.bias\", \"module.layer4.2.bn2.running_mean\", \"module.layer4.2.bn2.running_var\", \"module.layer4.2.conv3.weight\", \"module.layer4.2.bn3.weight\", \"module.layer4.2.bn3.bias\", \"module.layer4.2.bn3.running_mean\", \"module.layer4.2.bn3.running_var\". \n\tUnexpected key(s) in state_dict: \"conv1.weight\", \"bn1.weight\", \"bn1.bias\", \"bn1.running_mean\", \"bn1.running_var\", \"bn1.num_batches_tracked\", \"layer1.0.conv1.weight\", \"layer1.0.bn1.weight\", \"layer1.0.bn1.bias\", \"layer1.0.bn1.running_mean\", \"layer1.0.bn1.running_var\", \"layer1.0.bn1.num_batches_tracked\", \"layer1.0.conv2.weight\", \"layer1.0.bn2.weight\", \"layer1.0.bn2.bias\", \"layer1.0.bn2.running_mean\", \"layer1.0.bn2.running_var\", \"layer1.0.bn2.num_batches_tracked\", \"layer1.0.conv3.weight\", \"layer1.0.bn3.weight\", \"layer1.0.bn3.bias\", \"layer1.0.bn3.running_mean\", \"layer1.0.bn3.running_var\", \"layer1.0.bn3.num_batches_tracked\", \"layer1.0.downsample.0.weight\", \"layer1.0.downsample.1.weight\", \"layer1.0.downsample.1.bias\", \"layer1.0.downsample.1.running_mean\", \"layer1.0.downsample.1.running_var\", \"layer1.0.downsample.1.num_batches_tracked\", \"layer1.1.conv1.weight\", \"layer1.1.bn1.weight\", \"layer1.1.bn1.bias\", \"layer1.1.bn1.running_mean\", \"layer1.1.bn1.running_var\", \"layer1.1.bn1.num_batches_tracked\", \"layer1.1.conv2.weight\", \"layer1.1.bn2.weight\", \"layer1.1.bn2.bias\", \"layer1.1.bn2.running_mean\", \"layer1.1.bn2.running_var\", \"layer1.1.bn2.num_batches_tracked\", \"layer1.1.conv3.weight\", \"layer1.1.bn3.weight\", \"layer1.1.bn3.bias\", \"layer1.1.bn3.running_mean\", \"layer1.1.bn3.running_var\", \"layer1.1.bn3.num_batches_tracked\", \"layer1.2.conv1.weight\", \"layer1.2.bn1.weight\", \"layer1.2.bn1.bias\", \"layer1.2.bn1.running_mean\", \"layer1.2.bn1.running_var\", \"layer1.2.bn1.num_batches_tracked\", \"layer1.2.conv2.weight\", \"layer1.2.bn2.weight\", \"layer1.2.bn2.bias\", \"layer1.2.bn2.running_mean\", \"layer1.2.bn2.running_var\", \"layer1.2.bn2.num_batches_tracked\", \"layer1.2.conv3.weight\", \"layer1.2.bn3.weight\", \"layer1.2.bn3.bias\", \"layer1.2.bn3.running_mean\", \"layer1.2.bn3.running_var\", \"layer1.2.bn3.num_batches_tracked\", \"layer2.0.conv1.weight\", \"layer2.0.bn1.weight\", \"layer2.0.bn1.bias\", \"layer2.0.bn1.running_mean\", \"layer2.0.bn1.running_var\", \"layer2.0.bn1.num_batches_tracked\", \"layer2.0.conv2.weight\", \"layer2.0.bn2.weight\", \"layer2.0.bn2.bias\", \"layer2.0.bn2.running_mean\", \"layer2.0.bn2.running_var\", \"layer2.0.bn2.num_batches_tracked\", \"layer2.0.conv3.weight\", \"layer2.0.bn3.weight\", \"layer2.0.bn3.bias\", \"layer2.0.bn3.running_mean\", \"layer2.0.bn3.running_var\", \"layer2.0.bn3.num_batches_tracked\", \"layer2.0.downsample.0.weight\", \"layer2.0.downsample.1.weight\", \"layer2.0.downsample.1.bias\", \"layer2.0.downsample.1.running_mean\", \"layer2.0.downsample.1.running_var\", \"layer2.0.downsample.1.num_batches_tracked\", \"layer2.1.conv1.weight\", \"layer2.1.bn1.weight\", \"layer2.1.bn1.bias\", \"layer2.1.bn1.running_mean\", \"layer2.1.bn1.running_var\", \"layer2.1.bn1.num_batches_tracked\", \"layer2.1.conv2.weight\", \"layer2.1.bn2.weight\", \"layer2.1.bn2.bias\", \"layer2.1.bn2.running_mean\", \"layer2.1.bn2.running_var\", \"layer2.1.bn2.num_batches_tracked\", \"layer2.1.conv3.weight\", \"layer2.1.bn3.weight\", \"layer2.1.bn3.bias\", \"layer2.1.bn3.running_mean\", \"layer2.1.bn3.running_var\", \"layer2.1.bn3.num_batches_tracked\", \"layer2.2.conv1.weight\", \"layer2.2.bn1.weight\", \"layer2.2.bn1.bias\", \"layer2.2.bn1.running_mean\", \"layer2.2.bn1.running_var\", \"layer2.2.bn1.num_batches_tracked\", \"layer2.2.conv2.weight\", \"layer2.2.bn2.weight\", \"layer2.2.bn2.bias\", \"layer2.2.bn2.running_mean\", \"layer2.2.bn2.running_var\", \"layer2.2.bn2.num_batches_tracked\", \"layer2.2.conv3.weight\", \"layer2.2.bn3.weight\", \"layer2.2.bn3.bias\", \"layer2.2.bn3.running_mean\", \"layer2.2.bn3.running_var\", \"layer2.2.bn3.num_batches_tracked\", \"layer2.3.conv1.weight\", \"layer2.3.bn1.weight\", \"layer2.3.bn1.bias\", \"layer2.3.bn1.running_mean\", \"layer2.3.bn1.running_var\", \"layer2.3.bn1.num_batches_tracked\", \"layer2.3.conv2.weight\", \"layer2.3.bn2.weight\", \"layer2.3.bn2.bias\", \"layer2.3.bn2.running_mean\", \"layer2.3.bn2.running_var\", \"layer2.3.bn2.num_batches_tracked\", \"layer2.3.conv3.weight\", \"layer2.3.bn3.weight\", \"layer2.3.bn3.bias\", \"layer2.3.bn3.running_mean\", \"layer2.3.bn3.running_var\", \"layer2.3.bn3.num_batches_tracked\", \"layer3.0.conv1.weight\", \"layer3.0.bn1.weight\", \"layer3.0.bn1.bias\", \"layer3.0.bn1.running_mean\", \"layer3.0.bn1.running_var\", \"layer3.0.bn1.num_batches_tracked\", \"layer3.0.conv2.weight\", \"layer3.0.bn2.weight\", \"layer3.0.bn2.bias\", \"layer3.0.bn2.running_mean\", \"layer3.0.bn2.running_var\", \"layer3.0.bn2.num_batches_tracked\", \"layer3.0.conv3.weight\", \"layer3.0.bn3.weight\", \"layer3.0.bn3.bias\", \"layer3.0.bn3.running_mean\", \"layer3.0.bn3.running_var\", \"layer3.0.bn3.num_batches_tracked\", \"layer3.0.downsample.0.weight\", \"layer3.0.downsample.1.weight\", \"layer3.0.downsample.1.bias\", \"layer3.0.downsample.1.running_mean\", \"layer3.0.downsample.1.running_var\", \"layer3.0.downsample.1.num_batches_tracked\", \"layer3.1.conv1.weight\", \"layer3.1.bn1.weight\", \"layer3.1.bn1.bias\", \"layer3.1.bn1.running_mean\", \"layer3.1.bn1.running_var\", \"layer3.1.bn1.num_batches_tracked\", \"layer3.1.conv2.weight\", \"layer3.1.bn2.weight\", \"layer3.1.bn2.bias\", \"layer3.1.bn2.running_mean\", \"layer3.1.bn2.running_var\", \"layer3.1.bn2.num_batches_tracked\", \"layer3.1.conv3.weight\", \"layer3.1.bn3.weight\", \"layer3.1.bn3.bias\", \"layer3.1.bn3.running_mean\", \"layer3.1.bn3.running_var\", \"layer3.1.bn3.num_batches_tracked\", \"layer3.2.conv1.weight\", \"layer3.2.bn1.weight\", \"layer3.2.bn1.bias\", \"layer3.2.bn1.running_mean\", \"layer3.2.bn1.running_var\", \"layer3.2.bn1.num_batches_tracked\", \"layer3.2.conv2.weight\", \"layer3.2.bn2.weight\", \"layer3.2.bn2.bias\", \"layer3.2.bn2.running_mean\", \"layer3.2.bn2.running_var\", \"layer3.2.bn2.num_batches_tracked\", \"layer3.2.conv3.weight\", \"layer3.2.bn3.weight\", \"layer3.2.bn3.bias\", \"layer3.2.bn3.running_mean\", \"layer3.2.bn3.running_var\", \"layer3.2.bn3.num_batches_tracked\", \"layer3.3.conv1.weight\", \"layer3.3.bn1.weight\", \"layer3.3.bn1.bias\", \"layer3.3.bn1.running_mean\", \"layer3.3.bn1.running_var\", \"layer3.3.bn1.num_batches_tracked\", \"layer3.3.conv2.weight\", \"layer3.3.bn2.weight\", \"layer3.3.bn2.bias\", \"layer3.3.bn2.running_mean\", \"layer3.3.bn2.running_var\", \"layer3.3.bn2.num_batches_tracked\", \"layer3.3.conv3.weight\", \"layer3.3.bn3.weight\", \"layer3.3.bn3.bias\", \"layer3.3.bn3.running_mean\", \"layer3.3.bn3.running_var\", \"layer3.3.bn3.num_batches_tracked\", \"layer3.4.conv1.weight\", \"layer3.4.bn1.weight\", \"layer3.4.bn1.bias\", \"layer3.4.bn1.running_mean\", \"layer3.4.bn1.running_var\", \"layer3.4.bn1.num_batches_tracked\", \"layer3.4.conv2.weight\", \"layer3.4.bn2.weight\", \"layer3.4.bn2.bias\", \"layer3.4.bn2.running_mean\", \"layer3.4.bn2.running_var\", \"layer3.4.bn2.num_batches_tracked\", \"layer3.4.conv3.weight\", \"layer3.4.bn3.weight\", \"layer3.4.bn3.bias\", \"layer3.4.bn3.running_mean\", \"layer3.4.bn3.running_var\", \"layer3.4.bn3.num_batches_tracked\", \"layer3.5.conv1.weight\", \"layer3.5.bn1.weight\", \"layer3.5.bn1.bias\", \"layer3.5.bn1.running_mean\", \"layer3.5.bn1.running_var\", \"layer3.5.bn1.num_batches_tracked\", \"layer3.5.conv2.weight\", \"layer3.5.bn2.weight\", \"layer3.5.bn2.bias\", \"layer3.5.bn2.running_mean\", \"layer3.5.bn2.running_var\", \"layer3.5.bn2.num_batches_tracked\", \"layer3.5.conv3.weight\", \"layer3.5.bn3.weight\", \"layer3.5.bn3.bias\", \"layer3.5.bn3.running_mean\", \"layer3.5.bn3.running_var\", \"layer3.5.bn3.num_batches_tracked\", \"layer4.0.conv1.weight\", \"layer4.0.bn1.weight\", \"layer4.0.bn1.bias\", \"layer4.0.bn1.running_mean\", \"layer4.0.bn1.running_var\", \"layer4.0.bn1.num_batches_tracked\", \"layer4.0.conv2.weight\", \"layer4.0.bn2.weight\", \"layer4.0.bn2.bias\", \"layer4.0.bn2.running_mean\", \"layer4.0.bn2.running_var\", \"layer4.0.bn2.num_batches_tracked\", \"layer4.0.conv3.weight\", \"layer4.0.bn3.weight\", \"layer4.0.bn3.bias\", \"layer4.0.bn3.running_mean\", \"layer4.0.bn3.running_var\", \"layer4.0.bn3.num_batches_tracked\", \"layer4.0.downsample.0.weight\", \"layer4.0.downsample.1.weight\", \"layer4.0.downsample.1.bias\", \"layer4.0.downsample.1.running_mean\", \"layer4.0.downsample.1.running_var\", \"layer4.0.downsample.1.num_batches_tracked\", \"layer4.1.conv1.weight\", \"layer4.1.bn1.weight\", \"layer4.1.bn1.bias\", \"layer4.1.bn1.running_mean\", \"layer4.1.bn1.running_var\", \"layer4.1.bn1.num_batches_tracked\", \"layer4.1.conv2.weight\", \"layer4.1.bn2.weight\", \"layer4.1.bn2.bias\", \"layer4.1.bn2.running_mean\", \"layer4.1.bn2.running_var\", \"layer4.1.bn2.num_batches_tracked\", \"layer4.1.conv3.weight\", \"layer4.1.bn3.weight\", \"layer4.1.bn3.bias\", \"layer4.1.bn3.running_mean\", \"layer4.1.bn3.running_var\", \"layer4.1.bn3.num_batches_tracked\", \"layer4.2.conv1.weight\", \"layer4.2.bn1.weight\", \"layer4.2.bn1.bias\", \"layer4.2.bn1.running_mean\", \"layer4.2.bn1.running_var\", \"layer4.2.bn1.num_batches_tracked\", \"layer4.2.conv2.weight\", \"layer4.2.bn2.weight\", \"layer4.2.bn2.bias\", \"layer4.2.bn2.running_mean\", \"layer4.2.bn2.running_var\", \"layer4.2.bn2.num_batches_tracked\", \"layer4.2.conv3.weight\", \"layer4.2.bn3.weight\", \"layer4.2.bn3.bias\", \"layer4.2.bn3.running_mean\", \"layer4.2.bn3.running_var\", \"layer4.2.bn3.num_batches_tracked\". ",
  1340. "output_type": "error",
  1341. "traceback": [
  1342. "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
  1343. "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
  1344. "Cell \u001b[0;32mIn[19], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m validator \u001b[39m=\u001b[39m AccuracyValidator(key_map\u001b[39m=\u001b[39m{\u001b[39m\"\u001b[39m\u001b[39mtarget_val_with_labels\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m\"\u001b[39m\u001b[39msrc_val\u001b[39m\u001b[39m\"\u001b[39m})\n\u001b[0;32m----> 2\u001b[0m score \u001b[39m=\u001b[39m trainer\u001b[39m.\u001b[39;49mevaluate_best_model(datasets, validator, dc)\n\u001b[1;32m 3\u001b[0m \u001b[39mprint\u001b[39m(score)\n",
  1345. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/frameworks/ignite/ignite.py:323\u001b[0m, in \u001b[0;36mIgnite.evaluate_best_model\u001b[0;34m(self, datasets, validator, dataloader_creator)\u001b[0m\n\u001b[1;32m 321\u001b[0m dataloader_creator \u001b[39m=\u001b[39m c_f\u001b[39m.\u001b[39mdefault(dataloader_creator, DataloaderCreator, {})\n\u001b[1;32m 322\u001b[0m dataloaders \u001b[39m=\u001b[39m dataloader_creator(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mdatasets)\n\u001b[0;32m--> 323\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcheckpoint_fn\u001b[39m.\u001b[39;49mload_best_checkpoint({\u001b[39m\"\u001b[39;49m\u001b[39mmodels\u001b[39;49m\u001b[39m\"\u001b[39;49m: \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49madapter\u001b[39m.\u001b[39;49mmodels})\n\u001b[1;32m 324\u001b[0m collected_data \u001b[39m=\u001b[39m i_g\u001b[39m.\u001b[39mcollect_from_dataloaders(\n\u001b[1;32m 325\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcollector, dataloaders, validator\u001b[39m.\u001b[39mrequired_data\n\u001b[1;32m 326\u001b[0m )\n\u001b[1;32m 327\u001b[0m \u001b[39mreturn\u001b[39;00m val_utils\u001b[39m.\u001b[39mcall_val_hook(validator, collected_data)\n",
  1346. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/frameworks/ignite/checkpoint_utils.py:97\u001b[0m, in \u001b[0;36mCheckpointFnCreator.load_best_checkpoint\u001b[0;34m(self, to_load)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mload_best_checkpoint\u001b[39m(\u001b[39mself\u001b[39m, to_load):\n\u001b[1;32m 96\u001b[0m last_checkpoint \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mget_best_checkpoint()\n\u001b[0;32m---> 97\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mload_objects(to_load, last_checkpoint)\n",
  1347. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/frameworks/ignite/checkpoint_utils.py:93\u001b[0m, in \u001b[0;36mCheckpointFnCreator.load_objects\u001b[0;34m(self, to_load, checkpoint, global_step)\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mobjs\u001b[39m.\u001b[39mreload_objects(\n\u001b[1;32m 90\u001b[0m to_load, name\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mcheckpoint\u001b[39m\u001b[39m\"\u001b[39m, global_step\u001b[39m=\u001b[39mglobal_step\n\u001b[1;32m 91\u001b[0m )\n\u001b[1;32m 92\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 93\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mobjs\u001b[39m.\u001b[39;49mload_objects(to_load, \u001b[39mstr\u001b[39;49m(checkpoint))\n",
  1348. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/ignite/handlers/checkpoint.py:618\u001b[0m, in \u001b[0;36mCheckpoint.load_objects\u001b[0;34m(to_load, checkpoint, **kwargs)\u001b[0m\n\u001b[1;32m 616\u001b[0m \u001b[39mif\u001b[39;00m k \u001b[39mnot\u001b[39;00m \u001b[39min\u001b[39;00m checkpoint_obj:\n\u001b[1;32m 617\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mObject labeled by \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mk\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m from `to_load` is not found in the checkpoint\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 618\u001b[0m _load_object(obj, checkpoint_obj[k])\n",
  1349. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/ignite/handlers/checkpoint.py:605\u001b[0m, in \u001b[0;36mCheckpoint.load_objects.<locals>._load_object\u001b[0;34m(obj, chkpt_obj)\u001b[0m\n\u001b[1;32m 603\u001b[0m obj\u001b[39m.\u001b[39mload_state_dict(chkpt_obj, strict\u001b[39m=\u001b[39mis_state_dict_strict)\n\u001b[1;32m 604\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 605\u001b[0m obj\u001b[39m.\u001b[39;49mload_state_dict(chkpt_obj)\n",
  1350. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/pytorch_adapt/containers/base_container.py:158\u001b[0m, in \u001b[0;36mBaseContainer.load_state_dict\u001b[0;34m(self, state_dict)\u001b[0m\n\u001b[1;32m 156\u001b[0m c_f\u001b[39m.\u001b[39massert_state_dict_keys(state_dict, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkeys())\n\u001b[1;32m 157\u001b[0m \u001b[39mfor\u001b[39;00m k, v \u001b[39min\u001b[39;00m state_dict\u001b[39m.\u001b[39mitems():\n\u001b[0;32m--> 158\u001b[0m \u001b[39mself\u001b[39;49m[k]\u001b[39m.\u001b[39;49mload_state_dict(v)\n",
  1351. "File \u001b[0;32m/media/10TB71/shashemi/miniconda3/envs/cdtrans/lib/python3.8/site-packages/torch/nn/modules/module.py:1223\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict)\u001b[0m\n\u001b[1;32m 1218\u001b[0m error_msgs\u001b[39m.\u001b[39minsert(\n\u001b[1;32m 1219\u001b[0m \u001b[39m0\u001b[39m, \u001b[39m'\u001b[39m\u001b[39mMissing key(s) in state_dict: \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m. \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 1220\u001b[0m \u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mjoin(\u001b[39m'\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(k) \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m missing_keys)))\n\u001b[1;32m 1222\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(error_msgs) \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[0;32m-> 1223\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m'\u001b[39m\u001b[39mError(s) in loading state_dict for \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 1224\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mjoin(error_msgs)))\n\u001b[1;32m 1225\u001b[0m \u001b[39mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n",
  1352. "\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for DataParallel:\n\tMissing key(s) in state_dict: \"module.conv1.weight\", \"module.bn1.weight\", \"module.bn1.bias\", \"module.bn1.running_mean\", \"module.bn1.running_var\", \"module.layer1.0.conv1.weight\", \"module.layer1.0.bn1.weight\", \"module.layer1.0.bn1.bias\", \"module.layer1.0.bn1.running_mean\", \"module.layer1.0.bn1.running_var\", \"module.layer1.0.conv2.weight\", \"module.layer1.0.bn2.weight\", \"module.layer1.0.bn2.bias\", \"module.layer1.0.bn2.running_mean\", \"module.layer1.0.bn2.running_var\", \"module.layer1.0.conv3.weight\", \"module.layer1.0.bn3.weight\", \"module.layer1.0.bn3.bias\", \"module.layer1.0.bn3.running_mean\", \"module.layer1.0.bn3.running_var\", \"module.layer1.0.downsample.0.weight\", \"module.layer1.0.downsample.1.weight\", \"module.layer1.0.downsample.1.bias\", \"module.layer1.0.downsample.1.running_mean\", \"module.layer1.0.downsample.1.running_var\", \"module.layer1.1.conv1.weight\", \"module.layer1.1.bn1.weight\", \"module.layer1.1.bn1.bias\", \"module.layer1.1.bn1.running_mean\", \"module.layer1.1.bn1.running_var\", \"module.layer1.1.conv2.weight\", \"module.layer1.1.bn2.weight\", \"module.layer1.1.bn2.bias\", \"module.layer1.1.bn2.running_mean\", \"module.layer1.1.bn2.running_var\", \"module.layer1.1.conv3.weight\", \"module.layer1.1.bn3.weight\", \"module.layer1.1.bn3.bias\", \"module.layer1.1.bn3.running_mean\", \"module.layer1.1.bn3.running_var\", \"module.layer1.2.conv1.weight\", \"module.layer1.2.bn1.weight\", \"module.layer1.2.bn1.bias\", \"module.layer1.2.bn1.running_mean\", \"module.layer1.2.bn1.running_var\", \"module.layer1.2.conv2.weight\", \"module.layer1.2.bn2.weight\", \"module.layer1.2.bn2.bias\", \"module.layer1.2.bn2.running_mean\", \"module.layer1.2.bn2.running_var\", \"module.layer1.2.conv3.weight\", \"module.layer1.2.bn3.weight\", \"module.layer1.2.bn3.bias\", \"module.layer1.2.bn3.running_mean\", \"module.layer1.2.bn3.running_var\", \"module.layer2.0.conv1.weight\", \"module.layer2.0.bn1.weight\", \"module.layer2.0.bn1.bias\", \"module.layer2.0.bn1.running_mean\", \"module.layer2.0.bn1.running_var\", \"module.layer2.0.conv2.weight\", \"module.layer2.0.bn2.weight\", \"module.layer2.0.bn2.bias\", \"module.layer2.0.bn2.running_mean\", \"module.layer2.0.bn2.running_var\", \"module.layer2.0.conv3.weight\", \"module.layer2.0.bn3.weight\", \"module.layer2.0.bn3.bias\", \"module.layer2.0.bn3.running_mean\", \"module.layer2.0.bn3.running_var\", \"module.layer2.0.downsample.0.weight\", \"module.layer2.0.downsample.1.weight\", \"module.layer2.0.downsample.1.bias\", \"module.layer2.0.downsample.1.running_mean\", \"module.layer2.0.downsample.1.running_var\", \"module.layer2.1.conv1.weight\", \"module.layer2.1.bn1.weight\", \"module.layer2.1.bn1.bias\", \"module.layer2.1.bn1.running_mean\", \"module.layer2.1.bn1.running_var\", \"module.layer2.1.conv2.weight\", \"module.layer2.1.bn2.weight\", \"module.layer2.1.bn2.bias\", \"module.layer2.1.bn2.running_mean\", \"module.layer2.1.bn2.running_var\", \"module.layer2.1.conv3.weight\", \"module.layer2.1.bn3.weight\", \"module.layer2.1.bn3.bias\", \"module.layer2.1.bn3.running_mean\", \"module.layer2.1.bn3.running_var\", \"module.layer2.2.conv1.weight\", \"module.layer2.2.bn1.weight\", \"module.layer2.2.bn1.bias\", \"module.layer2.2.bn1.running_mean\", \"module.layer2.2.bn1.running_var\", \"module.layer2.2.conv2.weight\", \"module.layer2.2.bn2.weight\", \"module.layer2.2.bn2.bias\", \"module.layer2.2.bn2.running_mean\", \"module.layer2.2.bn2.running_var\", \"module.layer2.2.conv3.weight\", \"module.layer2.2.bn3.weight\", \"module.layer2.2.bn3.bias\", \"module.layer2.2.bn3.running_mean\", \"module.layer2.2.bn3.running_var\", \"module.layer2.3.conv1.weight\", \"module.layer2.3.bn1.weight\", \"module.layer2.3.bn1.bias\", \"module.layer2.3.bn1.running_mean\", \"module.layer2.3.bn1.running_var\", \"module.layer2.3.conv2.weight\", \"module.layer2.3.bn2.weight\", \"module.layer2.3.bn2.bias\", \"module.layer2.3.bn2.running_mean\", \"module.layer2.3.bn2.running_var\", \"module.layer2.3.conv3.weight\", \"module.layer2.3.bn3.weight\", \"module.layer2.3.bn3.bias\", \"module.layer2.3.bn3.running_mean\", \"module.layer2.3.bn3.running_var\", \"module.layer3.0.conv1.weight\", \"module.layer3.0.bn1.weight\", \"module.layer3.0.bn1.bias\", \"module.layer3.0.bn1.running_mean\", \"module.layer3.0.bn1.running_var\", \"module.layer3.0.conv2.weight\", \"module.layer3.0.bn2.weight\", \"module.layer3.0.bn2.bias\", \"module.layer3.0.bn2.running_mean\", \"module.layer3.0.bn2.running_var\", \"module.layer3.0.conv3.weight\", \"module.layer3.0.bn3.weight\", \"module.layer3.0.bn3.bias\", \"module.layer3.0.bn3.running_mean\", \"module.layer3.0.bn3.running_var\", \"module.layer3.0.downsample.0.weight\", \"module.layer3.0.downsample.1.weight\", \"module.layer3.0.downsample.1.bias\", \"module.layer3.0.downsample.1.running_mean\", \"module.layer3.0.downsample.1.running_var\", \"module.layer3.1.conv1.weight\", \"module.layer3.1.bn1.weight\", \"module.layer3.1.bn1.bias\", \"module.layer3.1.bn1.running_mean\", \"module.layer3.1.bn1.running_var\", \"module.layer3.1.conv2.weight\", \"module.layer3.1.bn2.weight\", \"module.layer3.1.bn2.bias\", \"module.layer3.1.bn2.running_mean\", \"module.layer3.1.bn2.running_var\", \"module.layer3.1.conv3.weight\", \"module.layer3.1.bn3.weight\", \"module.layer3.1.bn3.bias\", \"module.layer3.1.bn3.running_mean\", \"module.layer3.1.bn3.running_var\", \"module.layer3.2.conv1.weight\", \"module.layer3.2.bn1.weight\", \"module.layer3.2.bn1.bias\", \"module.layer3.2.bn1.running_mean\", \"module.layer3.2.bn1.running_var\", \"module.layer3.2.conv2.weight\", \"module.layer3.2.bn2.weight\", \"module.layer3.2.bn2.bias\", \"module.layer3.2.bn2.running_mean\", \"module.layer3.2.bn2.running_var\", \"module.layer3.2.conv3.weight\", \"module.layer3.2.bn3.weight\", \"module.layer3.2.bn3.bias\", \"module.layer3.2.bn3.running_mean\", \"module.layer3.2.bn3.running_var\", \"module.layer3.3.conv1.weight\", \"module.layer3.3.bn1.weight\", \"module.layer3.3.bn1.bias\", \"module.layer3.3.bn1.running_mean\", \"module.layer3.3.bn1.running_var\", \"module.layer3.3.conv2.weight\", \"module.layer3.3.bn2.weight\", \"module.layer3.3.bn2.bias\", \"module.layer3.3.bn2.running_mean\", \"module.layer3.3.bn2.running_var\", \"module.layer3.3.conv3.weight\", \"module.layer3.3.bn3.weight\", \"module.layer3.3.bn3.bias\", \"module.layer3.3.bn3.running_mean\", \"module.layer3.3.bn3.running_var\", \"module.layer3.4.conv1.weight\", \"module.layer3.4.bn1.weight\", \"module.layer3.4.bn1.bias\", \"module.layer3.4.bn1.running_mean\", \"module.layer3.4.bn1.running_var\", \"module.layer3.4.conv2.weight\", \"module.layer3.4.bn2.weight\", \"module.layer3.4.bn2.bias\", \"module.layer3.4.bn2.running_mean\", \"module.layer3.4.bn2.running_var\", \"module.layer3.4.conv3.weight\", \"module.layer3.4.bn3.weight\", \"module.layer3.4.bn3.bias\", \"module.layer3.4.bn3.running_mean\", \"module.layer3.4.bn3.running_var\", \"module.layer3.5.conv1.weight\", \"module.layer3.5.bn1.weight\", \"module.layer3.5.bn1.bias\", \"module.layer3.5.bn1.running_mean\", \"module.layer3.5.bn1.running_var\", \"module.layer3.5.conv2.weight\", \"module.layer3.5.bn2.weight\", \"module.layer3.5.bn2.bias\", \"module.layer3.5.bn2.running_mean\", \"module.layer3.5.bn2.running_var\", \"module.layer3.5.conv3.weight\", \"module.layer3.5.bn3.weight\", \"module.layer3.5.bn3.bias\", \"module.layer3.5.bn3.running_mean\", \"module.layer3.5.bn3.running_var\", \"module.layer4.0.conv1.weight\", \"module.layer4.0.bn1.weight\", \"module.layer4.0.bn1.bias\", \"module.layer4.0.bn1.running_mean\", \"module.layer4.0.bn1.running_var\", \"module.layer4.0.conv2.weight\", \"module.layer4.0.bn2.weight\", \"module.layer4.0.bn2.bias\", \"module.layer4.0.bn2.running_mean\", \"module.layer4.0.bn2.running_var\", \"module.layer4.0.conv3.weight\", \"module.layer4.0.bn3.weight\", \"module.layer4.0.bn3.bias\", \"module.layer4.0.bn3.running_mean\", \"module.layer4.0.bn3.running_var\", \"module.layer4.0.downsample.0.weight\", \"module.layer4.0.downsample.1.weight\", \"module.layer4.0.downsample.1.bias\", \"module.layer4.0.downsample.1.running_mean\", \"module.layer4.0.downsample.1.running_var\", \"module.layer4.1.conv1.weight\", \"module.layer4.1.bn1.weight\", \"module.layer4.1.bn1.bias\", \"module.layer4.1.bn1.running_mean\", \"module.layer4.1.bn1.running_var\", \"module.layer4.1.conv2.weight\", \"module.layer4.1.bn2.weight\", \"module.layer4.1.bn2.bias\", \"module.layer4.1.bn2.running_mean\", \"module.layer4.1.bn2.running_var\", \"module.layer4.1.conv3.weight\", \"module.layer4.1.bn3.weight\", \"module.layer4.1.bn3.bias\", \"module.layer4.1.bn3.running_mean\", \"module.layer4.1.bn3.running_var\", \"module.layer4.2.conv1.weight\", \"module.layer4.2.bn1.weight\", \"module.layer4.2.bn1.bias\", \"module.layer4.2.bn1.running_mean\", \"module.layer4.2.bn1.running_var\", \"module.layer4.2.conv2.weight\", \"module.layer4.2.bn2.weight\", \"module.layer4.2.bn2.bias\", \"module.layer4.2.bn2.running_mean\", \"module.layer4.2.bn2.running_var\", \"module.layer4.2.conv3.weight\", \"module.layer4.2.bn3.weight\", \"module.layer4.2.bn3.bias\", \"module.layer4.2.bn3.running_mean\", \"module.layer4.2.bn3.running_var\". \n\tUnexpected key(s) in state_dict: \"conv1.weight\", \"bn1.weight\", \"bn1.bias\", \"bn1.running_mean\", \"bn1.running_var\", \"bn1.num_batches_tracked\", \"layer1.0.conv1.weight\", \"layer1.0.bn1.weight\", \"layer1.0.bn1.bias\", \"layer1.0.bn1.running_mean\", \"layer1.0.bn1.running_var\", \"layer1.0.bn1.num_batches_tracked\", \"layer1.0.conv2.weight\", \"layer1.0.bn2.weight\", \"layer1.0.bn2.bias\", \"layer1.0.bn2.running_mean\", \"layer1.0.bn2.running_var\", \"layer1.0.bn2.num_batches_tracked\", \"layer1.0.conv3.weight\", \"layer1.0.bn3.weight\", \"layer1.0.bn3.bias\", \"layer1.0.bn3.running_mean\", \"layer1.0.bn3.running_var\", \"layer1.0.bn3.num_batches_tracked\", \"layer1.0.downsample.0.weight\", \"layer1.0.downsample.1.weight\", \"layer1.0.downsample.1.bias\", \"layer1.0.downsample.1.running_mean\", \"layer1.0.downsample.1.running_var\", \"layer1.0.downsample.1.num_batches_tracked\", \"layer1.1.conv1.weight\", \"layer1.1.bn1.weight\", \"layer1.1.bn1.bias\", \"layer1.1.bn1.running_mean\", \"layer1.1.bn1.running_var\", \"layer1.1.bn1.num_batches_tracked\", \"layer1.1.conv2.weight\", \"layer1.1.bn2.weight\", \"layer1.1.bn2.bias\", \"layer1.1.bn2.running_mean\", \"layer1.1.bn2.running_var\", \"layer1.1.bn2.num_batches_tracked\", \"layer1.1.conv3.weight\", \"layer1.1.bn3.weight\", \"layer1.1.bn3.bias\", \"layer1.1.bn3.running_mean\", \"layer1.1.bn3.running_var\", \"layer1.1.bn3.num_batches_tracked\", \"layer1.2.conv1.weight\", \"layer1.2.bn1.weight\", \"layer1.2.bn1.bias\", \"layer1.2.bn1.running_mean\", \"layer1.2.bn1.running_var\", \"layer1.2.bn1.num_batches_tracked\", \"layer1.2.conv2.weight\", \"layer1.2.bn2.weight\", \"layer1.2.bn2.bias\", \"layer1.2.bn2.running_mean\", \"layer1.2.bn2.running_var\", \"layer1.2.bn2.num_batches_tracked\", \"layer1.2.conv3.weight\", \"layer1.2.bn3.weight\", \"layer1.2.bn3.bias\", \"layer1.2.bn3.running_mean\", \"layer1.2.bn3.running_var\", \"layer1.2.bn3.num_batches_tracked\", \"layer2.0.conv1.weight\", \"layer2.0.bn1.weight\", \"layer2.0.bn1.bias\", \"layer2.0.bn1.running_mean\", \"layer2.0.bn1.running_var\", \"layer2.0.bn1.num_batches_tracked\", \"layer2.0.conv2.weight\", \"layer2.0.bn2.weight\", \"layer2.0.bn2.bias\", \"layer2.0.bn2.running_mean\", \"layer2.0.bn2.running_var\", \"layer2.0.bn2.num_batches_tracked\", \"layer2.0.conv3.weight\", \"layer2.0.bn3.weight\", \"layer2.0.bn3.bias\", \"layer2.0.bn3.running_mean\", \"layer2.0.bn3.running_var\", \"layer2.0.bn3.num_batches_tracked\", \"layer2.0.downsample.0.weight\", \"layer2.0.downsample.1.weight\", \"layer2.0.downsample.1.bias\", \"layer2.0.downsample.1.running_mean\", \"layer2.0.downsample.1.running_var\", \"layer2.0.downsample.1.num_batches_tracked\", \"layer2.1.conv1.weight\", \"layer2.1.bn1.weight\", \"layer2.1.bn1.bias\", \"layer2.1.bn1.running_mean\", \"layer2.1.bn1.running_var\", \"layer2.1.bn1.num_batches_tracked\", \"layer2.1.conv2.weight\", \"layer2.1.bn2.weight\", \"layer2.1.bn2.bias\", \"layer2.1.bn2.running_mean\", \"layer2.1.bn2.running_var\", \"layer2.1.bn2.num_batches_tracked\", \"layer2.1.conv3.weight\", \"layer2.1.bn3.weight\", \"layer2.1.bn3.bias\", \"layer2.1.bn3.running_mean\", \"layer2.1.bn3.running_var\", \"layer2.1.bn3.num_batches_tracked\", \"layer2.2.conv1.weight\", \"layer2.2.bn1.weight\", \"layer2.2.bn1.bias\", \"layer2.2.bn1.running_mean\", \"layer2.2.bn1.running_var\", \"layer2.2.bn1.num_batches_tracked\", \"layer2.2.conv2.weight\", \"layer2.2.bn2.weight\", \"layer2.2.bn2.bias\", \"layer2.2.bn2.running_mean\", \"layer2.2.bn2.running_var\", \"layer2.2.bn2.num_batches_tracked\", \"layer2.2.conv3.weight\", \"layer2.2.bn3.weight\", \"layer2.2.bn3.bias\", \"layer2.2.bn3.running_mean\", \"layer2.2.bn3.running_var\", \"layer2.2.bn3.num_batches_tracked\", \"layer2.3.conv1.weight\", \"layer2.3.bn1.weight\", \"layer2.3.bn1.bias\", \"layer2.3.bn1.running_mean\", \"layer2.3.bn1.running_var\", \"layer2.3.bn1.num_batches_tracked\", \"layer2.3.conv2.weight\", \"layer2.3.bn2.weight\", \"layer2.3.bn2.bias\", \"layer2.3.bn2.running_mean\", \"layer2.3.bn2.running_var\", \"layer2.3.bn2.num_batches_tracked\", \"layer2.3.conv3.weight\", \"layer2.3.bn3.weight\", \"layer2.3.bn3.bias\", \"layer2.3.bn3.running_mean\", \"layer2.3.bn3.running_var\", \"layer2.3.bn3.num_batches_tracked\", \"layer3.0.conv1.weight\", \"layer3.0.bn1.weight\", \"layer3.0.bn1.bias\", \"layer3.0.bn1.running_mean\", \"layer3.0.bn1.running_var\", \"layer3.0.bn1.num_batches_tracked\", \"layer3.0.conv2.weight\", \"layer3.0.bn2.weight\", \"layer3.0.bn2.bias\", \"layer3.0.bn2.running_mean\", \"layer3.0.bn2.running_var\", \"layer3.0.bn2.num_batches_tracked\", \"layer3.0.conv3.weight\", \"layer3.0.bn3.weight\", \"layer3.0.bn3.bias\", \"layer3.0.bn3.running_mean\", \"layer3.0.bn3.running_var\", \"layer3.0.bn3.num_batches_tracked\", \"layer3.0.downsample.0.weight\", \"layer3.0.downsample.1.weight\", \"layer3.0.downsample.1.bias\", \"layer3.0.downsample.1.running_mean\", \"layer3.0.downsample.1.running_var\", \"layer3.0.downsample.1.num_batches_tracked\", \"layer3.1.conv1.weight\", \"layer3.1.bn1.weight\", \"layer3.1.bn1.bias\", \"layer3.1.bn1.running_mean\", \"layer3.1.bn1.running_var\", \"layer3.1.bn1.num_batches_tracked\", \"layer3.1.conv2.weight\", \"layer3.1.bn2.weight\", \"layer3.1.bn2.bias\", \"layer3.1.bn2.running_mean\", \"layer3.1.bn2.running_var\", \"layer3.1.bn2.num_batches_tracked\", \"layer3.1.conv3.weight\", \"layer3.1.bn3.weight\", \"layer3.1.bn3.bias\", \"layer3.1.bn3.running_mean\", \"layer3.1.bn3.running_var\", \"layer3.1.bn3.num_batches_tracked\", \"layer3.2.conv1.weight\", \"layer3.2.bn1.weight\", \"layer3.2.bn1.bias\", \"layer3.2.bn1.running_mean\", \"layer3.2.bn1.running_var\", \"layer3.2.bn1.num_batches_tracked\", \"layer3.2.conv2.weight\", \"layer3.2.bn2.weight\", \"layer3.2.bn2.bias\", \"layer3.2.bn2.running_mean\", \"layer3.2.bn2.running_var\", \"layer3.2.bn2.num_batches_tracked\", \"layer3.2.conv3.weight\", \"layer3.2.bn3.weight\", \"layer3.2.bn3.bias\", \"layer3.2.bn3.running_mean\", \"layer3.2.bn3.running_var\", \"layer3.2.bn3.num_batches_tracked\", \"layer3.3.conv1.weight\", \"layer3.3.bn1.weight\", \"layer3.3.bn1.bias\", \"layer3.3.bn1.running_mean\", \"layer3.3.bn1.running_var\", \"layer3.3.bn1.num_batches_tracked\", \"layer3.3.conv2.weight\", \"layer3.3.bn2.weight\", \"layer3.3.bn2.bias\", \"layer3.3.bn2.running_mean\", \"layer3.3.bn2.running_var\", \"layer3.3.bn2.num_batches_tracked\", \"layer3.3.conv3.weight\", \"layer3.3.bn3.weight\", \"layer3.3.bn3.bias\", \"layer3.3.bn3.running_mean\", \"layer3.3.bn3.running_var\", \"layer3.3.bn3.num_batches_tracked\", \"layer3.4.conv1.weight\", \"layer3.4.bn1.weight\", \"layer3.4.bn1.bias\", \"layer3.4.bn1.running_mean\", \"layer3.4.bn1.running_var\", \"layer3.4.bn1.num_batches_tracked\", \"layer3.4.conv2.weight\", \"layer3.4.bn2.weight\", \"layer3.4.bn2.bias\", \"layer3.4.bn2.running_mean\", \"layer3.4.bn2.running_var\", \"layer3.4.bn2.num_batches_tracked\", \"layer3.4.conv3.weight\", \"layer3.4.bn3.weight\", \"layer3.4.bn3.bias\", \"layer3.4.bn3.running_mean\", \"layer3.4.bn3.running_var\", \"layer3.4.bn3.num_batches_tracked\", \"layer3.5.conv1.weight\", \"layer3.5.bn1.weight\", \"layer3.5.bn1.bias\", \"layer3.5.bn1.running_mean\", \"layer3.5.bn1.running_var\", \"layer3.5.bn1.num_batches_tracked\", \"layer3.5.conv2.weight\", \"layer3.5.bn2.weight\", \"layer3.5.bn2.bias\", \"layer3.5.bn2.running_mean\", \"layer3.5.bn2.running_var\", \"layer3.5.bn2.num_batches_tracked\", \"layer3.5.conv3.weight\", \"layer3.5.bn3.weight\", \"layer3.5.bn3.bias\", \"layer3.5.bn3.running_mean\", \"layer3.5.bn3.running_var\", \"layer3.5.bn3.num_batches_tracked\", \"layer4.0.conv1.weight\", \"layer4.0.bn1.weight\", \"layer4.0.bn1.bias\", \"layer4.0.bn1.running_mean\", \"layer4.0.bn1.running_var\", \"layer4.0.bn1.num_batches_tracked\", \"layer4.0.conv2.weight\", \"layer4.0.bn2.weight\", \"layer4.0.bn2.bias\", \"layer4.0.bn2.running_mean\", \"layer4.0.bn2.running_var\", \"layer4.0.bn2.num_batches_tracked\", \"layer4.0.conv3.weight\", \"layer4.0.bn3.weight\", \"layer4.0.bn3.bias\", \"layer4.0.bn3.running_mean\", \"layer4.0.bn3.running_var\", \"layer4.0.bn3.num_batches_tracked\", \"layer4.0.downsample.0.weight\", \"layer4.0.downsample.1.weight\", \"layer4.0.downsample.1.bias\", \"layer4.0.downsample.1.running_mean\", \"layer4.0.downsample.1.running_var\", \"layer4.0.downsample.1.num_batches_tracked\", \"layer4.1.conv1.weight\", \"layer4.1.bn1.weight\", \"layer4.1.bn1.bias\", \"layer4.1.bn1.running_mean\", \"layer4.1.bn1.running_var\", \"layer4.1.bn1.num_batches_tracked\", \"layer4.1.conv2.weight\", \"layer4.1.bn2.weight\", \"layer4.1.bn2.bias\", \"layer4.1.bn2.running_mean\", \"layer4.1.bn2.running_var\", \"layer4.1.bn2.num_batches_tracked\", \"layer4.1.conv3.weight\", \"layer4.1.bn3.weight\", \"layer4.1.bn3.bias\", \"layer4.1.bn3.running_mean\", \"layer4.1.bn3.running_var\", \"layer4.1.bn3.num_batches_tracked\", \"layer4.2.conv1.weight\", \"layer4.2.bn1.weight\", \"layer4.2.bn1.bias\", \"layer4.2.bn1.running_mean\", \"layer4.2.bn1.running_var\", \"layer4.2.bn1.num_batches_tracked\", \"layer4.2.conv2.weight\", \"layer4.2.bn2.weight\", \"layer4.2.bn2.bias\", \"layer4.2.bn2.running_mean\", \"layer4.2.bn2.running_var\", \"layer4.2.bn2.num_batches_tracked\", \"layer4.2.conv3.weight\", \"layer4.2.bn3.weight\", \"layer4.2.bn3.bias\", \"layer4.2.bn3.running_mean\", \"layer4.2.bn3.running_var\", \"layer4.2.bn3.num_batches_tracked\". "
  1353. ]
  1354. }
  1355. ],
  1356. "source": [
  1357. "validator = AccuracyValidator(key_map={\"target_val_with_labels\": \"src_val\"})\n",
  1358. "score = trainer.evaluate_best_model(datasets, validator, dc)\n",
  1359. "print(score)"
  1360. ]
  1361. },
  1362. {
  1363. "cell_type": "code",
  1364. "execution_count": 32,
  1365. "metadata": {},
  1366. "outputs": [
  1367. {
  1368. "ename": "TypeError",
  1369. "evalue": "__call__() missing 2 required positional arguments: 'engine' and 'to_save'",
  1370. "output_type": "error",
  1371. "traceback": [
  1372. "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
  1373. "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
  1374. "Cell \u001b[0;32mIn[32], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m checkpoint_fn\u001b[39m.\u001b[39;49mobjs()\n",
  1375. "\u001b[0;31mTypeError\u001b[0m: __call__() missing 2 required positional arguments: 'engine' and 'to_save'"
  1376. ]
  1377. }
  1378. ],
  1379. "source": []
  1380. },
  1381. {
  1382. "cell_type": "code",
  1383. "execution_count": null,
  1384. "metadata": {},
  1385. "outputs": [],
  1386. "source": []
  1387. }
  1388. ],
  1389. "metadata": {
  1390. "kernelspec": {
  1391. "display_name": "cdtrans",
  1392. "language": "python",
  1393. "name": "python3"
  1394. },
  1395. "language_info": {
  1396. "codemirror_mode": {
  1397. "name": "ipython",
  1398. "version": 3
  1399. },
  1400. "file_extension": ".py",
  1401. "mimetype": "text/x-python",
  1402. "name": "python",
  1403. "nbconvert_exporter": "python",
  1404. "pygments_lexer": "ipython3",
  1405. "version": "3.8.15"
  1406. },
  1407. "orig_nbformat": 4,
  1408. "vscode": {
  1409. "interpreter": {
  1410. "hash": "959b82c3a41427bdf7d14d4ba7335271e0c50cfcddd70501934b27dcc36968ad"
  1411. }
  1412. }
  1413. },
  1414. "nbformat": 4,
  1415. "nbformat_minor": 2
  1416. }