123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539 |
- {
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "bce6d2a3-c3df-46f9-926e-2dda07dc9a3d",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "from types import SimpleNamespace\n",
- "from typing import Optional\n",
- "\n",
- "import torch\n",
- "import torch.nn as nn"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "5095bac0-f9ef-4aee-8050-acab81ee0d6f",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
- "MODEL_NAME = 'bert-base-uncased'\n",
- "NAMESPACE = 'sadcl'\n",
- "\n",
- "NTOKENS = 10\n",
- "PROMPT_PLACE = 'post' # pre"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "id": "ad41bd6e-d7f5-4c4b-a4fd-de039bb9b8c7",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "def initialize_embedding(\n",
- " emb_dim: int,\n",
- " n_tokens: int, \n",
- " random_range: float,\n",
- " initialize_from: Optional[torch.Tensor] = None\n",
- "):\n",
- " if initialize_from is None:\n",
- " return torch.FloatTensor(n_tokens, emb_dim).uniform_(-random_range, random_range)\n",
- "\n",
- " assert initialize_from.shape == (n_tokens, )\n",
- "\n",
- " return initialize_from.clone().detach().tile(1, emb_dim)\n",
- "\n",
- "class SoftEmbedding(nn.Module):\n",
- " def __init__(\n",
- " self,\n",
- " emb_dim: int,\n",
- " n_tokens: int, \n",
- " random_range: float = 0.5,\n",
- " prompt_place: str = 'post',\n",
- " mode: str = 'cat',\n",
- " initialize_from: Optional[torch.Tensor] = None\n",
- " ):\n",
- " super().__init__()\n",
- " assert mode in ['cat', 'add']\n",
- " assert prompt_place in ['pre', 'post']\n",
- " \n",
- " self.post_tokenizer_map = {\n",
- " 'input_ids': 0,\n",
- " 'attention_mask': 1,\n",
- " 'token_type_ids': 0\n",
- " }\n",
- " self.n_tokens = n_tokens\n",
- " self.mode = mode\n",
- " self.prompt_place = prompt_place\n",
- " \n",
- " self.sadcl_learned_embedding = nn.parameter.Parameter(\n",
- " initialize_embedding(\n",
- " emb_dim,\n",
- " n_tokens,\n",
- " random_range,\n",
- " initialize_from\n",
- " )\n",
- " )\n",
- "\n",
- " assert self.sadcl_learned_embedding.shape == (n_tokens, emb_dim)\n",
- " \n",
- " def forward(self, input_embedding):\n",
- " # input_embedding.shape = (batch_size, num_of_input_tokens, emb_dim)\n",
- " batch_size = input_embedding.size(0)\n",
- " if self.mode == 'cat':\n",
- " learned_embedding = self.sadcl_learned_embedding.repeat(batch_size, 1, 1) # (batch_size, n_tokens, emb_dim)\n",
- " return self.concat_batch(input_embedding[self.get_slice_for_cat()], learned_embedding)\n",
- " else: # mode == add\n",
- " input_embedding[self.get_slice_for_add()] += self.sadcl_learned_embedding[None, :, :]\n",
- " return input_embedding\n",
- " \n",
- " def get_weights(self):\n",
- " return self.sadcl_learned_embedding.detach().clone()\n",
- " \n",
- " def set_weights(self, new_weights: torch.Tensor):\n",
- " self.sadcl_learned_embedding.data = new_weights\n",
- " \n",
- " def get_slice_for_add(self):\n",
- " if self.prompt_place == 'pre':\n",
- " return slice(None), slice(None, self.n_tokens), slice(None)\n",
- " else: # prompt_place == post\n",
- " return slice(None), slice(-self.n_tokens, None), slice(None)\n",
- " \n",
- " def get_slice_for_cat(self):\n",
- " if self.prompt_place == 'pre':\n",
- " return slice(None), slice(self.n_tokens, None), slice(None)\n",
- " else: # prompt_place == post\n",
- " return slice(None), slice(None, -self.n_tokens), slice(None)\n",
- " \n",
- " def concat_batch(self, orig_vals, new_vals):\n",
- " if self.prompt_place == 'pre':\n",
- " return torch.cat([new_vals, orig_vals], axis=1)\n",
- " else: # prompt_place == post\n",
- " return torch.cat([orig_vals, new_vals], axis=1)\n",
- " \n",
- " def post_tokenizer(self, **kwargs):\n",
- " for special_key, pad_val in self.post_tokenizer_map.items():\n",
- " if special_key in kwargs:\n",
- " orig_tokens = kwargs[special_key]\n",
- " batch_size = kwargs[special_key].size(0)\n",
- " new_vals = torch.full(\n",
- " size=(batch_size, self.n_tokens),\n",
- " fill_value=pad_val,\n",
- " dtype=orig_tokens.dtype,\n",
- " device=orig_tokens.device\n",
- " )\n",
- " kwargs[special_key].data = self.concat_batch(orig_tokens, new_vals)\n",
- " return kwargs\n",
- "\n",
- "class TransformerInjector(nn.Module):\n",
- " def __init__(self, module):\n",
- " super().__init__()\n",
- " self.original_module = module\n",
- " self.add_prompt = SoftEmbedding(\n",
- " emb_dim=module.output.dense.out_features,\n",
- " n_tokens=NTOKENS,\n",
- " prompt_place=PROMPT_PLACE,\n",
- " mode='add'\n",
- " )\n",
- " \n",
- " def forward(self, hidden_states, *args, **kwargs):\n",
- " hidden_states = self.add_prompt(hidden_states)\n",
- " return self.original_module(hidden_states, *args, **kwargs)\n",
- " \n",
- " @classmethod\n",
- " def muatate_list(cls, module_list):\n",
- " for idx, module in enumerate(module_list):\n",
- " module_list[idx] = cls(module)\n",
- " return module_list\n",
- " \n",
- "class NewEmbeddingLayer(nn.Module):\n",
- " def __init__(self, emb_layer=nn.Embedding):\n",
- " super().__init__()\n",
- " self.emb_layer = emb_layer\n",
- " self.soft_prompt = SoftEmbedding(\n",
- " emb_dim=emb_layer.weight.size(1),\n",
- " n_tokens=NTOKENS,\n",
- " prompt_place=PROMPT_PLACE\n",
- " )\n",
- " \n",
- " def forward(self, tokens):\n",
- " out = self.emb_layer(tokens)\n",
- " out = self.soft_prompt(out)\n",
- " return out\n",
- " \n",
- " def get_weights(self):\n",
- " return self.soft_prompt.get_weights()\n",
- " \n",
- " def set_weights(self, new_weights):\n",
- " self.soft_prompt.set_weights(new_weights)\n",
- " \n",
- " @classmethod\n",
- " def mutate(cls, model):\n",
- " emb_layer = model.get_input_embeddings()\n",
- " new_emb_layer = cls(emb_layer)\n",
- " model.set_input_embeddings(new_emb_layer)\n",
- " \n",
- " orig_forward = model.forward\n",
- " \n",
- " def new_forward(**kwargs):\n",
- " new_kwargs = new_emb_layer.soft_prompt.post_tokenizer(**kwargs)\n",
- " return orig_forward(**new_kwargs)\n",
- " \n",
- " model.forward = new_forward\n",
- " return new_emb_layer"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "id": "79bf6687-5a88-4181-88dc-740d11dd89ac",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
- ]
- }
- ],
- "source": [
- "from transformers import BertForSequenceClassification, BertTokenizerFast\n",
- "\n",
- "model = BertForSequenceClassification.from_pretrained(MODEL_NAME)\n",
- "tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)\n",
- "\n",
- "peft_module = NewEmbeddingLayer.mutate(model)\n",
- "peft_bert_layers = TransformerInjector.muatate_list(model.bert.encoder.layer)\n",
- "\n",
- "model.to(DEVICE);"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 23,
- "id": "c0b0a48e-0b0b-43de-ae78-d2521cfee69e",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([[-0.2546, -0.0352, -0.4110, ..., 0.0189, 0.4121, 0.2206],\n",
- " [ 0.0670, 0.0600, 0.4493, ..., -0.4346, 0.4130, -0.3507],\n",
- " [ 0.0827, 0.3569, 0.0943, ..., -0.3451, -0.1879, 0.0831],\n",
- " ...,\n",
- " [-0.0489, -0.2570, -0.3328, ..., -0.4109, 0.0884, -0.0290],\n",
- " [-0.2705, -0.3854, 0.4559, ..., -0.0480, -0.4039, 0.4245],\n",
- " [-0.1941, 0.2237, 0.3494, ..., -0.1199, -0.3030, -0.1530]],\n",
- " device='cuda:0')"
- ]
- },
- "execution_count": 23,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "old_w = peft_module.get_weights()\n",
- "old_w"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "id": "d3753569-c95f-4f8e-99ec-e6f990ec55a8",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "# tokens = tokenizer(\"Hi bye\", return_tensors='pt').to(DEVICE)\n",
- "\n",
- "# model.eval()\n",
- "# with torch.no_grad():\n",
- "# out = model(**tokens)\n",
- "# out"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "23b6f5b1-bbb7-43b9-b5a9-e62d313f4244",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Found cached dataset glue (/home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "e2bc8f1df0934619941ae8e37e2be807",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- " 0%| | 0/3 [00:00<?, ?it/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "from _datasets import AutoLoad\n",
- "autoload = AutoLoad()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "id": "45cb37be-9aee-45f6-8a8a-bf859197a7d4",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "bert.embeddings.word_embeddings.soft_prompt.sadcl_learned_embedding\n",
- "bert.encoder.layer.0.add_prompt.sadcl_learned_embedding\n",
- "bert.encoder.layer.1.add_prompt.sadcl_learned_embedding\n",
- "bert.encoder.layer.2.add_prompt.sadcl_learned_embedding\n",
- "bert.encoder.layer.3.add_prompt.sadcl_learned_embedding\n",
- "bert.encoder.layer.4.add_prompt.sadcl_learned_embedding\n",
- "bert.encoder.layer.5.add_prompt.sadcl_learned_embedding\n",
- "bert.encoder.layer.6.add_prompt.sadcl_learned_embedding\n",
- "bert.encoder.layer.7.add_prompt.sadcl_learned_embedding\n",
- "bert.encoder.layer.8.add_prompt.sadcl_learned_embedding\n",
- "bert.encoder.layer.9.add_prompt.sadcl_learned_embedding\n",
- "bert.encoder.layer.10.add_prompt.sadcl_learned_embedding\n",
- "bert.encoder.layer.11.add_prompt.sadcl_learned_embedding\n",
- "classifier.weight\n",
- "classifier.bias\n"
- ]
- }
- ],
- "source": [
- "for param_name, weights in model.named_parameters():\n",
- " if 'classifier' in param_name or NAMESPACE in param_name:\n",
- " weights.requires_grad = True\n",
- " print(param_name)\n",
- " else:\n",
- " weights.requires_grad = False"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "id": "47f78a61-710f-410d-8f49-19da12eef09a",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Map: 0%| | 0/8551 [00:00<?, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Map: 0%| | 0/1043 [00:00<?, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Map: 0%| | 0/1063 [00:00<?, ? examples/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "loader_out = autoload.get_and_map(tokenizer, \"glue:cola\")\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "id": "8d75737f-e5c6-4dc1-94b9-8aaa507648e2",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'train': Dataset({\n",
- " features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
- " num_rows: 8551\n",
- " }),\n",
- " 'valid': Dataset({\n",
- " features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
- " num_rows: 1043\n",
- " }),\n",
- " 'output': {'kind': 'classification', 'range': {0, 1}}}"
- ]
- },
- "execution_count": 9,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "loader_out"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 28,
- "id": "2489364c-4d8d-4d69-8d52-7ac88d66e7f8",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "from config import load_config\n",
- "config = load_config('config.yaml')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 29,
- "id": "67e68e28-b4d0-42fd-a7e7-b1321485fc78",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Loading cached processed dataset at /home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-41a6799222324b5f.arrow\n",
- "Loading cached processed dataset at /home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9fc7d7deaf3161a2.arrow\n",
- "Loading cached processed dataset at /home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-0eb862d54758b38d.arrow\n",
- "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- " <div>\n",
- " \n",
- " <progress value='21440' max='21440' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
- " [21440/21440 07:01, Epoch 80/80]\n",
- " </div>\n",
- " <table border=\"1\" class=\"dataframe\">\n",
- " <thead>\n",
- " <tr style=\"text-align: left;\">\n",
- " <th>Epoch</th>\n",
- " <th>Training Loss</th>\n",
- " <th>Validation Loss</th>\n",
- " <th>Accuracy</th>\n",
- " <th>F1-score-1</th>\n",
- " <th>F1-score-ma</th>\n",
- " </tr>\n",
- " </thead>\n",
- " <tbody>\n",
- " <tr>\n",
- " <td>1</td>\n",
- " <td>No log</td>\n",
- " <td>0.655867</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817460</td>\n",
- " <td>0.408730</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>2</td>\n",
- " <td>0.577800</td>\n",
- " <td>0.639771</td>\n",
- " <td>0.763183</td>\n",
- " <td>0.848930</td>\n",
- " <td>0.650629</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>3</td>\n",
- " <td>0.577800</td>\n",
- " <td>0.507809</td>\n",
- " <td>0.766059</td>\n",
- " <td>0.849197</td>\n",
- " <td>0.663915</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>4</td>\n",
- " <td>0.528700</td>\n",
- " <td>0.523820</td>\n",
- " <td>0.770853</td>\n",
- " <td>0.852195</td>\n",
- " <td>0.671300</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>5</td>\n",
- " <td>0.528700</td>\n",
- " <td>0.480276</td>\n",
- " <td>0.794823</td>\n",
- " <td>0.861757</td>\n",
- " <td>0.731994</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>6</td>\n",
- " <td>0.499800</td>\n",
- " <td>0.506056</td>\n",
- " <td>0.776606</td>\n",
- " <td>0.855906</td>\n",
- " <td>0.679552</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>7</td>\n",
- " <td>0.499800</td>\n",
- " <td>0.475724</td>\n",
- " <td>0.795781</td>\n",
- " <td>0.863198</td>\n",
- " <td>0.730276</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>8</td>\n",
- " <td>0.482900</td>\n",
- " <td>0.494971</td>\n",
- " <td>0.790988</td>\n",
- " <td>0.860614</td>\n",
- " <td>0.721495</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>9</td>\n",
- " <td>0.482900</td>\n",
- " <td>0.478771</td>\n",
- " <td>0.786194</td>\n",
- " <td>0.858592</td>\n",
- " <td>0.710239</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>10</td>\n",
- " <td>0.465700</td>\n",
- " <td>0.502414</td>\n",
- " <td>0.780441</td>\n",
- " <td>0.858903</td>\n",
- " <td>0.682151</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>11</td>\n",
- " <td>0.465700</td>\n",
- " <td>0.498116</td>\n",
- " <td>0.794823</td>\n",
- " <td>0.866584</td>\n",
- " <td>0.711300</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>12</td>\n",
- " <td>0.461800</td>\n",
- " <td>0.537117</td>\n",
- " <td>0.780441</td>\n",
- " <td>0.860281</td>\n",
- " <td>0.673988</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>13</td>\n",
- " <td>0.461800</td>\n",
- " <td>0.465851</td>\n",
- " <td>0.802493</td>\n",
- " <td>0.868286</td>\n",
- " <td>0.736825</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>14</td>\n",
- " <td>0.445000</td>\n",
- " <td>0.487390</td>\n",
- " <td>0.795781</td>\n",
- " <td>0.865953</td>\n",
- " <td>0.718691</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>15</td>\n",
- " <td>0.435600</td>\n",
- " <td>0.440423</td>\n",
- " <td>0.801534</td>\n",
- " <td>0.864440</td>\n",
- " <td>0.747068</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>16</td>\n",
- " <td>0.435600</td>\n",
- " <td>0.483897</td>\n",
- " <td>0.803452</td>\n",
- " <td>0.869344</td>\n",
- " <td>0.736413</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>17</td>\n",
- " <td>0.423500</td>\n",
- " <td>0.461727</td>\n",
- " <td>0.806328</td>\n",
- " <td>0.872152</td>\n",
- " <td>0.736471</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>18</td>\n",
- " <td>0.423500</td>\n",
- " <td>0.491034</td>\n",
- " <td>0.794823</td>\n",
- " <td>0.865915</td>\n",
- " <td>0.714590</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>19</td>\n",
- " <td>0.410400</td>\n",
- " <td>0.451404</td>\n",
- " <td>0.806328</td>\n",
- " <td>0.868490</td>\n",
- " <td>0.750608</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>20</td>\n",
- " <td>0.410400</td>\n",
- " <td>0.439862</td>\n",
- " <td>0.808245</td>\n",
- " <td>0.872611</td>\n",
- " <td>0.742507</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>21</td>\n",
- " <td>0.408100</td>\n",
- " <td>0.443258</td>\n",
- " <td>0.794823</td>\n",
- " <td>0.865915</td>\n",
- " <td>0.714590</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>22</td>\n",
- " <td>0.408100</td>\n",
- " <td>0.450756</td>\n",
- " <td>0.805369</td>\n",
- " <td>0.871438</td>\n",
- " <td>0.735522</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>23</td>\n",
- " <td>0.404600</td>\n",
- " <td>0.483001</td>\n",
- " <td>0.797699</td>\n",
- " <td>0.867379</td>\n",
- " <td>0.720558</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>24</td>\n",
- " <td>0.404600</td>\n",
- " <td>0.481094</td>\n",
- " <td>0.794823</td>\n",
- " <td>0.866417</td>\n",
- " <td>0.712134</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>25</td>\n",
- " <td>0.397200</td>\n",
- " <td>0.509731</td>\n",
- " <td>0.798658</td>\n",
- " <td>0.867925</td>\n",
- " <td>0.722269</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>26</td>\n",
- " <td>0.397200</td>\n",
- " <td>0.468457</td>\n",
- " <td>0.813998</td>\n",
- " <td>0.872870</td>\n",
- " <td>0.763221</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>27</td>\n",
- " <td>0.388100</td>\n",
- " <td>0.450646</td>\n",
- " <td>0.802493</td>\n",
- " <td>0.869785</td>\n",
- " <td>0.730527</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>28</td>\n",
- " <td>0.379900</td>\n",
- " <td>0.518912</td>\n",
- " <td>0.800575</td>\n",
- " <td>0.868852</td>\n",
- " <td>0.726426</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>29</td>\n",
- " <td>0.379900</td>\n",
- " <td>0.474939</td>\n",
- " <td>0.803452</td>\n",
- " <td>0.870988</td>\n",
- " <td>0.729257</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>30</td>\n",
- " <td>0.375800</td>\n",
- " <td>0.468194</td>\n",
- " <td>0.799616</td>\n",
- " <td>0.868636</td>\n",
- " <td>0.723207</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>31</td>\n",
- " <td>0.375800</td>\n",
- " <td>0.447116</td>\n",
- " <td>0.810163</td>\n",
- " <td>0.872423</td>\n",
- " <td>0.750818</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>32</td>\n",
- " <td>0.370700</td>\n",
- " <td>0.537091</td>\n",
- " <td>0.802493</td>\n",
- " <td>0.870113</td>\n",
- " <td>0.729057</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>33</td>\n",
- " <td>0.370700</td>\n",
- " <td>0.475261</td>\n",
- " <td>0.807287</td>\n",
- " <td>0.871071</td>\n",
- " <td>0.744834</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>34</td>\n",
- " <td>0.367900</td>\n",
- " <td>0.487207</td>\n",
- " <td>0.802493</td>\n",
- " <td>0.870603</td>\n",
- " <td>0.726799</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>35</td>\n",
- " <td>0.367900</td>\n",
- " <td>0.437785</td>\n",
- " <td>0.806328</td>\n",
- " <td>0.871338</td>\n",
- " <td>0.739932</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>36</td>\n",
- " <td>0.358800</td>\n",
- " <td>0.508899</td>\n",
- " <td>0.808245</td>\n",
- " <td>0.872774</td>\n",
- " <td>0.741834</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>37</td>\n",
- " <td>0.358800</td>\n",
- " <td>0.552409</td>\n",
- " <td>0.800575</td>\n",
- " <td>0.869347</td>\n",
- " <td>0.724147</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>38</td>\n",
- " <td>0.355700</td>\n",
- " <td>0.496687</td>\n",
- " <td>0.802493</td>\n",
- " <td>0.871571</td>\n",
- " <td>0.722093</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>39</td>\n",
- " <td>0.355700</td>\n",
- " <td>0.504841</td>\n",
- " <td>0.816874</td>\n",
- " <td>0.875570</td>\n",
- " <td>0.764464</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>40</td>\n",
- " <td>0.345500</td>\n",
- " <td>0.483254</td>\n",
- " <td>0.790988</td>\n",
- " <td>0.865929</td>\n",
- " <td>0.696008</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>41</td>\n",
- " <td>0.345500</td>\n",
- " <td>0.512504</td>\n",
- " <td>0.796740</td>\n",
- " <td>0.868323</td>\n",
- " <td>0.711472</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>42</td>\n",
- " <td>0.351700</td>\n",
- " <td>0.497110</td>\n",
- " <td>0.800575</td>\n",
- " <td>0.870486</td>\n",
- " <td>0.718576</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>43</td>\n",
- " <td>0.339900</td>\n",
- " <td>0.471216</td>\n",
- " <td>0.798658</td>\n",
- " <td>0.867758</td>\n",
- " <td>0.723036</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>44</td>\n",
- " <td>0.339900</td>\n",
- " <td>0.531487</td>\n",
- " <td>0.805369</td>\n",
- " <td>0.870783</td>\n",
- " <td>0.738304</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>45</td>\n",
- " <td>0.341300</td>\n",
- " <td>0.540843</td>\n",
- " <td>0.807287</td>\n",
- " <td>0.870740</td>\n",
- " <td>0.746104</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>46</td>\n",
- " <td>0.341300</td>\n",
- " <td>0.476809</td>\n",
- " <td>0.803452</td>\n",
- " <td>0.869841</td>\n",
- " <td>0.734334</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>47</td>\n",
- " <td>0.337400</td>\n",
- " <td>0.479455</td>\n",
- " <td>0.819751</td>\n",
- " <td>0.877124</td>\n",
- " <td>0.769497</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>48</td>\n",
- " <td>0.337400</td>\n",
- " <td>0.446018</td>\n",
- " <td>0.815916</td>\n",
- " <td>0.875163</td>\n",
- " <td>0.762399</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>49</td>\n",
- " <td>0.334200</td>\n",
- " <td>0.548959</td>\n",
- " <td>0.813039</td>\n",
- " <td>0.875080</td>\n",
- " <td>0.751826</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>50</td>\n",
- " <td>0.334200</td>\n",
- " <td>0.500371</td>\n",
- " <td>0.797699</td>\n",
- " <td>0.867379</td>\n",
- " <td>0.720558</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>51</td>\n",
- " <td>0.331700</td>\n",
- " <td>0.503151</td>\n",
- " <td>0.808245</td>\n",
- " <td>0.871134</td>\n",
- " <td>0.748301</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>52</td>\n",
- " <td>0.331700</td>\n",
- " <td>0.556216</td>\n",
- " <td>0.798658</td>\n",
- " <td>0.868586</td>\n",
- " <td>0.719129</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>53</td>\n",
- " <td>0.326700</td>\n",
- " <td>0.478857</td>\n",
- " <td>0.816874</td>\n",
- " <td>0.875245</td>\n",
- " <td>0.765550</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>54</td>\n",
- " <td>0.326700</td>\n",
- " <td>0.508674</td>\n",
- " <td>0.806328</td>\n",
- " <td>0.870347</td>\n",
- " <td>0.743885</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>55</td>\n",
- " <td>0.326700</td>\n",
- " <td>0.510241</td>\n",
- " <td>0.807287</td>\n",
- " <td>0.870740</td>\n",
- " <td>0.746104</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>56</td>\n",
- " <td>0.327800</td>\n",
- " <td>0.510437</td>\n",
- " <td>0.803452</td>\n",
- " <td>0.870335</td>\n",
- " <td>0.732197</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>57</td>\n",
- " <td>0.327800</td>\n",
- " <td>0.516560</td>\n",
- " <td>0.804410</td>\n",
- " <td>0.871212</td>\n",
- " <td>0.732419</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>58</td>\n",
- " <td>0.320600</td>\n",
- " <td>0.482175</td>\n",
- " <td>0.810163</td>\n",
- " <td>0.872258</td>\n",
- " <td>0.751428</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>59</td>\n",
- " <td>0.320600</td>\n",
- " <td>0.534551</td>\n",
- " <td>0.809204</td>\n",
- " <td>0.870527</td>\n",
- " <td>0.754025</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>60</td>\n",
- " <td>0.311600</td>\n",
- " <td>0.529513</td>\n",
- " <td>0.804410</td>\n",
- " <td>0.869063</td>\n",
- " <td>0.741350</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>61</td>\n",
- " <td>0.311600</td>\n",
- " <td>0.529038</td>\n",
- " <td>0.812081</td>\n",
- " <td>0.872892</td>\n",
- " <td>0.756299</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>62</td>\n",
- " <td>0.317900</td>\n",
- " <td>0.551885</td>\n",
- " <td>0.797699</td>\n",
- " <td>0.866032</td>\n",
- " <td>0.726558</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>63</td>\n",
- " <td>0.317900</td>\n",
- " <td>0.500419</td>\n",
- " <td>0.808245</td>\n",
- " <td>0.870968</td>\n",
- " <td>0.748917</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>64</td>\n",
- " <td>0.315100</td>\n",
- " <td>0.466086</td>\n",
- " <td>0.809204</td>\n",
- " <td>0.871861</td>\n",
- " <td>0.749251</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>65</td>\n",
- " <td>0.315100</td>\n",
- " <td>0.492729</td>\n",
- " <td>0.811122</td>\n",
- " <td>0.872821</td>\n",
- " <td>0.752984</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>66</td>\n",
- " <td>0.306300</td>\n",
- " <td>0.463267</td>\n",
- " <td>0.813998</td>\n",
- " <td>0.874352</td>\n",
- " <td>0.758209</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>67</td>\n",
- " <td>0.306300</td>\n",
- " <td>0.568536</td>\n",
- " <td>0.811122</td>\n",
- " <td>0.872821</td>\n",
- " <td>0.752984</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>68</td>\n",
- " <td>0.308500</td>\n",
- " <td>0.539011</td>\n",
- " <td>0.803452</td>\n",
- " <td>0.868167</td>\n",
- " <td>0.741052</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>69</td>\n",
- " <td>0.308500</td>\n",
- " <td>0.526197</td>\n",
- " <td>0.808245</td>\n",
- " <td>0.871300</td>\n",
- " <td>0.747680</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>70</td>\n",
- " <td>0.304900</td>\n",
- " <td>0.506041</td>\n",
- " <td>0.811122</td>\n",
- " <td>0.872657</td>\n",
- " <td>0.753583</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>71</td>\n",
- " <td>0.302700</td>\n",
- " <td>0.581929</td>\n",
- " <td>0.798658</td>\n",
- " <td>0.866751</td>\n",
- " <td>0.727493</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>72</td>\n",
- " <td>0.302700</td>\n",
- " <td>0.516497</td>\n",
- " <td>0.810163</td>\n",
- " <td>0.872258</td>\n",
- " <td>0.751428</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>73</td>\n",
- " <td>0.308000</td>\n",
- " <td>0.507128</td>\n",
- " <td>0.807287</td>\n",
- " <td>0.870239</td>\n",
- " <td>0.747969</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>74</td>\n",
- " <td>0.308000</td>\n",
- " <td>0.520996</td>\n",
- " <td>0.803452</td>\n",
- " <td>0.868167</td>\n",
- " <td>0.741052</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>75</td>\n",
- " <td>0.304900</td>\n",
- " <td>0.517548</td>\n",
- " <td>0.806328</td>\n",
- " <td>0.869677</td>\n",
- " <td>0.746406</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>76</td>\n",
- " <td>0.304900</td>\n",
- " <td>0.503817</td>\n",
- " <td>0.804410</td>\n",
- " <td>0.868726</td>\n",
- " <td>0.742634</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>77</td>\n",
- " <td>0.298100</td>\n",
- " <td>0.508880</td>\n",
- " <td>0.809204</td>\n",
- " <td>0.871530</td>\n",
- " <td>0.750476</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>78</td>\n",
- " <td>0.298100</td>\n",
- " <td>0.505606</td>\n",
- " <td>0.808245</td>\n",
- " <td>0.870801</td>\n",
- " <td>0.749527</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>79</td>\n",
- " <td>0.304900</td>\n",
- " <td>0.526573</td>\n",
- " <td>0.802493</td>\n",
- " <td>0.867609</td>\n",
- " <td>0.739465</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>80</td>\n",
- " <td>0.304900</td>\n",
- " <td>0.523581</td>\n",
- " <td>0.804410</td>\n",
- " <td>0.868726</td>\n",
- " <td>0.742634</td>\n",
- " </tr>\n",
- " </tbody>\n",
- "</table><p>"
- ],
- "text/plain": [
- "<IPython.core.display.HTML object>"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/mohalisad/anaconda3/envs/deep/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
- " _warn_prf(average, modifier, msg_start, len(result))\n",
- "/home/mohalisad/anaconda3/envs/deep/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
- " _warn_prf(average, modifier, msg_start, len(result))\n",
- "/home/mohalisad/anaconda3/envs/deep/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
- " _warn_prf(average, modifier, msg_start, len(result))\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "TrainOutput(global_step=21440, training_loss=0.37007259682043275, metrics={'train_runtime': 421.6464, 'train_samples_per_second': 1622.402, 'train_steps_per_second': 50.848, 'total_flos': 8141300538608160.0, 'train_loss': 0.37007259682043275, 'epoch': 80.0})"
- ]
- },
- "execution_count": 29,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from transformers import TrainingArguments, Trainer, DataCollatorWithPadding\n",
- "from sklearn.metrics import classification_report\n",
- "\n",
- "\n",
- "def compute_metrics(pred):\n",
- " true_labels = pred.label_ids.ravel()\n",
- " pred_labels = pred.predictions.argmax(-1).ravel()\n",
- " report = classification_report(true_labels, pred_labels, output_dict=True)\n",
- " return {\n",
- " 'accuracy': report['accuracy'],\n",
- " 'f1-score-1': report['1']['f1-score'],\n",
- " 'f1-score-ma': report['macro avg']['f1-score']\n",
- " }\n",
- "\n",
- "\n",
- "# def train_model(input_model, task_name, train_dataset, eval_dataset, col_fn):\n",
- "# training_args = TrainingArguments(\n",
- "# evaluation_strategy=\"epoch\",\n",
- "# save_strategy=\"epoch\",\n",
- "# # The next 2 lines are important to ensure the dataset labels are properly passed to the model\n",
- "# remove_unused_columns=False,\n",
- "# **config.hf_trainer_params.to_dict()\n",
- "# )\n",
- "\n",
- "# trainer = Trainer(\n",
- "# model=input_model,\n",
- "# args=training_args,\n",
- "# train_dataset=train_dataset,\n",
- "# eval_dataset=eval_dataset,\n",
- "# data_collator=col_fn,\n",
- "# compute_metrics=compute_metrics\n",
- "# )\n",
- "# trainer.train()\n",
- "\n",
- "col_fn = DataCollatorWithPadding(\n",
- " tokenizer, return_tensors='pt', padding='longest'\n",
- ")\n",
- "\n",
- "loader_out = autoload.get_and_map(tokenizer, \"glue:cola\")\n",
- "num_labels = len(loader_out['output']['range'])\n",
- "\n",
- "training_args = TrainingArguments(\n",
- " evaluation_strategy=\"epoch\",\n",
- " save_strategy=\"epoch\",\n",
- " # The next 2 lines are important to ensure the dataset labels are properly passed to the model\n",
- " remove_unused_columns=False,\n",
- " **config.hf_trainer_params.to_dict()\n",
- ")\n",
- "\n",
- "trainer = Trainer(\n",
- " model=model,\n",
- " args=training_args,\n",
- " train_dataset=loader_out['train'],\n",
- " eval_dataset=loader_out['valid'],\n",
- " data_collator=col_fn,\n",
- " compute_metrics=compute_metrics\n",
- ")\n",
- "trainer.train()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 72,
- "id": "00bc804c-6133-4ccc-b6c4-697d859f94cf",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "BertLayer(\n",
- " (attention): BertAttention(\n",
- " (self): BertSelfAttention(\n",
- " (query): Linear(in_features=768, out_features=768, bias=True)\n",
- " (key): Linear(in_features=768, out_features=768, bias=True)\n",
- " (value): Linear(in_features=768, out_features=768, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " (output): BertSelfOutput(\n",
- " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
- " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " )\n",
- " (intermediate): BertIntermediate(\n",
- " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
- " (intermediate_act_fn): GELUActivation()\n",
- " )\n",
- " (output): BertOutput(\n",
- " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
- " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- ")"
- ]
- },
- "execution_count": 72,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "model.bert.encoder.layer[0]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "ad825227-f073-4d58-9b06-15dc882e2f74",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "tokenizer(\"hi bye\", return_tensors='pt')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "a015bcd2-6768-4289-a189-74ae9c7b08de",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "for param_name, weights in model.named_parameters():\n",
- " if weights.requires_grad:\n",
- " print(param_name)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "a009912e-7366-4deb-aae0-e78331b7e160",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "new_w = peft_module.get_weights()\n",
- "new_w\n",
- "(new_w - old_w).sum()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "98004c44-7a88-4f58-b956-d2ce6c3c56a8",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "new_w"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4301e6f1-8212-4a67-b93d-589717098b15",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "new_model = BertForSequenceClassification.from_pretrained(MODEL_NAME)\n",
- "new_model.to(DEVICE)\n",
- "nm_sd = new_model.state_dict()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4c001d2d-e862-4926-8a31-e218489654e1",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "old_sd = model.state_dict()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "db972023-2457-40fa-8b24-c9df81e9cb51",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "for key, val in old_sd.items():\n",
- " if key not in nm_sd:\n",
- " print(key)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "0f3ba7b5-441e-411d-a848-f907ec27496e",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "model.bert.embeddings.word_embeddings.sadcl_soft_prompt.learned_embedding.requires_grad"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "10c1e258-df79-46ce-bc99-cbd024948379",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "(old_sd['bert.embeddings.word_embeddings.emb_layer.weight'] - nm_sd['bert.embeddings.word_embeddings.weight']).sum()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "8c8ba1bf-dc9d-4864-8aa6-57140c6eb116",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "ed628433-610c-4a45-9780-42927e84198e",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "7247d9a1-5b2a-4379-8983-22aaf2d20b7b",
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "61c894cd-d601-43e0-b7b1-c3c67b5cd1f7",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "a = \"hi bye\"\n",
- "model.eval()\n",
- "with torch.no_grad():\n",
- " tokens = tokenizer(a, return_tensors='pt').to(DEVICE)\n",
- " o = model.bert.embeddings.word_embeddings(peft_module.post_tokenizer(input_ids=tokens['input_ids'])['input_ids'])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "e1dad24f",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([1, 10, 768])"
- ]
- },
- "execution_count": 23,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "o.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "632a6a0b-e2e3-4d3b-9d12-021f08e39b6e",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "o"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "811ee37b-61fb-4103-8900-64a5ffd2e7c7",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "len(out.hidden_states)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "3d768fa9-2d1d-4598-8fb1-3681a8f53897",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "setattr(c, 'a', 3)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "3293b793-32b7-4cee-986b-1286604e361b",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "class c:\n",
- " def a(self, i):\n",
- " print(i + 1)\n",
- " \n",
- " def b(self):\n",
- " c = self.a\n",
- " self.a = lambda i: c(i + 1)\n",
- "\n",
- "o = c()\n",
- "o.a(3)\n",
- "o.b()\n",
- "o.a(3)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "41adc585-a890-4971-95b9-df994adaada2",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python [conda env:deep]",
- "language": "python",
- "name": "conda-env-deep-py"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.11"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
- }
|