You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

04_T5_custom.ipynb 45KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "id": "896de91a-4ab9-40f5-a3c1-914535b6e0a7",
  6. "metadata": {},
  7. "source": [
  8. "# intro"
  9. ]
  10. },
  11. {
  12. "cell_type": "code",
  13. "execution_count": 1,
  14. "id": "5f17aae6-73f5-4793-95a3-09147ea89e04",
  15. "metadata": {
  16. "tags": []
  17. },
  18. "outputs": [
  19. {
  20. "name": "stdout",
  21. "output_type": "stream",
  22. "text": [
  23. "Python version is: 3.10.11\n",
  24. "Scikit-learn version is: 1.2.2\n",
  25. "Torch version is: 1.13.1+cu117\n",
  26. "Nvidia device is: NVIDIA GeForce RTX 4090\n",
  27. "Transformers version is: 4.32.1\n",
  28. "Adapterhub not found!!!\n"
  29. ]
  30. }
  31. ],
  32. "source": [
  33. "from typing import Optional\n",
  34. "\n",
  35. "import numpy as np\n",
  36. "from tqdm.notebook import tqdm\n",
  37. "\n",
  38. "import torch\n",
  39. "import torch.nn as nn\n",
  40. "from transformers import T5TokenizerFast, T5ForConditionalGeneration\n",
  41. "\n",
  42. "from _utils import print_system_info, generate_dataloader\n",
  43. "from _datasets import AutoLoad\n",
  44. "from _mydelta import T5Wrapper, auto_freeze\n",
  45. "from _trainer import train_loop, valid_loop\n",
  46. "\n",
  47. "print_system_info()"
  48. ]
  49. },
  50. {
  51. "cell_type": "code",
  52. "execution_count": 2,
  53. "id": "fb5ef784-fef0-4b7b-98e7-ec5d3575a9a8",
  54. "metadata": {
  55. "tags": []
  56. },
  57. "outputs": [],
  58. "source": [
  59. "from types import SimpleNamespace\n",
  60. "config = SimpleNamespace(\n",
  61. " model_name='google/t5-base-lm-adapt',\n",
  62. " n_tokens=30,\n",
  63. " n_layers=6,\n",
  64. " random_seed=42,\n",
  65. " task=['glue:cola'],\n",
  66. " hot_modules=['sadcl'],\n",
  67. " train_batch_size=32,\n",
  68. " valid_batch_size=32,\n",
  69. " balancify_sample=False,\n",
  70. " learning_rate=0.01,\n",
  71. " num_epochs=200\n",
  72. ")"
  73. ]
  74. },
  75. {
  76. "cell_type": "code",
  77. "execution_count": 3,
  78. "id": "d3802d01-7c5a-4c11-beaf-f683a2fb9d80",
  79. "metadata": {
  80. "tags": []
  81. },
  82. "outputs": [],
  83. "source": [
  84. "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
  85. "\n",
  86. "np.random.seed(config.random_seed)\n",
  87. "slected_tokens = torch.from_numpy(np.random.randint(0, 32128, size=(config.n_tokens,)))"
  88. ]
  89. },
  90. {
  91. "cell_type": "markdown",
  92. "id": "1e785d49-beca-4333-986e-b198bbaadf7d",
  93. "metadata": {},
  94. "source": [
  95. "# load model and date"
  96. ]
  97. },
  98. {
  99. "cell_type": "code",
  100. "execution_count": 4,
  101. "id": "afcc6244-978a-425a-9fa9-8b11dd0df8ba",
  102. "metadata": {
  103. "tags": []
  104. },
  105. "outputs": [],
  106. "source": [
  107. "model = T5ForConditionalGeneration.from_pretrained(config.model_name)\n",
  108. "tokenizer = T5TokenizerFast.from_pretrained(config.model_name, model_max_length=2048)"
  109. ]
  110. },
  111. {
  112. "cell_type": "code",
  113. "execution_count": 5,
  114. "id": "894a8474-e2e1-4f9d-b9ab-58d911808ec0",
  115. "metadata": {
  116. "tags": []
  117. },
  118. "outputs": [
  119. {
  120. "name": "stdout",
  121. "output_type": "stream",
  122. "text": [
  123. "encoder.block.6.soft_prompt.sadcl_learned_embedding\n",
  124. "encoder.block.7.soft_prompt.sadcl_learned_embedding\n",
  125. "encoder.block.8.soft_prompt.sadcl_learned_embedding\n",
  126. "encoder.block.9.soft_prompt.sadcl_learned_embedding\n",
  127. "encoder.block.10.soft_prompt.sadcl_learned_embedding\n",
  128. "encoder.block.11.soft_prompt.sadcl_learned_embedding\n"
  129. ]
  130. }
  131. ],
  132. "source": [
  133. "delta_module = T5Wrapper.mutate(\n",
  134. " model=model,\n",
  135. " config=config,\n",
  136. " slected_tokens=slected_tokens\n",
  137. ")\n",
  138. "auto_freeze(model, config.hot_modules, verbose=True)"
  139. ]
  140. },
  141. {
  142. "cell_type": "code",
  143. "execution_count": 15,
  144. "id": "9453d3cc-c04c-4a27-83aa-eaac3e49c14e",
  145. "metadata": {
  146. "tags": []
  147. },
  148. "outputs": [
  149. {
  150. "name": "stdout",
  151. "output_type": "stream",
  152. "text": [
  153. "shared.weight\n",
  154. "encoder.block.0.layer.0.SelfAttention.q.weight\n",
  155. "encoder.block.0.layer.0.SelfAttention.k.weight\n",
  156. "encoder.block.0.layer.0.SelfAttention.v.weight\n",
  157. "encoder.block.0.layer.0.SelfAttention.o.weight\n",
  158. "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight\n",
  159. "encoder.block.0.layer.0.layer_norm.weight\n",
  160. "encoder.block.0.layer.1.DenseReluDense.wi_0.weight\n",
  161. "encoder.block.0.layer.1.DenseReluDense.wi_1.weight\n",
  162. "encoder.block.0.layer.1.DenseReluDense.wo.weight\n",
  163. "encoder.block.0.layer.1.layer_norm.weight\n",
  164. "encoder.block.1.layer.0.SelfAttention.q.weight\n",
  165. "encoder.block.1.layer.0.SelfAttention.k.weight\n",
  166. "encoder.block.1.layer.0.SelfAttention.v.weight\n",
  167. "encoder.block.1.layer.0.SelfAttention.o.weight\n",
  168. "encoder.block.1.layer.0.layer_norm.weight\n",
  169. "encoder.block.1.layer.1.DenseReluDense.wi_0.weight\n",
  170. "encoder.block.1.layer.1.DenseReluDense.wi_1.weight\n",
  171. "encoder.block.1.layer.1.DenseReluDense.wo.weight\n",
  172. "encoder.block.1.layer.1.layer_norm.weight\n",
  173. "encoder.block.2.layer.0.SelfAttention.q.weight\n",
  174. "encoder.block.2.layer.0.SelfAttention.k.weight\n",
  175. "encoder.block.2.layer.0.SelfAttention.v.weight\n",
  176. "encoder.block.2.layer.0.SelfAttention.o.weight\n",
  177. "encoder.block.2.layer.0.layer_norm.weight\n",
  178. "encoder.block.2.layer.1.DenseReluDense.wi_0.weight\n",
  179. "encoder.block.2.layer.1.DenseReluDense.wi_1.weight\n",
  180. "encoder.block.2.layer.1.DenseReluDense.wo.weight\n",
  181. "encoder.block.2.layer.1.layer_norm.weight\n",
  182. "encoder.block.3.layer.0.SelfAttention.q.weight\n",
  183. "encoder.block.3.layer.0.SelfAttention.k.weight\n",
  184. "encoder.block.3.layer.0.SelfAttention.v.weight\n",
  185. "encoder.block.3.layer.0.SelfAttention.o.weight\n",
  186. "encoder.block.3.layer.0.layer_norm.weight\n",
  187. "encoder.block.3.layer.1.DenseReluDense.wi_0.weight\n",
  188. "encoder.block.3.layer.1.DenseReluDense.wi_1.weight\n",
  189. "encoder.block.3.layer.1.DenseReluDense.wo.weight\n",
  190. "encoder.block.3.layer.1.layer_norm.weight\n",
  191. "encoder.block.4.layer.0.SelfAttention.q.weight\n",
  192. "encoder.block.4.layer.0.SelfAttention.k.weight\n",
  193. "encoder.block.4.layer.0.SelfAttention.v.weight\n",
  194. "encoder.block.4.layer.0.SelfAttention.o.weight\n",
  195. "encoder.block.4.layer.0.layer_norm.weight\n",
  196. "encoder.block.4.layer.1.DenseReluDense.wi_0.weight\n",
  197. "encoder.block.4.layer.1.DenseReluDense.wi_1.weight\n",
  198. "encoder.block.4.layer.1.DenseReluDense.wo.weight\n",
  199. "encoder.block.4.layer.1.layer_norm.weight\n",
  200. "encoder.block.5.layer.0.SelfAttention.q.weight\n",
  201. "encoder.block.5.layer.0.SelfAttention.k.weight\n",
  202. "encoder.block.5.layer.0.SelfAttention.v.weight\n",
  203. "encoder.block.5.layer.0.SelfAttention.o.weight\n",
  204. "encoder.block.5.layer.0.layer_norm.weight\n",
  205. "encoder.block.5.layer.1.DenseReluDense.wi_0.weight\n",
  206. "encoder.block.5.layer.1.DenseReluDense.wi_1.weight\n",
  207. "encoder.block.5.layer.1.DenseReluDense.wo.weight\n",
  208. "encoder.block.5.layer.1.layer_norm.weight\n",
  209. "encoder.block.6.original_module.layer.0.SelfAttention.q.weight\n",
  210. "encoder.block.6.original_module.layer.0.SelfAttention.k.weight\n",
  211. "encoder.block.6.original_module.layer.0.SelfAttention.v.weight\n",
  212. "encoder.block.6.original_module.layer.0.SelfAttention.o.weight\n",
  213. "encoder.block.6.original_module.layer.0.layer_norm.weight\n",
  214. "encoder.block.6.original_module.layer.1.DenseReluDense.wi_0.weight\n",
  215. "encoder.block.6.original_module.layer.1.DenseReluDense.wi_1.weight\n",
  216. "encoder.block.6.original_module.layer.1.DenseReluDense.wo.weight\n",
  217. "encoder.block.6.original_module.layer.1.layer_norm.weight\n",
  218. "encoder.block.6.soft_prompt.sadcl_learned_embedding\n",
  219. "encoder.block.7.original_module.layer.0.SelfAttention.q.weight\n",
  220. "encoder.block.7.original_module.layer.0.SelfAttention.k.weight\n",
  221. "encoder.block.7.original_module.layer.0.SelfAttention.v.weight\n",
  222. "encoder.block.7.original_module.layer.0.SelfAttention.o.weight\n",
  223. "encoder.block.7.original_module.layer.0.layer_norm.weight\n",
  224. "encoder.block.7.original_module.layer.1.DenseReluDense.wi_0.weight\n",
  225. "encoder.block.7.original_module.layer.1.DenseReluDense.wi_1.weight\n",
  226. "encoder.block.7.original_module.layer.1.DenseReluDense.wo.weight\n",
  227. "encoder.block.7.original_module.layer.1.layer_norm.weight\n",
  228. "encoder.block.7.soft_prompt.sadcl_learned_embedding\n",
  229. "encoder.block.8.original_module.layer.0.SelfAttention.q.weight\n",
  230. "encoder.block.8.original_module.layer.0.SelfAttention.k.weight\n",
  231. "encoder.block.8.original_module.layer.0.SelfAttention.v.weight\n",
  232. "encoder.block.8.original_module.layer.0.SelfAttention.o.weight\n",
  233. "encoder.block.8.original_module.layer.0.layer_norm.weight\n",
  234. "encoder.block.8.original_module.layer.1.DenseReluDense.wi_0.weight\n",
  235. "encoder.block.8.original_module.layer.1.DenseReluDense.wi_1.weight\n",
  236. "encoder.block.8.original_module.layer.1.DenseReluDense.wo.weight\n",
  237. "encoder.block.8.original_module.layer.1.layer_norm.weight\n",
  238. "encoder.block.8.soft_prompt.sadcl_learned_embedding\n",
  239. "encoder.block.9.original_module.layer.0.SelfAttention.q.weight\n",
  240. "encoder.block.9.original_module.layer.0.SelfAttention.k.weight\n",
  241. "encoder.block.9.original_module.layer.0.SelfAttention.v.weight\n",
  242. "encoder.block.9.original_module.layer.0.SelfAttention.o.weight\n",
  243. "encoder.block.9.original_module.layer.0.layer_norm.weight\n",
  244. "encoder.block.9.original_module.layer.1.DenseReluDense.wi_0.weight\n",
  245. "encoder.block.9.original_module.layer.1.DenseReluDense.wi_1.weight\n",
  246. "encoder.block.9.original_module.layer.1.DenseReluDense.wo.weight\n",
  247. "encoder.block.9.original_module.layer.1.layer_norm.weight\n",
  248. "encoder.block.9.soft_prompt.sadcl_learned_embedding\n",
  249. "encoder.block.10.original_module.layer.0.SelfAttention.q.weight\n",
  250. "encoder.block.10.original_module.layer.0.SelfAttention.k.weight\n",
  251. "encoder.block.10.original_module.layer.0.SelfAttention.v.weight\n",
  252. "encoder.block.10.original_module.layer.0.SelfAttention.o.weight\n",
  253. "encoder.block.10.original_module.layer.0.layer_norm.weight\n",
  254. "encoder.block.10.original_module.layer.1.DenseReluDense.wi_0.weight\n",
  255. "encoder.block.10.original_module.layer.1.DenseReluDense.wi_1.weight\n",
  256. "encoder.block.10.original_module.layer.1.DenseReluDense.wo.weight\n",
  257. "encoder.block.10.original_module.layer.1.layer_norm.weight\n",
  258. "encoder.block.10.soft_prompt.sadcl_learned_embedding\n",
  259. "encoder.block.11.original_module.layer.0.SelfAttention.q.weight\n",
  260. "encoder.block.11.original_module.layer.0.SelfAttention.k.weight\n",
  261. "encoder.block.11.original_module.layer.0.SelfAttention.v.weight\n",
  262. "encoder.block.11.original_module.layer.0.SelfAttention.o.weight\n",
  263. "encoder.block.11.original_module.layer.0.layer_norm.weight\n",
  264. "encoder.block.11.original_module.layer.1.DenseReluDense.wi_0.weight\n",
  265. "encoder.block.11.original_module.layer.1.DenseReluDense.wi_1.weight\n",
  266. "encoder.block.11.original_module.layer.1.DenseReluDense.wo.weight\n",
  267. "encoder.block.11.original_module.layer.1.layer_norm.weight\n",
  268. "encoder.block.11.soft_prompt.sadcl_learned_embedding\n",
  269. "encoder.final_layer_norm.weight\n",
  270. "decoder.block.0.layer.0.SelfAttention.q.weight\n",
  271. "decoder.block.0.layer.0.SelfAttention.k.weight\n",
  272. "decoder.block.0.layer.0.SelfAttention.v.weight\n",
  273. "decoder.block.0.layer.0.SelfAttention.o.weight\n",
  274. "decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight\n",
  275. "decoder.block.0.layer.0.layer_norm.weight\n",
  276. "decoder.block.0.layer.1.EncDecAttention.q.weight\n",
  277. "decoder.block.0.layer.1.EncDecAttention.k.weight\n",
  278. "decoder.block.0.layer.1.EncDecAttention.v.weight\n",
  279. "decoder.block.0.layer.1.EncDecAttention.o.weight\n",
  280. "decoder.block.0.layer.1.layer_norm.weight\n",
  281. "decoder.block.0.layer.2.DenseReluDense.wi_0.weight\n",
  282. "decoder.block.0.layer.2.DenseReluDense.wi_1.weight\n",
  283. "decoder.block.0.layer.2.DenseReluDense.wo.weight\n",
  284. "decoder.block.0.layer.2.layer_norm.weight\n",
  285. "decoder.block.1.layer.0.SelfAttention.q.weight\n",
  286. "decoder.block.1.layer.0.SelfAttention.k.weight\n",
  287. "decoder.block.1.layer.0.SelfAttention.v.weight\n",
  288. "decoder.block.1.layer.0.SelfAttention.o.weight\n",
  289. "decoder.block.1.layer.0.layer_norm.weight\n",
  290. "decoder.block.1.layer.1.EncDecAttention.q.weight\n",
  291. "decoder.block.1.layer.1.EncDecAttention.k.weight\n",
  292. "decoder.block.1.layer.1.EncDecAttention.v.weight\n",
  293. "decoder.block.1.layer.1.EncDecAttention.o.weight\n",
  294. "decoder.block.1.layer.1.layer_norm.weight\n",
  295. "decoder.block.1.layer.2.DenseReluDense.wi_0.weight\n",
  296. "decoder.block.1.layer.2.DenseReluDense.wi_1.weight\n",
  297. "decoder.block.1.layer.2.DenseReluDense.wo.weight\n",
  298. "decoder.block.1.layer.2.layer_norm.weight\n",
  299. "decoder.block.2.layer.0.SelfAttention.q.weight\n",
  300. "decoder.block.2.layer.0.SelfAttention.k.weight\n",
  301. "decoder.block.2.layer.0.SelfAttention.v.weight\n",
  302. "decoder.block.2.layer.0.SelfAttention.o.weight\n",
  303. "decoder.block.2.layer.0.layer_norm.weight\n",
  304. "decoder.block.2.layer.1.EncDecAttention.q.weight\n",
  305. "decoder.block.2.layer.1.EncDecAttention.k.weight\n",
  306. "decoder.block.2.layer.1.EncDecAttention.v.weight\n",
  307. "decoder.block.2.layer.1.EncDecAttention.o.weight\n",
  308. "decoder.block.2.layer.1.layer_norm.weight\n",
  309. "decoder.block.2.layer.2.DenseReluDense.wi_0.weight\n",
  310. "decoder.block.2.layer.2.DenseReluDense.wi_1.weight\n",
  311. "decoder.block.2.layer.2.DenseReluDense.wo.weight\n",
  312. "decoder.block.2.layer.2.layer_norm.weight\n",
  313. "decoder.block.3.layer.0.SelfAttention.q.weight\n",
  314. "decoder.block.3.layer.0.SelfAttention.k.weight\n",
  315. "decoder.block.3.layer.0.SelfAttention.v.weight\n",
  316. "decoder.block.3.layer.0.SelfAttention.o.weight\n",
  317. "decoder.block.3.layer.0.layer_norm.weight\n",
  318. "decoder.block.3.layer.1.EncDecAttention.q.weight\n",
  319. "decoder.block.3.layer.1.EncDecAttention.k.weight\n",
  320. "decoder.block.3.layer.1.EncDecAttention.v.weight\n",
  321. "decoder.block.3.layer.1.EncDecAttention.o.weight\n",
  322. "decoder.block.3.layer.1.layer_norm.weight\n",
  323. "decoder.block.3.layer.2.DenseReluDense.wi_0.weight\n",
  324. "decoder.block.3.layer.2.DenseReluDense.wi_1.weight\n",
  325. "decoder.block.3.layer.2.DenseReluDense.wo.weight\n",
  326. "decoder.block.3.layer.2.layer_norm.weight\n",
  327. "decoder.block.4.layer.0.SelfAttention.q.weight\n",
  328. "decoder.block.4.layer.0.SelfAttention.k.weight\n",
  329. "decoder.block.4.layer.0.SelfAttention.v.weight\n",
  330. "decoder.block.4.layer.0.SelfAttention.o.weight\n",
  331. "decoder.block.4.layer.0.layer_norm.weight\n",
  332. "decoder.block.4.layer.1.EncDecAttention.q.weight\n",
  333. "decoder.block.4.layer.1.EncDecAttention.k.weight\n",
  334. "decoder.block.4.layer.1.EncDecAttention.v.weight\n",
  335. "decoder.block.4.layer.1.EncDecAttention.o.weight\n",
  336. "decoder.block.4.layer.1.layer_norm.weight\n",
  337. "decoder.block.4.layer.2.DenseReluDense.wi_0.weight\n",
  338. "decoder.block.4.layer.2.DenseReluDense.wi_1.weight\n",
  339. "decoder.block.4.layer.2.DenseReluDense.wo.weight\n",
  340. "decoder.block.4.layer.2.layer_norm.weight\n",
  341. "decoder.block.5.layer.0.SelfAttention.q.weight\n",
  342. "decoder.block.5.layer.0.SelfAttention.k.weight\n",
  343. "decoder.block.5.layer.0.SelfAttention.v.weight\n",
  344. "decoder.block.5.layer.0.SelfAttention.o.weight\n",
  345. "decoder.block.5.layer.0.layer_norm.weight\n",
  346. "decoder.block.5.layer.1.EncDecAttention.q.weight\n",
  347. "decoder.block.5.layer.1.EncDecAttention.k.weight\n",
  348. "decoder.block.5.layer.1.EncDecAttention.v.weight\n",
  349. "decoder.block.5.layer.1.EncDecAttention.o.weight\n",
  350. "decoder.block.5.layer.1.layer_norm.weight\n",
  351. "decoder.block.5.layer.2.DenseReluDense.wi_0.weight\n",
  352. "decoder.block.5.layer.2.DenseReluDense.wi_1.weight\n",
  353. "decoder.block.5.layer.2.DenseReluDense.wo.weight\n",
  354. "decoder.block.5.layer.2.layer_norm.weight\n",
  355. "decoder.block.6.layer.0.SelfAttention.q.weight\n",
  356. "decoder.block.6.layer.0.SelfAttention.k.weight\n",
  357. "decoder.block.6.layer.0.SelfAttention.v.weight\n",
  358. "decoder.block.6.layer.0.SelfAttention.o.weight\n",
  359. "decoder.block.6.layer.0.layer_norm.weight\n",
  360. "decoder.block.6.layer.1.EncDecAttention.q.weight\n",
  361. "decoder.block.6.layer.1.EncDecAttention.k.weight\n",
  362. "decoder.block.6.layer.1.EncDecAttention.v.weight\n",
  363. "decoder.block.6.layer.1.EncDecAttention.o.weight\n",
  364. "decoder.block.6.layer.1.layer_norm.weight\n",
  365. "decoder.block.6.layer.2.DenseReluDense.wi_0.weight\n",
  366. "decoder.block.6.layer.2.DenseReluDense.wi_1.weight\n",
  367. "decoder.block.6.layer.2.DenseReluDense.wo.weight\n",
  368. "decoder.block.6.layer.2.layer_norm.weight\n",
  369. "decoder.block.7.layer.0.SelfAttention.q.weight\n",
  370. "decoder.block.7.layer.0.SelfAttention.k.weight\n",
  371. "decoder.block.7.layer.0.SelfAttention.v.weight\n",
  372. "decoder.block.7.layer.0.SelfAttention.o.weight\n",
  373. "decoder.block.7.layer.0.layer_norm.weight\n",
  374. "decoder.block.7.layer.1.EncDecAttention.q.weight\n",
  375. "decoder.block.7.layer.1.EncDecAttention.k.weight\n",
  376. "decoder.block.7.layer.1.EncDecAttention.v.weight\n",
  377. "decoder.block.7.layer.1.EncDecAttention.o.weight\n",
  378. "decoder.block.7.layer.1.layer_norm.weight\n",
  379. "decoder.block.7.layer.2.DenseReluDense.wi_0.weight\n",
  380. "decoder.block.7.layer.2.DenseReluDense.wi_1.weight\n",
  381. "decoder.block.7.layer.2.DenseReluDense.wo.weight\n",
  382. "decoder.block.7.layer.2.layer_norm.weight\n",
  383. "decoder.block.8.layer.0.SelfAttention.q.weight\n",
  384. "decoder.block.8.layer.0.SelfAttention.k.weight\n",
  385. "decoder.block.8.layer.0.SelfAttention.v.weight\n",
  386. "decoder.block.8.layer.0.SelfAttention.o.weight\n",
  387. "decoder.block.8.layer.0.layer_norm.weight\n",
  388. "decoder.block.8.layer.1.EncDecAttention.q.weight\n",
  389. "decoder.block.8.layer.1.EncDecAttention.k.weight\n",
  390. "decoder.block.8.layer.1.EncDecAttention.v.weight\n",
  391. "decoder.block.8.layer.1.EncDecAttention.o.weight\n",
  392. "decoder.block.8.layer.1.layer_norm.weight\n",
  393. "decoder.block.8.layer.2.DenseReluDense.wi_0.weight\n",
  394. "decoder.block.8.layer.2.DenseReluDense.wi_1.weight\n",
  395. "decoder.block.8.layer.2.DenseReluDense.wo.weight\n",
  396. "decoder.block.8.layer.2.layer_norm.weight\n",
  397. "decoder.block.9.layer.0.SelfAttention.q.weight\n",
  398. "decoder.block.9.layer.0.SelfAttention.k.weight\n",
  399. "decoder.block.9.layer.0.SelfAttention.v.weight\n",
  400. "decoder.block.9.layer.0.SelfAttention.o.weight\n",
  401. "decoder.block.9.layer.0.layer_norm.weight\n",
  402. "decoder.block.9.layer.1.EncDecAttention.q.weight\n",
  403. "decoder.block.9.layer.1.EncDecAttention.k.weight\n",
  404. "decoder.block.9.layer.1.EncDecAttention.v.weight\n",
  405. "decoder.block.9.layer.1.EncDecAttention.o.weight\n",
  406. "decoder.block.9.layer.1.layer_norm.weight\n",
  407. "decoder.block.9.layer.2.DenseReluDense.wi_0.weight\n",
  408. "decoder.block.9.layer.2.DenseReluDense.wi_1.weight\n",
  409. "decoder.block.9.layer.2.DenseReluDense.wo.weight\n",
  410. "decoder.block.9.layer.2.layer_norm.weight\n",
  411. "decoder.block.10.layer.0.SelfAttention.q.weight\n",
  412. "decoder.block.10.layer.0.SelfAttention.k.weight\n",
  413. "decoder.block.10.layer.0.SelfAttention.v.weight\n",
  414. "decoder.block.10.layer.0.SelfAttention.o.weight\n",
  415. "decoder.block.10.layer.0.layer_norm.weight\n",
  416. "decoder.block.10.layer.1.EncDecAttention.q.weight\n",
  417. "decoder.block.10.layer.1.EncDecAttention.k.weight\n",
  418. "decoder.block.10.layer.1.EncDecAttention.v.weight\n",
  419. "decoder.block.10.layer.1.EncDecAttention.o.weight\n",
  420. "decoder.block.10.layer.1.layer_norm.weight\n",
  421. "decoder.block.10.layer.2.DenseReluDense.wi_0.weight\n",
  422. "decoder.block.10.layer.2.DenseReluDense.wi_1.weight\n",
  423. "decoder.block.10.layer.2.DenseReluDense.wo.weight\n",
  424. "decoder.block.10.layer.2.layer_norm.weight\n",
  425. "decoder.block.11.layer.0.SelfAttention.q.weight\n",
  426. "decoder.block.11.layer.0.SelfAttention.k.weight\n",
  427. "decoder.block.11.layer.0.SelfAttention.v.weight\n",
  428. "decoder.block.11.layer.0.SelfAttention.o.weight\n",
  429. "decoder.block.11.layer.0.layer_norm.weight\n",
  430. "decoder.block.11.layer.1.EncDecAttention.q.weight\n",
  431. "decoder.block.11.layer.1.EncDecAttention.k.weight\n",
  432. "decoder.block.11.layer.1.EncDecAttention.v.weight\n",
  433. "decoder.block.11.layer.1.EncDecAttention.o.weight\n",
  434. "decoder.block.11.layer.1.layer_norm.weight\n",
  435. "decoder.block.11.layer.2.DenseReluDense.wi_0.weight\n",
  436. "decoder.block.11.layer.2.DenseReluDense.wi_1.weight\n",
  437. "decoder.block.11.layer.2.DenseReluDense.wo.weight\n",
  438. "decoder.block.11.layer.2.layer_norm.weight\n",
  439. "decoder.final_layer_norm.weight\n",
  440. "lm_head.weight\n"
  441. ]
  442. }
  443. ],
  444. "source": [
  445. "for x, y in model.named_parameters():\n",
  446. " print(x)"
  447. ]
  448. },
  449. {
  450. "cell_type": "code",
  451. "execution_count": 6,
  452. "id": "4a34e1f1-1fc1-4577-a87d-efeac33894b1",
  453. "metadata": {},
  454. "outputs": [
  455. {
  456. "name": "stderr",
  457. "output_type": "stream",
  458. "text": [
  459. "Found cached dataset glue (/home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
  460. ]
  461. },
  462. {
  463. "data": {
  464. "application/vnd.jupyter.widget-view+json": {
  465. "model_id": "22d7491179634c75ab8a5c70e9e4188f",
  466. "version_major": 2,
  467. "version_minor": 0
  468. },
  469. "text/plain": [
  470. " 0%| | 0/3 [00:00<?, ?it/s]"
  471. ]
  472. },
  473. "metadata": {},
  474. "output_type": "display_data"
  475. },
  476. {
  477. "name": "stderr",
  478. "output_type": "stream",
  479. "text": [
  480. "Loading cached processed dataset at /home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-63df8ebe4567b55a.arrow\n",
  481. "Loading cached processed dataset at /home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-bb3872c77bcda3cd.arrow\n"
  482. ]
  483. }
  484. ],
  485. "source": [
  486. "data_loader = AutoLoad(tokenizer)\n",
  487. "dataset = data_loader.get_and_map(config.task[0])\n",
  488. "train_loader, valid_loader = generate_dataloader(tokenizer, dataset['train'], dataset['valid'], config)"
  489. ]
  490. },
  491. {
  492. "cell_type": "code",
  493. "execution_count": 7,
  494. "id": "cf5aea38-4866-4026-b6d4-a8e8b50153b0",
  495. "metadata": {
  496. "tags": []
  497. },
  498. "outputs": [],
  499. "source": [
  500. "# model(**next(iter(train_loader))).loss.backward()\n",
  501. "# for i in range(6, 12):\n",
  502. "# o = model.encoder.block[i].soft_prompt.sadcl_learned_embedding.grad.abs().sum().item()\n",
  503. "# print(i, o)"
  504. ]
  505. },
  506. {
  507. "cell_type": "markdown",
  508. "id": "6281dbae-3023-4e95-82c9-c9d818c37622",
  509. "metadata": {},
  510. "source": [
  511. "# train model"
  512. ]
  513. },
  514. {
  515. "cell_type": "code",
  516. "execution_count": 10,
  517. "id": "dd92aff9-e4cb-4b1b-aece-0a7eee27e0e4",
  518. "metadata": {
  519. "tags": []
  520. },
  521. "outputs": [
  522. {
  523. "name": "stderr",
  524. "output_type": "stream",
  525. "text": [
  526. "\n",
  527. "KeyboardInterrupt\n",
  528. "\n"
  529. ]
  530. }
  531. ],
  532. "source": [
  533. "import wandb\n",
  534. "wandb.init(\n",
  535. " # set the wandb project where this run will be logged\n",
  536. " project=\"my-awesome-project\",\n",
  537. " # track hyperparameters and run metadata\n",
  538. " config=config.__dict__\n",
  539. ")"
  540. ]
  541. },
  542. {
  543. "cell_type": "code",
  544. "execution_count": 8,
  545. "id": "74f04c24-2298-4152-abde-c1ee6a0ea739",
  546. "metadata": {},
  547. "outputs": [
  548. {
  549. "name": "stderr",
  550. "output_type": "stream",
  551. "text": [
  552. " 0%| | 0/268 [00:00<?, ?it/s]You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
  553. "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:08<00:00, 32.56it/s]\n"
  554. ]
  555. },
  556. {
  557. "name": "stdout",
  558. "output_type": "stream",
  559. "text": [
  560. "{'train_loss': 8.963883996009827, 'valid_loss': 6.972279635342685, 'valid_accuracy': 0.0, 'valid_f1-score-1': 0.0, 'valid_f1-score-ma': 0.0}\n"
  561. ]
  562. },
  563. {
  564. "name": "stderr",
  565. "output_type": "stream",
  566. "text": [
  567. "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 34.00it/s]\n"
  568. ]
  569. },
  570. {
  571. "name": "stdout",
  572. "output_type": "stream",
  573. "text": [
  574. "{'train_loss': 7.36324492141382, 'valid_loss': 5.521347826177424, 'valid_accuracy': 0.0, 'valid_f1-score-1': 0.0, 'valid_f1-score-ma': 0.0}\n"
  575. ]
  576. },
  577. {
  578. "name": "stderr",
  579. "output_type": "stream",
  580. "text": [
  581. "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 33.97it/s]\n"
  582. ]
  583. },
  584. {
  585. "name": "stdout",
  586. "output_type": "stream",
  587. "text": [
  588. "{'train_loss': 6.192735992260833, 'valid_loss': 4.384567484711155, 'valid_accuracy': 0.0, 'valid_f1-score-1': 0.0, 'valid_f1-score-ma': 0.0}\n"
  589. ]
  590. },
  591. {
  592. "name": "stderr",
  593. "output_type": "stream",
  594. "text": [
  595. "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 33.89it/s]\n"
  596. ]
  597. },
  598. {
  599. "name": "stdout",
  600. "output_type": "stream",
  601. "text": [
  602. "{'train_loss': 5.118913385405469, 'valid_loss': 3.335551644816543, 'valid_accuracy': 0.05465004793863854, 'valid_f1-score-1': 0.14091470951792337, 'valid_f1-score-ma': 0.009394313967861558}\n"
  603. ]
  604. },
  605. {
  606. "name": "stderr",
  607. "output_type": "stream",
  608. "text": [
  609. "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 33.93it/s]\n"
  610. ]
  611. },
  612. {
  613. "name": "stdout",
  614. "output_type": "stream",
  615. "text": [
  616. "{'train_loss': 4.148920764674002, 'valid_loss': 2.2682720783985024, 'valid_accuracy': 0.174496644295302, 'valid_f1-score-1': 0.36804853387259856, 'valid_f1-score-ma': 0.02164991375721168}\n"
  617. ]
  618. },
  619. {
  620. "name": "stderr",
  621. "output_type": "stream",
  622. "text": [
  623. "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 33.94it/s]\n"
  624. ]
  625. },
  626. {
  627. "name": "stdout",
  628. "output_type": "stream",
  629. "text": [
  630. "{'train_loss': 3.1643025679374808, 'valid_loss': 1.2784492048350247, 'valid_accuracy': 0.5148609779482263, 'valid_f1-score-1': 0.7208053691275169, 'valid_f1-score-ma': 0.05544656685596284}\n"
  631. ]
  632. },
  633. {
  634. "name": "stderr",
  635. "output_type": "stream",
  636. "text": [
  637. "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 33.87it/s]\n"
  638. ]
  639. },
  640. {
  641. "name": "stdout",
  642. "output_type": "stream",
  643. "text": [
  644. "{'train_loss': 2.235221519843856, 'valid_loss': 0.6245457141688375, 'valid_accuracy': 0.6625119846596357, 'valid_f1-score-1': 0.7915151515151515, 'valid_f1-score-ma': 0.09733333333333333}\n"
  645. ]
  646. },
  647. {
  648. "name": "stderr",
  649. "output_type": "stream",
  650. "text": [
  651. "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 33.91it/s]\n"
  652. ]
  653. },
  654. {
  655. "name": "stdout",
  656. "output_type": "stream",
  657. "text": [
  658. "{'train_loss': 1.5051592252592543, 'valid_loss': 0.4341738431742697, 'valid_accuracy': 0.6768935762224353, 'valid_f1-score-1': 0.8077147866744595, 'valid_f1-score-ma': 0.0810680363745307}\n"
  659. ]
  660. },
  661. {
  662. "name": "stderr",
  663. "output_type": "stream",
  664. "text": [
  665. "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 33.99it/s]\n"
  666. ]
  667. },
  668. {
  669. "name": "stdout",
  670. "output_type": "stream",
  671. "text": [
  672. "{'train_loss': 1.0515665002723238, 'valid_loss': 0.3996329452052261, 'valid_accuracy': 0.6826462128475551, 'valid_f1-score-1': 0.8151116199198626, 'valid_f1-score-ma': 0.08151116199198626}\n"
  673. ]
  674. },
  675. {
  676. "name": "stderr",
  677. "output_type": "stream",
  678. "text": [
  679. "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 34.04it/s]\n"
  680. ]
  681. },
  682. {
  683. "name": "stdout",
  684. "output_type": "stream",
  685. "text": [
  686. "{'train_loss': 0.8272433252032123, 'valid_loss': 0.3832718174565922, 'valid_accuracy': 0.6855225311601151, 'valid_f1-score-1': 0.8162100456621004, 'valid_f1-score-ma': 0.11660143509458577}\n"
  687. ]
  688. },
  689. {
  690. "name": "stderr",
  691. "output_type": "stream",
  692. "text": [
  693. "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 34.00it/s]\n"
  694. ]
  695. },
  696. {
  697. "name": "stdout",
  698. "output_type": "stream",
  699. "text": [
  700. "{'train_loss': 0.7063313028705653, 'valid_loss': 0.36713372216080176, 'valid_accuracy': 0.6874400767018217, 'valid_f1-score-1': 0.8175598631698974, 'valid_f1-score-ma': 0.16351197263397949}\n"
  701. ]
  702. },
  703. {
  704. "name": "stderr",
  705. "output_type": "stream",
  706. "text": [
  707. "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 33.92it/s]\n"
  708. ]
  709. },
  710. {
  711. "name": "stdout",
  712. "output_type": "stream",
  713. "text": [
  714. "{'train_loss': 0.6392747724234168, 'valid_loss': 0.36563855906327564, 'valid_accuracy': 0.6874400767018217, 'valid_f1-score-1': 0.8170940170940172, 'valid_f1-score-ma': 0.2042735042735043}\n"
  715. ]
  716. },
  717. {
  718. "name": "stderr",
  719. "output_type": "stream",
  720. "text": [
  721. "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 34.00it/s]\n"
  722. ]
  723. },
  724. {
  725. "name": "stdout",
  726. "output_type": "stream",
  727. "text": [
  728. "{'train_loss': 0.5930970842713741, 'valid_loss': 0.3603471038919507, 'valid_accuracy': 0.6874400767018217, 'valid_f1-score-1': 0.8170940170940172, 'valid_f1-score-ma': 0.2042735042735043}\n"
  729. ]
  730. },
  731. {
  732. "name": "stderr",
  733. "output_type": "stream",
  734. "text": [
  735. " 40%|█████████████████████████████████████▏ | 106/268 [00:03<00:04, 33.89it/s]\n",
  736. "\n",
  737. "KeyboardInterrupt\n",
  738. "\n"
  739. ]
  740. }
  741. ],
  742. "source": [
  743. "optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)\n",
  744. "\n",
  745. "model.to(DEVICE)\n",
  746. "\n",
  747. "for epoch in range(config.num_epochs):\n",
  748. " train_out = train_loop(model=model, loader=train_loader, optimizer=optimizer)\n",
  749. " valid_out = valid_loop(model=model, loader=valid_loader)\n",
  750. " wandb.log({\n",
  751. " **train_out,\n",
  752. " **valid_out\n",
  753. " })\n",
  754. " \n",
  755. "wandb.finish()"
  756. ]
  757. },
  758. {
  759. "cell_type": "code",
  760. "execution_count": null,
  761. "id": "4368eb4c-fb7b-41bf-89e3-bf20cdfec967",
  762. "metadata": {},
  763. "outputs": [],
  764. "source": [
  765. "# pip uninstall bitsandbytes -y"
  766. ]
  767. },
  768. {
  769. "cell_type": "code",
  770. "execution_count": null,
  771. "id": "0de2e02c-c5fb-4d13-81fd-5ecb53d42b6c",
  772. "metadata": {
  773. "tags": []
  774. },
  775. "outputs": [],
  776. "source": [
  777. "dataset['train'].set_format(columns=['label', 'labels'])"
  778. ]
  779. },
  780. {
  781. "cell_type": "code",
  782. "execution_count": null,
  783. "id": "687f7994-e875-4f2e-b151-89460bf78eea",
  784. "metadata": {},
  785. "outputs": [],
  786. "source": [
  787. "dataset['train'][0:100]"
  788. ]
  789. },
  790. {
  791. "cell_type": "code",
  792. "execution_count": null,
  793. "id": "5493b857-49e0-4963-b220-7f00422b7511",
  794. "metadata": {
  795. "tags": []
  796. },
  797. "outputs": [],
  798. "source": [
  799. "from datasets import load_dataset\n",
  800. "x = load_dataset(\"glue\", \"sst2\")"
  801. ]
  802. },
  803. {
  804. "cell_type": "code",
  805. "execution_count": null,
  806. "id": "d055f9b6-294e-4f53-9941-21ecc040e92b",
  807. "metadata": {
  808. "tags": []
  809. },
  810. "outputs": [],
  811. "source": [
  812. "Counter(x['train']['label'])"
  813. ]
  814. },
  815. {
  816. "cell_type": "code",
  817. "execution_count": null,
  818. "id": "abf72943-aa98-4543-beec-942f6f601b89",
  819. "metadata": {
  820. "tags": []
  821. },
  822. "outputs": [],
  823. "source": [
  824. "g = x['train']\n",
  825. "l = g.features['label']"
  826. ]
  827. },
  828. {
  829. "cell_type": "code",
  830. "execution_count": null,
  831. "id": "919fb5a0-fc9c-4fcc-9fe6-21cff8960b51",
  832. "metadata": {
  833. "tags": []
  834. },
  835. "outputs": [],
  836. "source": [
  837. "l.int2str(1)"
  838. ]
  839. },
  840. {
  841. "cell_type": "code",
  842. "execution_count": null,
  843. "id": "b286f811-82bc-4f59-914f-e1c5cd5cd1ef",
  844. "metadata": {
  845. "tags": []
  846. },
  847. "outputs": [],
  848. "source": [
  849. "29780 / (29780 + 37569)"
  850. ]
  851. },
  852. {
  853. "cell_type": "code",
  854. "execution_count": 11,
  855. "id": "b581a87a-120d-4f7a-a8f4-b39c6a6d1843",
  856. "metadata": {
  857. "tags": []
  858. },
  859. "outputs": [],
  860. "source": [
  861. "from types import SimpleNamespace\n",
  862. "config = SimpleNamespace(\n",
  863. " model_name='google/t5-base-lm-adapt',\n",
  864. " peft_params={\n",
  865. " 'n_tokens': 30,\n",
  866. " 'n_layers': 6\n",
  867. " },\n",
  868. " random_seed=42,\n",
  869. " task=['glue:cola'],\n",
  870. " hot_modules=['sadcl'],\n",
  871. " train_batch_size=32,\n",
  872. " valid_batch_size=32,\n",
  873. " balancify_sample=False,\n",
  874. " learning_rate=0.01,\n",
  875. " num_epochs=50\n",
  876. ")"
  877. ]
  878. },
  879. {
  880. "cell_type": "code",
  881. "execution_count": 13,
  882. "id": "e7dbb2d9-d545-48e1-a0ac-6d79258a393b",
  883. "metadata": {
  884. "tags": []
  885. },
  886. "outputs": [
  887. {
  888. "name": "stdout",
  889. "output_type": "stream",
  890. "text": [
  891. "{\"model_name\": \"google/t5-base-lm-adapt\", \"peft_params\": {\"n_tokens\": 30, \"n_layers\": 6}, \"random_seed\": 42, \"task\": [\"glue:cola\"], \"hot_modules\": [\"sadcl\"], \"train_batch_size\": 32, \"valid_batch_size\": 32, \"balancify_sample\": false, \"learning_rate\": 0.01, \"num_epochs\": 50}\n"
  892. ]
  893. }
  894. ],
  895. "source": [
  896. "import json\n",
  897. "print(json.dumps(config.__dict__))"
  898. ]
  899. },
  900. {
  901. "cell_type": "code",
  902. "execution_count": 1,
  903. "id": "d18551c3-68e8-4ee0-8936-65bdec51f4eb",
  904. "metadata": {},
  905. "outputs": [],
  906. "source": [
  907. "from transformers import T5TokenizerFast, T5ForConditionalGeneration\n"
  908. ]
  909. },
  910. {
  911. "cell_type": "code",
  912. "execution_count": 2,
  913. "id": "ddac3321-27f4-4b89-aab6-f91ae8bbc86a",
  914. "metadata": {
  915. "tags": []
  916. },
  917. "outputs": [],
  918. "source": [
  919. "tokenizer = T5TokenizerFast.from_pretrained(\"google/t5-large-lm-adapt\", model_max_length=2048)\n"
  920. ]
  921. },
  922. {
  923. "cell_type": "code",
  924. "execution_count": 3,
  925. "id": "7456f182-5d9d-44da-b36d-2ec4052fbaf6",
  926. "metadata": {
  927. "tags": []
  928. },
  929. "outputs": [],
  930. "source": [
  931. "import numpy as np"
  932. ]
  933. },
  934. {
  935. "cell_type": "code",
  936. "execution_count": 4,
  937. "id": "c476e007-c4fe-4eb3-939c-25d0b0add711",
  938. "metadata": {
  939. "tags": []
  940. },
  941. "outputs": [
  942. {
  943. "data": {
  944. "text/plain": [
  945. "array([23830, 2611, 19567, 10149, 20142, 6737, 26963, 6788, 3871,\n",
  946. " 28330, 724, 7406, 11474, 18399, 2289, 25511, 25299, 23308,\n",
  947. " 25412, 370, 32091, 28829, 6148, 29154, 30369, 12979, 8560,\n",
  948. " 6872, 23228, 8051, 19537, 3741, 22206, 20744, 17051, 27857,\n",
  949. " 3830, 15329, 21857, 8296, 10768, 7854, 5710, 5405, 27449,\n",
  950. " 11528, 8599, 12695, 15427, 23726, 389, 3231, 15270, 26906,\n",
  951. " 23085, 15113, 31792, 8766, 9814, 15904, 6320, 23716, 19682,\n",
  952. " 2690, 30766, 21262, 11415, 2523, 26538, 3647, 13971, 21655,\n",
  953. " 287, 19479, 28945, 25134, 17673, 9792, 17556, 31293, 25795,\n",
  954. " 2753, 8955, 21049, 28409, 24281, 3610, 26070, 2189, 25611,\n",
  955. " 9641, 23766, 29195, 779, 18660, 10731, 19732, 1664, 2176,\n",
  956. " 2254])"
  957. ]
  958. },
  959. "execution_count": 4,
  960. "metadata": {},
  961. "output_type": "execute_result"
  962. }
  963. ],
  964. "source": [
  965. "np.random.randint(0, tokenizer.vocab_size, size=(100,))"
  966. ]
  967. },
  968. {
  969. "cell_type": "code",
  970. "execution_count": 1,
  971. "id": "e83e5c34-0860-4e2a-9b2a-016f37b35003",
  972. "metadata": {},
  973. "outputs": [],
  974. "source": [
  975. "import torch"
  976. ]
  977. },
  978. {
  979. "cell_type": "code",
  980. "execution_count": 2,
  981. "id": "0f032c7f-ddd4-4350-81c5-9296c8376d8c",
  982. "metadata": {
  983. "tags": []
  984. },
  985. "outputs": [],
  986. "source": [
  987. "w = torch.load('best.pt')"
  988. ]
  989. },
  990. {
  991. "cell_type": "code",
  992. "execution_count": 8,
  993. "id": "112c06a2-1295-4989-9ce6-5f6204e809ef",
  994. "metadata": {
  995. "tags": []
  996. },
  997. "outputs": [
  998. {
  999. "data": {
  1000. "text/plain": [
  1001. "tensor([[ 0.5470, -0.8095, -1.4617, ..., 0.8100, -1.1746, 0.5768],\n",
  1002. " [-0.9284, -0.6230, -2.4697, ..., 0.3947, -0.5427, -0.3088],\n",
  1003. " [ 1.4407, 0.8760, 0.2499, ..., 0.1860, -0.3176, 2.0041],\n",
  1004. " ...,\n",
  1005. " [ 0.8714, 1.1013, -2.7711, ..., -0.2819, 0.7087, -0.6164],\n",
  1006. " [ 0.8026, -0.7928, -0.8946, ..., -1.5204, 1.0164, -1.3527],\n",
  1007. " [ 0.4650, -2.1778, 0.0213, ..., -1.1430, -2.3895, -0.0235]],\n",
  1008. " device='cuda:0')"
  1009. ]
  1010. },
  1011. "execution_count": 8,
  1012. "metadata": {},
  1013. "output_type": "execute_result"
  1014. }
  1015. ],
  1016. "source": [
  1017. "w.pop('sadcl_learned_embedding')"
  1018. ]
  1019. },
  1020. {
  1021. "cell_type": "code",
  1022. "execution_count": 9,
  1023. "id": "4055f80b-ea2f-44db-bf32-98ea3ffe9597",
  1024. "metadata": {
  1025. "tags": []
  1026. },
  1027. "outputs": [
  1028. {
  1029. "data": {
  1030. "text/plain": [
  1031. "OrderedDict([('sadcl_mlp.0.weight',\n",
  1032. " tensor([[ 0.1171, -0.7743, 0.5095, ..., -1.0615, 1.5754, 0.7036],\n",
  1033. " [-0.2675, 0.0969, 0.0543, ..., 0.7276, -0.0671, 0.8296],\n",
  1034. " [-0.2987, -0.0700, -1.0519, ..., 0.6090, 0.0193, 0.0410],\n",
  1035. " ...,\n",
  1036. " [-0.1463, -0.8924, 0.7947, ..., 0.2265, -0.6957, 0.5928],\n",
  1037. " [-0.4365, -0.9251, -1.0378, ..., -0.8628, -0.5243, 0.0860],\n",
  1038. " [ 0.4860, 0.0648, -0.9160, ..., -0.5342, 0.1072, -0.1397]],\n",
  1039. " device='cuda:0')),\n",
  1040. " ('sadcl_mlp.0.bias',\n",
  1041. " tensor([-0.6311, -1.0433, -1.0390, -1.6997, -1.0766, -0.2802, -0.9433, -0.7127,\n",
  1042. " 0.5315, -1.0400, -0.3756, -0.2602, -0.7607, 0.7578, -0.7066, -0.3561,\n",
  1043. " -0.5580, -0.7671, -0.2557, -1.6528, -0.1438, -0.4875, -0.6291, -1.2763,\n",
  1044. " -0.2484, -0.6396, -0.7225, -0.8314, -1.3913, -0.7696, 0.0864, -0.7268,\n",
  1045. " -0.7812, -1.0606, -0.9011, 0.3322, 0.5159, -0.4453, -0.6409, 0.0714,\n",
  1046. " -0.2788, -0.1620, -0.9408, 0.1440, -0.8897, -0.9288, -1.2605, -1.2384,\n",
  1047. " -0.0090, -0.0661, -0.5203, -1.5729, -0.5143, -0.4943, -0.9472, -0.8107,\n",
  1048. " -0.5748, -1.1438, -0.8919, -0.8606, -1.0831, -1.4380, -1.0802, 0.0522,\n",
  1049. " 0.0785, -2.4277, -1.0447, -0.3124, 0.1173, -0.8195, -0.0623, -0.1913,\n",
  1050. " -1.4551, -0.0732, -1.1574, -0.2217, -0.6697, -0.5846, -0.2473, -0.0144,\n",
  1051. " -1.2317, -0.5024, -0.2301, 0.2265, -0.6478, -0.8726, -0.8367, -0.0312,\n",
  1052. " -0.4783, -0.3132, -0.6115, -1.5002, -0.6820, -0.9731, -0.6438, -0.8716,\n",
  1053. " -0.2628, -0.8308, -0.8588, 0.8616, -0.3398, 0.2025, -0.6247, -0.4494,\n",
  1054. " -1.2737, -0.9406, -0.5297, -0.4886, -1.6481, -2.5021, -0.1344, -0.8274,\n",
  1055. " -1.6135, -0.9598, -0.8659, -1.3385, -1.4567, -1.0869, -0.1999, -1.3751,\n",
  1056. " -0.4536, -1.0839, -1.0037, -0.0429, -0.5243, -0.8836, -0.9716, -1.1037],\n",
  1057. " device='cuda:0')),\n",
  1058. " ('sadcl_mlp.2.weight',\n",
  1059. " tensor([[ 0.3282, 0.2861, 0.4277, ..., -1.1185, 0.3197, 0.6003],\n",
  1060. " [-0.9305, 0.1462, -0.4269, ..., -0.1129, -0.7909, -0.6872],\n",
  1061. " [ 0.0067, -0.7521, -1.6837, ..., -0.2374, -0.2790, -0.9895],\n",
  1062. " ...,\n",
  1063. " [-1.8292, 0.9060, 1.3090, ..., -0.0273, -1.0552, -0.2187],\n",
  1064. " [-0.3804, 0.0945, 0.0337, ..., -1.6941, 0.0693, -0.0288],\n",
  1065. " [-1.3038, 0.2590, -0.2965, ..., 0.9425, -0.0090, -1.2449]],\n",
  1066. " device='cuda:0')),\n",
  1067. " ('sadcl_mlp.3.weight',\n",
  1068. " tensor([2.4987, 2.5515, 2.0518, ..., 1.9383, 1.2583, 1.1634], device='cuda:0')),\n",
  1069. " ('sadcl_mlp.3.bias',\n",
  1070. " tensor([ 1.0355, -0.7098, -1.4075, ..., -0.3350, -0.7165, -0.9371],\n",
  1071. " device='cuda:0'))])"
  1072. ]
  1073. },
  1074. "execution_count": 9,
  1075. "metadata": {},
  1076. "output_type": "execute_result"
  1077. }
  1078. ],
  1079. "source": [
  1080. "w"
  1081. ]
  1082. },
  1083. {
  1084. "cell_type": "code",
  1085. "execution_count": null,
  1086. "id": "fa50b148-1f34-41c9-b7f4-c26c3e4cbce6",
  1087. "metadata": {},
  1088. "outputs": [],
  1089. "source": []
  1090. }
  1091. ],
  1092. "metadata": {
  1093. "kernelspec": {
  1094. "display_name": "Python [conda env:deep]",
  1095. "language": "python",
  1096. "name": "conda-env-deep-py"
  1097. },
  1098. "language_info": {
  1099. "codemirror_mode": {
  1100. "name": "ipython",
  1101. "version": 3
  1102. },
  1103. "file_extension": ".py",
  1104. "mimetype": "text/x-python",
  1105. "name": "python",
  1106. "nbconvert_exporter": "python",
  1107. "pygments_lexer": "ipython3",
  1108. "version": "3.10.11"
  1109. }
  1110. },
  1111. "nbformat": 4,
  1112. "nbformat_minor": 5
  1113. }