|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201 |
- {
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "pycharm": {
- "name": "#%% md\n"
- },
- "tags": []
- },
- "source": [
- "# Intro"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "from abc import abstractmethod, ABC\n",
- "from os import PathLike\n",
- "from typing import Dict, Union, Optional, Iterable\n",
- "\n",
- "\n",
- "class base_peft(ABC):\n",
- " def __init__(self, base_model_name: Union[str, PathLike[str]], mask_token_id: int):\n",
- " self.base_model_name = base_model_name\n",
- " self.mask_token_id = mask_token_id\n",
- "\n",
- " def activate_task_for_training\n",
- "\n",
- " @abstractmethod\n",
- " def finetune_task(self, peft_name: str, train_dataset, validation_dataset):\n",
- " pass"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2023-08-15T13:16:40.910406Z",
- "start_time": "2023-08-15T13:16:40.860981Z"
- },
- "tags": []
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "/home/mohalisad/Developer/ProgressivePrompts\n"
- ]
- }
- ],
- "source": [
- "cd /home/mohalisad/Developer/ProgressivePrompts"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2023-08-15T13:16:42.467311Z",
- "start_time": "2023-08-15T13:16:42.313951Z"
- },
- "pycharm": {
- "is_executing": true,
- "name": "#%%\n"
- },
- "scrolled": true,
- "tags": []
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Python version is: 3.9.17\n",
- "Torch version is: 1.13.1+cu117\n",
- "Nvidia device is: NVIDIA GeForce RTX 4090\n",
- "Transformers version is: 4.26.1\n",
- "Adapterhub version is: 3.2.1\n"
- ]
- }
- ],
- "source": [
- "from utils import print_system_info\n",
- "print_system_info()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 31,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "from _datasets import AutoLoad\n",
- "from config import load_config\n",
- "from _models import BertAdapterModelWrapper, TokenizerMan\n",
- "\n",
- "\n",
- "config = load_config('config.yaml')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 39,
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "loading configuration file config.json from cache at /home/mohalisad/.cache/huggingface/hub/models--bert-base-uncased/snapshots/1dbc166cf8765166998eff31ade2eb64c8a40076/config.json\n",
- "Model config BertConfig {\n",
- " \"architectures\": [\n",
- " \"BertForMaskedLM\"\n",
- " ],\n",
- " \"attention_probs_dropout_prob\": 0.1,\n",
- " \"classifier_dropout\": null,\n",
- " \"gradient_checkpointing\": false,\n",
- " \"hidden_act\": \"gelu\",\n",
- " \"hidden_dropout_prob\": 0.1,\n",
- " \"hidden_size\": 768,\n",
- " \"initializer_range\": 0.02,\n",
- " \"intermediate_size\": 3072,\n",
- " \"layer_norm_eps\": 1e-12,\n",
- " \"max_position_embeddings\": 512,\n",
- " \"model_type\": \"bert\",\n",
- " \"num_attention_heads\": 12,\n",
- " \"num_hidden_layers\": 12,\n",
- " \"pad_token_id\": 0,\n",
- " \"position_embedding_type\": \"absolute\",\n",
- " \"transformers_version\": \"4.26.1\",\n",
- " \"type_vocab_size\": 2,\n",
- " \"use_cache\": true,\n",
- " \"vocab_size\": 30522\n",
- "}\n",
- "\n",
- "loading weights file model.safetensors from cache at /home/mohalisad/.cache/huggingface/hub/models--bert-base-uncased/snapshots/1dbc166cf8765166998eff31ade2eb64c8a40076/model.safetensors\n",
- "Generate config GenerationConfig {\n",
- " \"pad_token_id\": 0,\n",
- " \"transformers_version\": \"4.26.1\"\n",
- "}\n",
- "\n",
- "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertAdapterModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias']\n",
- "- This IS expected if you are initializing BertAdapterModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
- "- This IS NOT expected if you are initializing BertAdapterModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
- "All the weights of BertAdapterModel were initialized from the model checkpoint at bert-base-uncased.\n",
- "If your task is similar to the task the model of the checkpoint was trained on, you can already use BertAdapterModel for predictions without further training.\n",
- "Generation config file not found, using a generation config created from the model config.\n",
- "loading file vocab.txt from cache at /home/mohalisad/.cache/huggingface/hub/models--bert-base-uncased/snapshots/1dbc166cf8765166998eff31ade2eb64c8a40076/vocab.txt\n",
- "loading file tokenizer.json from cache at /home/mohalisad/.cache/huggingface/hub/models--bert-base-uncased/snapshots/1dbc166cf8765166998eff31ade2eb64c8a40076/tokenizer.json\n",
- "loading file added_tokens.json from cache at None\n",
- "loading file special_tokens_map.json from cache at None\n",
- "loading file tokenizer_config.json from cache at /home/mohalisad/.cache/huggingface/hub/models--bert-base-uncased/snapshots/1dbc166cf8765166998eff31ade2eb64c8a40076/tokenizer_config.json\n",
- "loading configuration file config.json from cache at /home/mohalisad/.cache/huggingface/hub/models--bert-base-uncased/snapshots/1dbc166cf8765166998eff31ade2eb64c8a40076/config.json\n",
- "Model config BertConfig {\n",
- " \"_name_or_path\": \"bert-base-uncased\",\n",
- " \"architectures\": [\n",
- " \"BertForMaskedLM\"\n",
- " ],\n",
- " \"attention_probs_dropout_prob\": 0.1,\n",
- " \"classifier_dropout\": null,\n",
- " \"gradient_checkpointing\": false,\n",
- " \"hidden_act\": \"gelu\",\n",
- " \"hidden_dropout_prob\": 0.1,\n",
- " \"hidden_size\": 768,\n",
- " \"initializer_range\": 0.02,\n",
- " \"intermediate_size\": 3072,\n",
- " \"layer_norm_eps\": 1e-12,\n",
- " \"max_position_embeddings\": 512,\n",
- " \"model_type\": \"bert\",\n",
- " \"num_attention_heads\": 12,\n",
- " \"num_hidden_layers\": 12,\n",
- " \"pad_token_id\": 0,\n",
- " \"position_embedding_type\": \"absolute\",\n",
- " \"transformers_version\": \"4.26.1\",\n",
- " \"type_vocab_size\": 2,\n",
- " \"use_cache\": true,\n",
- " \"vocab_size\": 30522\n",
- "}\n",
- "\n"
- ]
- }
- ],
- "source": [
- "# import transformers\n",
- "# transformers.logging.set_verbosity_debug()\n",
- "adapter_wrapper = BertAdapterModelWrapper(\n",
- " base_model_name=config.base_model.name,\n",
- " mask_token_id=config.base_model.mask_token_id\n",
- ")\n",
- "tokenizer_man = TokenizerMan(config.base_model.kind, config.base_model.name)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 40,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "auto_loader = AutoLoad()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 41,
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "f983a58646a54aa6841312408f00f491",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Map: 0%| | 0/8551 [00:00<?, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "99ea0309b4384a0ab7a458710ae2e443",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Map: 0%| | 0/1043 [00:00<?, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "d041fd8948044b5e8b0f761079a04894",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Map: 0%| | 0/1063 [00:00<?, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Adding adapter 'glue:cola'.\n",
- "Adding head 'glue:cola' with config {'head_type': 'classification', 'num_labels': 2, 'layers': 2, 'activation_function': 'tanh', 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'use_pooler': False, 'bias': True}.\n",
- "PyTorch: setting up devices\n",
- "The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n",
- "/home/mohalisad/anaconda3/envs/lll/lib/python3.9/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
- " warnings.warn(\n",
- "***** Running training *****\n",
- " Num examples = 8551\n",
- " Num Epochs = 15\n",
- " Instantaneous batch size per device = 32\n",
- " Total train batch size (w. parallel, distributed & accumulation) = 32\n",
- " Gradient Accumulation steps = 1\n",
- " Total optimization steps = 4020\n",
- " Number of trainable parameters = 1486658\n",
- "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- " <div>\n",
- " \n",
- " <progress value='4020' max='4020' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
- " [4020/4020 01:08, Epoch 15/15]\n",
- " </div>\n",
- " <table border=\"1\" class=\"dataframe\">\n",
- " <thead>\n",
- " <tr style=\"text-align: left;\">\n",
- " <th>Epoch</th>\n",
- " <th>Training Loss</th>\n",
- " <th>Validation Loss</th>\n",
- " <th>Accuracy</th>\n",
- " <th>F1-score-1</th>\n",
- " <th>F1-score-ma</th>\n",
- " </tr>\n",
- " </thead>\n",
- " <tbody>\n",
- " <tr>\n",
- " <td>1</td>\n",
- " <td>No log</td>\n",
- " <td>0.521243</td>\n",
- " <td>0.772771</td>\n",
- " <td>0.854512</td>\n",
- " <td>0.667956</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>2</td>\n",
- " <td>0.484900</td>\n",
- " <td>0.475989</td>\n",
- " <td>0.795781</td>\n",
- " <td>0.866290</td>\n",
- " <td>0.717121</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>3</td>\n",
- " <td>0.484900</td>\n",
- " <td>0.473902</td>\n",
- " <td>0.799616</td>\n",
- " <td>0.868471</td>\n",
- " <td>0.723974</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>4</td>\n",
- " <td>0.390000</td>\n",
- " <td>0.454408</td>\n",
- " <td>0.815916</td>\n",
- " <td>0.877707</td>\n",
- " <td>0.752807</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>5</td>\n",
- " <td>0.390000</td>\n",
- " <td>0.460564</td>\n",
- " <td>0.822627</td>\n",
- " <td>0.880414</td>\n",
- " <td>0.768593</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>6</td>\n",
- " <td>0.330900</td>\n",
- " <td>0.421414</td>\n",
- " <td>0.831256</td>\n",
- " <td>0.883752</td>\n",
- " <td>0.788030</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>7</td>\n",
- " <td>0.330900</td>\n",
- " <td>0.452820</td>\n",
- " <td>0.833174</td>\n",
- " <td>0.885375</td>\n",
- " <td>0.789519</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>8</td>\n",
- " <td>0.292000</td>\n",
- " <td>0.465746</td>\n",
- " <td>0.826462</td>\n",
- " <td>0.881777</td>\n",
- " <td>0.777825</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>9</td>\n",
- " <td>0.292000</td>\n",
- " <td>0.491992</td>\n",
- " <td>0.832215</td>\n",
- " <td>0.885396</td>\n",
- " <td>0.786169</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>10</td>\n",
- " <td>0.255500</td>\n",
- " <td>0.508437</td>\n",
- " <td>0.827421</td>\n",
- " <td>0.883117</td>\n",
- " <td>0.776723</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>11</td>\n",
- " <td>0.255500</td>\n",
- " <td>0.519635</td>\n",
- " <td>0.837009</td>\n",
- " <td>0.888889</td>\n",
- " <td>0.791567</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>12</td>\n",
- " <td>0.232300</td>\n",
- " <td>0.522434</td>\n",
- " <td>0.828380</td>\n",
- " <td>0.883388</td>\n",
- " <td>0.779262</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>13</td>\n",
- " <td>0.232300</td>\n",
- " <td>0.532363</td>\n",
- " <td>0.835091</td>\n",
- " <td>0.886991</td>\n",
- " <td>0.791013</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>14</td>\n",
- " <td>0.219900</td>\n",
- " <td>0.557935</td>\n",
- " <td>0.831256</td>\n",
- " <td>0.885566</td>\n",
- " <td>0.782199</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>15</td>\n",
- " <td>0.202800</td>\n",
- " <td>0.547973</td>\n",
- " <td>0.832215</td>\n",
- " <td>0.885845</td>\n",
- " <td>0.784695</td>\n",
- " </tr>\n",
- " </tbody>\n",
- "</table><p>"
- ],
- "text/plain": [
- "<IPython.core.display.HTML object>"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "***** Running Evaluation *****\n",
- " Num examples = 1043\n",
- " Batch size = 32\n",
- "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-268\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-268/glue:cola/adapter_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-268/glue:cola/pytorch_adapter.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-268/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-268/glue:cola/pytorch_model_head.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-268/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-268/glue:cola/pytorch_model_head.bin\n",
- "***** Running Evaluation *****\n",
- " Num examples = 1043\n",
- " Batch size = 32\n",
- "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-536\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-536/glue:cola/adapter_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-536/glue:cola/pytorch_adapter.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-536/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-536/glue:cola/pytorch_model_head.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-536/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-536/glue:cola/pytorch_model_head.bin\n",
- "***** Running Evaluation *****\n",
- " Num examples = 1043\n",
- " Batch size = 32\n",
- "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-804\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-804/glue:cola/adapter_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-804/glue:cola/pytorch_adapter.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-804/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-804/glue:cola/pytorch_model_head.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-804/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-804/glue:cola/pytorch_model_head.bin\n",
- "***** Running Evaluation *****\n",
- " Num examples = 1043\n",
- " Batch size = 32\n",
- "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1072\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1072/glue:cola/adapter_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1072/glue:cola/pytorch_adapter.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1072/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1072/glue:cola/pytorch_model_head.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1072/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1072/glue:cola/pytorch_model_head.bin\n",
- "***** Running Evaluation *****\n",
- " Num examples = 1043\n",
- " Batch size = 32\n",
- "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1340\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1340/glue:cola/adapter_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1340/glue:cola/pytorch_adapter.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1340/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1340/glue:cola/pytorch_model_head.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1340/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1340/glue:cola/pytorch_model_head.bin\n",
- "***** Running Evaluation *****\n",
- " Num examples = 1043\n",
- " Batch size = 32\n",
- "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1608\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1608/glue:cola/adapter_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1608/glue:cola/pytorch_adapter.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1608/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1608/glue:cola/pytorch_model_head.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1608/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1608/glue:cola/pytorch_model_head.bin\n",
- "***** Running Evaluation *****\n",
- " Num examples = 1043\n",
- " Batch size = 32\n",
- "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1876\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1876/glue:cola/adapter_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1876/glue:cola/pytorch_adapter.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1876/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1876/glue:cola/pytorch_model_head.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1876/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-1876/glue:cola/pytorch_model_head.bin\n",
- "***** Running Evaluation *****\n",
- " Num examples = 1043\n",
- " Batch size = 32\n",
- "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2144\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2144/glue:cola/adapter_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2144/glue:cola/pytorch_adapter.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2144/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2144/glue:cola/pytorch_model_head.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2144/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2144/glue:cola/pytorch_model_head.bin\n",
- "***** Running Evaluation *****\n",
- " Num examples = 1043\n",
- " Batch size = 32\n",
- "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2412\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2412/glue:cola/adapter_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2412/glue:cola/pytorch_adapter.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2412/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2412/glue:cola/pytorch_model_head.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2412/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2412/glue:cola/pytorch_model_head.bin\n",
- "***** Running Evaluation *****\n",
- " Num examples = 1043\n",
- " Batch size = 32\n",
- "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2680\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2680/glue:cola/adapter_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2680/glue:cola/pytorch_adapter.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2680/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2680/glue:cola/pytorch_model_head.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2680/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2680/glue:cola/pytorch_model_head.bin\n",
- "***** Running Evaluation *****\n",
- " Num examples = 1043\n",
- " Batch size = 32\n",
- "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2948\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2948/glue:cola/adapter_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2948/glue:cola/pytorch_adapter.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2948/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2948/glue:cola/pytorch_model_head.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2948/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-2948/glue:cola/pytorch_model_head.bin\n",
- "***** Running Evaluation *****\n",
- " Num examples = 1043\n",
- " Batch size = 32\n",
- "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3216\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3216/glue:cola/adapter_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3216/glue:cola/pytorch_adapter.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3216/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3216/glue:cola/pytorch_model_head.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3216/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3216/glue:cola/pytorch_model_head.bin\n",
- "***** Running Evaluation *****\n",
- " Num examples = 1043\n",
- " Batch size = 32\n",
- "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3484\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3484/glue:cola/adapter_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3484/glue:cola/pytorch_adapter.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3484/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3484/glue:cola/pytorch_model_head.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3484/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3484/glue:cola/pytorch_model_head.bin\n",
- "***** Running Evaluation *****\n",
- " Num examples = 1043\n",
- " Batch size = 32\n",
- "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3752\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3752/glue:cola/adapter_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3752/glue:cola/pytorch_adapter.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3752/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3752/glue:cola/pytorch_model_head.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3752/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-3752/glue:cola/pytorch_model_head.bin\n",
- "***** Running Evaluation *****\n",
- " Num examples = 1043\n",
- " Batch size = 32\n",
- "Saving model checkpoint to /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-4020\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-4020/glue:cola/adapter_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-4020/glue:cola/pytorch_adapter.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-4020/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-4020/glue:cola/pytorch_model_head.bin\n",
- "Configuration saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-4020/glue:cola/head_config.json\n",
- "Module weights saved in /home/mohalisad/Developer/ProgressivePrompts/cp3/checkpoint-4020/glue:cola/pytorch_model_head.bin\n",
- "\n",
- "\n",
- "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
- "\n",
- "\n"
- ]
- }
- ],
- "source": [
- "for task_name in config.tasks:\n",
- " loader_out = auto_loader.get_and_map(tokenizer_man.tokenizer, task_name)\n",
- " num_labels = len(loader_out['output']['range'])\n",
- " adapter_wrapper.add_classification_adapter(task_name, num_labels=num_labels)\n",
- " adapter_wrapper.finetune_adapter(\n",
- " task_name,\n",
- " loader_out['train'],\n",
- " loader_out['valid'],\n",
- " tokenizer_man.get_col_fn(),\n",
- " config.hf_trainer_params.to_dict()\n",
- " )"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Opendelta"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "from bigmodelvis import Visualization\n",
- "from transformers import BertForSequenceClassification\n",
- "from opendelta import AdapterModel"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 42,
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "loading configuration file config.json from cache at /home/mohalisad/.cache/huggingface/hub/models--bert-base-uncased/snapshots/1dbc166cf8765166998eff31ade2eb64c8a40076/config.json\n",
- "Model config BertConfig {\n",
- " \"architectures\": [\n",
- " \"BertForMaskedLM\"\n",
- " ],\n",
- " \"attention_probs_dropout_prob\": 0.1,\n",
- " \"classifier_dropout\": null,\n",
- " \"gradient_checkpointing\": false,\n",
- " \"hidden_act\": \"gelu\",\n",
- " \"hidden_dropout_prob\": 0.1,\n",
- " \"hidden_size\": 768,\n",
- " \"initializer_range\": 0.02,\n",
- " \"intermediate_size\": 3072,\n",
- " \"layer_norm_eps\": 1e-12,\n",
- " \"max_position_embeddings\": 512,\n",
- " \"model_type\": \"bert\",\n",
- " \"num_attention_heads\": 12,\n",
- " \"num_hidden_layers\": 12,\n",
- " \"pad_token_id\": 0,\n",
- " \"position_embedding_type\": \"absolute\",\n",
- " \"transformers_version\": \"4.26.1\",\n",
- " \"type_vocab_size\": 2,\n",
- " \"use_cache\": true,\n",
- " \"vocab_size\": 30522\n",
- "}\n",
- "\n",
- "loading weights file model.safetensors from cache at /home/mohalisad/.cache/huggingface/hub/models--bert-base-uncased/snapshots/1dbc166cf8765166998eff31ade2eb64c8a40076/model.safetensors\n",
- "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias']\n",
- "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
- "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
- "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
- ]
- }
- ],
- "source": [
- "base_model = BertForSequenceClassification.from_pretrained(config.base_model.name)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 43,
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">root</span>\n",
- "├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">bert </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertModel)</span>\n",
- "│ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertEmbeddings)</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">word_embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(Embedding) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[30522, 768]</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">position_embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(Embedding) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[512, 768]</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">token_type_embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(Embedding) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[2, 768]</span>\n",
- "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">LayerNorm </span><span style=\"color: #008000; text-decoration-color: #008000\">(LayerNorm) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[768] bias:[768]</span>\n",
- "│ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">encoder </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertEncoder)</span>\n",
- "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">layer </span><span style=\"color: #008000; text-decoration-color: #008000\">(ModuleList)</span>\n",
- "│ │ └── <span style=\"color: #800000; text-decoration-color: #800000\">0-11</span><span style=\"color: #008000; text-decoration-color: #008000\">(BertLayer)</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">attention </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertAttention)</span>\n",
- "│ │ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">self </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertSelfAttention)</span>\n",
- "│ │ │ │ ├── <span style=\"color: #800000; text-decoration-color: #800000\">query,key,value</span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[768, 768] bias:[768]</span>\n",
- "│ │ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">prefix_tuning </span><span style=\"color: #008000; text-decoration-color: #008000\">(PrefixTuningShim)</span>\n",
- "│ │ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">pool </span><span style=\"color: #008000; text-decoration-color: #008000\">(PrefixTuningPool)</span>\n",
- "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">output </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertSelfOutput)</span>\n",
- "│ │ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[768, 768] bias:[768]</span>\n",
- "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">LayerNorm </span><span style=\"color: #008000; text-decoration-color: #008000\">(LayerNorm) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[768] bias:[768]</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">intermediate </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertIntermediate)</span>\n",
- "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[3072, 768] bias:[3072]</span>\n",
- "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">output </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertOutput)</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[768, 3072] bias:[768]</span>\n",
- "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">LayerNorm </span><span style=\"color: #008000; text-decoration-color: #008000\">(LayerNorm) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[768] bias:[768]</span>\n",
- "│ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">pooler </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertPooler)</span>\n",
- "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[768, 768] bias:[768]</span>\n",
- "│ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">prefix_tuning </span><span style=\"color: #008000; text-decoration-color: #008000\">(PrefixTuningPool)</span>\n",
- "└── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">classifier </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[2, 768] bias:[2]</span>\n",
- "</pre>\n"
- ],
- "text/plain": [
- "\u001b[37mroot\u001b[0m\n",
- "├── \u001b[37mbert \u001b[0m\u001b[32m(BertModel)\u001b[0m\n",
- "│ ├── \u001b[37membeddings \u001b[0m\u001b[32m(BertEmbeddings)\u001b[0m\n",
- "│ │ ├── \u001b[37mword_embeddings \u001b[0m\u001b[32m(Embedding) \u001b[0m\u001b[36mweight:[30522, 768]\u001b[0m\n",
- "│ │ ├── \u001b[37mposition_embeddings \u001b[0m\u001b[32m(Embedding) \u001b[0m\u001b[36mweight:[512, 768]\u001b[0m\n",
- "│ │ ├── \u001b[37mtoken_type_embeddings \u001b[0m\u001b[32m(Embedding) \u001b[0m\u001b[36mweight:[2, 768]\u001b[0m\n",
- "│ │ └── \u001b[37mLayerNorm \u001b[0m\u001b[32m(LayerNorm) \u001b[0m\u001b[36mweight:[768] \u001b[0m\u001b[36mbias:[768]\u001b[0m\n",
- "│ ├── \u001b[37mencoder \u001b[0m\u001b[32m(BertEncoder)\u001b[0m\n",
- "│ │ └── \u001b[37mlayer \u001b[0m\u001b[32m(ModuleList)\u001b[0m\n",
- "│ │ └── \u001b[31m0-11\u001b[0m\u001b[32m(BertLayer)\u001b[0m\n",
- "│ │ ├── \u001b[37mattention \u001b[0m\u001b[32m(BertAttention)\u001b[0m\n",
- "│ │ │ ├── \u001b[37mself \u001b[0m\u001b[32m(BertSelfAttention)\u001b[0m\n",
- "│ │ │ │ ├── \u001b[31mquery,key,value\u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[768, 768] \u001b[0m\u001b[36mbias:[768]\u001b[0m\n",
- "│ │ │ │ └── \u001b[37mprefix_tuning \u001b[0m\u001b[32m(PrefixTuningShim)\u001b[0m\n",
- "│ │ │ │ └── \u001b[37mpool \u001b[0m\u001b[32m(PrefixTuningPool)\u001b[0m\n",
- "│ │ │ └── \u001b[37moutput \u001b[0m\u001b[32m(BertSelfOutput)\u001b[0m\n",
- "│ │ │ ├── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[768, 768] \u001b[0m\u001b[36mbias:[768]\u001b[0m\n",
- "│ │ │ └── \u001b[37mLayerNorm \u001b[0m\u001b[32m(LayerNorm) \u001b[0m\u001b[36mweight:[768] \u001b[0m\u001b[36mbias:[768]\u001b[0m\n",
- "│ │ ├── \u001b[37mintermediate \u001b[0m\u001b[32m(BertIntermediate)\u001b[0m\n",
- "│ │ │ └── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[3072, 768] \u001b[0m\u001b[36mbias:[3072]\u001b[0m\n",
- "│ │ └── \u001b[37moutput \u001b[0m\u001b[32m(BertOutput)\u001b[0m\n",
- "│ │ ├── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[768, 3072] \u001b[0m\u001b[36mbias:[768]\u001b[0m\n",
- "│ │ └── \u001b[37mLayerNorm \u001b[0m\u001b[32m(LayerNorm) \u001b[0m\u001b[36mweight:[768] \u001b[0m\u001b[36mbias:[768]\u001b[0m\n",
- "│ ├── \u001b[37mpooler \u001b[0m\u001b[32m(BertPooler)\u001b[0m\n",
- "│ │ └── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[768, 768] \u001b[0m\u001b[36mbias:[768]\u001b[0m\n",
- "│ └── \u001b[37mprefix_tuning \u001b[0m\u001b[32m(PrefixTuningPool)\u001b[0m\n",
- "└── \u001b[37mclassifier \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[2, 768] \u001b[0m\u001b[36mbias:[2]\u001b[0m\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "Visualization(base_model).structure_graph();"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 44,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "delta_model = AdapterModel(base_model, bottleneck_dim=48)\n",
- "# leave the delta tuning modules and the newly initialized classification head tunable.\n",
- "delta_model.freeze_module(exclude=[\"deltas\", \"classifier\"])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 45,
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">root</span>\n",
- "├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">bert </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertModel)</span>\n",
- "│ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertEmbeddings)</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">word_embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(Embedding) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[30522, 768]</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">position_embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(Embedding) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[512, 768]</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">token_type_embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(Embedding) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[2, 768]</span>\n",
- "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">LayerNorm </span><span style=\"color: #008000; text-decoration-color: #008000\">(LayerNorm) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768] bias:[768]</span>\n",
- "│ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">encoder </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertEncoder)</span>\n",
- "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">layer </span><span style=\"color: #008000; text-decoration-color: #008000\">(ModuleList)</span>\n",
- "│ │ └── <span style=\"color: #800000; text-decoration-color: #800000\">0-11</span><span style=\"color: #008000; text-decoration-color: #008000\">(BertLayer)</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">attention </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertAttention)</span>\n",
- "│ │ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">self </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertSelfAttention)</span>\n",
- "│ │ │ │ ├── <span style=\"color: #800000; text-decoration-color: #800000\">query,key,value</span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768, 768] bias:[768]</span>\n",
- "│ │ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">prefix_tuning </span><span style=\"color: #008000; text-decoration-color: #008000\">(PrefixTuningShim)</span>\n",
- "│ │ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">pool </span><span style=\"color: #008000; text-decoration-color: #008000\">(PrefixTuningPool)</span>\n",
- "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">output </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertSelfOutput)</span>\n",
- "│ │ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768, 768] bias:[768]</span>\n",
- "│ │ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">adapter </span><span style=\"color: #008000; text-decoration-color: #008000\">(AdapterLayer)</span>\n",
- "│ │ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">modulelist </span><span style=\"color: #008000; text-decoration-color: #008000\">(Sequential)</span>\n",
- "│ │ │ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">down_proj </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #af00ff; text-decoration-color: #af00ff\">weight:[48, 768] bias:[48]</span>\n",
- "│ │ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">up_proj </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #af00ff; text-decoration-color: #af00ff\">weight:[768, 48] bias:[768]</span>\n",
- "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">LayerNorm </span><span style=\"color: #008000; text-decoration-color: #008000\">(LayerNorm) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768] bias:[768]</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">intermediate </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertIntermediate)</span>\n",
- "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[3072, 768] bias:[3072]</span>\n",
- "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">output </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertOutput)</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768, 3072] bias:[768]</span>\n",
- "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">adapter </span><span style=\"color: #008000; text-decoration-color: #008000\">(AdapterLayer)</span>\n",
- "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">modulelist </span><span style=\"color: #008000; text-decoration-color: #008000\">(Sequential)</span>\n",
- "│ │ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">down_proj </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #af00ff; text-decoration-color: #af00ff\">weight:[48, 768] bias:[48]</span>\n",
- "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">up_proj </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #af00ff; text-decoration-color: #af00ff\">weight:[768, 48] bias:[768]</span>\n",
- "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">LayerNorm </span><span style=\"color: #008000; text-decoration-color: #008000\">(LayerNorm) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768] bias:[768]</span>\n",
- "│ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">pooler </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertPooler)</span>\n",
- "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768, 768] bias:[768]</span>\n",
- "│ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">prefix_tuning </span><span style=\"color: #008000; text-decoration-color: #008000\">(PrefixTuningPool)</span>\n",
- "└── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">classifier </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[2, 768] bias:[2]</span>\n",
- "</pre>\n"
- ],
- "text/plain": [
- "\u001b[37mroot\u001b[0m\n",
- "├── \u001b[37mbert \u001b[0m\u001b[32m(BertModel)\u001b[0m\n",
- "│ ├── \u001b[37membeddings \u001b[0m\u001b[32m(BertEmbeddings)\u001b[0m\n",
- "│ │ ├── \u001b[37mword_embeddings \u001b[0m\u001b[32m(Embedding) \u001b[0m\u001b[38;2;0;70;100mweight:[30522, 768]\u001b[0m\n",
- "│ │ ├── \u001b[37mposition_embeddings \u001b[0m\u001b[32m(Embedding) \u001b[0m\u001b[38;2;0;70;100mweight:[512, 768]\u001b[0m\n",
- "│ │ ├── \u001b[37mtoken_type_embeddings \u001b[0m\u001b[32m(Embedding) \u001b[0m\u001b[38;2;0;70;100mweight:[2, 768]\u001b[0m\n",
- "│ │ └── \u001b[37mLayerNorm \u001b[0m\u001b[32m(LayerNorm) \u001b[0m\u001b[38;2;0;70;100mweight:[768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
- "│ ├── \u001b[37mencoder \u001b[0m\u001b[32m(BertEncoder)\u001b[0m\n",
- "│ │ └── \u001b[37mlayer \u001b[0m\u001b[32m(ModuleList)\u001b[0m\n",
- "│ │ └── \u001b[31m0-11\u001b[0m\u001b[32m(BertLayer)\u001b[0m\n",
- "│ │ ├── \u001b[37mattention \u001b[0m\u001b[32m(BertAttention)\u001b[0m\n",
- "│ │ │ ├── \u001b[37mself \u001b[0m\u001b[32m(BertSelfAttention)\u001b[0m\n",
- "│ │ │ │ ├── \u001b[31mquery,key,value\u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;0;70;100mweight:[768, 768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
- "│ │ │ │ └── \u001b[37mprefix_tuning \u001b[0m\u001b[32m(PrefixTuningShim)\u001b[0m\n",
- "│ │ │ │ └── \u001b[37mpool \u001b[0m\u001b[32m(PrefixTuningPool)\u001b[0m\n",
- "│ │ │ └── \u001b[37moutput \u001b[0m\u001b[32m(BertSelfOutput)\u001b[0m\n",
- "│ │ │ ├── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;0;70;100mweight:[768, 768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
- "│ │ │ │ └── \u001b[37madapter \u001b[0m\u001b[32m(AdapterLayer)\u001b[0m\n",
- "│ │ │ │ └── \u001b[37mmodulelist \u001b[0m\u001b[32m(Sequential)\u001b[0m\n",
- "│ │ │ │ ├── \u001b[37mdown_proj \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;175;0;255mweight:[48, 768] \u001b[0m\u001b[38;2;175;0;255mbias:[48]\u001b[0m\n",
- "│ │ │ │ └── \u001b[37mup_proj \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;175;0;255mweight:[768, 48] \u001b[0m\u001b[38;2;175;0;255mbias:[768]\u001b[0m\n",
- "│ │ │ └── \u001b[37mLayerNorm \u001b[0m\u001b[32m(LayerNorm) \u001b[0m\u001b[38;2;0;70;100mweight:[768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
- "│ │ ├── \u001b[37mintermediate \u001b[0m\u001b[32m(BertIntermediate)\u001b[0m\n",
- "│ │ │ └── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;0;70;100mweight:[3072, 768] \u001b[0m\u001b[38;2;0;70;100mbias:[3072]\u001b[0m\n",
- "│ │ └── \u001b[37moutput \u001b[0m\u001b[32m(BertOutput)\u001b[0m\n",
- "│ │ ├── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;0;70;100mweight:[768, 3072] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
- "│ │ │ └── \u001b[37madapter \u001b[0m\u001b[32m(AdapterLayer)\u001b[0m\n",
- "│ │ │ └── \u001b[37mmodulelist \u001b[0m\u001b[32m(Sequential)\u001b[0m\n",
- "│ │ │ ├── \u001b[37mdown_proj \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;175;0;255mweight:[48, 768] \u001b[0m\u001b[38;2;175;0;255mbias:[48]\u001b[0m\n",
- "│ │ │ └── \u001b[37mup_proj \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;175;0;255mweight:[768, 48] \u001b[0m\u001b[38;2;175;0;255mbias:[768]\u001b[0m\n",
- "│ │ └── \u001b[37mLayerNorm \u001b[0m\u001b[32m(LayerNorm) \u001b[0m\u001b[38;2;0;70;100mweight:[768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
- "│ ├── \u001b[37mpooler \u001b[0m\u001b[32m(BertPooler)\u001b[0m\n",
- "│ │ └── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;0;70;100mweight:[768, 768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
- "│ └── \u001b[37mprefix_tuning \u001b[0m\u001b[32m(PrefixTuningPool)\u001b[0m\n",
- "└── \u001b[37mclassifier \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[2, 768] \u001b[0m\u001b[36mbias:[2]\u001b[0m\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "Visualization(base_model).structure_graph();"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2023-08-13T16:06:44.674950Z",
- "start_time": "2023-08-13T16:06:42.233454Z"
- }
- },
- "outputs": [],
- "source": [
- "from transformers import TrainingArguments, Trainer\n",
- "from sklearn.metrics import classification_report\n",
- "\n",
- "\n",
- "def compute_metrics(pred):\n",
- " true_labels = pred.label_ids.ravel()\n",
- " pred_labels = pred.predictions.argmax(-1).ravel()\n",
- " report = classification_report(true_labels, pred_labels, output_dict=True)\n",
- " return {\n",
- " 'accuracy': report['accuracy'],\n",
- " 'f1-score-1': report['1']['f1-score'],\n",
- " 'f1-score-ma': report['macro avg']['f1-score']\n",
- " }\n",
- "\n",
- "\n",
- "def train_model(input_model, task_name, train_dataset, eval_dataset, col_fn):\n",
- " training_args = TrainingArguments(\n",
- " evaluation_strategy=\"epoch\",\n",
- " save_strategy=\"epoch\",\n",
- " # The next 2 lines are important to ensure the dataset labels are properly passed to the model\n",
- " remove_unused_columns=False,\n",
- " **config.hf_trainer_params.to_dict()\n",
- " )\n",
- "\n",
- " trainer = Trainer(\n",
- " model=input_model,\n",
- " args=training_args,\n",
- " train_dataset=train_dataset,\n",
- " eval_dataset=eval_dataset,\n",
- " data_collator=col_fn,\n",
- " compute_metrics=compute_metrics\n",
- " )\n",
- " trainer.train()\n",
- "\n",
- "\n",
- "for task_name in config.tasks:\n",
- " loader_out = auto_loader.get_and_map(tokenizer_man.tokenizer, task_name)\n",
- " num_labels = len(loader_out['output']['range'])\n",
- " train_model(\n",
- " base_model,\n",
- " task_name,\n",
- " loader_out['train'],\n",
- " loader_out['valid'],\n",
- " tokenizer_man.get_col_fn()\n",
- " )"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 47,
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">root</span>\n",
- "├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">bert </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertModel)</span>\n",
- "│ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertEmbeddings)</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">word_embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(Embedding) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[30522, 768]</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">position_embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(Embedding) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[512, 768]</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">token_type_embeddings </span><span style=\"color: #008000; text-decoration-color: #008000\">(Embedding) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[2, 768]</span>\n",
- "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">LayerNorm </span><span style=\"color: #008000; text-decoration-color: #008000\">(LayerNorm) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768] bias:[768]</span>\n",
- "│ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">encoder </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertEncoder)</span>\n",
- "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">layer </span><span style=\"color: #008000; text-decoration-color: #008000\">(ModuleList)</span>\n",
- "│ │ └── <span style=\"color: #800000; text-decoration-color: #800000\">0-11</span><span style=\"color: #008000; text-decoration-color: #008000\">(BertLayer)</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">attention </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertAttention)</span>\n",
- "│ │ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">self </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertSelfAttention)</span>\n",
- "│ │ │ │ ├── <span style=\"color: #800000; text-decoration-color: #800000\">query,key,value</span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768, 768] bias:[768]</span>\n",
- "│ │ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">prefix_tuning </span><span style=\"color: #008000; text-decoration-color: #008000\">(PrefixTuningShim)</span>\n",
- "│ │ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">pool </span><span style=\"color: #008000; text-decoration-color: #008000\">(PrefixTuningPool)</span>\n",
- "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">output </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertSelfOutput)</span>\n",
- "│ │ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768, 768] bias:[768]</span>\n",
- "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">LayerNorm </span><span style=\"color: #008000; text-decoration-color: #008000\">(LayerNorm) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768] bias:[768]</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">intermediate </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertIntermediate)</span>\n",
- "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[3072, 768] bias:[3072]</span>\n",
- "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">output </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertOutput)</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768, 3072] bias:[768]</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">LayerNorm </span><span style=\"color: #008000; text-decoration-color: #008000\">(LayerNorm) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768] bias:[768]</span>\n",
- "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">adapters </span><span style=\"color: #008000; text-decoration-color: #008000\">(ModuleDict)</span>\n",
- "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">glue:cola </span><span style=\"color: #008000; text-decoration-color: #008000\">(Adapter)</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">non_linearity </span><span style=\"color: #008000; text-decoration-color: #008000\">(Activation_Function_Class)</span>\n",
- "│ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">adapter_down </span><span style=\"color: #008000; text-decoration-color: #008000\">(Sequential)</span>\n",
- "│ │ │ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">0 </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[48, 768] bias:[48]</span>\n",
- "│ │ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">1 </span><span style=\"color: #008000; text-decoration-color: #008000\">(Activation_Function_Class)</span>\n",
- "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">adapter_up </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[768, 48] bias:[768]</span>\n",
- "│ ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">pooler </span><span style=\"color: #008000; text-decoration-color: #008000\">(BertPooler)</span>\n",
- "│ │ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">dense </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #004664; text-decoration-color: #004664\">weight:[768, 768] bias:[768]</span>\n",
- "│ └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">prefix_tuning </span><span style=\"color: #008000; text-decoration-color: #008000\">(PrefixTuningPool)</span>\n",
- "└── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">heads </span><span style=\"color: #008000; text-decoration-color: #008000\">(ModuleDict)</span>\n",
- " └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">glue:cola </span><span style=\"color: #008000; text-decoration-color: #008000\">(ClassificationHead)</span>\n",
- " ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">1 </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[768, 768] bias:[768]</span>\n",
- " ├── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">2 </span><span style=\"color: #008000; text-decoration-color: #008000\">(Activation_Function_Class)</span>\n",
- " └── <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">4 </span><span style=\"color: #008000; text-decoration-color: #008000\">(Linear) </span><span style=\"color: #008080; text-decoration-color: #008080\">weight:[2, 768] bias:[2]</span>\n",
- "</pre>\n"
- ],
- "text/plain": [
- "\u001b[37mroot\u001b[0m\n",
- "├── \u001b[37mbert \u001b[0m\u001b[32m(BertModel)\u001b[0m\n",
- "│ ├── \u001b[37membeddings \u001b[0m\u001b[32m(BertEmbeddings)\u001b[0m\n",
- "│ │ ├── \u001b[37mword_embeddings \u001b[0m\u001b[32m(Embedding) \u001b[0m\u001b[38;2;0;70;100mweight:[30522, 768]\u001b[0m\n",
- "│ │ ├── \u001b[37mposition_embeddings \u001b[0m\u001b[32m(Embedding) \u001b[0m\u001b[38;2;0;70;100mweight:[512, 768]\u001b[0m\n",
- "│ │ ├── \u001b[37mtoken_type_embeddings \u001b[0m\u001b[32m(Embedding) \u001b[0m\u001b[38;2;0;70;100mweight:[2, 768]\u001b[0m\n",
- "│ │ └── \u001b[37mLayerNorm \u001b[0m\u001b[32m(LayerNorm) \u001b[0m\u001b[38;2;0;70;100mweight:[768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
- "│ ├── \u001b[37mencoder \u001b[0m\u001b[32m(BertEncoder)\u001b[0m\n",
- "│ │ └── \u001b[37mlayer \u001b[0m\u001b[32m(ModuleList)\u001b[0m\n",
- "│ │ └── \u001b[31m0-11\u001b[0m\u001b[32m(BertLayer)\u001b[0m\n",
- "│ │ ├── \u001b[37mattention \u001b[0m\u001b[32m(BertAttention)\u001b[0m\n",
- "│ │ │ ├── \u001b[37mself \u001b[0m\u001b[32m(BertSelfAttention)\u001b[0m\n",
- "│ │ │ │ ├── \u001b[31mquery,key,value\u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;0;70;100mweight:[768, 768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
- "│ │ │ │ └── \u001b[37mprefix_tuning \u001b[0m\u001b[32m(PrefixTuningShim)\u001b[0m\n",
- "│ │ │ │ └── \u001b[37mpool \u001b[0m\u001b[32m(PrefixTuningPool)\u001b[0m\n",
- "│ │ │ └── \u001b[37moutput \u001b[0m\u001b[32m(BertSelfOutput)\u001b[0m\n",
- "│ │ │ ├── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;0;70;100mweight:[768, 768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
- "│ │ │ └── \u001b[37mLayerNorm \u001b[0m\u001b[32m(LayerNorm) \u001b[0m\u001b[38;2;0;70;100mweight:[768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
- "│ │ ├── \u001b[37mintermediate \u001b[0m\u001b[32m(BertIntermediate)\u001b[0m\n",
- "│ │ │ └── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;0;70;100mweight:[3072, 768] \u001b[0m\u001b[38;2;0;70;100mbias:[3072]\u001b[0m\n",
- "│ │ └── \u001b[37moutput \u001b[0m\u001b[32m(BertOutput)\u001b[0m\n",
- "│ │ ├── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;0;70;100mweight:[768, 3072] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
- "│ │ ├── \u001b[37mLayerNorm \u001b[0m\u001b[32m(LayerNorm) \u001b[0m\u001b[38;2;0;70;100mweight:[768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
- "│ │ └── \u001b[37madapters \u001b[0m\u001b[32m(ModuleDict)\u001b[0m\n",
- "│ │ └── \u001b[37mglue:cola \u001b[0m\u001b[32m(Adapter)\u001b[0m\n",
- "│ │ ├── \u001b[37mnon_linearity \u001b[0m\u001b[32m(Activation_Function_Class)\u001b[0m\n",
- "│ │ ├── \u001b[37madapter_down \u001b[0m\u001b[32m(Sequential)\u001b[0m\n",
- "│ │ │ ├── \u001b[37m0 \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[48, 768] \u001b[0m\u001b[36mbias:[48]\u001b[0m\n",
- "│ │ │ └── \u001b[37m1 \u001b[0m\u001b[32m(Activation_Function_Class)\u001b[0m\n",
- "│ │ └── \u001b[37madapter_up \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[768, 48] \u001b[0m\u001b[36mbias:[768]\u001b[0m\n",
- "│ ├── \u001b[37mpooler \u001b[0m\u001b[32m(BertPooler)\u001b[0m\n",
- "│ │ └── \u001b[37mdense \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[38;2;0;70;100mweight:[768, 768] \u001b[0m\u001b[38;2;0;70;100mbias:[768]\u001b[0m\n",
- "│ └── \u001b[37mprefix_tuning \u001b[0m\u001b[32m(PrefixTuningPool)\u001b[0m\n",
- "└── \u001b[37mheads \u001b[0m\u001b[32m(ModuleDict)\u001b[0m\n",
- " └── \u001b[37mglue:cola \u001b[0m\u001b[32m(ClassificationHead)\u001b[0m\n",
- " ├── \u001b[37m1 \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[768, 768] \u001b[0m\u001b[36mbias:[768]\u001b[0m\n",
- " ├── \u001b[37m2 \u001b[0m\u001b[32m(Activation_Function_Class)\u001b[0m\n",
- " └── \u001b[37m4 \u001b[0m\u001b[32m(Linear) \u001b[0m\u001b[36mweight:[2, 768] \u001b[0m\u001b[36mbias:[2]\u001b[0m\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "Visualization(adapter_wrapper.model).structure_graph();"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2023-08-15T13:11:54.968862Z",
- "start_time": "2023-08-15T13:11:54.946870Z"
- }
- },
- "outputs": [],
- "source": [
- "results"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2023-08-15T13:23:50.492273Z",
- "start_time": "2023-08-15T13:22:40.985364Z"
- }
- },
- "outputs": [],
- "source": [
- "from _datasets import GLUEHelper\n",
- " \n",
- "gl_helper = GLUEHelper()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2023-08-15T13:46:17.380290Z",
- "start_time": "2023-08-15T13:46:17.346993Z"
- }
- },
- "outputs": [],
- "source": [
- "for n in range(0, 1000):\n",
- " out = gl_helper.datasets['stsb']['train'][n]\n",
- " if out['label'] == 0.:\n",
- " print(out)\n",
- " break"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from evaluate import load\n",
- "glue_metric = load('glue', 'stsb')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "results = glue_metric.compute(predictions=[-0.5, -0.3], references=[-0.5, 1])\n",
- "results"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2023-08-13T18:17:59.084998Z",
- "start_time": "2023-08-13T18:17:59.050653Z"
- }
- },
- "outputs": [],
- "source": [
- "gl_helper.datasets['mnli']"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2023-08-13T18:17:59.157406Z",
- "start_time": "2023-08-13T18:17:59.081370Z"
- }
- },
- "outputs": [],
- "source": [
- "gl_helper.datasets['mnli_matched']\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2023-08-13T18:18:01.203910Z",
- "start_time": "2023-08-13T18:18:01.171842Z"
- }
- },
- "outputs": [],
- "source": [
- "gl_helper.datasets['mnli_mismatched']\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2023-08-13T18:30:16.905587Z",
- "start_time": "2023-08-13T18:30:16.775197Z"
- }
- },
- "outputs": [],
- "source": [
- "import transformers\n",
- "\n",
- "\n",
- "print(transformers.__version__)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2023-08-13T18:29:49.383120Z",
- "start_time": "2023-08-13T18:29:40.017083Z"
- }
- },
- "outputs": [],
- "source": [
- "pip install adapter-transformers"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python [conda env:lll]",
- "language": "python",
- "name": "conda-env-lll-py"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.17"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
- }
|