{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "cbff7109-365e-42c9-82b1-8e0fa8173d8d", "metadata": { "tags": [] }, "outputs": [], "source": [ "import pandas as pd \n", "import numpy as np\n", "from latex_table import generate_table, generate_rows\n", "import matplotlib.pyplot as plt\n", "from matplotlib.ticker import FormatStrFormatter\n", "\n", "class WandBWrapper:\n", " def __init__(self, prefix=''):\n", " import wandb\n", " self.api = wandb.Api()\n", " self.prefix = prefix\n", " \n", " def get_runs(self, name):\n", " return self.api.runs(f\"{self.prefix}{name}\")\n", " \n", " def _preprocess_config(self, run):\n", " return {\n", " k: v for k,v in run.config.items()\n", " if not k.startswith('_')\n", " }\n", " \n", " def _best_in_history(self, run, key):\n", " out = run.history()[key].astype('float').fillna(0).max()\n", " return max(out, 0)\n", " \n", " def get_full_history(self, runs, tasks, model_size=''):\n", " task_names = [model_size + '_' + task_name for task_name in tasks]\n", " return {\n", " task_name: pd.DataFrame({\n", " run.name: run.history()['valid_mean']\n", " for run in self.get_runs(task_name)\n", " if run.name in runs\n", " })[runs]\n", " for task_name in task_names\n", " }\n", " \n", " def get_runs_best(self, name, run_name_filter=None):\n", " runs = self.get_runs(name)\n", " return {\n", " run.name: self._best_in_history(run, 'valid_mean')\n", " for run in runs\n", " if run_name_filter is None or run.name in run_name_filter\n", " }\n", " \n", " def get_runs_tasks_df(self, runs, tasks, model_size=''):\n", " task_names = [model_size + '_' + task_name for task_name in tasks]\n", " results = {task_name: self.get_runs_best(task_name, runs) for task_name in task_names}\n", " return pd.DataFrame(results).T[runs].T" ] }, { "cell_type": "code", "execution_count": 4, "id": "2e3239bf-7044-4ffd-93f3-39272dbd82ff", "metadata": { "tags": [] }, "outputs": [], "source": [ "tasks = [\n", " # 'glue-wnli',\n", " # 'glue-rte',\n", " 'glue-qqp', # new datasets\n", " # 'glue-qnli', # new datasets\n", " # 'glue-mnli', # new datasets\n", " # 'glue-sst2', # new datasets\n", " # 'glue-stsb', # new datasets\n", " 'glue-mrpc',\n", " 'glue-cola',\n", " # 'superglue-multirc', # new datasets\n", " 'superglue-rte',\n", " 'superglue-cb',\n", " # 'superglue-copa', # not in attempt\n", " 'superglue-wic',\n", " 'superglue-boolq',\n", "]\n", "\n", "runs = [\n", " '10_combine_128',\n", "]\n", "\n", "df = WandBWrapper(\"mohalisad/iclr_orig_t5_t5_\").get_runs_tasks_df(\n", " runs=runs,\n", " tasks=tasks,\n", " model_size='base'\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "id": "050389ec-ce24-431f-b1cb-e21f4c942c20", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", " | base_glue-qqp | \n", "base_glue-mrpc | \n", "base_glue-cola | \n", "base_superglue-rte | \n", "base_superglue-cb | \n", "base_superglue-copa | \n", "base_superglue-wic | \n", "base_superglue-boolq | \n", "
---|---|---|---|---|---|---|---|---|
10_combine_128 | \n", "0.892432 | \n", "0.909251 | \n", "0.596682 | \n", "0.801444 | \n", "0.968944 | \n", "0.66 | \n", "0.675549 | \n", "0.813456 | \n", "