You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

explore_ds.ipynb 15KB

3 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "id": "3526e83a-baa5-4278-81ce-e142e0a6d208",
  7. "metadata": {
  8. "tags": []
  9. },
  10. "outputs": [],
  11. "source": [
  12. "import sys\n",
  13. "from pathlib import Path\n",
  14. "sys.path.append(Path('./').absolute().parent.__str__())\n",
  15. "from _datasets import AutoLoad"
  16. ]
  17. },
  18. {
  19. "cell_type": "code",
  20. "execution_count": 48,
  21. "id": "5a0264f8-4b67-44e2-8aa9-468ae8b249b5",
  22. "metadata": {
  23. "tags": []
  24. },
  25. "outputs": [
  26. {
  27. "name": "stdout",
  28. "output_type": "stream",
  29. "text": [
  30. "(12, 15)\n",
  31. "{'a': 'b'}\n"
  32. ]
  33. }
  34. ],
  35. "source": [
  36. "class Test():\n",
  37. " def __new__(cls, *args, **kwargs):\n",
  38. " print(args)\n",
  39. " print(kwargs)\n",
  40. "Test(12, 15, a='b')"
  41. ]
  42. },
  43. {
  44. "cell_type": "code",
  45. "execution_count": 10,
  46. "id": "f0d8ead2-cfa6-4044-8e7a-6b7146bea9cd",
  47. "metadata": {
  48. "tags": []
  49. },
  50. "outputs": [],
  51. "source": [
  52. "from transformers import T5TokenizerFast\n",
  53. "\n",
  54. "tokenizer = T5TokenizerFast.from_pretrained('google/t5-small-lm-adapt')\n",
  55. "tokenizer._is_seq2seq = True\n",
  56. "loader = AutoLoad(tokenizer=tokenizer)"
  57. ]
  58. },
  59. {
  60. "cell_type": "code",
  61. "execution_count": 19,
  62. "id": "07c556fd-780d-4aee-a5e9-ad81a474d94b",
  63. "metadata": {
  64. "tags": []
  65. },
  66. "outputs": [
  67. {
  68. "data": {
  69. "text/plain": [
  70. "['sentence1', 'sentence2']"
  71. ]
  72. },
  73. "execution_count": 19,
  74. "metadata": {},
  75. "output_type": "execute_result"
  76. }
  77. ],
  78. "source": [
  79. "loader.glue_helper.get_task_input('stsb')"
  80. ]
  81. },
  82. {
  83. "cell_type": "code",
  84. "execution_count": 11,
  85. "id": "04feb162-ef3f-42a8-ab00-23d3faea5209",
  86. "metadata": {
  87. "tags": []
  88. },
  89. "outputs": [
  90. {
  91. "data": {
  92. "application/vnd.jupyter.widget-view+json": {
  93. "model_id": "8165afbb7bcb474e80b9538b0c0c39da",
  94. "version_major": 2,
  95. "version_minor": 0
  96. },
  97. "text/plain": [
  98. "Map: 0%| | 0/5749 [00:00<?, ? examples/s]"
  99. ]
  100. },
  101. "metadata": {},
  102. "output_type": "display_data"
  103. },
  104. {
  105. "data": {
  106. "application/vnd.jupyter.widget-view+json": {
  107. "model_id": "95318c2e7b684eabb280fd34d014f1d3",
  108. "version_major": 2,
  109. "version_minor": 0
  110. },
  111. "text/plain": [
  112. "Map: 0%| | 0/1500 [00:00<?, ? examples/s]"
  113. ]
  114. },
  115. "metadata": {},
  116. "output_type": "display_data"
  117. },
  118. {
  119. "data": {
  120. "application/vnd.jupyter.widget-view+json": {
  121. "model_id": "0e47b3895f4d4f77920c8d82579ec683",
  122. "version_major": 2,
  123. "version_minor": 0
  124. },
  125. "text/plain": [
  126. "Map: 0%| | 0/1500 [00:00<?, ? examples/s]"
  127. ]
  128. },
  129. "metadata": {},
  130. "output_type": "display_data"
  131. }
  132. ],
  133. "source": [
  134. "ds = loader.get_and_map('glue:stsb')"
  135. ]
  136. },
  137. {
  138. "cell_type": "code",
  139. "execution_count": 43,
  140. "id": "9dcf1e0c-e703-4e30-9dab-bfc54cde7d3f",
  141. "metadata": {
  142. "tags": []
  143. },
  144. "outputs": [
  145. {
  146. "data": {
  147. "application/vnd.jupyter.widget-view+json": {
  148. "model_id": "e703362287be445fa8f3949c592b1c26",
  149. "version_major": 2,
  150. "version_minor": 0
  151. },
  152. "text/plain": [
  153. "Downloading data: 0%| | 0.00/51.8M [00:00<?, ?B/s]"
  154. ]
  155. },
  156. "metadata": {},
  157. "output_type": "display_data"
  158. },
  159. {
  160. "data": {
  161. "application/vnd.jupyter.widget-view+json": {
  162. "model_id": "2d231baabf80401eacf8c400a811c5ac",
  163. "version_major": 2,
  164. "version_minor": 0
  165. },
  166. "text/plain": [
  167. "Generating train split: 0%| | 0/100730 [00:00<?, ? examples/s]"
  168. ]
  169. },
  170. "metadata": {},
  171. "output_type": "display_data"
  172. },
  173. {
  174. "data": {
  175. "application/vnd.jupyter.widget-view+json": {
  176. "model_id": "6c699b3fdf1e468e9ef8a442651d1f7c",
  177. "version_major": 2,
  178. "version_minor": 0
  179. },
  180. "text/plain": [
  181. "Generating validation split: 0%| | 0/10000 [00:00<?, ? examples/s]"
  182. ]
  183. },
  184. "metadata": {},
  185. "output_type": "display_data"
  186. },
  187. {
  188. "data": {
  189. "application/vnd.jupyter.widget-view+json": {
  190. "model_id": "91acd57830124beeb29c9869f3b67788",
  191. "version_major": 2,
  192. "version_minor": 0
  193. },
  194. "text/plain": [
  195. "Generating test split: 0%| | 0/10000 [00:00<?, ? examples/s]"
  196. ]
  197. },
  198. "metadata": {},
  199. "output_type": "display_data"
  200. }
  201. ],
  202. "source": [
  203. "from datasets import load_dataset\n",
  204. "\n",
  205. "ds = load_dataset('super_glue', 'record')"
  206. ]
  207. },
  208. {
  209. "cell_type": "code",
  210. "execution_count": 46,
  211. "id": "c4d652d7-8237-4e5a-85e5-faf39a88eea5",
  212. "metadata": {
  213. "tags": []
  214. },
  215. "outputs": [
  216. {
  217. "data": {
  218. "text/plain": [
  219. "{'passage': \"For everyone who has ever thought about shooting their boss - metaphorically, o fcourse - this one is for you. An employee of a Texas armored car company got to do just that this week to 'demonstrate that they take client safety seriously'. And to further that demonstration, the CEO was sitting alone inside the Mercedes-Benz as 12 rounds from an AK-47 rained down upon the SUV. The company, Texas Armoring Corporation, has supplied protected vehicles to the Pope, celebrities like rapper T.I. and actor Steven Segal and oil executives in West Africa, according to My San Antonio. Texas Armoring Corp. & Jason Forston.\\n@highlight\\nTexas Armoring Corporation created a video to show the effectiveness of their armored\\n@highlight\\nCEO R. Trent Kimball sat in the drivers seat of a Mercedes-Benz SUV\\n@highlight\\nTotal of 12 rounds fired at the windscreen\\n@highlight\\nCompany known for working with celebrities, oil barons and even the Pope\",\n",
  220. " 'query': \"'When it comes to assuring our clients' safety, we take product testing extremely seriously,' @placeholder says in a video taken of the display.\",\n",
  221. " 'entities': ['Steven Segal',\n",
  222. " 'Texas Armoring Corp.',\n",
  223. " 'Trent Kimball',\n",
  224. " 'Texas Armoring Corporation',\n",
  225. " 'Texas',\n",
  226. " 'AK-47',\n",
  227. " 'Pope',\n",
  228. " 'Mercedes-Benz',\n",
  229. " 'San Antonio',\n",
  230. " 'West Africa',\n",
  231. " 'rapper T.I.',\n",
  232. " 'Jason Forston'],\n",
  233. " 'entity_spans': {'text': ['Texas',\n",
  234. " 'Mercedes-Benz',\n",
  235. " 'AK-47',\n",
  236. " 'Texas Armoring Corporation',\n",
  237. " 'Pope',\n",
  238. " 'rapper T.I.',\n",
  239. " 'Steven Segal',\n",
  240. " 'West Africa',\n",
  241. " 'San Antonio',\n",
  242. " 'Texas Armoring Corp.',\n",
  243. " 'Jason Forston',\n",
  244. " 'Texas Armoring Corporation',\n",
  245. " 'Trent Kimball',\n",
  246. " 'Mercedes-Benz',\n",
  247. " 'Pope'],\n",
  248. " 'start': [128,\n",
  249. " 313,\n",
  250. " 348,\n",
  251. " 393,\n",
  252. " 460,\n",
  253. " 483,\n",
  254. " 505,\n",
  255. " 540,\n",
  256. " 569,\n",
  257. " 582,\n",
  258. " 605,\n",
  259. " 631,\n",
  260. " 735,\n",
  261. " 778,\n",
  262. " 929],\n",
  263. " 'end': [133,\n",
  264. " 326,\n",
  265. " 353,\n",
  266. " 419,\n",
  267. " 464,\n",
  268. " 494,\n",
  269. " 517,\n",
  270. " 551,\n",
  271. " 580,\n",
  272. " 602,\n",
  273. " 618,\n",
  274. " 657,\n",
  275. " 748,\n",
  276. " 791,\n",
  277. " 933]},\n",
  278. " 'answers': ['Trent Kimball'],\n",
  279. " 'idx': {'passage': 4, 'query': 10}}"
  280. ]
  281. },
  282. "execution_count": 46,
  283. "metadata": {},
  284. "output_type": "execute_result"
  285. }
  286. ],
  287. "source": [
  288. "ds['train'][10]"
  289. ]
  290. },
  291. {
  292. "cell_type": "code",
  293. "execution_count": 31,
  294. "id": "c77ab84e-1cd2-4038-9354-b7f2668bc99d",
  295. "metadata": {
  296. "tags": []
  297. },
  298. "outputs": [],
  299. "source": [
  300. "from evaluate import load"
  301. ]
  302. },
  303. {
  304. "cell_type": "code",
  305. "execution_count": 38,
  306. "id": "dc4b8326-43c7-4941-aae5-3cbea1f793cb",
  307. "metadata": {
  308. "tags": []
  309. },
  310. "outputs": [
  311. {
  312. "data": {
  313. "text/plain": [
  314. "{'exact_match': 0.0, 'f1_m': 0.0, 'f1_a': 0.0}"
  315. ]
  316. },
  317. "execution_count": 38,
  318. "metadata": {},
  319. "output_type": "execute_result"
  320. }
  321. ],
  322. "source": [
  323. "metric = load('super_glue', 'multirc')\n",
  324. "metric.compute(\n",
  325. " predictions=[{'prediction': 0, 'idx':{'paragraph': 0, 'question': 0, 'answer': 2}}],\n",
  326. " references=[1]\n",
  327. ") "
  328. ]
  329. },
  330. {
  331. "cell_type": "code",
  332. "execution_count": 39,
  333. "id": "13da4dac-ae6f-4a36-a6ed-ebf077eef625",
  334. "metadata": {
  335. "tags": []
  336. },
  337. "outputs": [
  338. {
  339. "data": {
  340. "text/plain": [
  341. "EvaluationModule(name: \"super_glue\", module_type: \"metric\", features: {'predictions': {'idx': {'answer': Value(dtype='int64', id=None), 'paragraph': Value(dtype='int64', id=None), 'question': Value(dtype='int64', id=None)}, 'prediction': Value(dtype='int64', id=None)}, 'references': Value(dtype='int64', id=None)}, usage: \"\"\"\n",
  342. "Compute SuperGLUE evaluation metric associated to each SuperGLUE dataset.\n",
  343. "Args:\n",
  344. " predictions: list of predictions to score. Depending on the SuperGlUE subset:\n",
  345. " - for 'record': list of question-answer dictionaries with the following keys:\n",
  346. " - 'idx': index of the question as specified by the dataset\n",
  347. " - 'prediction_text': the predicted answer text\n",
  348. " - for 'multirc': list of question-answer dictionaries with the following keys:\n",
  349. " - 'idx': index of the question-answer pair as specified by the dataset\n",
  350. " - 'prediction': the predicted answer label\n",
  351. " - otherwise: list of predicted labels\n",
  352. " references: list of reference labels. Depending on the SuperGLUE subset:\n",
  353. " - for 'record': list of question-answers dictionaries with the following keys:\n",
  354. " - 'idx': index of the question as specified by the dataset\n",
  355. " - 'answers': list of possible answers\n",
  356. " - otherwise: list of reference labels\n",
  357. "Returns: depending on the SuperGLUE subset:\n",
  358. " - for 'record':\n",
  359. " - 'exact_match': Exact match between answer and gold answer\n",
  360. " - 'f1': F1 score\n",
  361. " - for 'multirc':\n",
  362. " - 'exact_match': Exact match between answer and gold answer\n",
  363. " - 'f1_m': Per-question macro-F1 score\n",
  364. " - 'f1_a': Average F1 score over all answers\n",
  365. " - for 'axb':\n",
  366. " 'matthews_correlation': Matthew Correlation\n",
  367. " - for 'cb':\n",
  368. " - 'accuracy': Accuracy\n",
  369. " - 'f1': F1 score\n",
  370. " - for all others:\n",
  371. " - 'accuracy': Accuracy\n",
  372. "Examples:\n",
  373. "\n",
  374. " >>> super_glue_metric = evaluate.load('super_glue', 'copa') # any of [\"copa\", \"rte\", \"wic\", \"wsc\", \"wsc.fixed\", \"boolq\", \"axg\"]\n",
  375. " >>> predictions = [0, 1]\n",
  376. " >>> references = [0, 1]\n",
  377. " >>> results = super_glue_metric.compute(predictions=predictions, references=references)\n",
  378. " >>> print(results)\n",
  379. " {'accuracy': 1.0}\n",
  380. "\n",
  381. " >>> super_glue_metric = evaluate.load('super_glue', 'cb')\n",
  382. " >>> predictions = [0, 1]\n",
  383. " >>> references = [0, 1]\n",
  384. " >>> results = super_glue_metric.compute(predictions=predictions, references=references)\n",
  385. " >>> print(results)\n",
  386. " {'accuracy': 1.0, 'f1': 1.0}\n",
  387. "\n",
  388. " >>> super_glue_metric = evaluate.load('super_glue', 'record')\n",
  389. " >>> predictions = [{'idx': {'passage': 0, 'query': 0}, 'prediction_text': 'answer'}]\n",
  390. " >>> references = [{'idx': {'passage': 0, 'query': 0}, 'answers': ['answer', 'another_answer']}]\n",
  391. " >>> results = super_glue_metric.compute(predictions=predictions, references=references)\n",
  392. " >>> print(results)\n",
  393. " {'exact_match': 1.0, 'f1': 1.0}\n",
  394. "\n",
  395. " >>> super_glue_metric = evaluate.load('super_glue', 'multirc')\n",
  396. " >>> predictions = [{'idx': {'answer': 0, 'paragraph': 0, 'question': 0}, 'prediction': 0}, {'idx': {'answer': 1, 'paragraph': 2, 'question': 3}, 'prediction': 1}]\n",
  397. " >>> references = [0, 1]\n",
  398. " >>> results = super_glue_metric.compute(predictions=predictions, references=references)\n",
  399. " >>> print(results)\n",
  400. " {'exact_match': 1.0, 'f1_m': 1.0, 'f1_a': 1.0}\n",
  401. "\n",
  402. " >>> super_glue_metric = evaluate.load('super_glue', 'axb')\n",
  403. " >>> references = [0, 1]\n",
  404. " >>> predictions = [0, 1]\n",
  405. " >>> results = super_glue_metric.compute(predictions=predictions, references=references)\n",
  406. " >>> print(results)\n",
  407. " {'matthews_correlation': 1.0}\n",
  408. "\"\"\", stored examples: 0)"
  409. ]
  410. },
  411. "execution_count": 39,
  412. "metadata": {},
  413. "output_type": "execute_result"
  414. }
  415. ],
  416. "source": [
  417. "metric"
  418. ]
  419. },
  420. {
  421. "cell_type": "code",
  422. "execution_count": 29,
  423. "id": "020f35a1-09ec-4ef3-94f4-28144778a3ab",
  424. "metadata": {
  425. "tags": []
  426. },
  427. "outputs": [
  428. {
  429. "name": "stdout",
  430. "output_type": "stream",
  431. "text": [
  432. "0.1\n",
  433. "0.1\n",
  434. "0.1\n",
  435. "0.1\n",
  436. "0.1\n",
  437. "0.1\n",
  438. "0.1\n",
  439. "0.1\n",
  440. "0.1\n",
  441. "0.1\n",
  442. "0.1\n",
  443. "0.1\n",
  444. "0.1\n",
  445. "0.1\n",
  446. "0.1\n",
  447. "0.1\n",
  448. "0.1\n",
  449. "0.1\n",
  450. "0.1\n",
  451. "0.1\n",
  452. "0.1\n",
  453. "0.1\n",
  454. "0.1\n",
  455. "0.1\n",
  456. "0.1\n",
  457. "0.1\n",
  458. "0.1\n",
  459. "0.1\n",
  460. "0.1\n",
  461. "0.1\n",
  462. "0.1\n",
  463. "0.1\n",
  464. "0.1\n",
  465. "0.1\n",
  466. "0.1\n",
  467. "0.1\n",
  468. "0.1\n",
  469. "0.1\n",
  470. "0.1\n",
  471. "0.1\n",
  472. "0.1\n",
  473. "0.1\n",
  474. "0.1\n",
  475. "0.1\n",
  476. "0.1\n",
  477. "0.1\n",
  478. "0.1\n",
  479. "0.1\n",
  480. "0.1\n",
  481. "0.1\n",
  482. "0.1\n",
  483. "0.1\n",
  484. "0.1\n",
  485. "0.1\n",
  486. "0.1\n",
  487. "0.1\n",
  488. "0.1\n",
  489. "0.1\n"
  490. ]
  491. }
  492. ],
  493. "source": [
  494. "from transformers import T5ForConditionalGeneration\n",
  495. "import torch\n",
  496. "\n",
  497. "model = T5ForConditionalGeneration.from_pretrained('google/t5-small-lm-adapt')\n",
  498. "\n",
  499. "def mutate_remove_dropout(model):\n",
  500. " for module in model.modules():\n",
  501. " if isinstance(module, torch.nn.Dropout):\n",
  502. " module._backup_p = module.p\n",
  503. " module.p = 0\n",
  504. " print(module._backup_p)\n",
  505. "mutate_remove_dropout(model)"
  506. ]
  507. },
  508. {
  509. "cell_type": "code",
  510. "execution_count": null,
  511. "id": "146e1eb3-f6a6-41d2-ab84-13b62de8983a",
  512. "metadata": {},
  513. "outputs": [],
  514. "source": []
  515. }
  516. ],
  517. "metadata": {
  518. "kernelspec": {
  519. "display_name": "Python [conda env:deep]",
  520. "language": "python",
  521. "name": "conda-env-deep-py"
  522. },
  523. "language_info": {
  524. "codemirror_mode": {
  525. "name": "ipython",
  526. "version": 3
  527. },
  528. "file_extension": ".py",
  529. "mimetype": "text/x-python",
  530. "name": "python",
  531. "nbconvert_exporter": "python",
  532. "pygments_lexer": "ipython3",
  533. "version": "3.10.13"
  534. }
  535. },
  536. "nbformat": 4,
  537. "nbformat_minor": 5
  538. }