123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- {
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "cbff7109-365e-42c9-82b1-8e0fa8173d8d",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "import pandas as pd \n",
- "import numpy as np\n",
- "from latex_table import generate_table, generate_rows\n",
- "import matplotlib.pyplot as plt\n",
- "from matplotlib.ticker import FormatStrFormatter\n",
- "\n",
- "class WandBWrapper:\n",
- " def __init__(self, prefix=''):\n",
- " import wandb\n",
- " self.api = wandb.Api()\n",
- " self.prefix = prefix\n",
- " \n",
- " def get_runs(self, name):\n",
- " return self.api.runs(f\"{self.prefix}{name}\")\n",
- " \n",
- " def _preprocess_config(self, run):\n",
- " return {\n",
- " k: v for k,v in run.config.items()\n",
- " if not k.startswith('_')\n",
- " }\n",
- " \n",
- " def _best_in_history(self, run, key):\n",
- " out = run.history()[key].astype('float').fillna(0).max()\n",
- " return max(out, 0)\n",
- " \n",
- " def get_full_history(self, runs, tasks, model_size=''):\n",
- " task_names = [model_size + '_' + task_name for task_name in tasks]\n",
- " return {\n",
- " task_name: pd.DataFrame({\n",
- " run.name: run.history()['valid_mean']\n",
- " for run in self.get_runs(task_name)\n",
- " if run.name in runs\n",
- " })[runs]\n",
- " for task_name in task_names\n",
- " }\n",
- " \n",
- " def get_runs_best(self, name, run_name_filter=None):\n",
- " runs = self.get_runs(name)\n",
- " return {\n",
- " run.name: self._best_in_history(run, 'valid_mean')\n",
- " for run in runs\n",
- " if run_name_filter is None or run.name in run_name_filter\n",
- " }\n",
- " \n",
- " def get_runs_tasks_df(self, runs, tasks, model_size=''):\n",
- " task_names = [model_size + '_' + task_name for task_name in tasks]\n",
- " results = {task_name: self.get_runs_best(task_name, runs) for task_name in task_names}\n",
- " return pd.DataFrame(results).T[runs].T"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "2e3239bf-7044-4ffd-93f3-39272dbd82ff",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "tasks = [\n",
- " # 'glue-wnli',\n",
- " # 'glue-rte',\n",
- " 'glue-qqp', # new datasets\n",
- " # 'glue-qnli', # new datasets\n",
- " # 'glue-mnli', # new datasets\n",
- " # 'glue-sst2', # new datasets\n",
- " # 'glue-stsb', # new datasets\n",
- " 'glue-mrpc',\n",
- " 'glue-cola',\n",
- " # 'superglue-multirc', # new datasets\n",
- " 'superglue-rte',\n",
- " 'superglue-cb',\n",
- " # 'superglue-copa', # not in attempt\n",
- " 'superglue-wic',\n",
- " 'superglue-boolq',\n",
- "]\n",
- "\n",
- "runs = [\n",
- " '10_combine_128',\n",
- "]\n",
- "\n",
- "df = WandBWrapper(\"mohalisad/iclr_orig_t5_t5_\").get_runs_tasks_df(\n",
- " runs=runs,\n",
- " tasks=tasks,\n",
- " model_size='base'\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "id": "050389ec-ce24-431f-b1cb-e21f4c942c20",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "<div>\n",
- "<style scoped>\n",
- " .dataframe tbody tr th:only-of-type {\n",
- " vertical-align: middle;\n",
- " }\n",
- "\n",
- " .dataframe tbody tr th {\n",
- " vertical-align: top;\n",
- " }\n",
- "\n",
- " .dataframe thead th {\n",
- " text-align: right;\n",
- " }\n",
- "</style>\n",
- "<table border=\"1\" class=\"dataframe\">\n",
- " <thead>\n",
- " <tr style=\"text-align: right;\">\n",
- " <th></th>\n",
- " <th>base_glue-qqp</th>\n",
- " <th>base_glue-mrpc</th>\n",
- " <th>base_glue-cola</th>\n",
- " <th>base_superglue-rte</th>\n",
- " <th>base_superglue-cb</th>\n",
- " <th>base_superglue-copa</th>\n",
- " <th>base_superglue-wic</th>\n",
- " <th>base_superglue-boolq</th>\n",
- " </tr>\n",
- " </thead>\n",
- " <tbody>\n",
- " <tr>\n",
- " <th>10_combine_128</th>\n",
- " <td>0.892432</td>\n",
- " <td>0.909251</td>\n",
- " <td>0.596682</td>\n",
- " <td>0.801444</td>\n",
- " <td>0.968944</td>\n",
- " <td>0.66</td>\n",
- " <td>0.675549</td>\n",
- " <td>0.813456</td>\n",
- " </tr>\n",
- " </tbody>\n",
- "</table>\n",
- "</div>"
- ],
- "text/plain": [
- " base_glue-qqp base_glue-mrpc base_glue-cola \\\n",
- "10_combine_128 0.892432 0.909251 0.596682 \n",
- "\n",
- " base_superglue-rte base_superglue-cb base_superglue-copa \\\n",
- "10_combine_128 0.801444 0.968944 0.66 \n",
- "\n",
- " base_superglue-wic base_superglue-boolq \n",
- "10_combine_128 0.675549 0.813456 "
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "36774895-c1e4-4d26-bfc7-69e4003d2bbb",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python [conda env:deep]",
- "language": "python",
- "name": "conda-env-deep-py"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.13"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
- }
|