{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "cbff7109-365e-42c9-82b1-8e0fa8173d8d", "metadata": { "tags": [] }, "outputs": [], "source": [ "import pandas as pd \n", "import numpy as np\n", "from latex_table import generate_table, generate_rows\n", "import matplotlib.pyplot as plt\n", "from matplotlib.ticker import FormatStrFormatter\n", "\n", "class WandBWrapper:\n", " def __init__(self, prefix=''):\n", " import wandb\n", " self.api = wandb.Api()\n", " self.prefix = prefix\n", " \n", " def get_runs(self, name):\n", " return self.api.runs(f\"{self.prefix}{name}\")\n", " \n", " def _preprocess_config(self, run):\n", " return {\n", " k: v for k,v in run.config.items()\n", " if not k.startswith('_')\n", " }\n", " \n", " def _best_in_history(self, run, key):\n", " out = run.history()[key].astype('float').fillna(0).max()\n", " return max(out, 0)\n", " \n", " def get_full_history(self, runs, tasks, model_size=''):\n", " task_names = [model_size + '_' + task_name for task_name in tasks]\n", " return {\n", " task_name: pd.DataFrame({\n", " run.name: run.history()['valid_mean']\n", " for run in self.get_runs(task_name)\n", " if run.name in runs\n", " })[runs]\n", " for task_name in task_names\n", " }\n", " \n", " def get_runs_best(self, name, run_name_filter=None):\n", " runs = self.get_runs(name)\n", " return {\n", " run.name: self._best_in_history(run, 'valid_mean')\n", " for run in runs\n", " if run_name_filter is None or run.name in run_name_filter\n", " }\n", " \n", " def get_runs_tasks_df(self, runs, tasks, model_size=''):\n", " task_names = [model_size + '_' + task_name for task_name in tasks]\n", " results = {task_name: self.get_runs_best(task_name, runs) for task_name in task_names}\n", " return pd.DataFrame(results).T[runs].T" ] }, { "cell_type": "code", "execution_count": 4, "id": "2e3239bf-7044-4ffd-93f3-39272dbd82ff", "metadata": { "tags": [] }, "outputs": [], "source": [ "tasks = [\n", " # 'glue-wnli',\n", " # 'glue-rte',\n", " 'glue-qqp', # new datasets\n", " # 'glue-qnli', # new datasets\n", " # 'glue-mnli', # new datasets\n", " # 'glue-sst2', # new datasets\n", " # 'glue-stsb', # new datasets\n", " 'glue-mrpc',\n", " 'glue-cola',\n", " # 'superglue-multirc', # new datasets\n", " 'superglue-rte',\n", " 'superglue-cb',\n", " # 'superglue-copa', # not in attempt\n", " 'superglue-wic',\n", " 'superglue-boolq',\n", "]\n", "\n", "runs = [\n", " '10_combine_128',\n", "]\n", "\n", "df = WandBWrapper(\"mohalisad/iclr_orig_t5_t5_\").get_runs_tasks_df(\n", " runs=runs,\n", " tasks=tasks,\n", " model_size='base'\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "id": "050389ec-ce24-431f-b1cb-e21f4c942c20", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
base_glue-qqpbase_glue-mrpcbase_glue-colabase_superglue-rtebase_superglue-cbbase_superglue-copabase_superglue-wicbase_superglue-boolq
10_combine_1280.8924320.9092510.5966820.8014440.9689440.660.6755490.813456
\n", "
" ], "text/plain": [ " base_glue-qqp base_glue-mrpc base_glue-cola \\\n", "10_combine_128 0.892432 0.909251 0.596682 \n", "\n", " base_superglue-rte base_superglue-cb base_superglue-copa \\\n", "10_combine_128 0.801444 0.968944 0.66 \n", "\n", " base_superglue-wic base_superglue-boolq \n", "10_combine_128 0.675549 0.813456 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "code", "execution_count": null, "id": "36774895-c1e4-4d26-bfc7-69e4003d2bbb", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:deep]", "language": "python", "name": "conda-env-deep-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 5 }