You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

softmax.ipynb 9.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "id": "54a7edcf-605f-40f1-9e89-d62067f55dd3",
  7. "metadata": {
  8. "tags": []
  9. },
  10. "outputs": [],
  11. "source": [
  12. "import pandas as pd \n",
  13. "import numpy as np\n",
  14. "from latex_table import generate_table, generate_rows\n",
  15. "import matplotlib.pyplot as plt\n",
  16. "from matplotlib.ticker import FormatStrFormatter\n",
  17. "\n",
  18. "class WandBWrapper:\n",
  19. " def __init__(self, prefix=''):\n",
  20. " import wandb\n",
  21. " self.api = wandb.Api()\n",
  22. " self.prefix = prefix\n",
  23. " \n",
  24. " def get_runs(self, name):\n",
  25. " return self.api.runs(f\"{self.prefix}{name}\")\n",
  26. " \n",
  27. " def _preprocess_config(self, run):\n",
  28. " return {\n",
  29. " k: v for k,v in run.config.items()\n",
  30. " if not k.startswith('_')\n",
  31. " }\n",
  32. " \n",
  33. " def _best_in_history(self, run, key):\n",
  34. " out = run.history()[key].astype('float').fillna(0).max()\n",
  35. " return max(out, 0)\n",
  36. " \n",
  37. " def get_full_history(self, runs, tasks, model_size=''):\n",
  38. " task_names = [model_size + '_' + task_name for task_name in tasks]\n",
  39. " return {\n",
  40. " task_name: pd.DataFrame({\n",
  41. " run.name: run.history()['valid_mean']\n",
  42. " for run in self.get_runs(task_name)\n",
  43. " if run.name in runs\n",
  44. " })[runs]\n",
  45. " for task_name in task_names\n",
  46. " }\n",
  47. " \n",
  48. " def get_runs_best(self, name, run_name_filter=None):\n",
  49. " runs = self.get_runs(name)\n",
  50. " return {\n",
  51. " run.name: self._best_in_history(run, 'valid_mean')\n",
  52. " for run in runs\n",
  53. " if run_name_filter is None or run.name in run_name_filter\n",
  54. " }\n",
  55. " \n",
  56. " def get_runs_tasks_df(self, runs, tasks, model_size=''):\n",
  57. " task_names = [model_size + '_' + task_name for task_name in tasks]\n",
  58. " results = {task_name: self.get_runs_best(task_name, runs) for task_name in task_names}\n",
  59. " return pd.DataFrame(results).T[runs].T"
  60. ]
  61. },
  62. {
  63. "cell_type": "code",
  64. "execution_count": 2,
  65. "id": "1d044235-2d14-4e4b-ad87-2077c9cd89a4",
  66. "metadata": {
  67. "tags": []
  68. },
  69. "outputs": [],
  70. "source": [
  71. "tasks = [\n",
  72. " # 'glue-wnli',\n",
  73. " # 'glue-rte',\n",
  74. " 'glue-qqp', # new datasets\n",
  75. " 'glue-qnli', # new datasets\n",
  76. " 'glue-mnli', # new datasets\n",
  77. " 'glue-sst2', # new datasets\n",
  78. " 'glue-stsb', # new datasets\n",
  79. " 'glue-mrpc',\n",
  80. " 'glue-cola',\n",
  81. " 'superglue-multirc', # new datasets\n",
  82. " 'superglue-rte',\n",
  83. " 'superglue-cb',\n",
  84. " 'superglue-copa',\n",
  85. " 'superglue-wic',\n",
  86. " 'superglue-boolq',\n",
  87. "]\n",
  88. "\n",
  89. "runs = [\n",
  90. " '10_combine_128',\n",
  91. "] \n",
  92. "\n",
  93. "# small_df_softmax = WandBWrapper(\"mohalisad/iclr_softmax_effect_t5_\").get_runs_tasks_df(\n",
  94. "# runs=runs,\n",
  95. "# tasks=tasks,\n",
  96. "# model_size='small'\n",
  97. "# )\n",
  98. "small_df_no_softmax = WandBWrapper(\"mohalisad/hzi_cluster_t5_\").get_runs_tasks_df(\n",
  99. " runs=runs,\n",
  100. " tasks=tasks,\n",
  101. " model_size='small'\n",
  102. ")\n"
  103. ]
  104. },
  105. {
  106. "cell_type": "code",
  107. "execution_count": 7,
  108. "id": "7300ed8f-4477-4e4c-b818-c265c3f02aae",
  109. "metadata": {
  110. "tags": []
  111. },
  112. "outputs": [],
  113. "source": [
  114. "small_df = pd.concat([small_df_no_softmax, small_df_no_softmax], ignore_index=True)\n",
  115. "small_df['name'] = ['softmax', 'no_softmax']\n",
  116. "small_df.set_index('name', inplace=True)"
  117. ]
  118. },
  119. {
  120. "cell_type": "code",
  121. "execution_count": 10,
  122. "id": "fe96e491-24ce-4cb8-a25e-0db9cb98435d",
  123. "metadata": {
  124. "tags": []
  125. },
  126. "outputs": [],
  127. "source": [
  128. "import numpy as np\n",
  129. "\n",
  130. "def _tblr_args():\n",
  131. " return r\"\"\"column{2-16} = {c},\n",
  132. " cell{1}{3} = {r=3}{b},\n",
  133. " cell{1}{4} = {c=7}{c},\n",
  134. " cell{1}{11} = {c=6}{},\n",
  135. " vline{3, 4,11,17} = {1-3}{},\n",
  136. " hline{2} = {3-15}{},\n",
  137. " row{4, 7} = {c},\n",
  138. " cell{4, 7}{1} = {c=16}{},\n",
  139. " hline{6, 9} = {-}{},\n",
  140. " hline{4, 7, 10} = {-}{2px},,\"\"\"\n",
  141. "\n",
  142. "def _head_rows():\n",
  143. " return [\n",
  144. " r\" & \\rot{\\eztb{Softmax}} & \\rot{\\eztb{Dropout}} & GLUE &&&&&&& SuperGLUE &&&&&&\",\n",
  145. " r\"Task→ &&& QQP & QNLI & MNLI & SST-2 & STS-B & MRPC & CoLA & MultiRC & RTE & CB & COPA & WiC & BoolQ & Avg.\",\n",
  146. " r\"Method↓ &&& F1/Acc. & Acc. & Acc. & Acc. & PCC/$\\rho$ & F1/Acc. & MCC & F1a/EM & Acc. & F1/Acc. & Acc. & Acc. & Acc. & -\"\n",
  147. " ]\n",
  148. "\n",
  149. "def _section_row(name):\n",
  150. " return name + \"&&&&&&& &&&&&&&&&\"\n",
  151. "\n",
  152. "def _convert_number(n):\n",
  153. " if n == 0:\n",
  154. " return '0.0 $\\\\dag$'\n",
  155. " return f\"{100 * n:.1f}\"\n",
  156. "\n",
  157. "def _normal_row(name, is_softmax, is_dropout, numbers, bold_mask=None):\n",
  158. " numbers_str = [_convert_number(n) for n in numbers]\n",
  159. " if bold_mask is not None:\n",
  160. " for idx, bold_state in enumerate(bold_mask):\n",
  161. " if bold_state:\n",
  162. " numbers_str[idx] = \"\\\\textbf{\" + numbers_str[idx] + \"}\"\n",
  163. " \n",
  164. " soft_mark = \"\\\\cmark\" if is_softmax else \"\\\\xmark\"\n",
  165. " drop_mark = \"\\\\cmark\" if is_dropout else \"\\\\xmark\"\n",
  166. " return \" & \".join([name, soft_mark, drop_mark, *numbers_str])\n",
  167. " \n",
  168. "def generate_rows(names, softmaxes, dropouts, numbers):\n",
  169. " mean = numbers.mean(axis=1, keepdims=True)\n",
  170. " numbers = np.concatenate((numbers, mean), axis=1)\n",
  171. " pefts = numbers\n",
  172. " pefts_best = pefts.max(axis=0)\n",
  173. " \n",
  174. " rows = [\n",
  175. " _normal_row(name, is_softmax, drop, peft_row, peft_row == pefts_best)\n",
  176. " for (name, is_softmax, drop, peft_row) in zip(names, softmaxes, dropouts, pefts)\n",
  177. " ]\n",
  178. " return rows\n",
  179. " \n",
  180. "def generate_table(rows1_key, rows1, rows2_key, rows2):\n",
  181. " end_line = '\\\\\\\\\\n'\n",
  182. " rows = [\n",
  183. " *_head_rows(),\n",
  184. " _section_row(rows1_key),\n",
  185. " *rows1,\n",
  186. " _section_row(rows2_key),\n",
  187. " *rows2,\n",
  188. " ]\n",
  189. " return r\"\"\"\\begin{tblr}{\n",
  190. " %s\n",
  191. "}\n",
  192. "%s\n",
  193. "\\end{tblr}\n",
  194. "\"\"\" % (_tblr_args(), end_line.join(rows + [\"\"]))"
  195. ]
  196. },
  197. {
  198. "cell_type": "code",
  199. "execution_count": 11,
  200. "id": "ac11ea00-a9af-4454-982f-2aed9b552e5e",
  201. "metadata": {},
  202. "outputs": [
  203. {
  204. "name": "stdout",
  205. "output_type": "stream",
  206. "text": [
  207. "\\begin{tblr}{\n",
  208. " column{2-16} = {c},\n",
  209. " cell{1}{3} = {r=3}{b},\n",
  210. " cell{1}{4} = {c=7}{c},\n",
  211. " cell{1}{11} = {c=6}{},\n",
  212. " vline{3, 4,11,17} = {1-3}{},\n",
  213. " hline{2} = {3-15}{},\n",
  214. " row{4, 7} = {c},\n",
  215. " cell{4, 7}{1} = {c=16}{},\n",
  216. " hline{6, 9} = {-}{},\n",
  217. " hline{4, 7, 10} = {-}{2px},,\n",
  218. "}\n",
  219. " & \\rot{\\eztb{Softmax}} & \\rot{\\eztb{Dropout}} & GLUE &&&&&&& SuperGLUE &&&&&&\\\\\n",
  220. "Task→ &&& QQP & QNLI & MNLI & SST-2 & STS-B & MRPC & CoLA & MultiRC & RTE & CB & COPA & WiC & BoolQ & Avg.\\\\\n",
  221. "Method↓ &&& F1/Acc. & Acc. & Acc. & Acc. & PCC/$\\rho$ & F1/Acc. & MCC & F1a/EM & Acc. & F1/Acc. & Acc. & Acc. & Acc. & -\\\\\n",
  222. "T5v1.1 Small LM-Adapted&&&&&&& &&&&&&&&&\\\\\n",
  223. "SuperPos PT & \\cmark & \\xmark & \\textbf{81.2} & \\textbf{85.3} & \\textbf{71.7} & \\textbf{89.8} & \\textbf{84.0} & \\textbf{87.9} & \\textbf{38.9} & \\textbf{41.6} & \\textbf{64.6} & \\textbf{75.2} & \\textbf{58.0} & \\textbf{65.7} & \\textbf{68.9} & \\textbf{70.2}\\\\\n",
  224. "SuperPos PT & \\xmark & \\xmark & \\textbf{81.2} & \\textbf{85.3} & \\textbf{71.7} & \\textbf{89.8} & \\textbf{84.0} & \\textbf{87.9} & \\textbf{38.9} & \\textbf{41.6} & \\textbf{64.6} & \\textbf{75.2} & \\textbf{58.0} & \\textbf{65.7} & \\textbf{68.9} & \\textbf{70.2}\\\\\n",
  225. "T5v1.1 Base LM-Adapted&&&&&&& &&&&&&&&&\\\\\n",
  226. "SuperPos PT & \\cmark & \\xmark & \\textbf{81.2} & \\textbf{85.3} & \\textbf{71.7} & \\textbf{89.8} & \\textbf{84.0} & \\textbf{87.9} & \\textbf{38.9} & \\textbf{41.6} & \\textbf{64.6} & \\textbf{75.2} & \\textbf{58.0} & \\textbf{65.7} & \\textbf{68.9} & \\textbf{70.2}\\\\\n",
  227. "SuperPos PT & \\xmark & \\xmark & \\textbf{81.2} & \\textbf{85.3} & \\textbf{71.7} & \\textbf{89.8} & \\textbf{84.0} & \\textbf{87.9} & \\textbf{38.9} & \\textbf{41.6} & \\textbf{64.6} & \\textbf{75.2} & \\textbf{58.0} & \\textbf{65.7} & \\textbf{68.9} & \\textbf{70.2}\\\\\n",
  228. "\n",
  229. "\\end{tblr}\n",
  230. "\n"
  231. ]
  232. }
  233. ],
  234. "source": [
  235. "dropouts = [False, False]\n",
  236. "softmaxes = [True, False]\n",
  237. "names = ['SuperPos PT'] * 2\n",
  238. "# base_rows = generate_rows(names, dropouts, base_df.to_numpy())\n",
  239. "small_rows = generate_rows(names, softmaxes, dropouts, small_df.to_numpy())\n",
  240. "print(generate_table('T5v1.1 Small LM-Adapted', small_rows, 'T5v1.1 Base LM-Adapted', small_rows))"
  241. ]
  242. },
  243. {
  244. "cell_type": "code",
  245. "execution_count": null,
  246. "id": "e138dc33-5b68-4b27-95e9-39c76f4cbc37",
  247. "metadata": {},
  248. "outputs": [],
  249. "source": []
  250. }
  251. ],
  252. "metadata": {
  253. "kernelspec": {
  254. "display_name": "Python [conda env:flash]",
  255. "language": "python",
  256. "name": "conda-env-flash-py"
  257. },
  258. "language_info": {
  259. "codemirror_mode": {
  260. "name": "ipython",
  261. "version": 3
  262. },
  263. "file_extension": ".py",
  264. "mimetype": "text/x-python",
  265. "name": "python",
  266. "nbconvert_exporter": "python",
  267. "pygments_lexer": "ipython3",
  268. "version": "3.10.13"
  269. }
  270. },
  271. "nbformat": 4,
  272. "nbformat_minor": 5
  273. }