You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

orig_t5.ipynb 6.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "id": "cbff7109-365e-42c9-82b1-8e0fa8173d8d",
  7. "metadata": {
  8. "tags": []
  9. },
  10. "outputs": [],
  11. "source": [
  12. "import pandas as pd \n",
  13. "import numpy as np\n",
  14. "from latex_table import generate_table, generate_rows\n",
  15. "import matplotlib.pyplot as plt\n",
  16. "from matplotlib.ticker import FormatStrFormatter\n",
  17. "\n",
  18. "class WandBWrapper:\n",
  19. " def __init__(self, prefix=''):\n",
  20. " import wandb\n",
  21. " self.api = wandb.Api()\n",
  22. " self.prefix = prefix\n",
  23. " \n",
  24. " def get_runs(self, name):\n",
  25. " return self.api.runs(f\"{self.prefix}{name}\")\n",
  26. " \n",
  27. " def _preprocess_config(self, run):\n",
  28. " return {\n",
  29. " k: v for k,v in run.config.items()\n",
  30. " if not k.startswith('_')\n",
  31. " }\n",
  32. " \n",
  33. " def _best_in_history(self, run, key):\n",
  34. " out = run.history()[key].astype('float').fillna(0).max()\n",
  35. " return max(out, 0)\n",
  36. " \n",
  37. " def get_full_history(self, runs, tasks, model_size=''):\n",
  38. " task_names = [model_size + '_' + task_name for task_name in tasks]\n",
  39. " return {\n",
  40. " task_name: pd.DataFrame({\n",
  41. " run.name: run.history()['valid_mean']\n",
  42. " for run in self.get_runs(task_name)\n",
  43. " if run.name in runs\n",
  44. " })[runs]\n",
  45. " for task_name in task_names\n",
  46. " }\n",
  47. " \n",
  48. " def get_runs_best(self, name, run_name_filter=None):\n",
  49. " runs = self.get_runs(name)\n",
  50. " return {\n",
  51. " run.name: self._best_in_history(run, 'valid_mean')\n",
  52. " for run in runs\n",
  53. " if run_name_filter is None or run.name in run_name_filter\n",
  54. " }\n",
  55. " \n",
  56. " def get_runs_tasks_df(self, runs, tasks, model_size=''):\n",
  57. " task_names = [model_size + '_' + task_name for task_name in tasks]\n",
  58. " results = {task_name: self.get_runs_best(task_name, runs) for task_name in task_names}\n",
  59. " return pd.DataFrame(results).T[runs].T"
  60. ]
  61. },
  62. {
  63. "cell_type": "code",
  64. "execution_count": 4,
  65. "id": "2e3239bf-7044-4ffd-93f3-39272dbd82ff",
  66. "metadata": {
  67. "tags": []
  68. },
  69. "outputs": [],
  70. "source": [
  71. "tasks = [\n",
  72. " # 'glue-wnli',\n",
  73. " # 'glue-rte',\n",
  74. " 'glue-qqp', # new datasets\n",
  75. " # 'glue-qnli', # new datasets\n",
  76. " # 'glue-mnli', # new datasets\n",
  77. " # 'glue-sst2', # new datasets\n",
  78. " # 'glue-stsb', # new datasets\n",
  79. " 'glue-mrpc',\n",
  80. " 'glue-cola',\n",
  81. " # 'superglue-multirc', # new datasets\n",
  82. " 'superglue-rte',\n",
  83. " 'superglue-cb',\n",
  84. " # 'superglue-copa', # not in attempt\n",
  85. " 'superglue-wic',\n",
  86. " 'superglue-boolq',\n",
  87. "]\n",
  88. "\n",
  89. "runs = [\n",
  90. " '10_combine_128',\n",
  91. "]\n",
  92. "\n",
  93. "df = WandBWrapper(\"mohalisad/iclr_orig_t5_t5_\").get_runs_tasks_df(\n",
  94. " runs=runs,\n",
  95. " tasks=tasks,\n",
  96. " model_size='base'\n",
  97. ")"
  98. ]
  99. },
  100. {
  101. "cell_type": "code",
  102. "execution_count": 5,
  103. "id": "050389ec-ce24-431f-b1cb-e21f4c942c20",
  104. "metadata": {
  105. "tags": []
  106. },
  107. "outputs": [
  108. {
  109. "data": {
  110. "text/html": [
  111. "<div>\n",
  112. "<style scoped>\n",
  113. " .dataframe tbody tr th:only-of-type {\n",
  114. " vertical-align: middle;\n",
  115. " }\n",
  116. "\n",
  117. " .dataframe tbody tr th {\n",
  118. " vertical-align: top;\n",
  119. " }\n",
  120. "\n",
  121. " .dataframe thead th {\n",
  122. " text-align: right;\n",
  123. " }\n",
  124. "</style>\n",
  125. "<table border=\"1\" class=\"dataframe\">\n",
  126. " <thead>\n",
  127. " <tr style=\"text-align: right;\">\n",
  128. " <th></th>\n",
  129. " <th>base_glue-qqp</th>\n",
  130. " <th>base_glue-mrpc</th>\n",
  131. " <th>base_glue-cola</th>\n",
  132. " <th>base_superglue-rte</th>\n",
  133. " <th>base_superglue-cb</th>\n",
  134. " <th>base_superglue-copa</th>\n",
  135. " <th>base_superglue-wic</th>\n",
  136. " <th>base_superglue-boolq</th>\n",
  137. " </tr>\n",
  138. " </thead>\n",
  139. " <tbody>\n",
  140. " <tr>\n",
  141. " <th>10_combine_128</th>\n",
  142. " <td>0.892432</td>\n",
  143. " <td>0.909251</td>\n",
  144. " <td>0.596682</td>\n",
  145. " <td>0.801444</td>\n",
  146. " <td>0.968944</td>\n",
  147. " <td>0.66</td>\n",
  148. " <td>0.675549</td>\n",
  149. " <td>0.813456</td>\n",
  150. " </tr>\n",
  151. " </tbody>\n",
  152. "</table>\n",
  153. "</div>"
  154. ],
  155. "text/plain": [
  156. " base_glue-qqp base_glue-mrpc base_glue-cola \\\n",
  157. "10_combine_128 0.892432 0.909251 0.596682 \n",
  158. "\n",
  159. " base_superglue-rte base_superglue-cb base_superglue-copa \\\n",
  160. "10_combine_128 0.801444 0.968944 0.66 \n",
  161. "\n",
  162. " base_superglue-wic base_superglue-boolq \n",
  163. "10_combine_128 0.675549 0.813456 "
  164. ]
  165. },
  166. "execution_count": 5,
  167. "metadata": {},
  168. "output_type": "execute_result"
  169. }
  170. ],
  171. "source": [
  172. "df"
  173. ]
  174. },
  175. {
  176. "cell_type": "code",
  177. "execution_count": null,
  178. "id": "36774895-c1e4-4d26-bfc7-69e4003d2bbb",
  179. "metadata": {},
  180. "outputs": [],
  181. "source": []
  182. }
  183. ],
  184. "metadata": {
  185. "kernelspec": {
  186. "display_name": "Python [conda env:deep]",
  187. "language": "python",
  188. "name": "conda-env-deep-py"
  189. },
  190. "language_info": {
  191. "codemirror_mode": {
  192. "name": "ipython",
  193. "version": 3
  194. },
  195. "file_extension": ".py",
  196. "mimetype": "text/x-python",
  197. "name": "python",
  198. "nbconvert_exporter": "python",
  199. "pygments_lexer": "ipython3",
  200. "version": "3.10.13"
  201. }
  202. },
  203. "nbformat": 4,
  204. "nbformat_minor": 5
  205. }