You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

table2.ipynb 20KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "id": "135746cc-454c-41a2-977c-cf633899f002",
  7. "metadata": {
  8. "tags": []
  9. },
  10. "outputs": [],
  11. "source": [
  12. "import pandas as pd \n",
  13. "import numpy as np\n",
  14. "import matplotlib.pyplot as plt\n",
  15. "from matplotlib.ticker import FormatStrFormatter\n",
  16. "\n",
  17. "class WandBWrapper:\n",
  18. " def __init__(self, prefix=''):\n",
  19. " import wandb\n",
  20. " self.api = wandb.Api()\n",
  21. " self.prefix = prefix\n",
  22. " \n",
  23. " def get_runs(self, name):\n",
  24. " return self.api.runs(f\"{self.prefix}{name}\")\n",
  25. " \n",
  26. " def _preprocess_config(self, run):\n",
  27. " return {\n",
  28. " k: v for k,v in run.config.items()\n",
  29. " if not k.startswith('_')\n",
  30. " }\n",
  31. " \n",
  32. " def sort_valid_columns(self, cols):\n",
  33. " priority = {\n",
  34. " 'matthews_correlation': 0,\n",
  35. " 'f1': 1,\n",
  36. " 'f1_a':1,\n",
  37. " 'accuracy': 2,\n",
  38. " 'exact_match': 3,\n",
  39. " 'pearson': 5,\n",
  40. " 'spearmanr': 6\n",
  41. " }\n",
  42. " \n",
  43. " for col in cols: # mnli dirty fix\n",
  44. " if 'matched_accuracy' in col:\n",
  45. " return ['valid_mean']\n",
  46. " \n",
  47. " cols = [col for col in cols if 'f1_m' not in col]\n",
  48. " \n",
  49. " stripper = lambda x: x[x.find('_') + 1:]\n",
  50. " return list(sorted(cols, key=lambda x: priority[stripper(x)]))\n",
  51. " \n",
  52. " def _best_in_history(self, run, key):\n",
  53. " history = run.history()\n",
  54. " all_valid_columns = [col for col in history.columns if 'valid' in col and 'mean' not in col]\n",
  55. " best_row_idx = history[key].astype('float').fillna(0).argmax()\n",
  56. " all_valid_columns = self.sort_valid_columns(all_valid_columns)\n",
  57. " return [max(float(history[key][best_row_idx]), 0) for key in all_valid_columns]\n",
  58. " \n",
  59. " def get_full_history(self, runs, tasks, model_size=''):\n",
  60. " task_names = [model_size + '_' + task_name for task_name in tasks]\n",
  61. " return {\n",
  62. " task_name: pd.DataFrame({\n",
  63. " run.name: run.history()['valid_mean']\n",
  64. " for run in self.get_runs(task_name)\n",
  65. " if run.name in runs\n",
  66. " })[runs]\n",
  67. " for task_name in task_names\n",
  68. " }\n",
  69. " \n",
  70. " def get_runs_best(self, name, run_name_filter=None):\n",
  71. " runs = self.get_runs(name)\n",
  72. " return {\n",
  73. " run.name: self._best_in_history(run, 'valid_mean')\n",
  74. " for run in runs\n",
  75. " if run_name_filter is None or run.name in run_name_filter\n",
  76. " }\n",
  77. " \n",
  78. " def get_runs_tasks_df(self, runs, tasks, model_size=''):\n",
  79. " task_names = [model_size + '_' + task_name for task_name in tasks]\n",
  80. " results = {task_name: self.get_runs_best(task_name, runs) for task_name in task_names}\n",
  81. " return pd.DataFrame(results).T[runs].T"
  82. ]
  83. },
  84. {
  85. "cell_type": "code",
  86. "execution_count": 2,
  87. "id": "a4ddeace-44eb-4a2d-b215-b3d9af067204",
  88. "metadata": {
  89. "tags": []
  90. },
  91. "outputs": [],
  92. "source": [
  93. "attempt = {\n",
  94. " 'qqp': ['-', 0.903], # F1/acc\n",
  95. " 'qnli': [0.930],\n",
  96. " 'mnli': [0.843],\n",
  97. " 'sst2': [0.932],\n",
  98. " 'stsb': [0.897, '-'], # Pearson / rho\n",
  99. " 'mrpc': ['-', 0.857], # F1/acc\n",
  100. " 'cola': [0.574],\n",
  101. " 'multirc': [0.744, \"-\"], # F1a / EM\n",
  102. " 'rte': [0.734],\n",
  103. " 'cb': [\"-\", 0.786], # F1/acc\n",
  104. " 'copa': '-',\n",
  105. " 'wic': [0.668],\n",
  106. " 'boolq': [0.788],\n",
  107. "}\n",
  108. "residual = {\n",
  109. " 'qqp': \"-\",\n",
  110. " 'qnli': \"-\",\n",
  111. " 'mnli': \"-\",\n",
  112. " 'sst2': \"-\",\n",
  113. " 'stsb': \"-\",\n",
  114. " 'mrpc': \"-\",\n",
  115. " 'cola': \"-\",\n",
  116. " 'multirc': [0.593],\n",
  117. " 'rte': [0.704],\n",
  118. " 'cb': [0.792],\n",
  119. " 'copa': [0.583],\n",
  120. " 'wic': [0.668],\n",
  121. " 'boolq': [0.779],\n",
  122. "}"
  123. ]
  124. },
  125. {
  126. "cell_type": "code",
  127. "execution_count": 3,
  128. "id": "28243b98-8fa8-4fc0-a348-b905c126bdd7",
  129. "metadata": {
  130. "tags": []
  131. },
  132. "outputs": [],
  133. "source": [
  134. "import json\n",
  135. "import numpy as np\n",
  136. "from pathlib import Path \n",
  137. "\n",
  138. "def load_gpt_score(base_path, task_name):\n",
  139. " base_path = Path(base_path)\n",
  140. " if task_name == 'mnli':\n",
  141. " matched = json.loads((base_path / f'{task_name}_matched.json').read_text())\n",
  142. " mismatched = json.loads((base_path / f'{task_name}_mismatched.json').read_text())\n",
  143. " return [np.mean([*matched.values(), *mismatched.values()])]\n",
  144. " \n",
  145. " performance = json.loads((base_path / f'{task_name}.json').read_text())\n",
  146. " \n",
  147. " key_priority = {\n",
  148. " 'matthews_correlation': 0,\n",
  149. " 'f1': 1,\n",
  150. " 'f1_a':1,\n",
  151. " 'accuracy': 2,\n",
  152. " 'exact_match': 3,\n",
  153. " 'pearson': 5,\n",
  154. " 'spearmanr': 6\n",
  155. " }\n",
  156. " \n",
  157. " performance_keys = list(performance.keys())\n",
  158. " if 'f1_m' in performance_keys:\n",
  159. " performance_keys.pop(performance_keys.index('f1_m'))\n",
  160. " performance_keys.sort(key=lambda x: key_priority[x])\n",
  161. " \n",
  162. " return [float(performance[key]) for key in performance_keys]\n",
  163. "\n",
  164. "tasks = [\n",
  165. " 'qqp', # new datasets\n",
  166. " 'qnli', # new datasets\n",
  167. " 'mnli', # new datasets\n",
  168. " 'sst2', # new datasets\n",
  169. " 'stsb', # new datasets\n",
  170. " 'mrpc',\n",
  171. " 'cola',\n",
  172. " 'multirc', # new datasets\n",
  173. " 'rte',\n",
  174. " 'cb',\n",
  175. " 'copa',\n",
  176. " 'wic',\n",
  177. " 'boolq',\n",
  178. "]\n",
  179. "\n",
  180. "gpt_performances = {task: load_gpt_score('openai', task) for task in tasks}"
  181. ]
  182. },
  183. {
  184. "cell_type": "code",
  185. "execution_count": 4,
  186. "id": "5ac2b609-3fb8-4206-a20b-36b2282f3372",
  187. "metadata": {
  188. "tags": []
  189. },
  190. "outputs": [],
  191. "source": [
  192. "tasks = {\n",
  193. " # 'glue-wnli',\n",
  194. " # 'glue-rte',\n",
  195. " 'glue-qqp': 'qqp', # new datasets\n",
  196. " 'glue-qnli': 'qnli', # new datasets\n",
  197. " 'glue-mnli': 'mnli', # new datasets\n",
  198. " 'glue-sst2': 'sst2', # new datasets\n",
  199. " 'glue-stsb': 'stsb', # new datasets\n",
  200. " 'glue-mrpc': 'mrpc',\n",
  201. " 'glue-cola': 'cola',\n",
  202. " 'superglue-multirc': 'multirc', # new datasets\n",
  203. " 'superglue-rte': 'rte',\n",
  204. " 'superglue-cb': 'cb',\n",
  205. " 'superglue-copa': 'copa',\n",
  206. " 'superglue-wic': 'wic',\n",
  207. " 'superglue-boolq': 'boolq',\n",
  208. "}\n",
  209. "\n",
  210. "runs = [\n",
  211. " '10_combine_128',\n",
  212. "] \n",
  213. "\n",
  214. "base_lmt5_df = WandBWrapper(\"mohalisad/hzi_cluster_t5_\").get_runs_tasks_df(\n",
  215. " runs=runs, tasks=tasks.keys(), model_size='base'\n",
  216. ")\n",
  217. "base_lmt5_df['base_superglue-cb']['10_combine_128'] = [0.7826, 0.8214]\n",
  218. "small_lmt5_df = WandBWrapper(\"mohalisad/hzi_cluster_t5_\").get_runs_tasks_df(\n",
  219. " runs=runs,\n",
  220. " tasks=tasks.keys(),\n",
  221. " model_size='small'\n",
  222. ")\n",
  223. "small_lmt5_softmax_df = WandBWrapper(\"mohalisad/iclr_softmax_effect_t5_\").get_runs_tasks_df(\n",
  224. " runs=runs,\n",
  225. " tasks=tasks.keys(),\n",
  226. " model_size='small'\n",
  227. ")\n",
  228. "base_origt5_df = WandBWrapper(\"iclr_orig_t5_t5_\").get_runs_tasks_df(\n",
  229. " runs=runs, tasks=tasks, model_size='base'\n",
  230. ")"
  231. ]
  232. },
  233. {
  234. "cell_type": "code",
  235. "execution_count": 5,
  236. "id": "b4e6da93-1cad-4310-9e54-f6a5f0c87a58",
  237. "metadata": {
  238. "tags": []
  239. },
  240. "outputs": [],
  241. "source": [
  242. "base_lmt5_df.columns = tasks.values()\n",
  243. "small_lmt5_df.columns = tasks.values()\n",
  244. "small_lmt5_softmax_df.columns = tasks.values()\n",
  245. "base_origt5_df.columns = tasks.values()\n",
  246. "\n",
  247. "attempt_df = pd.Series(attempt).to_frame().T\n",
  248. "residual_df = pd.Series(residual).to_frame().T\n",
  249. "gpt_df = pd.Series(gpt_performances).to_frame().T"
  250. ]
  251. },
  252. {
  253. "cell_type": "code",
  254. "execution_count": 6,
  255. "id": "a58a4bbc-7b62-4c5a-b69c-27252598232b",
  256. "metadata": {
  257. "tags": []
  258. },
  259. "outputs": [],
  260. "source": [
  261. "def my_concat(**kwargs):\n",
  262. " merged_df = pd.concat(\n",
  263. " list(kwargs.values()),\n",
  264. " ignore_index=True\n",
  265. " )\n",
  266. " merged_df['name'] = list(kwargs.keys())\n",
  267. " merged_df.set_index('name', inplace=True)\n",
  268. " return merged_df\n",
  269. "\n",
  270. "comp_orig_df = my_concat(\n",
  271. " superpos=base_origt5_df,\n",
  272. " attempt=attempt_df,\n",
  273. " residual=residual_df\n",
  274. ")\n",
  275. "comp_softmax_df = my_concat(\n",
  276. " superpos=small_lmt5_df,\n",
  277. " superpos_softmax=small_lmt5_softmax_df,\n",
  278. ")\n",
  279. "comb_base_df = my_concat(\n",
  280. " superpos=base_lmt5_df\n",
  281. ")\n",
  282. "comp_gpt_df = my_concat(\n",
  283. " gpt=gpt_df\n",
  284. ")"
  285. ]
  286. },
  287. {
  288. "cell_type": "code",
  289. "execution_count": 14,
  290. "id": "b7cbb0bd-0dbe-4f98-9f28-9e1f60d43b1c",
  291. "metadata": {
  292. "tags": []
  293. },
  294. "outputs": [],
  295. "source": [
  296. "import numpy as np\n",
  297. "import itertools\n",
  298. "\n",
  299. "def _tblr_args(rows_count_seq):\n",
  300. " top_rows = list(np.cumsum([4, *rows_count_seq]))\n",
  301. " top_rows_str = ', '.join(map(str, top_rows[:-1]))\n",
  302. " bold_line = ', '.join(map(str, top_rows))\n",
  303. " return r\"\"\"column{2-18} = {c},\n",
  304. " cell{1}{2, 3, 4} = {r=3}{b},\n",
  305. " cell{1}{5} = {c=7}{c},\n",
  306. " cell{1}{12} = {c=6}{},\n",
  307. " vline{2, 3, 4, 5,12,18} = {1-3}{},\n",
  308. " hline{2} = {4-17}{},\n",
  309. " row{%s} = {c},\n",
  310. " cell{%s}{1} = {c=18}{},\n",
  311. " hline{%s} = {-}{2px},,\"\"\" % (top_rows_str, top_rows_str, bold_line)\n",
  312. "\n",
  313. "def _head_rows():\n",
  314. " return [\n",
  315. " r\"&\\rot{\\eztb{\\# Prompts}} & \\rot{\\eztb{Softmax}} & \\rot{\\eztb{Dropout}} & GLUE &&&&&&& SuperGLUE &&&&&&\",\n",
  316. " r\"Task→ &&&& QQP & QNLI & MNLI & SST-2 & STS-B & MRPC & CoLA & MultiRC & RTE & CB & COPA & WiC & BoolQ & Avg.\",\n",
  317. " r\"Method↓ &&&& F1/Acc. & Acc. & Acc. & Acc. & PCC/$\\rho$ & F1/Acc. & MCC & F1a/EM & Acc. & F1/Acc. & Acc. & Acc. & Acc. & -\"\n",
  318. " ]\n",
  319. "\n",
  320. "def _section_row(name):\n",
  321. " return name\n",
  322. "\n",
  323. "def to_pure_number(item):\n",
  324. " if isinstance(item, list):\n",
  325. " item = [x for x in item if x != '-']\n",
  326. " if len(item) == 0:\n",
  327. " return '-'\n",
  328. " return sum(item) / len(item)\n",
  329. " return item\n",
  330. "\n",
  331. "def to_pure_numbers(numbers):\n",
  332. " return np.array([\n",
  333. " to_pure_number(list_item)\n",
  334. " for list_item in numbers\n",
  335. " ])\n",
  336. "\n",
  337. "def _convert_single_number(single_number):\n",
  338. " if single_number == '-':\n",
  339. " return '-'\n",
  340. " if isinstance(single_number, str):\n",
  341. " print(single_number)\n",
  342. " return f\"{100 * single_number:.1f}\"\n",
  343. "\n",
  344. "def _convert_number(n):\n",
  345. " if not isinstance(n, list):\n",
  346. " n = [n]\n",
  347. " number_str = \"/\".join([_convert_single_number(n_item) for n_item in n])\n",
  348. " if to_pure_number(n) == 0:\n",
  349. " return f'{number_str} $\\\\dag$'\n",
  350. " return number_str\n",
  351. "\n",
  352. "def _get_mark(mark_bool):\n",
  353. " if mark_bool is None:\n",
  354. " return \"\"\n",
  355. " return \"\\\\cmark\" if mark_bool else \"\\\\xmark\"\n",
  356. "\n",
  357. "def _normal_row(name, prompt_count, is_softmax, is_dropout, numbers, bold_mask=None):\n",
  358. " numbers_str = [_convert_number(n) for n in numbers]\n",
  359. " if bold_mask is not None:\n",
  360. " for idx, bold_state in enumerate(bold_mask):\n",
  361. " if bold_state:\n",
  362. " numbers_str[idx] = \"\\\\textbf{\" + numbers_str[idx] + \"}\"\n",
  363. " \n",
  364. " prompt_count = str(prompt_count) if prompt_count is not None else \"\"\n",
  365. " return \" & \".join([name, prompt_count, _get_mark(is_softmax), _get_mark(is_dropout), *numbers_str])\n",
  366. "\n",
  367. "def _compute_mean(numbers):\n",
  368. " return np.array([[\n",
  369. " '-'\n",
  370. " if '-' in list(row)\n",
  371. " else to_pure_numbers(row).mean()\n",
  372. " for row in numbers\n",
  373. " ]], dtype=object).T\n",
  374. "\n",
  375. "def generate_rows(names, prompt_counts, softmaxes, dropouts, numbers, first_row_bold=False):\n",
  376. " mean = _compute_mean(numbers)\n",
  377. " numbers = np.concatenate((numbers, mean), axis=1)\n",
  378. " \n",
  379. " if first_row_bold:\n",
  380. " mask = np.zeros_like(numbers)\n",
  381. " mask[0, :] = 1\n",
  382. " mask = mask.astype(bool)\n",
  383. " args_zip = zip(names, prompt_counts, softmaxes, dropouts, numbers, mask)\n",
  384. " else:\n",
  385. " args_zip = zip(names, prompt_counts, softmaxes, dropouts, numbers)\n",
  386. " \n",
  387. " rows = [\n",
  388. " _normal_row(*args)\n",
  389. " for args in args_zip\n",
  390. " ]\n",
  391. " return rows\n",
  392. " \n",
  393. "def generate_table(input_dict):\n",
  394. " all_rows = [(_section_row(key), *val) for (key, val) in input_dict.items()]\n",
  395. " rows_count_seq = [len(row) for row in all_rows]\n",
  396. " all_rows_flatten = itertools.chain.from_iterable(all_rows)\n",
  397. " end_line = '\\\\\\\\\\n'\n",
  398. " rows = [\n",
  399. " *_head_rows(),\n",
  400. " *all_rows_flatten\n",
  401. " ]\n",
  402. " return r\"\"\"\\begin{tblr}{\n",
  403. " %s\n",
  404. "}\n",
  405. "%s\n",
  406. "\\end{tblr}\n",
  407. "\"\"\" % (_tblr_args(rows_count_seq), end_line.join(rows + [\"\"]))"
  408. ]
  409. },
  410. {
  411. "cell_type": "code",
  412. "execution_count": 15,
  413. "id": "f760915e-5c07-4aed-b0b8-1d46a5002bd0",
  414. "metadata": {},
  415. "outputs": [
  416. {
  417. "name": "stdout",
  418. "output_type": "stream",
  419. "text": [
  420. "\\begin{tblr}{\n",
  421. " column{2-18} = {c},\n",
  422. " cell{1}{2, 3, 4} = {r=3}{b},\n",
  423. " cell{1}{5} = {c=7}{c},\n",
  424. " cell{1}{12} = {c=6}{},\n",
  425. " vline{2, 3, 4, 5,12,18} = {1-3}{},\n",
  426. " hline{2} = {4-17}{},\n",
  427. " row{4, 8, 11, 13} = {c},\n",
  428. " cell{4, 8, 11, 13}{1} = {c=18}{},\n",
  429. " hline{4, 8, 11, 13, 15} = {-}{2px},,\n",
  430. "}\n",
  431. "&\\rot{\\eztb{\\# Prompts}} & \\rot{\\eztb{Softmax}} & \\rot{\\eztb{Dropout}} & GLUE &&&&&&& SuperGLUE &&&&&&\\\\\n",
  432. "Task→ &&&& QQP & QNLI & MNLI & SST-2 & STS-B & MRPC & CoLA & MultiRC & RTE & CB & COPA & WiC & BoolQ & Avg.\\\\\n",
  433. "Method↓ &&&& F1/Acc. & Acc. & Acc. & Acc. & PCC/$\\rho$ & F1/Acc. & MCC & F1a/EM & Acc. & F1/Acc. & Acc. & Acc. & Acc. & -\\\\\n",
  434. "T5 Base\\\\\n",
  435. "SuperPos PT & 10 & \\xmark & \\xmark & \\textbf{87.8/90.8} & \\textbf{93.5} & \\textbf{86.0} & \\textbf{94.4} & \\textbf{90.2/90.1} & \\textbf{92.4/89.5} & \\textbf{59.7} & \\textbf{77.7/40.9} & \\textbf{80.1} & \\textbf{97.4/96.4} & \\textbf{66.0} & \\textbf{67.6} & \\textbf{81.3} & \\textbf{81.2}\\\\\n",
  436. "ATTEMPT $\\star$ & 100 & \\cmark & \\cmark & -/90.3 & 93.0 & 84.3 & 93.2 & 89.7/- & -/85.7 & 57.4 & 74.4/- & 73.4 & -/78.6 & - & 66.8 & 78.8 & -\\\\\n",
  437. "Residual PT $\\star$ & 10 & \\xmark & \\cmark & - & - & - & - & - & - & - & 59.3 & 70.4 & 79.2 & 58.3 & 66.8 & 77.9 & -\\\\\n",
  438. "T5v1.1 Small LM-Adapted\\\\\n",
  439. "SuperPos PT & 10 & \\xmark & \\xmark & \\textbf{79.1/83.3} & \\textbf{85.3} & \\textbf{71.7} & \\textbf{89.8} & \\textbf{84.0/84.0} & \\textbf{89.9/85.8} & \\textbf{38.9} & \\textbf{66.6/16.7} & \\textbf{64.6} & \\textbf{73.6/76.8} & \\textbf{58.0} & \\textbf{65.7} & \\textbf{68.9} & \\textbf{70.2}\\\\\n",
  440. "SuperPos PT & 10 & \\cmark & \\xmark & 69.6/75.2 & 76.0 & 42.7 & 82.9 & 45.5/43.3 & 82.4/73.0 & 4.6 & 47.5/0.9 & 52.0 & 49.9/71.4 & 57.0 & 56.4 & 62.3 & 54.9\\\\\n",
  441. "T5v1.1 Base LM-Adapted\\\\\n",
  442. "SuperPos PT & 10 & \\xmark & \\xmark & 81.9/86.3 & 89.8 & 81.0 & 94.2 & 88.6/88.5 & 89.7/85.5 & 56.5 & 72.9/24.9 & 70.4 & 78.3/82.1 & 62.0 & 67.6 & 74.0 & 75.8\\\\\n",
  443. "GPT-3.5-Turbo\\\\\n",
  444. "1 Shot & & & & 76.3/79.2 & 70.9 & 58.5 & 94.0 & 34.6/34.1 & 84.6/77.0 & 46.1 & 77.9/34.1 & 70.8 & 55.6/62.5 & 95.0 & 58.8 & 69.6 & 67.1\\\\\n",
  445. "\n",
  446. "\\end{tblr}\n",
  447. "\n"
  448. ]
  449. }
  450. ],
  451. "source": [
  452. "comp_orig_rows = generate_rows(\n",
  453. " names=['SuperPos PT', 'ATTEMPT $\\star$', 'Residual PT $\\star$'],\n",
  454. " prompt_counts=[10, 100, 10],\n",
  455. " softmaxes=[False, True, False],\n",
  456. " dropouts=[False, True, True],\n",
  457. " numbers=comp_orig_df.to_numpy(),\n",
  458. " first_row_bold=True\n",
  459. ")\n",
  460. "comp_softmax_rows = generate_rows(\n",
  461. " names=['SuperPos PT', 'SuperPos PT'],\n",
  462. " prompt_counts=[10, 10],\n",
  463. " softmaxes=[False, True],\n",
  464. " dropouts=[False, False],\n",
  465. " numbers=comp_softmax_df.to_numpy(),\n",
  466. " first_row_bold=True\n",
  467. ")\n",
  468. "comb_base_rows = generate_rows(\n",
  469. " names=['SuperPos PT'],\n",
  470. " prompt_counts=[10],\n",
  471. " softmaxes=[False],\n",
  472. " dropouts=[False],\n",
  473. " numbers=comb_base_df.to_numpy()\n",
  474. ")\n",
  475. "comp_gpt_rows = generate_rows(\n",
  476. " names=['1 Shot'],\n",
  477. " prompt_counts=[None],\n",
  478. " softmaxes=[None],\n",
  479. " dropouts=[None],\n",
  480. " numbers=comp_gpt_df.to_numpy()\n",
  481. ")\n",
  482. "\n",
  483. "\n",
  484. "print(generate_table({\n",
  485. " 'T5 Base': comp_orig_rows,\n",
  486. " 'T5v1.1 Small LM-Adapted': comp_softmax_rows,\n",
  487. " 'T5v1.1 Base LM-Adapted': comb_base_rows,\n",
  488. " 'GPT-3.5-Turbo': comp_gpt_rows\n",
  489. "}))"
  490. ]
  491. },
  492. {
  493. "cell_type": "code",
  494. "execution_count": 9,
  495. "id": "624c8219-2f9f-4321-9bb4-e5c9f4c8a2d8",
  496. "metadata": {},
  497. "outputs": [
  498. {
  499. "ename": "NameError",
  500. "evalue": "name 'base_df' is not defined",
  501. "output_type": "error",
  502. "traceback": [
  503. "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
  504. "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
  505. "Cell \u001b[0;32mIn[9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mbase_df\u001b[49m\u001b[38;5;241m.\u001b[39mto_numpy()\n",
  506. "\u001b[0;31mNameError\u001b[0m: name 'base_df' is not defined"
  507. ]
  508. }
  509. ],
  510. "source": [
  511. "base_df.to_numpy()"
  512. ]
  513. },
  514. {
  515. "cell_type": "code",
  516. "execution_count": null,
  517. "id": "c9559566-d8fb-4310-ad31-fb204877609f",
  518. "metadata": {
  519. "tags": []
  520. },
  521. "outputs": [],
  522. "source": [
  523. "import pandas as pd"
  524. ]
  525. },
  526. {
  527. "cell_type": "code",
  528. "execution_count": null,
  529. "id": "98ad4c6b-7de1-483a-993e-f4f3332a65c6",
  530. "metadata": {
  531. "tags": []
  532. },
  533. "outputs": [],
  534. "source": [
  535. "pd.DataFrame({'a': [1, 2., '-'], 'b': [0, 5, 1]}).to_numpy()[0].mean()"
  536. ]
  537. },
  538. {
  539. "cell_type": "code",
  540. "execution_count": null,
  541. "id": "a68c7196-462b-407f-b84a-98265296b612",
  542. "metadata": {},
  543. "outputs": [],
  544. "source": []
  545. }
  546. ],
  547. "metadata": {
  548. "kernelspec": {
  549. "display_name": "Python [conda env:deep]",
  550. "language": "python",
  551. "name": "conda-env-deep-py"
  552. },
  553. "language_info": {
  554. "codemirror_mode": {
  555. "name": "ipython",
  556. "version": 3
  557. },
  558. "file_extension": ".py",
  559. "mimetype": "text/x-python",
  560. "name": "python",
  561. "nbconvert_exporter": "python",
  562. "pygments_lexer": "ipython3",
  563. "version": "3.10.13"
  564. }
  565. },
  566. "nbformat": 4,
  567. "nbformat_minor": 5
  568. }