{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "54a7edcf-605f-40f1-9e89-d62067f55dd3", "metadata": { "tags": [] }, "outputs": [], "source": [ "import pandas as pd \n", "import numpy as np\n", "from latex_table import generate_table, generate_rows\n", "import matplotlib.pyplot as plt\n", "from matplotlib.ticker import FormatStrFormatter\n", "\n", "class WandBWrapper:\n", " def __init__(self, prefix=''):\n", " import wandb\n", " self.api = wandb.Api()\n", " self.prefix = prefix\n", " \n", " def get_runs(self, name):\n", " return self.api.runs(f\"{self.prefix}{name}\")\n", " \n", " def _preprocess_config(self, run):\n", " return {\n", " k: v for k,v in run.config.items()\n", " if not k.startswith('_')\n", " }\n", " \n", " def _best_in_history(self, run, key):\n", " out = run.history()[key].astype('float').fillna(0).max()\n", " return max(out, 0)\n", " \n", " def get_full_history(self, runs, tasks, model_size=''):\n", " task_names = [model_size + '_' + task_name for task_name in tasks]\n", " return {\n", " task_name: pd.DataFrame({\n", " run.name: run.history()['valid_mean']\n", " for run in self.get_runs(task_name)\n", " if run.name in runs\n", " })[runs]\n", " for task_name in task_names\n", " }\n", " \n", " def get_runs_best(self, name, run_name_filter=None):\n", " runs = self.get_runs(name)\n", " return {\n", " run.name: self._best_in_history(run, 'valid_mean')\n", " for run in runs\n", " if run_name_filter is None or run.name in run_name_filter\n", " }\n", " \n", " def get_runs_tasks_df(self, runs, tasks, model_size=''):\n", " task_names = [model_size + '_' + task_name for task_name in tasks]\n", " results = {task_name: self.get_runs_best(task_name, runs) for task_name in task_names}\n", " return pd.DataFrame(results).T[runs].T" ] }, { "cell_type": "code", "execution_count": 2, "id": "1d044235-2d14-4e4b-ad87-2077c9cd89a4", "metadata": { "tags": [] }, "outputs": [], "source": [ "tasks = [\n", " # 'glue-wnli',\n", " # 'glue-rte',\n", " 'glue-qqp', # new datasets\n", " 'glue-qnli', # new datasets\n", " 'glue-mnli', # new datasets\n", " 'glue-sst2', # new datasets\n", " 'glue-stsb', # new datasets\n", " 'glue-mrpc',\n", " 'glue-cola',\n", " 'superglue-multirc', # new datasets\n", " 'superglue-rte',\n", " 'superglue-cb',\n", " 'superglue-copa',\n", " 'superglue-wic',\n", " 'superglue-boolq',\n", "]\n", "\n", "runs = [\n", " '10_combine_128',\n", "] \n", "\n", "# small_df_softmax = WandBWrapper(\"mohalisad/iclr_softmax_effect_t5_\").get_runs_tasks_df(\n", "# runs=runs,\n", "# tasks=tasks,\n", "# model_size='small'\n", "# )\n", "small_df_no_softmax = WandBWrapper(\"mohalisad/hzi_cluster_t5_\").get_runs_tasks_df(\n", " runs=runs,\n", " tasks=tasks,\n", " model_size='small'\n", ")\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "7300ed8f-4477-4e4c-b818-c265c3f02aae", "metadata": { "tags": [] }, "outputs": [], "source": [ "small_df = pd.concat([small_df_no_softmax, small_df_no_softmax], ignore_index=True)\n", "small_df['name'] = ['softmax', 'no_softmax']\n", "small_df.set_index('name', inplace=True)" ] }, { "cell_type": "code", "execution_count": 10, "id": "fe96e491-24ce-4cb8-a25e-0db9cb98435d", "metadata": { "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "\n", "def _tblr_args():\n", " return r\"\"\"column{2-16} = {c},\n", " cell{1}{3} = {r=3}{b},\n", " cell{1}{4} = {c=7}{c},\n", " cell{1}{11} = {c=6}{},\n", " vline{3, 4,11,17} = {1-3}{},\n", " hline{2} = {3-15}{},\n", " row{4, 7} = {c},\n", " cell{4, 7}{1} = {c=16}{},\n", " hline{6, 9} = {-}{},\n", " hline{4, 7, 10} = {-}{2px},,\"\"\"\n", "\n", "def _head_rows():\n", " return [\n", " r\" & \\rot{\\eztb{Softmax}} & \\rot{\\eztb{Dropout}} & GLUE &&&&&&& SuperGLUE &&&&&&\",\n", " r\"Task→ &&& QQP & QNLI & MNLI & SST-2 & STS-B & MRPC & CoLA & MultiRC & RTE & CB & COPA & WiC & BoolQ & Avg.\",\n", " r\"Method↓ &&& F1/Acc. & Acc. & Acc. & Acc. & PCC/$\\rho$ & F1/Acc. & MCC & F1a/EM & Acc. & F1/Acc. & Acc. & Acc. & Acc. & -\"\n", " ]\n", "\n", "def _section_row(name):\n", " return name + \"&&&&&&& &&&&&&&&&\"\n", "\n", "def _convert_number(n):\n", " if n == 0:\n", " return '0.0 $\\\\dag$'\n", " return f\"{100 * n:.1f}\"\n", "\n", "def _normal_row(name, is_softmax, is_dropout, numbers, bold_mask=None):\n", " numbers_str = [_convert_number(n) for n in numbers]\n", " if bold_mask is not None:\n", " for idx, bold_state in enumerate(bold_mask):\n", " if bold_state:\n", " numbers_str[idx] = \"\\\\textbf{\" + numbers_str[idx] + \"}\"\n", " \n", " soft_mark = \"\\\\cmark\" if is_softmax else \"\\\\xmark\"\n", " drop_mark = \"\\\\cmark\" if is_dropout else \"\\\\xmark\"\n", " return \" & \".join([name, soft_mark, drop_mark, *numbers_str])\n", " \n", "def generate_rows(names, softmaxes, dropouts, numbers):\n", " mean = numbers.mean(axis=1, keepdims=True)\n", " numbers = np.concatenate((numbers, mean), axis=1)\n", " pefts = numbers\n", " pefts_best = pefts.max(axis=0)\n", " \n", " rows = [\n", " _normal_row(name, is_softmax, drop, peft_row, peft_row == pefts_best)\n", " for (name, is_softmax, drop, peft_row) in zip(names, softmaxes, dropouts, pefts)\n", " ]\n", " return rows\n", " \n", "def generate_table(rows1_key, rows1, rows2_key, rows2):\n", " end_line = '\\\\\\\\\\n'\n", " rows = [\n", " *_head_rows(),\n", " _section_row(rows1_key),\n", " *rows1,\n", " _section_row(rows2_key),\n", " *rows2,\n", " ]\n", " return r\"\"\"\\begin{tblr}{\n", " %s\n", "}\n", "%s\n", "\\end{tblr}\n", "\"\"\" % (_tblr_args(), end_line.join(rows + [\"\"]))" ] }, { "cell_type": "code", "execution_count": 11, "id": "ac11ea00-a9af-4454-982f-2aed9b552e5e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\\begin{tblr}{\n", " column{2-16} = {c},\n", " cell{1}{3} = {r=3}{b},\n", " cell{1}{4} = {c=7}{c},\n", " cell{1}{11} = {c=6}{},\n", " vline{3, 4,11,17} = {1-3}{},\n", " hline{2} = {3-15}{},\n", " row{4, 7} = {c},\n", " cell{4, 7}{1} = {c=16}{},\n", " hline{6, 9} = {-}{},\n", " hline{4, 7, 10} = {-}{2px},,\n", "}\n", " & \\rot{\\eztb{Softmax}} & \\rot{\\eztb{Dropout}} & GLUE &&&&&&& SuperGLUE &&&&&&\\\\\n", "Task→ &&& QQP & QNLI & MNLI & SST-2 & STS-B & MRPC & CoLA & MultiRC & RTE & CB & COPA & WiC & BoolQ & Avg.\\\\\n", "Method↓ &&& F1/Acc. & Acc. & Acc. & Acc. & PCC/$\\rho$ & F1/Acc. & MCC & F1a/EM & Acc. & F1/Acc. & Acc. & Acc. & Acc. & -\\\\\n", "T5v1.1 Small LM-Adapted&&&&&&& &&&&&&&&&\\\\\n", "SuperPos PT & \\cmark & \\xmark & \\textbf{81.2} & \\textbf{85.3} & \\textbf{71.7} & \\textbf{89.8} & \\textbf{84.0} & \\textbf{87.9} & \\textbf{38.9} & \\textbf{41.6} & \\textbf{64.6} & \\textbf{75.2} & \\textbf{58.0} & \\textbf{65.7} & \\textbf{68.9} & \\textbf{70.2}\\\\\n", "SuperPos PT & \\xmark & \\xmark & \\textbf{81.2} & \\textbf{85.3} & \\textbf{71.7} & \\textbf{89.8} & \\textbf{84.0} & \\textbf{87.9} & \\textbf{38.9} & \\textbf{41.6} & \\textbf{64.6} & \\textbf{75.2} & \\textbf{58.0} & \\textbf{65.7} & \\textbf{68.9} & \\textbf{70.2}\\\\\n", "T5v1.1 Base LM-Adapted&&&&&&& &&&&&&&&&\\\\\n", "SuperPos PT & \\cmark & \\xmark & \\textbf{81.2} & \\textbf{85.3} & \\textbf{71.7} & \\textbf{89.8} & \\textbf{84.0} & \\textbf{87.9} & \\textbf{38.9} & \\textbf{41.6} & \\textbf{64.6} & \\textbf{75.2} & \\textbf{58.0} & \\textbf{65.7} & \\textbf{68.9} & \\textbf{70.2}\\\\\n", "SuperPos PT & \\xmark & \\xmark & \\textbf{81.2} & \\textbf{85.3} & \\textbf{71.7} & \\textbf{89.8} & \\textbf{84.0} & \\textbf{87.9} & \\textbf{38.9} & \\textbf{41.6} & \\textbf{64.6} & \\textbf{75.2} & \\textbf{58.0} & \\textbf{65.7} & \\textbf{68.9} & \\textbf{70.2}\\\\\n", "\n", "\\end{tblr}\n", "\n" ] } ], "source": [ "dropouts = [False, False]\n", "softmaxes = [True, False]\n", "names = ['SuperPos PT'] * 2\n", "# base_rows = generate_rows(names, dropouts, base_df.to_numpy())\n", "small_rows = generate_rows(names, softmaxes, dropouts, small_df.to_numpy())\n", "print(generate_table('T5v1.1 Small LM-Adapted', small_rows, 'T5v1.1 Base LM-Adapted', small_rows))" ] }, { "cell_type": "code", "execution_count": null, "id": "e138dc33-5b68-4b27-95e9-39c76f4cbc37", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:flash]", "language": "python", "name": "conda-env-flash-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 5 }