{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "135746cc-454c-41a2-977c-cf633899f002", "metadata": { "tags": [] }, "outputs": [], "source": [ "import pandas as pd \n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from matplotlib.ticker import FormatStrFormatter\n", "\n", "class WandBWrapper:\n", " def __init__(self, prefix=''):\n", " import wandb\n", " self.api = wandb.Api()\n", " self.prefix = prefix\n", " \n", " def get_runs(self, name):\n", " return self.api.runs(f\"{self.prefix}{name}\")\n", " \n", " def _preprocess_config(self, run):\n", " return {\n", " k: v for k,v in run.config.items()\n", " if not k.startswith('_')\n", " }\n", " \n", " def sort_valid_columns(self, cols):\n", " priority = {\n", " 'matthews_correlation': 0,\n", " 'f1': 1,\n", " 'f1_a':1,\n", " 'accuracy': 2,\n", " 'exact_match': 3,\n", " 'pearson': 5,\n", " 'spearmanr': 6\n", " }\n", " \n", " for col in cols: # mnli dirty fix\n", " if 'matched_accuracy' in col:\n", " return ['valid_mean']\n", " \n", " cols = [col for col in cols if 'f1_m' not in col]\n", " \n", " stripper = lambda x: x[x.find('_') + 1:]\n", " return list(sorted(cols, key=lambda x: priority[stripper(x)]))\n", " \n", " def _best_in_history(self, run, key):\n", " history = run.history()\n", " all_valid_columns = [col for col in history.columns if 'valid' in col and 'mean' not in col]\n", " best_row_idx = history[key].astype('float').fillna(0).argmax()\n", " all_valid_columns = self.sort_valid_columns(all_valid_columns)\n", " return [max(float(history[key][best_row_idx]), 0) for key in all_valid_columns]\n", " \n", " def get_full_history(self, runs, tasks, model_size=''):\n", " task_names = [model_size + '_' + task_name for task_name in tasks]\n", " return {\n", " task_name: pd.DataFrame({\n", " run.name: run.history()['valid_mean']\n", " for run in self.get_runs(task_name)\n", " if run.name in runs\n", " })[runs]\n", " for task_name in task_names\n", " }\n", " \n", " def get_runs_best(self, name, run_name_filter=None):\n", " runs = self.get_runs(name)\n", " return {\n", " run.name: self._best_in_history(run, 'valid_mean')\n", " for run in runs\n", " if run_name_filter is None or run.name in run_name_filter\n", " }\n", " \n", " def get_runs_tasks_df(self, runs, tasks, model_size=''):\n", " task_names = [model_size + '_' + task_name for task_name in tasks]\n", " results = {task_name: self.get_runs_best(task_name, runs) for task_name in task_names}\n", " return pd.DataFrame(results).T[runs].T" ] }, { "cell_type": "code", "execution_count": 2, "id": "a4ddeace-44eb-4a2d-b215-b3d9af067204", "metadata": { "tags": [] }, "outputs": [], "source": [ "attempt = {\n", " 'qqp': ['-', 0.903], # F1/acc\n", " 'qnli': [0.930],\n", " 'mnli': [0.843],\n", " 'sst2': [0.932],\n", " 'stsb': [0.897, '-'], # Pearson / rho\n", " 'mrpc': ['-', 0.857], # F1/acc\n", " 'cola': [0.574],\n", " 'multirc': [0.744, \"-\"], # F1a / EM\n", " 'rte': [0.734],\n", " 'cb': [\"-\", 0.786], # F1/acc\n", " 'copa': '-',\n", " 'wic': [0.668],\n", " 'boolq': [0.788],\n", "}\n", "residual = {\n", " 'qqp': \"-\",\n", " 'qnli': \"-\",\n", " 'mnli': \"-\",\n", " 'sst2': \"-\",\n", " 'stsb': \"-\",\n", " 'mrpc': \"-\",\n", " 'cola': \"-\",\n", " 'multirc': [0.593],\n", " 'rte': [0.704],\n", " 'cb': [0.792],\n", " 'copa': [0.583],\n", " 'wic': [0.668],\n", " 'boolq': [0.779],\n", "}" ] }, { "cell_type": "code", "execution_count": 3, "id": "28243b98-8fa8-4fc0-a348-b905c126bdd7", "metadata": { "tags": [] }, "outputs": [], "source": [ "import json\n", "import numpy as np\n", "from pathlib import Path \n", "\n", "def load_gpt_score(base_path, task_name):\n", " base_path = Path(base_path)\n", " if task_name == 'mnli':\n", " matched = json.loads((base_path / f'{task_name}_matched.json').read_text())\n", " mismatched = json.loads((base_path / f'{task_name}_mismatched.json').read_text())\n", " return [np.mean([*matched.values(), *mismatched.values()])]\n", " \n", " performance = json.loads((base_path / f'{task_name}.json').read_text())\n", " \n", " key_priority = {\n", " 'matthews_correlation': 0,\n", " 'f1': 1,\n", " 'f1_a':1,\n", " 'accuracy': 2,\n", " 'exact_match': 3,\n", " 'pearson': 5,\n", " 'spearmanr': 6\n", " }\n", " \n", " performance_keys = list(performance.keys())\n", " if 'f1_m' in performance_keys:\n", " performance_keys.pop(performance_keys.index('f1_m'))\n", " performance_keys.sort(key=lambda x: key_priority[x])\n", " \n", " return [float(performance[key]) for key in performance_keys]\n", "\n", "tasks = [\n", " 'qqp', # new datasets\n", " 'qnli', # new datasets\n", " 'mnli', # new datasets\n", " 'sst2', # new datasets\n", " 'stsb', # new datasets\n", " 'mrpc',\n", " 'cola',\n", " 'multirc', # new datasets\n", " 'rte',\n", " 'cb',\n", " 'copa',\n", " 'wic',\n", " 'boolq',\n", "]\n", "\n", "gpt_performances = {task: load_gpt_score('openai', task) for task in tasks}" ] }, { "cell_type": "code", "execution_count": 4, "id": "5ac2b609-3fb8-4206-a20b-36b2282f3372", "metadata": { "tags": [] }, "outputs": [], "source": [ "tasks = {\n", " # 'glue-wnli',\n", " # 'glue-rte',\n", " 'glue-qqp': 'qqp', # new datasets\n", " 'glue-qnli': 'qnli', # new datasets\n", " 'glue-mnli': 'mnli', # new datasets\n", " 'glue-sst2': 'sst2', # new datasets\n", " 'glue-stsb': 'stsb', # new datasets\n", " 'glue-mrpc': 'mrpc',\n", " 'glue-cola': 'cola',\n", " 'superglue-multirc': 'multirc', # new datasets\n", " 'superglue-rte': 'rte',\n", " 'superglue-cb': 'cb',\n", " 'superglue-copa': 'copa',\n", " 'superglue-wic': 'wic',\n", " 'superglue-boolq': 'boolq',\n", "}\n", "\n", "runs = [\n", " '10_combine_128',\n", "] \n", "\n", "base_lmt5_df = WandBWrapper(\"mohalisad/hzi_cluster_t5_\").get_runs_tasks_df(\n", " runs=runs, tasks=tasks.keys(), model_size='base'\n", ")\n", "base_lmt5_df['base_superglue-cb']['10_combine_128'] = [0.7826, 0.8214]\n", "small_lmt5_df = WandBWrapper(\"mohalisad/hzi_cluster_t5_\").get_runs_tasks_df(\n", " runs=runs,\n", " tasks=tasks.keys(),\n", " model_size='small'\n", ")\n", "small_lmt5_softmax_df = WandBWrapper(\"mohalisad/iclr_softmax_effect_t5_\").get_runs_tasks_df(\n", " runs=runs,\n", " tasks=tasks.keys(),\n", " model_size='small'\n", ")\n", "base_origt5_df = WandBWrapper(\"iclr_orig_t5_t5_\").get_runs_tasks_df(\n", " runs=runs, tasks=tasks, model_size='base'\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "id": "b4e6da93-1cad-4310-9e54-f6a5f0c87a58", "metadata": { "tags": [] }, "outputs": [], "source": [ "base_lmt5_df.columns = tasks.values()\n", "small_lmt5_df.columns = tasks.values()\n", "small_lmt5_softmax_df.columns = tasks.values()\n", "base_origt5_df.columns = tasks.values()\n", "\n", "attempt_df = pd.Series(attempt).to_frame().T\n", "residual_df = pd.Series(residual).to_frame().T\n", "gpt_df = pd.Series(gpt_performances).to_frame().T" ] }, { "cell_type": "code", "execution_count": 6, "id": "a58a4bbc-7b62-4c5a-b69c-27252598232b", "metadata": { "tags": [] }, "outputs": [], "source": [ "def my_concat(**kwargs):\n", " merged_df = pd.concat(\n", " list(kwargs.values()),\n", " ignore_index=True\n", " )\n", " merged_df['name'] = list(kwargs.keys())\n", " merged_df.set_index('name', inplace=True)\n", " return merged_df\n", "\n", "comp_orig_df = my_concat(\n", " superpos=base_origt5_df,\n", " attempt=attempt_df,\n", " residual=residual_df\n", ")\n", "comp_softmax_df = my_concat(\n", " superpos=small_lmt5_df,\n", " superpos_softmax=small_lmt5_softmax_df,\n", ")\n", "comb_base_df = my_concat(\n", " superpos=base_lmt5_df\n", ")\n", "comp_gpt_df = my_concat(\n", " gpt=gpt_df\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "id": "b7cbb0bd-0dbe-4f98-9f28-9e1f60d43b1c", "metadata": { "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "import itertools\n", "\n", "def _tblr_args(rows_count_seq):\n", " top_rows = list(np.cumsum([4, *rows_count_seq]))\n", " top_rows_str = ', '.join(map(str, top_rows[:-1]))\n", " bold_line = ', '.join(map(str, top_rows))\n", " return r\"\"\"column{2-18} = {c},\n", " cell{1}{2, 3, 4} = {r=3}{b},\n", " cell{1}{5} = {c=7}{c},\n", " cell{1}{12} = {c=6}{},\n", " vline{2, 3, 4, 5,12,18} = {1-3}{},\n", " hline{2} = {4-17}{},\n", " row{%s} = {c},\n", " cell{%s}{1} = {c=18}{},\n", " hline{%s} = {-}{2px},,\"\"\" % (top_rows_str, top_rows_str, bold_line)\n", "\n", "def _head_rows():\n", " return [\n", " r\"&\\rot{\\eztb{\\# Prompts}} & \\rot{\\eztb{Softmax}} & \\rot{\\eztb{Dropout}} & GLUE &&&&&&& SuperGLUE &&&&&&\",\n", " r\"Task→ &&&& QQP & QNLI & MNLI & SST-2 & STS-B & MRPC & CoLA & MultiRC & RTE & CB & COPA & WiC & BoolQ & Avg.\",\n", " r\"Method↓ &&&& F1/Acc. & Acc. & Acc. & Acc. & PCC/$\\rho$ & F1/Acc. & MCC & F1a/EM & Acc. & F1/Acc. & Acc. & Acc. & Acc. & -\"\n", " ]\n", "\n", "def _section_row(name):\n", " return name\n", "\n", "def to_pure_number(item):\n", " if isinstance(item, list):\n", " item = [x for x in item if x != '-']\n", " if len(item) == 0:\n", " return '-'\n", " return sum(item) / len(item)\n", " return item\n", "\n", "def to_pure_numbers(numbers):\n", " return np.array([\n", " to_pure_number(list_item)\n", " for list_item in numbers\n", " ])\n", "\n", "def _convert_single_number(single_number):\n", " if single_number == '-':\n", " return '-'\n", " if isinstance(single_number, str):\n", " print(single_number)\n", " return f\"{100 * single_number:.1f}\"\n", "\n", "def _convert_number(n):\n", " if not isinstance(n, list):\n", " n = [n]\n", " number_str = \"/\".join([_convert_single_number(n_item) for n_item in n])\n", " if to_pure_number(n) == 0:\n", " return f'{number_str} $\\\\dag$'\n", " return number_str\n", "\n", "def _get_mark(mark_bool):\n", " if mark_bool is None:\n", " return \"\"\n", " return \"\\\\cmark\" if mark_bool else \"\\\\xmark\"\n", "\n", "def _normal_row(name, prompt_count, is_softmax, is_dropout, numbers, bold_mask=None):\n", " numbers_str = [_convert_number(n) for n in numbers]\n", " if bold_mask is not None:\n", " for idx, bold_state in enumerate(bold_mask):\n", " if bold_state:\n", " numbers_str[idx] = \"\\\\textbf{\" + numbers_str[idx] + \"}\"\n", " \n", " prompt_count = str(prompt_count) if prompt_count is not None else \"\"\n", " return \" & \".join([name, prompt_count, _get_mark(is_softmax), _get_mark(is_dropout), *numbers_str])\n", "\n", "def _compute_mean(numbers):\n", " return np.array([[\n", " '-'\n", " if '-' in list(row)\n", " else to_pure_numbers(row).mean()\n", " for row in numbers\n", " ]], dtype=object).T\n", "\n", "def generate_rows(names, prompt_counts, softmaxes, dropouts, numbers, first_row_bold=False):\n", " mean = _compute_mean(numbers)\n", " numbers = np.concatenate((numbers, mean), axis=1)\n", " \n", " if first_row_bold:\n", " mask = np.zeros_like(numbers)\n", " mask[0, :] = 1\n", " mask = mask.astype(bool)\n", " args_zip = zip(names, prompt_counts, softmaxes, dropouts, numbers, mask)\n", " else:\n", " args_zip = zip(names, prompt_counts, softmaxes, dropouts, numbers)\n", " \n", " rows = [\n", " _normal_row(*args)\n", " for args in args_zip\n", " ]\n", " return rows\n", " \n", "def generate_table(input_dict):\n", " all_rows = [(_section_row(key), *val) for (key, val) in input_dict.items()]\n", " rows_count_seq = [len(row) for row in all_rows]\n", " all_rows_flatten = itertools.chain.from_iterable(all_rows)\n", " end_line = '\\\\\\\\\\n'\n", " rows = [\n", " *_head_rows(),\n", " *all_rows_flatten\n", " ]\n", " return r\"\"\"\\begin{tblr}{\n", " %s\n", "}\n", "%s\n", "\\end{tblr}\n", "\"\"\" % (_tblr_args(rows_count_seq), end_line.join(rows + [\"\"]))" ] }, { "cell_type": "code", "execution_count": 15, "id": "f760915e-5c07-4aed-b0b8-1d46a5002bd0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\\begin{tblr}{\n", " column{2-18} = {c},\n", " cell{1}{2, 3, 4} = {r=3}{b},\n", " cell{1}{5} = {c=7}{c},\n", " cell{1}{12} = {c=6}{},\n", " vline{2, 3, 4, 5,12,18} = {1-3}{},\n", " hline{2} = {4-17}{},\n", " row{4, 8, 11, 13} = {c},\n", " cell{4, 8, 11, 13}{1} = {c=18}{},\n", " hline{4, 8, 11, 13, 15} = {-}{2px},,\n", "}\n", "&\\rot{\\eztb{\\# Prompts}} & \\rot{\\eztb{Softmax}} & \\rot{\\eztb{Dropout}} & GLUE &&&&&&& SuperGLUE &&&&&&\\\\\n", "Task→ &&&& QQP & QNLI & MNLI & SST-2 & STS-B & MRPC & CoLA & MultiRC & RTE & CB & COPA & WiC & BoolQ & Avg.\\\\\n", "Method↓ &&&& F1/Acc. & Acc. & Acc. & Acc. & PCC/$\\rho$ & F1/Acc. & MCC & F1a/EM & Acc. & F1/Acc. & Acc. & Acc. & Acc. & -\\\\\n", "T5 Base\\\\\n", "SuperPos PT & 10 & \\xmark & \\xmark & \\textbf{87.8/90.8} & \\textbf{93.5} & \\textbf{86.0} & \\textbf{94.4} & \\textbf{90.2/90.1} & \\textbf{92.4/89.5} & \\textbf{59.7} & \\textbf{77.7/40.9} & \\textbf{80.1} & \\textbf{97.4/96.4} & \\textbf{66.0} & \\textbf{67.6} & \\textbf{81.3} & \\textbf{81.2}\\\\\n", "ATTEMPT $\\star$ & 100 & \\cmark & \\cmark & -/90.3 & 93.0 & 84.3 & 93.2 & 89.7/- & -/85.7 & 57.4 & 74.4/- & 73.4 & -/78.6 & - & 66.8 & 78.8 & -\\\\\n", "Residual PT $\\star$ & 10 & \\xmark & \\cmark & - & - & - & - & - & - & - & 59.3 & 70.4 & 79.2 & 58.3 & 66.8 & 77.9 & -\\\\\n", "T5v1.1 Small LM-Adapted\\\\\n", "SuperPos PT & 10 & \\xmark & \\xmark & \\textbf{79.1/83.3} & \\textbf{85.3} & \\textbf{71.7} & \\textbf{89.8} & \\textbf{84.0/84.0} & \\textbf{89.9/85.8} & \\textbf{38.9} & \\textbf{66.6/16.7} & \\textbf{64.6} & \\textbf{73.6/76.8} & \\textbf{58.0} & \\textbf{65.7} & \\textbf{68.9} & \\textbf{70.2}\\\\\n", "SuperPos PT & 10 & \\cmark & \\xmark & 69.6/75.2 & 76.0 & 42.7 & 82.9 & 45.5/43.3 & 82.4/73.0 & 4.6 & 47.5/0.9 & 52.0 & 49.9/71.4 & 57.0 & 56.4 & 62.3 & 54.9\\\\\n", "T5v1.1 Base LM-Adapted\\\\\n", "SuperPos PT & 10 & \\xmark & \\xmark & 81.9/86.3 & 89.8 & 81.0 & 94.2 & 88.6/88.5 & 89.7/85.5 & 56.5 & 72.9/24.9 & 70.4 & 78.3/82.1 & 62.0 & 67.6 & 74.0 & 75.8\\\\\n", "GPT-3.5-Turbo\\\\\n", "1 Shot & & & & 76.3/79.2 & 70.9 & 58.5 & 94.0 & 34.6/34.1 & 84.6/77.0 & 46.1 & 77.9/34.1 & 70.8 & 55.6/62.5 & 95.0 & 58.8 & 69.6 & 67.1\\\\\n", "\n", "\\end{tblr}\n", "\n" ] } ], "source": [ "comp_orig_rows = generate_rows(\n", " names=['SuperPos PT', 'ATTEMPT $\\star$', 'Residual PT $\\star$'],\n", " prompt_counts=[10, 100, 10],\n", " softmaxes=[False, True, False],\n", " dropouts=[False, True, True],\n", " numbers=comp_orig_df.to_numpy(),\n", " first_row_bold=True\n", ")\n", "comp_softmax_rows = generate_rows(\n", " names=['SuperPos PT', 'SuperPos PT'],\n", " prompt_counts=[10, 10],\n", " softmaxes=[False, True],\n", " dropouts=[False, False],\n", " numbers=comp_softmax_df.to_numpy(),\n", " first_row_bold=True\n", ")\n", "comb_base_rows = generate_rows(\n", " names=['SuperPos PT'],\n", " prompt_counts=[10],\n", " softmaxes=[False],\n", " dropouts=[False],\n", " numbers=comb_base_df.to_numpy()\n", ")\n", "comp_gpt_rows = generate_rows(\n", " names=['1 Shot'],\n", " prompt_counts=[None],\n", " softmaxes=[None],\n", " dropouts=[None],\n", " numbers=comp_gpt_df.to_numpy()\n", ")\n", "\n", "\n", "print(generate_table({\n", " 'T5 Base': comp_orig_rows,\n", " 'T5v1.1 Small LM-Adapted': comp_softmax_rows,\n", " 'T5v1.1 Base LM-Adapted': comb_base_rows,\n", " 'GPT-3.5-Turbo': comp_gpt_rows\n", "}))" ] }, { "cell_type": "code", "execution_count": 9, "id": "624c8219-2f9f-4321-9bb4-e5c9f4c8a2d8", "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'base_df' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mbase_df\u001b[49m\u001b[38;5;241m.\u001b[39mto_numpy()\n", "\u001b[0;31mNameError\u001b[0m: name 'base_df' is not defined" ] } ], "source": [ "base_df.to_numpy()" ] }, { "cell_type": "code", "execution_count": null, "id": "c9559566-d8fb-4310-ad31-fb204877609f", "metadata": { "tags": [] }, "outputs": [], "source": [ "import pandas as pd" ] }, { "cell_type": "code", "execution_count": null, "id": "98ad4c6b-7de1-483a-993e-f4f3332a65c6", "metadata": { "tags": [] }, "outputs": [], "source": [ "pd.DataFrame({'a': [1, 2., '-'], 'b': [0, 5, 1]}).to_numpy()[0].mean()" ] }, { "cell_type": "code", "execution_count": null, "id": "a68c7196-462b-407f-b84a-98265296b612", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:deep]", "language": "python", "name": "conda-env-deep-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 5 }