You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

00_bert_ah.ipynb 77KB

3 months ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {
  6. "pycharm": {
  7. "name": "#%% md\n"
  8. },
  9. "tags": []
  10. },
  11. "source": [
  12. "# Intro"
  13. ]
  14. },
  15. {
  16. "cell_type": "code",
  17. "execution_count": 1,
  18. "metadata": {
  19. "tags": []
  20. },
  21. "outputs": [],
  22. "source": [
  23. "from abc import abstractmethod, ABC\n",
  24. "from os import PathLike\n",
  25. "from typing import Dict, Union, Optional, Iterable\n",
  26. "\n",
  27. "\n",
  28. "class base_peft(ABC):\n",
  29. " def __init__(self, base_model_name: Union[str, PathLike[str]], mask_token_id: int):\n",
  30. " self.base_model_name = base_model_name\n",
  31. " self.mask_token_id = mask_token_id\n",
  32. "\n",
  33. " def activate_task_for_training\n",
  34. "\n",
  35. " @abstractmethod\n",
  36. " def finetune_task(self, peft_name: str, train_dataset, validation_dataset):\n",
  37. " pass"
  38. ]
  39. },
  40. {
  41. "cell_type": "code",
  42. "execution_count": 1,
  43. "metadata": {
  44. "ExecuteTime": {
  45. "end_time": "2023-08-15T13:16:40.910406Z",
  46. "start_time": "2023-08-15T13:16:40.860981Z"
  47. },
  48. "tags": []
  49. },
  50. "outputs": [
  51. {
  52. "name": "stdout",
  53. "output_type": "stream",
  54. "text": [
  55. "/home/mohalisad/Developer/ProgressivePrompts\n"
  56. ]
  57. }
  58. ],
  59. "source": [
  60. "cd /home/mohalisad/Developer/ProgressivePrompts"
  61. ]
  62. },
  63. {
  64. "cell_type": "code",
  65. "execution_count": 2,
  66. "metadata": {
  67. "ExecuteTime": {
  68. "end_time": "2023-08-15T13:16:42.467311Z",
  69. "start_time": "2023-08-15T13:16:42.313951Z"
  70. },
  71. "pycharm": {
  72. "is_executing": true,
  73. "name": "#%%\n"
  74. },
  75. "scrolled": true,
  76. "tags": []
  77. },
  78. "outputs": [
  79. {
  80. "name": "stdout",
  81. "output_type": "stream",
  82. "text": [
  83. "Python version is: 3.9.17\n",
  84. "Torch version is: 1.13.1+cu117\n",
  85. "Nvidia device is: NVIDIA GeForce RTX 4090\n",
  86. "Transformers version is: 4.26.1\n",
  87. "Adapterhub version is: 3.2.1\n"
  88. ]
  89. }
  90. ],
  91. "source": [
  92. "from utils import print_system_info\n",
  93. "print_system_info()"
  94. ]
  95. },
  96. {
  97. "cell_type": "markdown",
  98. "metadata": {},
  99. "source": [
  100. "# Dataset"
  101. ]
  102. },
  103. {
  104. "cell_type": "code",
  105. "execution_count": 31,
  106. "metadata": {
  107. "tags": []
  108. },
  109. "outputs": [],
  110. "source": [
  111. "from _datasets import AutoLoad\n",
  112. "from config import load_config\n",
  113. "from _models import BertAdapterModelWrapper, TokenizerMan\n",
  114. "\n",
  115. "\n",
  116. "config = load_config('config.yaml')"
  117. ]
  118. },
  119. {
  120. "cell_type": "code",
  121. "execution_count": 39,
  122. "metadata": {
  123. "tags": []
  124. },
  125. "outputs": [
  126. {
  127. "name": "stderr",
  128. "output_type": "stream",
  129. "text": [
  130. "loading configuration file config.json from cache at /home/mohalisad/.cache/huggingface/hub/models--bert-base-uncased/snapshots/1dbc166cf8765166998eff31ade2eb64c8a40076/config.json\n",
  131. "Model config BertConfig {\n",
  132. " \"architectures\": [\n",
  133. " \"BertForMaskedLM\"\n",
  134. " ],\n",
  135. " \"attention_probs_dropout_prob\": 0.1,\n",
  136. " \"classifier_dropout\": null,\n",
  137. " \"gradient_checkpointing\": false,\n",
  138. " \"hidden_act\": \"gelu\",\n",
  139. " \"hidden_dropout_prob\": 0.1,\n",
  140. " \"hidden_size\": 768,\n",
  141. " \"initializer_range\": 0.02,\n",
  142. " \"intermediate_size\": 3072,\n",
  143. " \"layer_norm_eps\": 1e-12,\n",
  144. " \"max_position_embeddings\": 512,\n",
  145. " \"model_type\": \"bert\",\n",
  146. " \"num_attention_heads\": 12,\n",
  147. " \"num_hidden_layers\": 12,\n",
  148. " \"pad_token_id\": 0,\n",
  149. " \"position_embedding_type\": \"absolute\",\n",
  150. " \"transformers_version\": \"4.26.1\",\n",
  151. " \"type_vocab_size\": 2,\n",
  152. " \"use_cache\": true,\n",
  153. " \"vocab_size\": 30522\n",
  154. "}\n",
  155. "\n",
  156. "loading weights file model.safetensors from cache at /home/mohalisad/.cache/huggingface/hub/models--bert-base-uncased/snapshots/1dbc166cf8765166998eff31ade2eb64c8a40076/model.safetensors\n",
  157. "Generate config GenerationConfig {\n",
  158. " \"pad_token_id\": 0,\n",
  159. " \"transformers_version\": \"4.26.1\"\n",
  160. "}\n",
  161. "\n",
  162. "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertAdapterModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias']\n",
  163. "- This IS expected if you are initializing BertAdapterModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
  164. "- This IS NOT expected if you are initializing BertAdapterModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
  165. "All the weights of BertAdapterModel were initialized from the model checkpoint at bert-base-uncased.\n",
  166. "If your task is similar to the task the model of the checkpoint was trained on, you can already use BertAdapterModel for predictions without further training.\n",
  167. "Generation config file not found, using a generation config created from the model config.\n",
  168. "loading file vocab.txt from cache at /home/mohalisad/.cache/huggingface/hub/models--bert-base-uncased/snapshots/1dbc166cf8765166998eff31ade2eb64c8a40076/vocab.txt\n",
  169. "loading file tokenizer.json from cache at /home/mohalisad/.cache/huggingface/hub/models--bert-base-uncased/snapshots/1dbc166cf8765166998eff31ade2eb64c8a40076/tokenizer.json\n",
  170. "loading file added_tokens.json from cache at None\n",
  171. "loading file special_tokens_map.json from cache at None\n",
  172. "loading file tokenizer_config.json from cache at /home/mohalisad/.cache/huggingface/hub/models--bert-base-uncased/snapshots/1dbc166cf8765166998eff31ade2eb64c8a40076/tokenizer_config.json\n",
  173. "loading configuration file config.json from cache at /home/mohalisad/.cache/huggingface/hub/models--bert-base-uncased/snapshots/1dbc166cf8765166998eff31ade2eb64c8a40076/config.json\n",
  174. "Model config BertConfig {\n",
  175. " \"_name_or_path\": \"bert-base-uncased\",\n",
  176. " \"architectures\": [\n",
  177. " \"BertForMaskedLM\"\n",
  178. " ],\n",
  179. " \"attention_probs_dropout_prob\": 0.1,\n",
  180. " \"classifier_dropout\": null,\n",
  181. " \"gradient_checkpointing\": false,\n",
  182. " \"hidden_act\": \"gelu\",\n",
  183. " \"hidden_dropout_prob\": 0.1,\n",
  184. " \"hidden_size\": 768,\n",
  185. " \"initializer_range\": 0.02,\n",
  186. " \"intermediate_size\": 3072,\n",
  187. " \"layer_norm_eps\": 1e-12,\n",
  188. " \"max_position_embeddings\": 512,\n",
  189. " \"model_type\": \"bert\",\n",
  190. " \"num_attention_heads\": 12,\n",
  191. " \"num_hidden_layers\": 12,\n",
  192. " \"pad_token_id\": 0,\n",
  193. " \"position_embedding_type\": \"absolute\",\n",
  194. " \"transformers_version\": \"4.26.1\",\n",
  195. " \"type_vocab_size\": 2,\n",
  196. " \"use_cache\": true,\n",
  197. " \"vocab_size\": 30522\n",
  198. "}\n",
  199. "\n"
  200. ]
  201. }
  202. ],
  203. "source": [
  204. "# import transformers\n",
  205. "# transformers.logging.set_verbosity_debug()\n",
  206. "adapter_wrapper = BertAdapterModelWrapper(\n",
  207. " base_model_name=config.base_model.name,\n",
  208. " mask_token_id=config.base_model.mask_token_id\n",
  209. ")\n",
  210. "tokenizer_man = TokenizerMan(config.base_model.kind, config.base_model.name)"
  211. ]
  212. },
  213. {
  214. "cell_type": "code",
  215. "execution_count": 40,
  216. "metadata": {
  217. "tags": []
  218. },
  219. "outputs": [],
  220. "source": [
  221. "auto_loader = AutoLoad()"
  222. ]
  223. },
  224. {
  225. "cell_type": "code",
  226. "execution_count": 41,
  227. "metadata": {
  228. "tags": []
  229. },
  230. "outputs": [
  231. {
  232. "data": {
  233. "application/vnd.jupyter.widget-view+json": {
  234. "model_id": "f983a58646a54aa6841312408f00f491",
  235. "version_major": 2,
  236. "version_minor": 0
  237. },
  238. "text/plain": [
  239. "Map: 0%| | 0/8551 [00:00<?, ? examples/s]"
  240. ]
  241. },
  242. "metadata": {},
  243. "output_type": "display_data"
  244. },
  245. {
  246. "data": {
  247. "application/vnd.jupyter.widget-view+json": {
  248. "model_id": "99ea0309b4384a0ab7a458710ae2e443",
  249. "version_major": 2,
  250. "version_minor": 0
  251. },
  252. "text/plain": [
  253. "Map: 0%| | 0/1043 [00:00<?, ? examples/s]"
  254. ]
  255. },
  256. "metadata": {},
  257. "output_type": "display_data"
  258. },
  259. {
  260. "data": {
  261. "application/vnd.jupyter.widget-view+json": {
  262. "model_id": "d041fd8948044b5e8b0f761079a04894",
  263. "version_major": 2,
  264. "version_minor": 0
  265. },
  266. "text/plain": [
  267. "Map: 0%| | 0/1063 [00:00<?, ? examples/s]"
  268. ]
  269. },
  270. "metadata": {},
  271. "output_type": "display_data"
  272. },
  273. {
  274. "name": "stderr",
  275. "output_type": "stream",
  276. "text": [
  277. "Adding adapter 'glue:cola'.\n",
  278. "Adding head 'glue:cola' with config {'head_type': 'classification', 'num_labels': 2, 'layers': 2, 'activation_function': 'tanh', 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'use_pooler': False, 'bias': True}.\n",
  279. "PyTorch: setting up devices\n",
  280. "The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n",
  281. "/home/mohalisad/anaconda3/envs/lll/lib/python3.9/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
  282. " warnings.warn(\n",
  283. "***** Running training *****\n",
  284. " Num examples = 8551\n",
  285. " Num Epochs = 15\n",
  286. " Instantaneous batch size per device = 32\n",
  287. " Total train batch size (w. parallel, distributed & accumulation) = 32\n",
  288. " Gradient Accumulation steps = 1\n",
  289. " Total optimization steps = 4020\n",
  290. " Number of trainable parameters = 1486658\n",
  291. "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
  292. ]
  293. },
  294. {
  295. "data": {
  296. "text/html": [
  297. "\n",
  298. " <div>\n",
  299. " \n",
  300. " <progress value='4020' max='4020' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
  301. " [4020/4020 01:08, Epoch 15/15]\n",
  302. " </div>\n",
  303. " <table border=\"1\" class=\"dataframe\">\n",
  304. " <thead>\n",
  305. " <tr style=\"text-align: left;\">\n",
  306. " <th>Epoch</th>\n",
  307. " <th>Training Loss</th>\n",
  308. " <th>Validation Loss</th>\n",
  309. " <th>Accuracy</th>\n",
  310. " <th>F1-score-1</th>\n",
  311. " <th>F1-score-ma</th>\n",
  312. " </tr>\n",
  313. " </thead>\n",
  314. " <tbody>\n",
  315. " <tr>\n",
  316. " <td>1</td>\n",
  317. " <td>No log</td>\n",
  318. " <td>0.521243</td>\n",
  319. " <td>0.772771</td>\n",
  320. " <td>0.854512</td>\n",
  321. " <td>0.667956</td>\n",
  322. " </tr>\n",
  323. " <tr>\n",
  324. " <td>2</td>\n",
  325. " <td>0.484900</td>\n",
  326. " <td>0.475989</td>\n",
  327. " <td>0.795781</td>\n",
  328. " <td>0.866290</td>\n",
  329. " <td>0.717121</td>\n",
  330. " </tr>\n",
  331. " <tr>\n",
  332. " <td>3</td>\n",
  333. " <td>0.484900</td>\n",
  334. " <td>0.473902</td>\n",
  335. " <td>0.799616</td>\n",
  336. " <td>0.868471</td>\n",
  337. " <td>0.723974</td>\n",
  338. " </tr>\n",
  339. " <tr>\n",
  340. " <td>4</td>\n",
  341. " <td>0.390000</td>\n",
  342. " <td>0.454408</td>\n",
  343. " <td>0.815916</td>\n",
  344. " <td>0.877707</td>\n",
  345. " <td>0.752807</td>\n",
  346. " </tr>\n",
  347. " <tr>\n",
  348. " <td>5</td>\n",
  349. " <td>0.390000</td>\n",
  350. " <td>0.460564</td>\n",
  351. " <td>0.822627</td>\n",
  352. " <td>0.880414</td>\n",
  353. " <td>0.768593</td>\n",
  354. " </tr>\n",
  355. " <tr>\n",
  356. " <td>6</td>\n",
  357. " <td>0.330900</td>\n",
  358. " <td>0.421414</td>\n",
  359. " <td>0.831256</td>\n",
  360. " <td>0.883752</td>\n",
  361. " <td>0.788030</td>\n",
  362. " </tr>\n",
  363. " <tr>\n",
  364. " <td>7</td>\n",
  365. " <td>0.330900</td>\n",
  366. " <td>0.452820</td>\n",
  367. " <td>0.833174</td>\n",
  368. " <td>0.885375</td>\n",
  369. " <td>0.789519</td>\n",
  370. " </tr>\n",
  371. " <tr>\n",
  372. " <td>8</td>\n",
  373. " <td>0.292000</td>\n",
  374. " <td>0.465746</td>\n",
  375. " <td>0.826462</td>\n",
  376. " <td>0.881777</td>\n",
  377. " <td>0.777825</td>\n",
  378. " </tr>\n",
  379. " <tr>\n",
  380. " <td>9</td>\n",
  381. " <td>0.292000</td>\n",
  382. " <td>0.491992</td>\n",
  383. " <td>0.832215</td>\n",
  384. " <td>0.885396</td>\n",
  385. " <td>0.786169</td>\n",
  386. " </tr>\n",
  387. " <tr>\n",
  388. " <td>10</td>\n",
  389. " <td>0.255500</td>\n",
  390. " <td>0.508437</td>\n",
  391. " <td>0.827421</td>\n",
  392. " <td>0.883117</td>\n",
  393. " <td>0.776723</td>\n",
  394. " </tr>\n",
  395. " <tr>\n",
  396. " <td>11</td>\n",
  397. " <td>0.255500</td>\n",
  398. " <td>0.519635</td>\n",
  399. " <td>0.837009</td>\n",
  400. " <td>0.888889</td>\n",
  401. " <td>0.791567</td>\n",
  402. " </tr>\n",
  403. " <tr>\n",
  404. " <td>12</td>\n",
  405. " <td>0.232300</td>\n",
  406. " <td>0.522434</td>\n",
  407. " <td>0.828380</td>\n",
  408. " <td>0.883388</td>\n",
  409. " <td>0.779262</td>\n",
  410. " </tr>\n",
  411. " <tr>\n",
  412. " <td>13</td>\n",
  413. " <td>0.232300</td>\n",
  414. " <td>0.532363</td>\n",
  415. " <td>0.835091</td>\n",
  416. " <td>0.886991</td>\n",
  417. " <td>0.791013</td>\n",
  418. " </tr>\n",
  419. " <tr>\n",
  420. " <td>14</td>\n",
  421. " <td>0.219900</td>\n",
  422. " <td>0.557935</td>\n",
  423. " <td>0.831256</td>\n",
  424. " <td>0.885566</td>\n",
  425. " <td>0.782199</td>\n",
  426. " </tr>\n",
  427. " <tr>\n",
  428. " <td>15</td>\n",
  429. " <td>0.202800</td>\n",
  430. " <td>0.547973</td>\n",
  431. " <td>0.832215</td>\n",
  432. " <td>0.885845</td>\n",
  433. " <td>0.784695</td>\n",
  434. " </tr>\n",
  435. " </tbody>\n",
  436. "</table><p>"
  437. ],
  438. "text/plain": [
  439. "<IPython.core.display.HTML object>"
  440. ]
  441. },
  442. "metadata": {},
  443. "output_type": "display_data"
  444. },
  445. {
  446. "name": "stderr",
  447. "output_type": "stream",
  448. "text": [
  449. "***** Running Evaluation *****\n",
  450. " Num examples = 1043\n",
  451. " Batch size = 32\n",
  452. "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-268\n",
  453. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-268/glue:cola/adapter_config.json\n",
  454. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-268/glue:cola/pytorch_adapter.bin\n",
  455. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-268/glue:cola/head_config.json\n",
  456. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-268/glue:cola/pytorch_model_head.bin\n",
  457. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-268/glue:cola/head_config.json\n",
  458. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-268/glue:cola/pytorch_model_head.bin\n",
  459. "***** Running Evaluation *****\n",
  460. " Num examples = 1043\n",
  461. " Batch size = 32\n",
  462. "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-536\n",
  463. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-536/glue:cola/adapter_config.json\n",
  464. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-536/glue:cola/pytorch_adapter.bin\n",
  465. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-536/glue:cola/head_config.json\n",
  466. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-536/glue:cola/pytorch_model_head.bin\n",
  467. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-536/glue:cola/head_config.json\n",
  468. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-536/glue:cola/pytorch_model_head.bin\n",
  469. "***** Running Evaluation *****\n",
  470. " Num examples = 1043\n",
  471. " Batch size = 32\n",
  472. "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-804\n",
  473. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-804/glue:cola/adapter_config.json\n",
  474. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-804/glue:cola/pytorch_adapter.bin\n",
  475. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-804/glue:cola/head_config.json\n",
  476. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-804/glue:cola/pytorch_model_head.bin\n",
  477. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-804/glue:cola/head_config.json\n",
  478. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-804/glue:cola/pytorch_model_head.bin\n",
  479. "***** Running Evaluation *****\n",
  480. " Num examples = 1043\n",
  481. " Batch size = 32\n",
  482. "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1072\n",
  483. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1072/glue:cola/adapter_config.json\n",
  484. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1072/glue:cola/pytorch_adapter.bin\n",
  485. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1072/glue:cola/head_config.json\n",
  486. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1072/glue:cola/pytorch_model_head.bin\n",
  487. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1072/glue:cola/head_config.json\n",
  488. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1072/glue:cola/pytorch_model_head.bin\n",
  489. "***** Running Evaluation *****\n",
  490. " Num examples = 1043\n",
  491. " Batch size = 32\n",
  492. "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1340\n",
  493. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1340/glue:cola/adapter_config.json\n",
  494. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1340/glue:cola/pytorch_adapter.bin\n",
  495. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1340/glue:cola/head_config.json\n",
  496. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1340/glue:cola/pytorch_model_head.bin\n",
  497. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1340/glue:cola/head_config.json\n",
  498. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1340/glue:cola/pytorch_model_head.bin\n",
  499. "***** Running Evaluation *****\n",
  500. " Num examples = 1043\n",
  501. " Batch size = 32\n",
  502. "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1608\n",
  503. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1608/glue:cola/adapter_config.json\n",
  504. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1608/glue:cola/pytorch_adapter.bin\n",
  505. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1608/glue:cola/head_config.json\n",
  506. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1608/glue:cola/pytorch_model_head.bin\n",
  507. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1608/glue:cola/head_config.json\n",
  508. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1608/glue:cola/pytorch_model_head.bin\n",
  509. "***** Running Evaluation *****\n",
  510. " Num examples = 1043\n",
  511. " Batch size = 32\n",
  512. "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1876\n",
  513. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1876/glue:cola/adapter_config.json\n",
  514. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1876/glue:cola/pytorch_adapter.bin\n",
  515. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1876/glue:cola/head_config.json\n",
  516. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1876/glue:cola/pytorch_model_head.bin\n",
  517. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1876/glue:cola/head_config.json\n",
  518. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1876/glue:cola/pytorch_model_head.bin\n",
  519. "***** Running Evaluation *****\n",
  520. " Num examples = 1043\n",
  521. " Batch size = 32\n",
  522. "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2144\n",
  523. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2144/glue:cola/adapter_config.json\n",
  524. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2144/glue:cola/pytorch_adapter.bin\n",
  525. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2144/glue:cola/head_config.json\n",
  526. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2144/glue:cola/pytorch_model_head.bin\n",
  527. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2144/glue:cola/head_config.json\n",
  528. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2144/glue:cola/pytorch_model_head.bin\n",
  529. "***** Running Evaluation *****\n",
  530. " Num examples = 1043\n",
  531. " Batch size = 32\n",
  532. "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2412\n",
  533. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2412/glue:cola/adapter_config.json\n",
  534. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2412/glue:cola/pytorch_adapter.bin\n",
  535. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2412/glue:cola/head_config.json\n",
  536. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2412/glue:cola/pytorch_model_head.bin\n",
  537. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2412/glue:cola/head_config.json\n",
  538. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2412/glue:cola/pytorch_model_head.bin\n",
  539. "***** Running Evaluation *****\n",
  540. " Num examples = 1043\n",
  541. " Batch size = 32\n",
  542. "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2680\n",
  543. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2680/glue:cola/adapter_config.json\n",
  544. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2680/glue:cola/pytorch_adapter.bin\n",
  545. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2680/glue:cola/head_config.json\n",
  546. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2680/glue:cola/pytorch_model_head.bin\n",
  547. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2680/glue:cola/head_config.json\n",
  548. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2680/glue:cola/pytorch_model_head.bin\n",
  549. "***** Running Evaluation *****\n",
  550. " Num examples = 1043\n",
  551. " Batch size = 32\n",
  552. "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2948\n",
  553. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2948/glue:cola/adapter_config.json\n",
  554. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2948/glue:cola/pytorch_adapter.bin\n",
  555. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2948/glue:cola/head_config.json\n",
  556. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2948/glue:cola/pytorch_model_head.bin\n",
  557. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2948/glue:cola/head_config.json\n",
  558. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2948/glue:cola/pytorch_model_head.bin\n",
  559. "***** Running Evaluation *****\n",
  560. " Num examples = 1043\n",
  561. " Batch size = 32\n",
  562. "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3216\n",
  563. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3216/glue:cola/adapter_config.json\n",
  564. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3216/glue:cola/pytorch_adapter.bin\n",
  565. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3216/glue:cola/head_config.json\n",
  566. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3216/glue:cola/pytorch_model_head.bin\n",
  567. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3216/glue:cola/head_config.json\n",
  568. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3216/glue:cola/pytorch_model_head.bin\n",
  569. "***** Running Evaluation *****\n",
  570. " Num examples = 1043\n",
  571. " Batch size = 32\n",
  572. "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3484\n",
  573. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3484/glue:cola/adapter_config.json\n",
  574. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3484/glue:cola/pytorch_adapter.bin\n",
  575. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3484/glue:cola/head_config.json\n",
  576. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3484/glue:cola/pytorch_model_head.bin\n",
  577. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3484/glue:cola/head_config.json\n",
  578. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3484/glue:cola/pytorch_model_head.bin\n",
  579. "***** Running Evaluation *****\n",
  580. " Num examples = 1043\n",
  581. " Batch size = 32\n",
  582. "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3752\n",
  583. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3752/glue:cola/adapter_config.json\n",
  584. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3752/glue:cola/pytorch_adapter.bin\n",
  585. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3752/glue:cola/head_config.json\n",
  586. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3752/glue:cola/pytorch_model_head.bin\n",
  587. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3752/glue:cola/head_config.json\n",
  588. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3752/glue:cola/pytorch_model_head.bin\n",
  589. "***** Running Evaluation *****\n",
  590. " Num examples = 1043\n",
  591. " Batch size = 32\n",
  592. "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-4020\n",
  593. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-4020/glue:cola/adapter_config.json\n",
  594. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-4020/glue:cola/pytorch_adapter.bin\n",
  595. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-4020/glue:cola/head_config.json\n",
  596. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-4020/glue:cola/pytorch_model_head.bin\n",
  597. "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-4020/glue:cola/head_config.json\n",
  598. "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-4020/glue:cola/pytorch_model_head.bin\n",
  599. "\n",
  600. "\n",
  601. "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
  602. "\n",
  603. "\n"
  604. ]
  605. }
  606. ],
  607. "source": [
  608. "for task_name in config.tasks:\n",
  609. " loader_out = auto_loader.get_and_map(tokenizer_man.tokenizer, task_name)\n",
  610. " num_labels = len(loader_out['output']['range'])\n",
  611. " adapter_wrapper.add_classification_adapter(task_name, num_labels=num_labels)\n",
  612. " adapter_wrapper.finetune_adapter(\n",
  613. " task_name,\n",
  614. " loader_out['train'],\n",
  615. " loader_out['valid'],\n",
  616. " tokenizer_man.get_col_fn(),\n",
  617. " config.hf_trainer_params.to_dict()\n",
  618. " )"
  619. ]
  620. },
  621. {
  622. "cell_type": "markdown",
  623. "metadata": {},
  624. "source": [
  625. "# Opendelta"
  626. ]
  627. },
  628. {
  629. "cell_type": "code",
  630. "execution_count": 24,
  631. "metadata": {
  632. "tags": []
  633. },
  634. "outputs": [],
  635. "source": [
  636. "from bigmodelvis import Visualization\n",
  637. "from transformers import BertForSequenceClassification\n",
  638. "from opendelta import AdapterModel"
  639. ]
  640. },
  641. {
  642. "cell_type": "code",
  643. "execution_count": 42,
  644. "metadata": {
  645. "tags": []
  646. },
  647. "outputs": [
  648. {
  649. "name": "stderr",
  650. "output_type": "stream",
  651. "text": [
  652. "loading configuration file config.json from cache at /home/mohalisad/.cache/huggingface/hub/models--bert-base-uncased/snapshots/1dbc166cf8765166998eff31ade2eb64c8a40076/config.json\n",
  653. "Model config BertConfig {\n",
  654. " \"architectures\": [\n",
  655. " \"BertForMaskedLM\"\n",
  656. " ],\n",
  657. " \"attention_probs_dropout_prob\": 0.1,\n",
  658. " \"classifier_dropout\": null,\n",
  659. " \"gradient_checkpointing\": false,\n",
  660. " \"hidden_act\": \"gelu\",\n",
  661. " \"hidden_dropout_prob\": 0.1,\n",
  662. " \"hidden_size\": 768,\n",
  663. " \"initializer_range\": 0.02,\n",
  664. " \"intermediate_size\": 3072,\n",
  665. " \"layer_norm_eps\": 1e-12,\n",
  666. " \"max_position_embeddings\": 512,\n",
  667. " \"model_type\": \"bert\",\n",
  668. " \"num_attention_heads\": 12,\n",
  669. " \"num_hidden_layers\": 12,\n",
  670. " \"pad_token_id\": 0,\n",
  671. " \"position_embedding_type\": \"absolute\",\n",
  672. " \"transformers_version\": \"4.26.1\",\n",
  673. " \"type_vocab_size\": 2,\n",
  674. " \"use_cache\": true,\n",
  675. " \"vocab_size\": 30522\n",
  676. "}\n",
  677. "\n",
  678. "loading weights file model.safetensors from cache at /home/mohalisad/.cache/huggingface/hub/models--bert-base-uncased/snapshots/1dbc166cf8765166998eff31ade2eb64c8a40076/model.safetensors\n",
  679. "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias']\n",
  680. "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
  681. "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
  682. "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
  683. "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
  684. ]
  685. }
  686. ],
  687. "source": [
  688. "base_model = BertForSequenceClassification.from_pretrained(config.base_model.name)"
  689. ]
  690. },
  691. {
  692. "cell_type": "code",
  693. "execution_count": 43,
  694. "metadata": {
  695. "tags": []
  696. },
  697. "outputs": [
  698. {
  699. "data": {
  700. "text/html": [
  701. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">root</span>\n",
  702. "├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">bert </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertModel)</span>\n",
  703. "│ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertEmbeddings)</span>\n",
  704. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">word_embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(Embedding) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[30522, 768]</span>\n",
  705. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">position_embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(Embedding) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[512, 768]</span>\n",
  706. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">token_type_embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(Embedding) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[2, 768]</span>\n",
  707. "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">LayerNorm </span><span style=\"color: #008000; text-decoration-color: #008000\">(LayerNorm) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[768] bias:[768]</span>\n",
  708. "│ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">encoder </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertEncoder)</span>\n",
  709. "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">layer </span><span style=\"color: #008000; text-decoration-color: #008000\">(ModuleList)</span>\n",
  710. "│ │ └── <span style=\"color: #800000; text-decoration-color: #800000\">0-11</span><span style=\"color: #008000; text-decoration-color: #008000\">(BertLayer)</span>\n",
  711. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">attention </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertAttention)</span>\n",
  712. "│ │ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">self </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertSelfAttention)</span>\n",
  713. "│ │ │ │ ├── <span style=\"color: #800000; text-decoration-color: #800000\">query,key,value</span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[768, 768] bias:[768]</span>\n",
  714. "│ │ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">prefix_tuning </span><span style=\"color: #008000; text-decoration-color: #008000\">(PrefixTuningShim)</span>\n",
  715. "│ │ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">pool </span><span style=\"color: #008000; text-decoration-color: #008000\">(PrefixTuningPool)</span>\n",
  716. "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">output </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertSelfOutput)</span>\n",
  717. "│ │ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[768, 768] bias:[768]</span>\n",
  718. "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">LayerNorm </span><span style=\"color: #008000; text-decoration-color: #008000\">(LayerNorm) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[768] bias:[768]</span>\n",
  719. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">intermediate </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertIntermediate)</span>\n",
  720. "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[3072, 768] bias:[3072]</span>\n",
  721. "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">output </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertOutput)</span>\n",
  722. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[768, 3072] bias:[768]</span>\n",
  723. "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">LayerNorm </span><span style=\"color: #008000; text-decoration-color: #008000\">(LayerNorm) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[768] bias:[768]</span>\n",
  724. "│ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">pooler </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertPooler)</span>\n",
  725. "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[768, 768] bias:[768]</span>\n",
  726. "│ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">prefix_tuning </span><span style=\"color: #008000; text-decoration-color: #008000\">(PrefixTuningPool)</span>\n",
  727. "└── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">classifier </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[2, 768] bias:[2]</span>\n",
  728. "</pre>\n"
  729. ],
  730. "text/plain": [
  731. "\u001b[37mroot\u001b[0m\n",
  732. "├── \u001b[37mbert \u001b[0m\u001b[32m(BertModel)\u001b[0m\n",
  733. "│ ├── \u001b[37membeddings \u001b[0m\u001b[32m(BertEmbeddings)\u001b[0m\n",
  734. "│ │ ├── \u001b[37mword_embeddings \u001b[0m\u001b[32m(Embedding) \u001b[0m\u001b[36mweight:[30522, 768]\u001b[0m\n",
  735. "│ │ ├── \u001b[37mposition_embeddings \u001b[0m\u001b[32m(Embedding) \u001b[0m\u001b[36mweight:[512, 768]\u001b[0m\n",
  736. "│ │ ├── \u001b[37mtoken_type_embeddings \u001b[0m\u001b[32m(Embedding) \u001b[0m\u001b[36mweight:[2, 768]\u001b[0m\n",
  737. "│ │ └── \u001b[37mLayerNorm \u001b[0m\u001b[32m(LayerNorm) \u001b[0m\u001b[36mweight:[768] \u001b[0m\u001b[36mbias:[768]\u001b[0m\n",
  738. "│ ├── \u001b[37mencoder \u001b[0m\u001b[32m(BertEncoder)\u001b[0m\n",
  739. "│ │ └── \u001b[37mlayer \u001b[0m\u001b[32m(ModuleList)\u001b[0m\n",
  740. "│ │ └── \u001b[31m0-11\u001b[0m\u001b[32m(BertLayer)\u001b[0m\n",
  741. "│ │ ├── \u001b[37mattention \u001b[0m\u001b[32m(BertAttention)\u001b[0m\n",
  742. "│ │ │ ├── \u001b[37mself \u001b[0m\u001b[32m(BertSelfAttention)\u001b[0m\n",
  743. "│ │ │ │ ├── \u001b[31mquery,key,value\u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[768, 768] \u001b[0m\u001b[36mbias:[768]\u001b[0m\n",
  744. "│ │ │ │ └── \u001b[37mprefix_tuning \u001b[0m\u001b[32m(PrefixTuningShim)\u001b[0m\n",
  745. "│ │ │ │ └── \u001b[37mpool \u001b[0m\u001b[32m(PrefixTuningPool)\u001b[0m\n",
  746. "│ │ │ └── \u001b[37moutput \u001b[0m\u001b[32m(BertSelfOutput)\u001b[0m\n",
  747. "│ │ │ ├── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[768, 768] \u001b[0m\u001b[36mbias:[768]\u001b[0m\n",
  748. "│ │ │ └── \u001b[37mLayerNorm \u001b[0m\u001b[32m(LayerNorm) \u001b[0m\u001b[36mweight:[768] \u001b[0m\u001b[36mbias:[768]\u001b[0m\n",
  749. "│ │ ├── \u001b[37mintermediate \u001b[0m\u001b[32m(BertIntermediate)\u001b[0m\n",
  750. "│ │ │ └── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[3072, 768] \u001b[0m\u001b[36mbias:[3072]\u001b[0m\n",
  751. "│ │ └── \u001b[37moutput \u001b[0m\u001b[32m(BertOutput)\u001b[0m\n",
  752. "│ │ ├── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[768, 3072] \u001b[0m\u001b[36mbias:[768]\u001b[0m\n",
  753. "│ │ └── \u001b[37mLayerNorm \u001b[0m\u001b[32m(LayerNorm) \u001b[0m\u001b[36mweight:[768] \u001b[0m\u001b[36mbias:[768]\u001b[0m\n",
  754. "│ ├── \u001b[37mpooler \u001b[0m\u001b[32m(BertPooler)\u001b[0m\n",
  755. "│ │ └── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[768, 768] \u001b[0m\u001b[36mbias:[768]\u001b[0m\n",
  756. "│ └── \u001b[37mprefix_tuning \u001b[0m\u001b[32m(PrefixTuningPool)\u001b[0m\n",
  757. "└── \u001b[37mclassifier \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[2, 768] \u001b[0m\u001b[36mbias:[2]\u001b[0m\n"
  758. ]
  759. },
  760. "metadata": {},
  761. "output_type": "display_data"
  762. }
  763. ],
  764. "source": [
  765. "Visualization(base_model).structure_graph();"
  766. ]
  767. },
  768. {
  769. "cell_type": "code",
  770. "execution_count": 44,
  771. "metadata": {
  772. "tags": []
  773. },
  774. "outputs": [],
  775. "source": [
  776. "delta_model = AdapterModel(base_model, bottleneck_dim=48)\n",
  777. "# leave the delta tuning modules and the newly initialized classification head tunable.\n",
  778. "delta_model.freeze_module(exclude=[\"deltas\", \"classifier\"])"
  779. ]
  780. },
  781. {
  782. "cell_type": "code",
  783. "execution_count": 45,
  784. "metadata": {
  785. "tags": []
  786. },
  787. "outputs": [
  788. {
  789. "data": {
  790. "text/html": [
  791. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">root</span>\n",
  792. "├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">bert </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertModel)</span>\n",
  793. "│ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertEmbeddings)</span>\n",
  794. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">word_embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(Embedding) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[30522, 768]</span>\n",
  795. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">position_embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(Embedding) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[512, 768]</span>\n",
  796. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">token_type_embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(Embedding) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[2, 768]</span>\n",
  797. "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">LayerNorm </span><span style=\"color: #008000; text-decoration-color: #008000\">(LayerNorm) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768] bias:[768]</span>\n",
  798. "│ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">encoder </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertEncoder)</span>\n",
  799. "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">layer </span><span style=\"color: #008000; text-decoration-color: #008000\">(ModuleList)</span>\n",
  800. "│ │ └── <span style=\"color: #800000; text-decoration-color: #800000\">0-11</span><span style=\"color: #008000; text-decoration-color: #008000\">(BertLayer)</span>\n",
  801. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">attention </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertAttention)</span>\n",
  802. "│ │ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">self </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertSelfAttention)</span>\n",
  803. "│ │ │ │ ├── <span style=\"color: #800000; text-decoration-color: #800000\">query,key,value</span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768, 768] bias:[768]</span>\n",
  804. "│ │ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">prefix_tuning </span><span style=\"color: #008000; text-decoration-color: #008000\">(PrefixTuningShim)</span>\n",
  805. "│ │ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">pool </span><span style=\"color: #008000; text-decoration-color: #008000\">(PrefixTuningPool)</span>\n",
  806. "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">output </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertSelfOutput)</span>\n",
  807. "│ │ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768, 768] bias:[768]</span>\n",
  808. "│ │ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">adapter </span><span style=\"color: #008000; text-decoration-color: #008000\">(AdapterLayer)</span>\n",
  809. "│ │ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">modulelist </span><span style=\"color: #008000; text-decoration-color: #008000\">(Sequential)</span>\n",
  810. "│ │ │ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">down_proj </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #af00ff; text-decoration-color: #af00ff\">weight:[48, 768] bias:[48]</span>\n",
  811. "│ │ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">up_proj </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #af00ff; text-decoration-color: #af00ff\">weight:[768, 48] bias:[768]</span>\n",
  812. "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">LayerNorm </span><span style=\"color: #008000; text-decoration-color: #008000\">(LayerNorm) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768] bias:[768]</span>\n",
  813. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">intermediate </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertIntermediate)</span>\n",
  814. "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[3072, 768] bias:[3072]</span>\n",
  815. "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">output </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertOutput)</span>\n",
  816. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768, 3072] bias:[768]</span>\n",
  817. "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">adapter </span><span style=\"color: #008000; text-decoration-color: #008000\">(AdapterLayer)</span>\n",
  818. "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">modulelist </span><span style=\"color: #008000; text-decoration-color: #008000\">(Sequential)</span>\n",
  819. "│ │ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">down_proj </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #af00ff; text-decoration-color: #af00ff\">weight:[48, 768] bias:[48]</span>\n",
  820. "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">up_proj </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #af00ff; text-decoration-color: #af00ff\">weight:[768, 48] bias:[768]</span>\n",
  821. "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">LayerNorm </span><span style=\"color: #008000; text-decoration-color: #008000\">(LayerNorm) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768] bias:[768]</span>\n",
  822. "│ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">pooler </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertPooler)</span>\n",
  823. "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768, 768] bias:[768]</span>\n",
  824. "│ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">prefix_tuning </span><span style=\"color: #008000; text-decoration-color: #008000\">(PrefixTuningPool)</span>\n",
  825. "└── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">classifier </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[2, 768] bias:[2]</span>\n",
  826. "</pre>\n"
  827. ],
  828. "text/plain": [
  829. "\u001b[37mroot\u001b[0m\n",
  830. "├── \u001b[37mbert \u001b[0m\u001b[32m(BertModel)\u001b[0m\n",
  831. "│ ├── \u001b[37membeddings \u001b[0m\u001b[32m(BertEmbeddings)\u001b[0m\n",
  832. "│ │ ├── \u001b[37mword_embeddings \u001b[0m\u001b[32m(Embedding) \u001b[0m\u001b[38;2;0;70;100mweight:[30522, 768]\u001b[0m\n",
  833. "│ │ ├── \u001b[37mposition_embeddings \u001b[0m\u001b[32m(Embedding) \u001b[0m\u001b[38;2;0;70;100mweight:[512, 768]\u001b[0m\n",
  834. "│ │ ├── \u001b[37mtoken_type_embeddings \u001b[0m\u001b[32m(Embedding) \u001b[0m\u001b[38;2;0;70;100mweight:[2, 768]\u001b[0m\n",
  835. "│ │ └── \u001b[37mLayerNorm \u001b[0m\u001b[32m(LayerNorm) \u001b[0m\u001b[38;2;0;70;100mweight:[768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
  836. "│ ├── \u001b[37mencoder \u001b[0m\u001b[32m(BertEncoder)\u001b[0m\n",
  837. "│ │ └── \u001b[37mlayer \u001b[0m\u001b[32m(ModuleList)\u001b[0m\n",
  838. "│ │ └── \u001b[31m0-11\u001b[0m\u001b[32m(BertLayer)\u001b[0m\n",
  839. "│ │ ├── \u001b[37mattention \u001b[0m\u001b[32m(BertAttention)\u001b[0m\n",
  840. "│ │ │ ├── \u001b[37mself \u001b[0m\u001b[32m(BertSelfAttention)\u001b[0m\n",
  841. "│ │ │ │ ├── \u001b[31mquery,key,value\u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;0;70;100mweight:[768, 768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
  842. "│ │ │ │ └── \u001b[37mprefix_tuning \u001b[0m\u001b[32m(PrefixTuningShim)\u001b[0m\n",
  843. "│ │ │ │ └── \u001b[37mpool \u001b[0m\u001b[32m(PrefixTuningPool)\u001b[0m\n",
  844. "│ │ │ └── \u001b[37moutput \u001b[0m\u001b[32m(BertSelfOutput)\u001b[0m\n",
  845. "│ │ │ ├── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;0;70;100mweight:[768, 768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
  846. "│ │ │ │ └── \u001b[37madapter \u001b[0m\u001b[32m(AdapterLayer)\u001b[0m\n",
  847. "│ │ │ │ └── \u001b[37mmodulelist \u001b[0m\u001b[32m(Sequential)\u001b[0m\n",
  848. "│ │ │ │ ├── \u001b[37mdown_proj \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;175;0;255mweight:[48, 768] \u001b[0m\u001b[38;2;175;0;255mbias:[48]\u001b[0m\n",
  849. "│ │ │ │ └── \u001b[37mup_proj \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;175;0;255mweight:[768, 48] \u001b[0m\u001b[38;2;175;0;255mbias:[768]\u001b[0m\n",
  850. "│ │ │ └── \u001b[37mLayerNorm \u001b[0m\u001b[32m(LayerNorm) \u001b[0m\u001b[38;2;0;70;100mweight:[768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
  851. "│ │ ├── \u001b[37mintermediate \u001b[0m\u001b[32m(BertIntermediate)\u001b[0m\n",
  852. "│ │ │ └── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;0;70;100mweight:[3072, 768] \u001b[0m\u001b[38;2;0;70;100mbias:[3072]\u001b[0m\n",
  853. "│ │ └── \u001b[37moutput \u001b[0m\u001b[32m(BertOutput)\u001b[0m\n",
  854. "│ │ ├── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;0;70;100mweight:[768, 3072] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
  855. "│ │ │ └── \u001b[37madapter \u001b[0m\u001b[32m(AdapterLayer)\u001b[0m\n",
  856. "│ │ │ └── \u001b[37mmodulelist \u001b[0m\u001b[32m(Sequential)\u001b[0m\n",
  857. "│ │ │ ├── \u001b[37mdown_proj \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;175;0;255mweight:[48, 768] \u001b[0m\u001b[38;2;175;0;255mbias:[48]\u001b[0m\n",
  858. "│ │ │ └── \u001b[37mup_proj \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;175;0;255mweight:[768, 48] \u001b[0m\u001b[38;2;175;0;255mbias:[768]\u001b[0m\n",
  859. "│ │ └── \u001b[37mLayerNorm \u001b[0m\u001b[32m(LayerNorm) \u001b[0m\u001b[38;2;0;70;100mweight:[768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
  860. "│ ├── \u001b[37mpooler \u001b[0m\u001b[32m(BertPooler)\u001b[0m\n",
  861. "│ │ └── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;0;70;100mweight:[768, 768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
  862. "│ └── \u001b[37mprefix_tuning \u001b[0m\u001b[32m(PrefixTuningPool)\u001b[0m\n",
  863. "└── \u001b[37mclassifier \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[2, 768] \u001b[0m\u001b[36mbias:[2]\u001b[0m\n"
  864. ]
  865. },
  866. "metadata": {},
  867. "output_type": "display_data"
  868. }
  869. ],
  870. "source": [
  871. "Visualization(base_model).structure_graph();"
  872. ]
  873. },
  874. {
  875. "cell_type": "code",
  876. "execution_count": null,
  877. "metadata": {
  878. "ExecuteTime": {
  879. "end_time": "2023-08-13T16:06:44.674950Z",
  880. "start_time": "2023-08-13T16:06:42.233454Z"
  881. }
  882. },
  883. "outputs": [],
  884. "source": [
  885. "from transformers import TrainingArguments, Trainer\n",
  886. "from sklearn.metrics import classification_report\n",
  887. "\n",
  888. "\n",
  889. "def compute_metrics(pred):\n",
  890. " true_labels = pred.label_ids.ravel()\n",
  891. " pred_labels = pred.predictions.argmax(-1).ravel()\n",
  892. " report = classification_report(true_labels, pred_labels, output_dict=True)\n",
  893. " return {\n",
  894. " 'accuracy': report['accuracy'],\n",
  895. " 'f1-score-1': report['1']['f1-score'],\n",
  896. " 'f1-score-ma': report['macro avg']['f1-score']\n",
  897. " }\n",
  898. "\n",
  899. "\n",
  900. "def train_model(input_model, task_name, train_dataset, eval_dataset, col_fn):\n",
  901. " training_args = TrainingArguments(\n",
  902. " evaluation_strategy=\"epoch\",\n",
  903. " save_strategy=\"epoch\",\n",
  904. " # The next 2 lines are important to ensure the dataset labels are properly passed to the model\n",
  905. " remove_unused_columns=False,\n",
  906. " **config.hf_trainer_params.to_dict()\n",
  907. " )\n",
  908. "\n",
  909. " trainer = Trainer(\n",
  910. " model=input_model,\n",
  911. " args=training_args,\n",
  912. " train_dataset=train_dataset,\n",
  913. " eval_dataset=eval_dataset,\n",
  914. " data_collator=col_fn,\n",
  915. " compute_metrics=compute_metrics\n",
  916. " )\n",
  917. " trainer.train()\n",
  918. "\n",
  919. "\n",
  920. "for task_name in config.tasks:\n",
  921. " loader_out = auto_loader.get_and_map(tokenizer_man.tokenizer, task_name)\n",
  922. " num_labels = len(loader_out['output']['range'])\n",
  923. " train_model(\n",
  924. " base_model,\n",
  925. " task_name,\n",
  926. " loader_out['train'],\n",
  927. " loader_out['valid'],\n",
  928. " tokenizer_man.get_col_fn()\n",
  929. " )"
  930. ]
  931. },
  932. {
  933. "cell_type": "code",
  934. "execution_count": 47,
  935. "metadata": {
  936. "tags": []
  937. },
  938. "outputs": [
  939. {
  940. "data": {
  941. "text/html": [
  942. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">root</span>\n",
  943. "├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">bert </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertModel)</span>\n",
  944. "│ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertEmbeddings)</span>\n",
  945. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">word_embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(Embedding) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[30522, 768]</span>\n",
  946. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">position_embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(Embedding) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[512, 768]</span>\n",
  947. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">token_type_embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(Embedding) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[2, 768]</span>\n",
  948. "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">LayerNorm </span><span style=\"color: #008000; text-decoration-color: #008000\">(LayerNorm) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768] bias:[768]</span>\n",
  949. "│ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">encoder </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertEncoder)</span>\n",
  950. "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">layer </span><span style=\"color: #008000; text-decoration-color: #008000\">(ModuleList)</span>\n",
  951. "│ │ └── <span style=\"color: #800000; text-decoration-color: #800000\">0-11</span><span style=\"color: #008000; text-decoration-color: #008000\">(BertLayer)</span>\n",
  952. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">attention </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertAttention)</span>\n",
  953. "│ │ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">self </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertSelfAttention)</span>\n",
  954. "│ │ │ │ ├── <span style=\"color: #800000; text-decoration-color: #800000\">query,key,value</span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768, 768] bias:[768]</span>\n",
  955. "│ │ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">prefix_tuning </span><span style=\"color: #008000; text-decoration-color: #008000\">(PrefixTuningShim)</span>\n",
  956. "│ │ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">pool </span><span style=\"color: #008000; text-decoration-color: #008000\">(PrefixTuningPool)</span>\n",
  957. "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">output </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertSelfOutput)</span>\n",
  958. "│ │ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768, 768] bias:[768]</span>\n",
  959. "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">LayerNorm </span><span style=\"color: #008000; text-decoration-color: #008000\">(LayerNorm) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768] bias:[768]</span>\n",
  960. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">intermediate </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertIntermediate)</span>\n",
  961. "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[3072, 768] bias:[3072]</span>\n",
  962. "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">output </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertOutput)</span>\n",
  963. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768, 3072] bias:[768]</span>\n",
  964. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">LayerNorm </span><span style=\"color: #008000; text-decoration-color: #008000\">(LayerNorm) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768] bias:[768]</span>\n",
  965. "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">adapters </span><span style=\"color: #008000; text-decoration-color: #008000\">(ModuleDict)</span>\n",
  966. "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">glue:cola </span><span style=\"color: #008000; text-decoration-color: #008000\">(Adapter)</span>\n",
  967. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">non_linearity </span><span style=\"color: #008000; text-decoration-color: #008000\">(Activation_Function_Class)</span>\n",
  968. "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">adapter_down </span><span style=\"color: #008000; text-decoration-color: #008000\">(Sequential)</span>\n",
  969. "│ │ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">0 </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[48, 768] bias:[48]</span>\n",
  970. "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">1 </span><span style=\"color: #008000; text-decoration-color: #008000\">(Activation_Function_Class)</span>\n",
  971. "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">adapter_up </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[768, 48] bias:[768]</span>\n",
  972. "│ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">pooler </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertPooler)</span>\n",
  973. "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768, 768] bias:[768]</span>\n",
  974. "│ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">prefix_tuning </span><span style=\"color: #008000; text-decoration-color: #008000\">(PrefixTuningPool)</span>\n",
  975. "└── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">heads </span><span style=\"color: #008000; text-decoration-color: #008000\">(ModuleDict)</span>\n",
  976. " └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">glue:cola </span><span style=\"color: #008000; text-decoration-color: #008000\">(ClassificationHead)</span>\n",
  977. " ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">1 </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[768, 768] bias:[768]</span>\n",
  978. " ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">2 </span><span style=\"color: #008000; text-decoration-color: #008000\">(Activation_Function_Class)</span>\n",
  979. " └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">4 </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[2, 768] bias:[2]</span>\n",
  980. "</pre>\n"
  981. ],
  982. "text/plain": [
  983. "\u001b[37mroot\u001b[0m\n",
  984. "├── \u001b[37mbert \u001b[0m\u001b[32m(BertModel)\u001b[0m\n",
  985. "│ ├── \u001b[37membeddings \u001b[0m\u001b[32m(BertEmbeddings)\u001b[0m\n",
  986. "│ │ ├── \u001b[37mword_embeddings \u001b[0m\u001b[32m(Embedding) \u001b[0m\u001b[38;2;0;70;100mweight:[30522, 768]\u001b[0m\n",
  987. "│ │ ├── \u001b[37mposition_embeddings \u001b[0m\u001b[32m(Embedding) \u001b[0m\u001b[38;2;0;70;100mweight:[512, 768]\u001b[0m\n",
  988. "│ │ ├── \u001b[37mtoken_type_embeddings \u001b[0m\u001b[32m(Embedding) \u001b[0m\u001b[38;2;0;70;100mweight:[2, 768]\u001b[0m\n",
  989. "│ │ └── \u001b[37mLayerNorm \u001b[0m\u001b[32m(LayerNorm) \u001b[0m\u001b[38;2;0;70;100mweight:[768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
  990. "│ ├── \u001b[37mencoder \u001b[0m\u001b[32m(BertEncoder)\u001b[0m\n",
  991. "│ │ └── \u001b[37mlayer \u001b[0m\u001b[32m(ModuleList)\u001b[0m\n",
  992. "│ │ └── \u001b[31m0-11\u001b[0m\u001b[32m(BertLayer)\u001b[0m\n",
  993. "│ │ ├── \u001b[37mattention \u001b[0m\u001b[32m(BertAttention)\u001b[0m\n",
  994. "│ │ │ ├── \u001b[37mself \u001b[0m\u001b[32m(BertSelfAttention)\u001b[0m\n",
  995. "│ │ │ │ ├── \u001b[31mquery,key,value\u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;0;70;100mweight:[768, 768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
  996. "│ │ │ │ └── \u001b[37mprefix_tuning \u001b[0m\u001b[32m(PrefixTuningShim)\u001b[0m\n",
  997. "│ │ │ │ └── \u001b[37mpool \u001b[0m\u001b[32m(PrefixTuningPool)\u001b[0m\n",
  998. "│ │ │ └── \u001b[37moutput \u001b[0m\u001b[32m(BertSelfOutput)\u001b[0m\n",
  999. "│ │ │ ├── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;0;70;100mweight:[768, 768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
  1000. "│ │ │ └── \u001b[37mLayerNorm \u001b[0m\u001b[32m(LayerNorm) \u001b[0m\u001b[38;2;0;70;100mweight:[768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
  1001. "│ │ ├── \u001b[37mintermediate \u001b[0m\u001b[32m(BertIntermediate)\u001b[0m\n",
  1002. "│ │ │ └── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;0;70;100mweight:[3072, 768] \u001b[0m\u001b[38;2;0;70;100mbias:[3072]\u001b[0m\n",
  1003. "│ │ └── \u001b[37moutput \u001b[0m\u001b[32m(BertOutput)\u001b[0m\n",
  1004. "│ │ ├── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;0;70;100mweight:[768, 3072] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
  1005. "│ │ ├── \u001b[37mLayerNorm \u001b[0m\u001b[32m(LayerNorm) \u001b[0m\u001b[38;2;0;70;100mweight:[768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
  1006. "│ │ └── \u001b[37madapters \u001b[0m\u001b[32m(ModuleDict)\u001b[0m\n",
  1007. "│ │ └── \u001b[37mglue:cola \u001b[0m\u001b[32m(Adapter)\u001b[0m\n",
  1008. "│ │ ├── \u001b[37mnon_linearity \u001b[0m\u001b[32m(Activation_Function_Class)\u001b[0m\n",
  1009. "│ │ ├── \u001b[37madapter_down \u001b[0m\u001b[32m(Sequential)\u001b[0m\n",
  1010. "│ │ │ ├── \u001b[37m0 \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[48, 768] \u001b[0m\u001b[36mbias:[48]\u001b[0m\n",
  1011. "│ │ │ └── \u001b[37m1 \u001b[0m\u001b[32m(Activation_Function_Class)\u001b[0m\n",
  1012. "│ │ └── \u001b[37madapter_up \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[768, 48] \u001b[0m\u001b[36mbias:[768]\u001b[0m\n",
  1013. "│ ├── \u001b[37mpooler \u001b[0m\u001b[32m(BertPooler)\u001b[0m\n",
  1014. "│ │ └── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;0;70;100mweight:[768, 768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
  1015. "│ └── \u001b[37mprefix_tuning \u001b[0m\u001b[32m(PrefixTuningPool)\u001b[0m\n",
  1016. "└── \u001b[37mheads \u001b[0m\u001b[32m(ModuleDict)\u001b[0m\n",
  1017. " └── \u001b[37mglue:cola \u001b[0m\u001b[32m(ClassificationHead)\u001b[0m\n",
  1018. " ├── \u001b[37m1 \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[768, 768] \u001b[0m\u001b[36mbias:[768]\u001b[0m\n",
  1019. " ├── \u001b[37m2 \u001b[0m\u001b[32m(Activation_Function_Class)\u001b[0m\n",
  1020. " └── \u001b[37m4 \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[2, 768] \u001b[0m\u001b[36mbias:[2]\u001b[0m\n"
  1021. ]
  1022. },
  1023. "metadata": {},
  1024. "output_type": "display_data"
  1025. }
  1026. ],
  1027. "source": [
  1028. "Visualization(adapter_wrapper.model).structure_graph();"
  1029. ]
  1030. },
  1031. {
  1032. "cell_type": "code",
  1033. "execution_count": null,
  1034. "metadata": {
  1035. "ExecuteTime": {
  1036. "end_time": "2023-08-15T13:11:54.968862Z",
  1037. "start_time": "2023-08-15T13:11:54.946870Z"
  1038. }
  1039. },
  1040. "outputs": [],
  1041. "source": [
  1042. "results"
  1043. ]
  1044. },
  1045. {
  1046. "cell_type": "code",
  1047. "execution_count": null,
  1048. "metadata": {
  1049. "ExecuteTime": {
  1050. "end_time": "2023-08-15T13:23:50.492273Z",
  1051. "start_time": "2023-08-15T13:22:40.985364Z"
  1052. }
  1053. },
  1054. "outputs": [],
  1055. "source": [
  1056. "from _datasets import GLUEHelper\n",
  1057. " \n",
  1058. "gl_helper = GLUEHelper()"
  1059. ]
  1060. },
  1061. {
  1062. "cell_type": "code",
  1063. "execution_count": null,
  1064. "metadata": {
  1065. "ExecuteTime": {
  1066. "end_time": "2023-08-15T13:46:17.380290Z",
  1067. "start_time": "2023-08-15T13:46:17.346993Z"
  1068. }
  1069. },
  1070. "outputs": [],
  1071. "source": [
  1072. "for n in range(0, 1000):\n",
  1073. " out = gl_helper.datasets['stsb']['train'][n]\n",
  1074. " if out['label'] == 0.:\n",
  1075. " print(out)\n",
  1076. " break"
  1077. ]
  1078. },
  1079. {
  1080. "cell_type": "code",
  1081. "execution_count": null,
  1082. "metadata": {},
  1083. "outputs": [],
  1084. "source": [
  1085. "from evaluate import load\n",
  1086. "glue_metric = load('glue', 'stsb')"
  1087. ]
  1088. },
  1089. {
  1090. "cell_type": "code",
  1091. "execution_count": null,
  1092. "metadata": {},
  1093. "outputs": [],
  1094. "source": [
  1095. "results = glue_metric.compute(predictions=[-0.5, -0.3], references=[-0.5, 1])\n",
  1096. "results"
  1097. ]
  1098. },
  1099. {
  1100. "cell_type": "code",
  1101. "execution_count": null,
  1102. "metadata": {
  1103. "ExecuteTime": {
  1104. "end_time": "2023-08-13T18:17:59.084998Z",
  1105. "start_time": "2023-08-13T18:17:59.050653Z"
  1106. }
  1107. },
  1108. "outputs": [],
  1109. "source": [
  1110. "gl_helper.datasets['mnli']"
  1111. ]
  1112. },
  1113. {
  1114. "cell_type": "code",
  1115. "execution_count": null,
  1116. "metadata": {
  1117. "ExecuteTime": {
  1118. "end_time": "2023-08-13T18:17:59.157406Z",
  1119. "start_time": "2023-08-13T18:17:59.081370Z"
  1120. }
  1121. },
  1122. "outputs": [],
  1123. "source": [
  1124. "gl_helper.datasets['mnli_matched']\n"
  1125. ]
  1126. },
  1127. {
  1128. "cell_type": "code",
  1129. "execution_count": null,
  1130. "metadata": {
  1131. "ExecuteTime": {
  1132. "end_time": "2023-08-13T18:18:01.203910Z",
  1133. "start_time": "2023-08-13T18:18:01.171842Z"
  1134. }
  1135. },
  1136. "outputs": [],
  1137. "source": [
  1138. "gl_helper.datasets['mnli_mismatched']\n"
  1139. ]
  1140. },
  1141. {
  1142. "cell_type": "code",
  1143. "execution_count": null,
  1144. "metadata": {
  1145. "ExecuteTime": {
  1146. "end_time": "2023-08-13T18:30:16.905587Z",
  1147. "start_time": "2023-08-13T18:30:16.775197Z"
  1148. }
  1149. },
  1150. "outputs": [],
  1151. "source": [
  1152. "import transformers\n",
  1153. "\n",
  1154. "\n",
  1155. "print(transformers.__version__)"
  1156. ]
  1157. },
  1158. {
  1159. "cell_type": "code",
  1160. "execution_count": null,
  1161. "metadata": {
  1162. "ExecuteTime": {
  1163. "end_time": "2023-08-13T18:29:49.383120Z",
  1164. "start_time": "2023-08-13T18:29:40.017083Z"
  1165. }
  1166. },
  1167. "outputs": [],
  1168. "source": [
  1169. "pip install adapter-transformers"
  1170. ]
  1171. },
  1172. {
  1173. "cell_type": "code",
  1174. "execution_count": null,
  1175. "metadata": {},
  1176. "outputs": [],
  1177. "source": []
  1178. }
  1179. ],
  1180. "metadata": {
  1181. "kernelspec": {
  1182. "display_name": "Python [conda env:lll]",
  1183. "language": "python",
  1184. "name": "conda-env-lll-py"
  1185. },
  1186. "language_info": {
  1187. "codemirror_mode": {
  1188. "name": "ipython",
  1189. "version": 3
  1190. },
  1191. "file_extension": ".py",
  1192. "mimetype": "text/x-python",
  1193. "name": "python",
  1194. "nbconvert_exporter": "python",
  1195. "pygments_lexer": "ipython3",
  1196. "version": "3.9.17"
  1197. }
  1198. },
  1199. "nbformat": 4,
  1200. "nbformat_minor": 4
  1201. }