{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "bce6d2a3-c3df-46f9-926e-2dda07dc9a3d", "metadata": { "tags": [] }, "outputs": [], "source": [ "from types import SimpleNamespace\n", "from typing import Optional\n", "\n", "import torch\n", "import torch.nn as nn" ] }, { "cell_type": "code", "execution_count": 2, "id": "5095bac0-f9ef-4aee-8050-acab81ee0d6f", "metadata": { "tags": [] }, "outputs": [], "source": [ "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "MODEL_NAME = 'bert-base-uncased'\n", "NAMESPACE = 'sadcl'\n", "\n", "NTOKENS = 10\n", "PROMPT_PLACE = 'post' # pre" ] }, { "cell_type": "code", "execution_count": 6, "id": "ad41bd6e-d7f5-4c4b-a4fd-de039bb9b8c7", "metadata": { "tags": [] }, "outputs": [], "source": [ "def initialize_embedding(\n", " emb_dim: int,\n", " n_tokens: int, \n", " random_range: float,\n", " initialize_from: Optional[torch.Tensor] = None\n", "):\n", " if initialize_from is None:\n", " return torch.FloatTensor(n_tokens, emb_dim).uniform_(-random_range, random_range)\n", "\n", " assert initialize_from.shape == (n_tokens, )\n", "\n", " return initialize_from.clone().detach().tile(1, emb_dim)\n", "\n", "class SoftEmbedding(nn.Module):\n", " def __init__(\n", " self,\n", " emb_dim: int,\n", " n_tokens: int, \n", " random_range: float = 0.5,\n", " prompt_place: str = 'post',\n", " mode: str = 'cat',\n", " initialize_from: Optional[torch.Tensor] = None\n", " ):\n", " super().__init__()\n", " assert mode in ['cat', 'add']\n", " assert prompt_place in ['pre', 'post']\n", " \n", " self.post_tokenizer_map = {\n", " 'input_ids': 0,\n", " 'attention_mask': 1,\n", " 'token_type_ids': 0\n", " }\n", " self.n_tokens = n_tokens\n", " self.mode = mode\n", " self.prompt_place = prompt_place\n", " \n", " self.sadcl_learned_embedding = nn.parameter.Parameter(\n", " initialize_embedding(\n", " emb_dim,\n", " n_tokens,\n", " random_range,\n", " initialize_from\n", " )\n", " )\n", "\n", " assert self.sadcl_learned_embedding.shape == (n_tokens, emb_dim)\n", " \n", " def forward(self, input_embedding):\n", " # input_embedding.shape = (batch_size, num_of_input_tokens, emb_dim)\n", " batch_size = input_embedding.size(0)\n", " if self.mode == 'cat':\n", " learned_embedding = self.sadcl_learned_embedding.repeat(batch_size, 1, 1) # (batch_size, n_tokens, emb_dim)\n", " return self.concat_batch(input_embedding[self.get_slice_for_cat()], learned_embedding)\n", " else: # mode == add\n", " input_embedding[self.get_slice_for_add()] += self.sadcl_learned_embedding[None, :, :]\n", " return input_embedding\n", " \n", " def get_weights(self):\n", " return self.sadcl_learned_embedding.detach().clone()\n", " \n", " def set_weights(self, new_weights: torch.Tensor):\n", " self.sadcl_learned_embedding.data = new_weights\n", " \n", " def get_slice_for_add(self):\n", " if self.prompt_place == 'pre':\n", " return slice(None), slice(None, self.n_tokens), slice(None)\n", " else: # prompt_place == post\n", " return slice(None), slice(-self.n_tokens, None), slice(None)\n", " \n", " def get_slice_for_cat(self):\n", " if self.prompt_place == 'pre':\n", " return slice(None), slice(self.n_tokens, None), slice(None)\n", " else: # prompt_place == post\n", " return slice(None), slice(None, -self.n_tokens), slice(None)\n", " \n", " def concat_batch(self, orig_vals, new_vals):\n", " if self.prompt_place == 'pre':\n", " return torch.cat([new_vals, orig_vals], axis=1)\n", " else: # prompt_place == post\n", " return torch.cat([orig_vals, new_vals], axis=1)\n", " \n", " def post_tokenizer(self, **kwargs):\n", " for special_key, pad_val in self.post_tokenizer_map.items():\n", " if special_key in kwargs:\n", " orig_tokens = kwargs[special_key]\n", " batch_size = kwargs[special_key].size(0)\n", " new_vals = torch.full(\n", " size=(batch_size, self.n_tokens),\n", " fill_value=pad_val,\n", " dtype=orig_tokens.dtype,\n", " device=orig_tokens.device\n", " )\n", " kwargs[special_key].data = self.concat_batch(orig_tokens, new_vals)\n", " return kwargs\n", "\n", "class TransformerInjector(nn.Module):\n", " def __init__(self, module):\n", " super().__init__()\n", " self.original_module = module\n", " self.add_prompt = SoftEmbedding(\n", " emb_dim=module.output.dense.out_features,\n", " n_tokens=NTOKENS,\n", " prompt_place=PROMPT_PLACE,\n", " mode='add'\n", " )\n", " \n", " def forward(self, hidden_states, *args, **kwargs):\n", " hidden_states = self.add_prompt(hidden_states)\n", " return self.original_module(hidden_states, *args, **kwargs)\n", " \n", " @classmethod\n", " def muatate_list(cls, module_list):\n", " for idx, module in enumerate(module_list):\n", " module_list[idx] = cls(module)\n", " return module_list\n", " \n", "class NewEmbeddingLayer(nn.Module):\n", " def __init__(self, emb_layer=nn.Embedding):\n", " super().__init__()\n", " self.emb_layer = emb_layer\n", " self.soft_prompt = SoftEmbedding(\n", " emb_dim=emb_layer.weight.size(1),\n", " n_tokens=NTOKENS,\n", " prompt_place=PROMPT_PLACE\n", " )\n", " \n", " def forward(self, tokens):\n", " out = self.emb_layer(tokens)\n", " out = self.soft_prompt(out)\n", " return out\n", " \n", " def get_weights(self):\n", " return self.soft_prompt.get_weights()\n", " \n", " def set_weights(self, new_weights):\n", " self.soft_prompt.set_weights(new_weights)\n", " \n", " @classmethod\n", " def mutate(cls, model):\n", " emb_layer = model.get_input_embeddings()\n", " new_emb_layer = cls(emb_layer)\n", " model.set_input_embeddings(new_emb_layer)\n", " \n", " orig_forward = model.forward\n", " \n", " def new_forward(**kwargs):\n", " new_kwargs = new_emb_layer.soft_prompt.post_tokenizer(**kwargs)\n", " return orig_forward(**new_kwargs)\n", " \n", " model.forward = new_forward\n", " return new_emb_layer" ] }, { "cell_type": "code", "execution_count": 7, "id": "79bf6687-5a88-4181-88dc-740d11dd89ac", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "from transformers import BertForSequenceClassification, BertTokenizerFast\n", "\n", "model = BertForSequenceClassification.from_pretrained(MODEL_NAME)\n", "tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)\n", "\n", "peft_module = NewEmbeddingLayer.mutate(model)\n", "peft_bert_layers = TransformerInjector.muatate_list(model.bert.encoder.layer)\n", "\n", "model.to(DEVICE);" ] }, { "cell_type": "code", "execution_count": 23, "id": "c0b0a48e-0b0b-43de-ae78-d2521cfee69e", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.2546, -0.0352, -0.4110, ..., 0.0189, 0.4121, 0.2206],\n", " [ 0.0670, 0.0600, 0.4493, ..., -0.4346, 0.4130, -0.3507],\n", " [ 0.0827, 0.3569, 0.0943, ..., -0.3451, -0.1879, 0.0831],\n", " ...,\n", " [-0.0489, -0.2570, -0.3328, ..., -0.4109, 0.0884, -0.0290],\n", " [-0.2705, -0.3854, 0.4559, ..., -0.0480, -0.4039, 0.4245],\n", " [-0.1941, 0.2237, 0.3494, ..., -0.1199, -0.3030, -0.1530]],\n", " device='cuda:0')" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "old_w = peft_module.get_weights()\n", "old_w" ] }, { "cell_type": "code", "execution_count": 24, "id": "d3753569-c95f-4f8e-99ec-e6f990ec55a8", "metadata": { "tags": [] }, "outputs": [], "source": [ "# tokens = tokenizer(\"Hi bye\", return_tensors='pt').to(DEVICE)\n", "\n", "# model.eval()\n", "# with torch.no_grad():\n", "# out = model(**tokens)\n", "# out" ] }, { "cell_type": "code", "execution_count": 3, "id": "23b6f5b1-bbb7-43b9-b5a9-e62d313f4244", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset glue (/home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e2bc8f1df0934619941ae8e37e2be807", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00\n", " \n", " \n", " [21440/21440 07:01, Epoch 80/80]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracyF1-score-1F1-score-ma
1No log0.6558670.6912750.8174600.408730
20.5778000.6397710.7631830.8489300.650629
30.5778000.5078090.7660590.8491970.663915
40.5287000.5238200.7708530.8521950.671300
50.5287000.4802760.7948230.8617570.731994
60.4998000.5060560.7766060.8559060.679552
70.4998000.4757240.7957810.8631980.730276
80.4829000.4949710.7909880.8606140.721495
90.4829000.4787710.7861940.8585920.710239
100.4657000.5024140.7804410.8589030.682151
110.4657000.4981160.7948230.8665840.711300
120.4618000.5371170.7804410.8602810.673988
130.4618000.4658510.8024930.8682860.736825
140.4450000.4873900.7957810.8659530.718691
150.4356000.4404230.8015340.8644400.747068
160.4356000.4838970.8034520.8693440.736413
170.4235000.4617270.8063280.8721520.736471
180.4235000.4910340.7948230.8659150.714590
190.4104000.4514040.8063280.8684900.750608
200.4104000.4398620.8082450.8726110.742507
210.4081000.4432580.7948230.8659150.714590
220.4081000.4507560.8053690.8714380.735522
230.4046000.4830010.7976990.8673790.720558
240.4046000.4810940.7948230.8664170.712134
250.3972000.5097310.7986580.8679250.722269
260.3972000.4684570.8139980.8728700.763221
270.3881000.4506460.8024930.8697850.730527
280.3799000.5189120.8005750.8688520.726426
290.3799000.4749390.8034520.8709880.729257
300.3758000.4681940.7996160.8686360.723207
310.3758000.4471160.8101630.8724230.750818
320.3707000.5370910.8024930.8701130.729057
330.3707000.4752610.8072870.8710710.744834
340.3679000.4872070.8024930.8706030.726799
350.3679000.4377850.8063280.8713380.739932
360.3588000.5088990.8082450.8727740.741834
370.3588000.5524090.8005750.8693470.724147
380.3557000.4966870.8024930.8715710.722093
390.3557000.5048410.8168740.8755700.764464
400.3455000.4832540.7909880.8659290.696008
410.3455000.5125040.7967400.8683230.711472
420.3517000.4971100.8005750.8704860.718576
430.3399000.4712160.7986580.8677580.723036
440.3399000.5314870.8053690.8707830.738304
450.3413000.5408430.8072870.8707400.746104
460.3413000.4768090.8034520.8698410.734334
470.3374000.4794550.8197510.8771240.769497
480.3374000.4460180.8159160.8751630.762399
490.3342000.5489590.8130390.8750800.751826
500.3342000.5003710.7976990.8673790.720558
510.3317000.5031510.8082450.8711340.748301
520.3317000.5562160.7986580.8685860.719129
530.3267000.4788570.8168740.8752450.765550
540.3267000.5086740.8063280.8703470.743885
550.3267000.5102410.8072870.8707400.746104
560.3278000.5104370.8034520.8703350.732197
570.3278000.5165600.8044100.8712120.732419
580.3206000.4821750.8101630.8722580.751428
590.3206000.5345510.8092040.8705270.754025
600.3116000.5295130.8044100.8690630.741350
610.3116000.5290380.8120810.8728920.756299
620.3179000.5518850.7976990.8660320.726558
630.3179000.5004190.8082450.8709680.748917
640.3151000.4660860.8092040.8718610.749251
650.3151000.4927290.8111220.8728210.752984
660.3063000.4632670.8139980.8743520.758209
670.3063000.5685360.8111220.8728210.752984
680.3085000.5390110.8034520.8681670.741052
690.3085000.5261970.8082450.8713000.747680
700.3049000.5060410.8111220.8726570.753583
710.3027000.5819290.7986580.8667510.727493
720.3027000.5164970.8101630.8722580.751428
730.3080000.5071280.8072870.8702390.747969
740.3080000.5209960.8034520.8681670.741052
750.3049000.5175480.8063280.8696770.746406
760.3049000.5038170.8044100.8687260.742634
770.2981000.5088800.8092040.8715300.750476
780.2981000.5056060.8082450.8708010.749527
790.3049000.5265730.8024930.8676090.739465
800.3049000.5235810.8044100.8687260.742634

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/home/mohalisad/anaconda3/envs/deep/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n", "/home/mohalisad/anaconda3/envs/deep/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n", "/home/mohalisad/anaconda3/envs/deep/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n" ] }, { "data": { "text/plain": [ "TrainOutput(global_step=21440, training_loss=0.37007259682043275, metrics={'train_runtime': 421.6464, 'train_samples_per_second': 1622.402, 'train_steps_per_second': 50.848, 'total_flos': 8141300538608160.0, 'train_loss': 0.37007259682043275, 'epoch': 80.0})" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from transformers import TrainingArguments, Trainer, DataCollatorWithPadding\n", "from sklearn.metrics import classification_report\n", "\n", "\n", "def compute_metrics(pred):\n", " true_labels = pred.label_ids.ravel()\n", " pred_labels = pred.predictions.argmax(-1).ravel()\n", " report = classification_report(true_labels, pred_labels, output_dict=True)\n", " return {\n", " 'accuracy': report['accuracy'],\n", " 'f1-score-1': report['1']['f1-score'],\n", " 'f1-score-ma': report['macro avg']['f1-score']\n", " }\n", "\n", "\n", "# def train_model(input_model, task_name, train_dataset, eval_dataset, col_fn):\n", "# training_args = TrainingArguments(\n", "# evaluation_strategy=\"epoch\",\n", "# save_strategy=\"epoch\",\n", "# # The next 2 lines are important to ensure the dataset labels are properly passed to the model\n", "# remove_unused_columns=False,\n", "# **config.hf_trainer_params.to_dict()\n", "# )\n", "\n", "# trainer = Trainer(\n", "# model=input_model,\n", "# args=training_args,\n", "# train_dataset=train_dataset,\n", "# eval_dataset=eval_dataset,\n", "# data_collator=col_fn,\n", "# compute_metrics=compute_metrics\n", "# )\n", "# trainer.train()\n", "\n", "col_fn = DataCollatorWithPadding(\n", " tokenizer, return_tensors='pt', padding='longest'\n", ")\n", "\n", "loader_out = autoload.get_and_map(tokenizer, \"glue:cola\")\n", "num_labels = len(loader_out['output']['range'])\n", "\n", "training_args = TrainingArguments(\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " # The next 2 lines are important to ensure the dataset labels are properly passed to the model\n", " remove_unused_columns=False,\n", " **config.hf_trainer_params.to_dict()\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=loader_out['train'],\n", " eval_dataset=loader_out['valid'],\n", " data_collator=col_fn,\n", " compute_metrics=compute_metrics\n", ")\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": 72, "id": "00bc804c-6133-4ccc-b6c4-697d859f94cf", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", ")" ] }, "execution_count": 72, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.bert.encoder.layer[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "ad825227-f073-4d58-9b06-15dc882e2f74", "metadata": { "tags": [] }, "outputs": [], "source": [ "tokenizer(\"hi bye\", return_tensors='pt')" ] }, { "cell_type": "code", "execution_count": null, "id": "a015bcd2-6768-4289-a189-74ae9c7b08de", "metadata": { "tags": [] }, "outputs": [], "source": [ "for param_name, weights in model.named_parameters():\n", " if weights.requires_grad:\n", " print(param_name)" ] }, { "cell_type": "code", "execution_count": null, "id": "a009912e-7366-4deb-aae0-e78331b7e160", "metadata": { "tags": [] }, "outputs": [], "source": [ "new_w = peft_module.get_weights()\n", "new_w\n", "(new_w - old_w).sum()" ] }, { "cell_type": "code", "execution_count": null, "id": "98004c44-7a88-4f58-b956-d2ce6c3c56a8", "metadata": { "tags": [] }, "outputs": [], "source": [ "new_w" ] }, { "cell_type": "code", "execution_count": null, "id": "4301e6f1-8212-4a67-b93d-589717098b15", "metadata": { "tags": [] }, "outputs": [], "source": [ "new_model = BertForSequenceClassification.from_pretrained(MODEL_NAME)\n", "new_model.to(DEVICE)\n", "nm_sd = new_model.state_dict()" ] }, { "cell_type": "code", "execution_count": null, "id": "4c001d2d-e862-4926-8a31-e218489654e1", "metadata": { "tags": [] }, "outputs": [], "source": [ "old_sd = model.state_dict()" ] }, { "cell_type": "code", "execution_count": null, "id": "db972023-2457-40fa-8b24-c9df81e9cb51", "metadata": { "tags": [] }, "outputs": [], "source": [ "for key, val in old_sd.items():\n", " if key not in nm_sd:\n", " print(key)" ] }, { "cell_type": "code", "execution_count": null, "id": "0f3ba7b5-441e-411d-a848-f907ec27496e", "metadata": { "tags": [] }, "outputs": [], "source": [ "model.bert.embeddings.word_embeddings.sadcl_soft_prompt.learned_embedding.requires_grad" ] }, { "cell_type": "code", "execution_count": null, "id": "10c1e258-df79-46ce-bc99-cbd024948379", "metadata": { "tags": [] }, "outputs": [], "source": [ "(old_sd['bert.embeddings.word_embeddings.emb_layer.weight'] - nm_sd['bert.embeddings.word_embeddings.weight']).sum()" ] }, { "cell_type": "code", "execution_count": null, "id": "8c8ba1bf-dc9d-4864-8aa6-57140c6eb116", "metadata": { "tags": [] }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "ed628433-610c-4a45-9780-42927e84198e", "metadata": { "tags": [] }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "7247d9a1-5b2a-4379-8983-22aaf2d20b7b", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "61c894cd-d601-43e0-b7b1-c3c67b5cd1f7", "metadata": { "tags": [] }, "outputs": [], "source": [ "a = \"hi bye\"\n", "model.eval()\n", "with torch.no_grad():\n", " tokens = tokenizer(a, return_tensors='pt').to(DEVICE)\n", " o = model.bert.embeddings.word_embeddings(peft_module.post_tokenizer(input_ids=tokens['input_ids'])['input_ids'])" ] }, { "cell_type": "code", "execution_count": null, "id": "e1dad24f", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 10, 768])" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "o.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "632a6a0b-e2e3-4d3b-9d12-021f08e39b6e", "metadata": { "tags": [] }, "outputs": [], "source": [ "o" ] }, { "cell_type": "code", "execution_count": null, "id": "811ee37b-61fb-4103-8900-64a5ffd2e7c7", "metadata": { "tags": [] }, "outputs": [], "source": [ "len(out.hidden_states)" ] }, { "cell_type": "code", "execution_count": null, "id": "3d768fa9-2d1d-4598-8fb1-3681a8f53897", "metadata": { "tags": [] }, "outputs": [], "source": [ "setattr(c, 'a', 3)" ] }, { "cell_type": "code", "execution_count": null, "id": "3293b793-32b7-4cee-986b-1286604e361b", "metadata": { "tags": [] }, "outputs": [], "source": [ "class c:\n", " def a(self, i):\n", " print(i + 1)\n", " \n", " def b(self):\n", " c = self.a\n", " self.a = lambda i: c(i + 1)\n", "\n", "o = c()\n", "o.a(3)\n", "o.b()\n", "o.a(3)" ] }, { "cell_type": "code", "execution_count": null, "id": "41adc585-a890-4971-95b9-df994adaada2", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:deep]", "language": "python", "name": "conda-env-deep-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" } }, "nbformat": 4, "nbformat_minor": 5 }