{ "cells": [ { "cell_type": "markdown", "id": "19c25879-e13a-4f5e-8b5a-67d6bb77c3f6", "metadata": { "tags": [] }, "source": [ "# Intro" ] }, { "cell_type": "code", "execution_count": 1, "id": "ca485005-54c1-4126-8c1e-53ca633b7f26", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Python version is: 3.10.11\n", "Torch version is: 1.13.1+cu117\n", "Nvidia device is: NVIDIA GeForce RTX 4090\n", "Transformers version is: 4.32.1\n", "Adapterhub not found!!!\n" ] } ], "source": [ "from transformers import GPT2TokenizerFast, GPT2Model, DataCollatorWithPadding\n", "from transformers.modeling_outputs import SequenceClassifierOutputWithPast\n", "import torch\n", "import torch.nn as nn\n", "from utils import print_system_info\n", "from typing import Literal, Optional, List, Dict, Callable\n", "from types import SimpleNamespace\n", "from dataclasses import dataclass\n", "\n", "print_system_info()" ] }, { "cell_type": "code", "execution_count": 2, "id": "931ebd25-5e5a-4fdf-b2db-92d4ccf7f88e", "metadata": { "tags": [] }, "outputs": [], "source": [ "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "MODEL_NAME = 'gpt2'\n", "NAMESPACE = 'sadcl'\n", "\n", "INIT_TEXT = \"sentiment or value or relation of the previous text is\"\n", "N_LAST_LAYERS = 10\n", "\n", "N_TOKENS = 5" ] }, { "cell_type": "markdown", "id": "e879d47e-0b67-452a-91c2-f36383efbed8", "metadata": {}, "source": [ "# Class" ] }, { "cell_type": "code", "execution_count": 3, "id": "7c61fcde-f9e7-4d30-b989-511010e6298b", "metadata": { "tags": [] }, "outputs": [], "source": [ "def initialize_embedding(\n", " emb_dim: int,\n", " n_tokens: int, \n", " random_range: float,\n", " initialize_from: Optional[torch.Tensor]\n", "):\n", " if initialize_from is None:\n", " return torch.FloatTensor(n_tokens, emb_dim).uniform_(-random_range, random_range)\n", "\n", " assert initialize_from.shape[0] >= n_tokens\n", " assert initialize_from.shape[1] == emb_dim\n", " return initialize_from[:n_tokens, :].detach().clone()\n", "\n", "class SoftEmbedding(nn.Module):\n", " def __init__(\n", " self,\n", " emb_dim: int,\n", " n_tokens: int,\n", " first_layer_flag: bool = False,\n", " random_range: float = 0.1,\n", " initialize_from: Optional[torch.Tensor] = None\n", " ):\n", " super().__init__()\n", " \n", " self.emb_dim = emb_dim\n", " self.n_tokens = n_tokens\n", " self.first_layer_flag = first_layer_flag\n", " \n", " self.sadcl_learned_embedding = nn.parameter.Parameter(\n", " initialize_embedding(\n", " emb_dim,\n", " n_tokens,\n", " random_range,\n", " initialize_from\n", " )\n", " )\n", " # self.sadcl_mlp = nn.Sequential(\n", " # nn.Linear(emb_dim, 24, bias=False),\n", " # nn.ReLU(),\n", " # nn.Linear(24, 768, bias=False)\n", " # )\n", "\n", " assert self.sadcl_learned_embedding.shape == (n_tokens, emb_dim)\n", " \n", " def forward(self, input_embedding, attention_mask, sequnce_lengths):\n", " # input_embedding.shape = (batch_size, num_of_input_tokens+n_tokens, emb_dim)\n", " # output_embedding = []\n", " \n", " learned_embedding = self.sadcl_learned_embedding# + self.sadcl_mlp(self.sadcl_learned_embedding)\n", " \n", " batch_size = input_embedding.size(0)\n", " learned_embedding = learned_embedding.repeat(batch_size, 1, 1) # (batch_size, n_tokens, emb_dim)\n", " \n", " attention_mask_shift = torch.zeros((batch_size, 1, 1, self.n_tokens), device=attention_mask.device)\n", " attention_mask = torch.cat([attention_mask_shift, attention_mask[:, :, :, :-self.n_tokens]], dim=-1)\n", " if self.first_layer_flag:\n", " output_embedding = torch.cat([learned_embedding, input_embedding[:, :-self.n_tokens]], dim=1)\n", " else:\n", " output_embedding = torch.cat([learned_embedding, input_embedding[:, self.n_tokens:]], dim=1)\n", " # print(attention_mask == 0)\n", " return output_embedding, attention_mask\n", " \n", " def get_weights(self):\n", " return self.sadcl_learned_embedding.detach().clone()\n", "\n", "\n", "class GPT2ModuleWrapper(nn.Module):\n", " def __init__(\n", " self,\n", " module,\n", " emb_dim:int,\n", " n_tokens:int,\n", " get_sequnce_lengths:int,\n", " first_layer_flag:bool,\n", " initialize_from:Optional[torch.Tensor] = None\n", " ):\n", " super().__init__()\n", " self.original_module = module\n", " self.soft_prompt = SoftEmbedding(\n", " emb_dim=emb_dim,\n", " n_tokens=n_tokens,\n", " first_layer_flag=first_layer_flag,\n", " initialize_from=initialize_from\n", " )\n", " self.get_sequnce_lengths = get_sequnce_lengths\n", " \n", " \n", " def forward(self, hidden_states, *args, **kwargs):\n", " output_embedding, attention_mask = self.soft_prompt(\n", " hidden_states,\n", " kwargs['attention_mask'],\n", " self.get_sequnce_lengths()\n", " )\n", " kwargs['attention_mask'] = attention_mask\n", " return self.original_module(output_embedding, *args, **kwargs)\n", "\n", "class GPT2Injector:\n", " def __init__(self):\n", " self.sequnce_lengths = None\n", " \n", " def get_sequnce_lengths(self):\n", " return self.sequnce_lengths\n", " \n", " def _mutate_model_forward(self, model):\n", " old_forward = model.forward\n", " pad_token_id = model.config.pad_token_id\n", " def new_forward(*args, **kwargs):\n", " input_ids = kwargs['input_ids']\n", " self.sequnce_lengths = (\n", " torch.eq(input_ids, pad_token_id).long().argmax(-1) - 1\n", " ).detach().cpu().tolist()\n", " return old_forward(*args, **kwargs)\n", " model.forward = new_forward\n", " \n", " def _reverse_mutate_model_forward(self, model):\n", " orig_class = type(model)\n", " model.forward = orig_class.forward.__get__(model, orig_class)\n", " \n", " def mutate(self, model, n_layers, n_tokens, init_prompts):\n", " self._mutate_model_forward(model)\n", " module_list = manager.model.h\n", " start = len(module_list) - n_layers\n", " for idx in range(start, len(module_list)):\n", " module_list[idx] = GPT2ModuleWrapper(\n", " module=module_list[idx],\n", " emb_dim=model.embed_dim,\n", " n_tokens=n_tokens,\n", " get_sequnce_lengths=self.get_sequnce_lengths,\n", " first_layer_flag=(idx == start),\n", " initialize_from=init_prompts[idx][0]\n", " )\n", " return module_list[start:]\n", " \n", " def reverse_mutate(self, model):\n", " self._reverse_mutate_model_forward(model)\n", " module_list = model.h\n", " for idx in range(len(module_list)):\n", " if type(module_list[idx]) is GPT2ModuleWrapper:\n", " module_list[idx] = module_list[idx].original_module\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "f215af71-8f06-4466-a1cc-bf27b1193627", "metadata": { "tags": [] }, "outputs": [], "source": [ "class MixHeadModel(nn.Module):\n", " def __init__(self, model, head):\n", " super().__init__()\n", " self.model = model\n", " self.sadcl_head = head\n", " \n", " def forward(self, *args, **kwargs):\n", " labels = kwargs.pop('labels', None)\n", " transformer_outputs = self.model(*args, **kwargs)\n", " out = self.sadcl_head(\n", " transformer_outputs=transformer_outputs,\n", " labels=labels\n", " )\n", " return out" ] }, { "cell_type": "code", "execution_count": 5, "id": "cea800ea-d538-4aab-8aca-41feaba49b7d", "metadata": { "tags": [] }, "outputs": [], "source": [ "class GPT2ClassificationHead(nn.Module):\n", " def __init__(\n", " self,\n", " emb_dim: int,\n", " n_labels: int,\n", " get_sequnce_lengths: Callable[[], List[int]],\n", " n_tokens: int,\n", " init_range: float,\n", " bias=True\n", " ):\n", " super().__init__()\n", " \n", " self.get_sequnce_lengths = get_sequnce_lengths\n", " self.n_labels = n_labels\n", " self.n_tokens = n_tokens\n", " self.loss_func = nn.CrossEntropyLoss()\n", " \n", " self.score = nn.Linear(emb_dim, n_labels, bias) # Bias is false in huggingface implementation\n", " \n", " self._init_weights(init_range)\n", " \n", " def _init_weights(self, init_range):\n", " self.score.weight.data.normal_(mean=0.0, std=init_range)\n", " if self.score.bias is not None:\n", " self.score.bias.data.zero_()\n", " \n", " def forward(self, transformer_outputs, labels=None):\n", " last_text_token_per_batch = self.get_sequnce_lengths()\n", " last_prompt_token_per_batch = [\n", " seqlen + self.n_tokens for seqlen in last_text_token_per_batch\n", " ]\n", " last_hidden_state = transformer_outputs.last_hidden_state\n", " batch_size = last_hidden_state.size(0)\n", " \n", " # last_text_token = last_hidden_state[range(batch_size), last_text_token_per_batch]\n", " last_prompt_token = last_hidden_state[range(batch_size), last_prompt_token_per_batch]\n", " logits = self.score(last_prompt_token)\n", " \n", " loss = None\n", " if labels is not None:\n", " loss = self.loss_func(logits.view(-1, self.n_labels), labels.view(-1))\n", " \n", " return SequenceClassifierOutputWithPast(\n", " loss=loss,\n", " logits=logits,\n", " past_key_values=transformer_outputs.past_key_values,\n", " hidden_states=transformer_outputs.hidden_states,\n", " attentions=transformer_outputs.attentions,\n", " )" ] }, { "cell_type": "code", "execution_count": 6, "id": "577784eb-ab61-424d-a633-7b030d6d06d3", "metadata": { "tags": [] }, "outputs": [], "source": [ "@dataclass\n", "class PEFTConfig:\n", " name: str\n", " kind: Literal['regression', 'classification', 'generation']\n", " n_labels: Optional[int] # only for classification\n", " @classmethod\n", " def classification(cls, name: str, n_labels: int):\n", " return cls(name=name, n_labels=n_labels, kind='classification')\n", "\n", "class GPT2LLL:\n", " def __init__(\n", " self,\n", " n_tokens=N_TOKENS,\n", " n_last_layers=N_LAST_LAYERS,\n", " model_name=MODEL_NAME,\n", " device=DEVICE,\n", " init_text=INIT_TEXT\n", " ):\n", " self.n_tokens = n_tokens\n", " self.n_last_layers = n_last_layers\n", " self.model_name = model_name\n", " self.device = device\n", " \n", " self.pefts = {}\n", " \n", " self.tokenizer = GPT2TokenizerFast.from_pretrained(model_name, add_prefix_space=True)\n", " self.tokenizer.pad_token = self.tokenizer.eos_token\n", " \n", " self.model = GPT2Model.from_pretrained(model_name, pad_token_id=self.tokenizer.pad_token_id)\n", " self.model.to(device);\n", " \n", " init_tokens = self.tokenizer(init_text, return_tensors='pt').to(device)\n", " with torch.no_grad():\n", " self.init_prompts = self.model(**init_tokens, output_hidden_states=True).hidden_states\n", " \n", " self.current_peft_name = None\n", " self.current_mix_model = None\n", " \n", " @property\n", " def current_peft(self):\n", " if self.current_peft_name is None:\n", " return None\n", " return self.pefts[self.current_peft_name]\n", " \n", " def generate_tokenizer_map(self):\n", " n_tokens = self.n_tokens\n", " tokenizer = self.tokenizer\n", " def return_function(rows):\n", " outputs_dict = tokenizer(rows)\n", " for row in outputs_dict['input_ids']:\n", " row.extend([tokenizer.pad_token_id] * n_tokens)\n", " for row in outputs_dict['attention_mask']:\n", " row.extend([0] * n_tokens)\n", " return outputs_dict\n", " return return_function\n", " \n", " def activate_peft(self, name):\n", " self.current_peft_name = name\n", " \n", " self.current_peft.injector.mutate(\n", " model=self.model,\n", " n_layers=self.n_last_layers,\n", " n_tokens=self.n_tokens,\n", " init_prompts=self.init_prompts\n", " )\n", " self.current_mix_model = MixHeadModel(\n", " head=self.current_peft.head,\n", " model=self.model\n", " )\n", " \n", " def auto_freeze(self):\n", " print(\"Unfreezed params are:\")\n", " for param_name, weights in self.current_mix_model.named_parameters():\n", " if NAMESPACE in param_name:\n", " weights.requires_grad = True\n", " print(\"- \" + param_name)\n", " else:\n", " weights.requires_grad = False\n", " \n", " def add_peft(self, config: PEFTConfig):\n", " assert config.name not in self.pefts\n", " injector = GPT2Injector()\n", " head = GPT2ClassificationHead(\n", " emb_dim=self.model.embed_dim,\n", " n_labels=config.n_labels,\n", " get_sequnce_lengths=injector.get_sequnce_lengths,\n", " n_tokens=self.n_tokens,\n", " init_range=self.model.config.initializer_range,\n", " bias=False\n", " )\n", " head.to(self.device)\n", " self.pefts[config.name] = SimpleNamespace(\n", " head=head,\n", " injector=injector\n", " )" ] }, { "cell_type": "markdown", "id": "8fcfcb04-6513-4321-917f-d13c2dba886e", "metadata": { "tags": [] }, "source": [ "# Train" ] }, { "cell_type": "markdown", "id": "003fb992-fb75-4655-b60a-284ef0dcf4eb", "metadata": { "tags": [] }, "source": [ "## Prepare" ] }, { "cell_type": "code", "execution_count": 7, "id": "cfe42619-bb12-430e-9359-5ee2d2e40bdc", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Unfreezed params are:\n", "- model.h.2.soft_prompt.sadcl_learned_embedding\n", "- model.h.3.soft_prompt.sadcl_learned_embedding\n", "- model.h.4.soft_prompt.sadcl_learned_embedding\n", "- model.h.5.soft_prompt.sadcl_learned_embedding\n", "- model.h.6.soft_prompt.sadcl_learned_embedding\n", "- model.h.7.soft_prompt.sadcl_learned_embedding\n", "- model.h.8.soft_prompt.sadcl_learned_embedding\n", "- model.h.9.soft_prompt.sadcl_learned_embedding\n", "- model.h.10.soft_prompt.sadcl_learned_embedding\n", "- model.h.11.soft_prompt.sadcl_learned_embedding\n", "- sadcl_head.score.weight\n" ] } ], "source": [ "peft_name = 'peft1'\n", "\n", "manager = GPT2LLL()\n", "manager.add_peft(PEFTConfig.classification(name=peft_name, n_labels=2))\n", "manager.activate_peft(peft_name)\n", "manager.auto_freeze()" ] }, { "cell_type": "code", "execution_count": 8, "id": "072bf63b-de2f-4c05-a6a2-6fde3bb5aa6d", "metadata": { "tags": [] }, "outputs": [], "source": [ "from config import load_config\n", "config = load_config('config.yaml')" ] }, { "cell_type": "code", "execution_count": 9, "id": "3b3827aa-a61c-4e34-9e67-86768bd8b446", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset glue (/home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ac4726d36f6241c59be6dbeee759fce2", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00\n", " \n", " \n", " [14609/42880 04:58 < 09:37, 48.96 it/s, Epoch 54.51/160]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracyF1-score-1F1-score-ma
1No log0.6179170.6912750.8172530.411713
20.6182000.6202590.6912750.8172530.411713
30.6182000.6122360.6912750.8172530.411713
40.6165000.6137890.6912750.8172530.411713
50.6165000.6159890.6912750.8172530.411713
60.6128000.6149610.6912750.8172530.411713
70.6128000.6126220.6912750.8172530.411713
80.6113000.6136910.6912750.8172530.411713
90.6113000.6138890.6912750.8172530.411713
100.6094000.6161570.6912750.8172530.411713
110.6094000.6144040.6912750.8172530.411713
120.6097000.6140050.6912750.8172530.411713
130.6097000.6117220.6912750.8172530.411713
140.6071000.6098910.6922340.8177170.415012
150.6066000.6123380.6912750.8172530.411713
160.6066000.6148020.6912750.8172530.411713
170.6046000.6142890.6912750.8172530.411713
180.6046000.6106620.6922340.8177170.415012
190.6036000.6108670.6922340.8177170.415012
200.6036000.6154600.6912750.8172530.411713
210.6026000.6120300.6922340.8177170.415012
220.6026000.6112540.6922340.8177170.415012
230.6019000.6127360.6912750.8172530.411713
240.6019000.6138390.6912750.8172530.411713
250.6048000.6123030.6912750.8172530.411713
260.6048000.6121390.6912750.8172530.411713
270.6034000.6121060.6912750.8172530.411713
280.6023000.6145600.6912750.8172530.411713
290.6023000.6135810.6912750.8172530.411713
300.6028000.6159650.6912750.8172530.411713
310.6028000.6137150.6922340.8177170.415012
320.6014000.6135450.6922340.8177170.415012
330.6014000.6126310.6922340.8177170.415012
340.6014000.6118810.6922340.8177170.415012
350.6014000.6145030.6912750.8172530.411713
360.6007000.6109120.6922340.8177170.415012
370.6007000.6119160.6922340.8177170.415012
380.6008000.6114090.6922340.8177170.415012
390.6008000.6136520.6922340.8177170.415012
400.6006000.6124130.6922340.8177170.415012
410.6006000.6136730.6912750.8172530.411713
420.6004000.6111540.6922340.8177170.415012
430.6000000.6112160.6922340.8177170.415012
440.6000000.6101180.6922340.8177170.415012
450.6019000.6115730.6922340.8177170.415012
460.6019000.6135710.6922340.8179240.412058
470.5987000.6118530.6912750.8172530.411713
480.5987000.6112130.6912750.8172530.411713
490.5976000.6118550.6922340.8179240.412058
500.5976000.6118710.6922340.8179240.412058
510.6001000.6120860.6922340.8179240.412058
520.6001000.6106660.6922340.8179240.412058
530.5996000.6134060.6922340.8179240.412058
540.5996000.6170410.6922340.8179240.412058

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[15], line 43\u001b[0m\n\u001b[1;32m 34\u001b[0m trainer \u001b[38;5;241m=\u001b[39m Trainer(\n\u001b[1;32m 35\u001b[0m model\u001b[38;5;241m=\u001b[39mmanager\u001b[38;5;241m.\u001b[39mcurrent_mix_model, \u001b[38;5;66;03m# manager.current_mix_model\u001b[39;00m\n\u001b[1;32m 36\u001b[0m args\u001b[38;5;241m=\u001b[39mtraining_args,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 40\u001b[0m compute_metrics\u001b[38;5;241m=\u001b[39mcompute_metrics\n\u001b[1;32m 41\u001b[0m )\n\u001b[1;32m 42\u001b[0m \u001b[38;5;66;03m# trainer.label_names = ['labels']\u001b[39;00m\n\u001b[0;32m---> 43\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpast_key_values\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/anaconda3/envs/deep/lib/python3.10/site-packages/transformers/trainer.py:1555\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1553\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m 1554\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1555\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1556\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1557\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1558\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1559\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1560\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/anaconda3/envs/deep/lib/python3.10/site-packages/transformers/trainer.py:1837\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 1834\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_begin(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[1;32m 1836\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 1837\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1839\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 1840\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 1841\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_tpu_available()\n\u001b[1;32m 1842\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 1843\u001b[0m ):\n\u001b[1;32m 1844\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 1845\u001b[0m tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n", "File \u001b[0;32m~/anaconda3/envs/deep/lib/python3.10/site-packages/transformers/trainer.py:2693\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 2691\u001b[0m scaled_loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[1;32m 2692\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 2693\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maccelerator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2695\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\u001b[38;5;241m.\u001b[39mdetach() \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mgradient_accumulation_steps\n", "File \u001b[0;32m~/anaconda3/envs/deep/lib/python3.10/site-packages/accelerate/accelerator.py:1923\u001b[0m, in \u001b[0;36mAccelerator.backward\u001b[0;34m(self, loss, **kwargs)\u001b[0m\n\u001b[1;32m 1921\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscaler\u001b[38;5;241m.\u001b[39mscale(loss)\u001b[38;5;241m.\u001b[39mbackward(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1922\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1923\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/anaconda3/envs/deep/lib/python3.10/site-packages/torch/_tensor.py:488\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 478\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 480\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 481\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 486\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 487\u001b[0m )\n\u001b[0;32m--> 488\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 489\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 490\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/anaconda3/envs/deep/lib/python3.10/site-packages/torch/autograd/__init__.py:197\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 192\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 194\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 195\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 196\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 197\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 198\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 199\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "from transformers import TrainingArguments, Trainer, DataCollatorWithPadding\n", "from sklearn.metrics import classification_report\n", "\n", "\n", "def compute_metrics(pred):\n", " true_labels = pred.label_ids.ravel()\n", " pred_labels = pred.predictions.argmax(-1).ravel()\n", " report = classification_report(true_labels, pred_labels, output_dict=True)\n", " return {\n", " 'accuracy': report['accuracy'],\n", " 'f1-score-1': report['1']['f1-score'],\n", " 'f1-score-ma': report['macro avg']['f1-score']\n", " }\n", "\n", "col_fn = DataCollatorWithPadding(\n", " manager.tokenizer, return_tensors='pt', padding='longest'\n", ")\n", "\n", "training_args = TrainingArguments(\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " # The next 2 lines are important to ensure the dataset labels are properly passed to the model\n", " remove_unused_columns=False,\n", " label_names=['labels'],\n", " **{\n", " 'output_dir': '/disks/part4/trash',\n", " 'num_train_epochs': 160,\n", " 'learning_rate': 0.00001,\n", " 'per_device_train_batch_size': 32,\n", " 'per_device_eval_batch_size': 32\n", " }\n", ")\n", "\n", "trainer = Trainer(\n", " model=manager.current_mix_model, # manager.current_mix_model\n", " args=training_args,\n", " train_dataset=dataset['train'],\n", " eval_dataset=dataset['validation'],\n", " data_collator=col_fn,\n", " compute_metrics=compute_metrics\n", ")\n", "# trainer.label_names = ['labels']\n", "trainer.train(ignore_keys_for_eval=[\"past_key_values\"])" ] }, { "cell_type": "code", "execution_count": 14, "id": "60f9a209-1a05-4f6c-b450-773f788b93a0", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "0.6912751677852349" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "np.mean(dataset['validation']['label'].numpy())" ] }, { "cell_type": "markdown", "id": "8806b4a8-2f34-4e42-b0cd-d6b53d2b465b", "metadata": {}, "source": [ "# debug" ] }, { "cell_type": "code", "execution_count": null, "id": "559e0faa-f179-4b0a-b4b3-246260fc9056", "metadata": { "tags": [] }, "outputs": [], "source": [ "inputs = col_fn(dataset['validation'][0:50]).to(DEVICE)\n", "outputs = manager.current_mix_model(**inputs)\n", "outputs.loss.backward()" ] }, { "cell_type": "code", "execution_count": null, "id": "43cf9e69-b9a8-4059-a8e5-cab21635388c", "metadata": { "tags": [] }, "outputs": [], "source": [ "for i in range(6, 12):\n", " o = manager.current_mix_model.model.h[i].soft_prompt.sadcl_learned_embedding.grad.abs().sum().item()\n", " print(i, o)" ] }, { "cell_type": "code", "execution_count": null, "id": "ff6980bd-52c0-497b-b272-0be58044ee2f", "metadata": { "tags": [] }, "outputs": [], "source": [ "manager.current_mix_model.sadcl_head.score.weight.grad" ] }, { "cell_type": "code", "execution_count": null, "id": "dd30fcbb-8385-4cf0-b2ce-26fa961e26c9", "metadata": { "tags": [] }, "outputs": [], "source": [ "raise Exception()" ] }, { "cell_type": "code", "execution_count": null, "id": "0a788c20-5f98-447e-970a-4e96a4694976", "metadata": { "tags": [] }, "outputs": [], "source": [ "from transformers import GPT2ForSequenceClassification\n", "\n", "mtest = GPT2ForSequenceClassification.from_pretrained('gpt2', pad_token_id=manager.tokenizer.pad_token_id)\n", "mtest.to(DEVICE)\n", "\n", "training_args = TrainingArguments(\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " # The next 2 lines are important to ensure the dataset labels are properly passed to the model\n", " remove_unused_columns=False,\n", " label_names=['labels'],\n", " **\n", " {\n", " 'output_dir': '/home/mohalisad/Developer/Thesis/cp3',\n", " 'num_train_epochs': 80,\n", " 'learning_rate': 0.00001,\n", " 'per_device_train_batch_size': 32,\n", " 'per_device_eval_batch_size': 32\n", " }\n", ")\n", "\n", "trainer = Trainer(\n", " model=mtest, # manager.current_mix_model\n", " args=training_args,\n", " train_dataset=dataset['train'],\n", " eval_dataset=dataset['validation'],\n", " data_collator=col_fn,\n", " compute_metrics=compute_metrics\n", ")\n", "# trainer.label_names = ['labels']\n", "trainer.train()" ] }, { "cell_type": "markdown", "id": "ebb6c1f3-104d-4185-a6db-1aade4a4c9c9", "metadata": {}, "source": [ "# Trash" ] }, { "cell_type": "code", "execution_count": null, "id": "f135d55d-b0cf-4b2b-bc4a-209ee70ca88b", "metadata": { "tags": [] }, "outputs": [], "source": [ "def map_inputs(str_list):\n", " tokens = manager.generate_tokenizer_map()(str_list)\n", " col_fn = DataCollatorWithPadding(manager.tokenizer)\n", " return col_fn(tokens).to(DEVICE)\n", " \n", "inputs = map_inputs([\"Hello, my dog is cute\", \"bye\", \"why are\"])\n", "label = torch.tensor([0, 1, 1], device=DEVICE)\n", "outputs = manager.current_mix_model(label=label, **inputs)" ] }, { "cell_type": "code", "execution_count": null, "id": "1a6d396c-59b4-43f8-9961-620ae96df172", "metadata": { "tags": [] }, "outputs": [], "source": [ "token_ids = manager.tokenizer(INIT_TEXT, return_tensors='pt')['input_ids'].to(DEVICE)" ] }, { "cell_type": "code", "execution_count": null, "id": "abb508ae-8fa0-47bf-a988-41147b739685", "metadata": { "tags": [] }, "outputs": [], "source": [ "token_ids" ] }, { "cell_type": "code", "execution_count": null, "id": "e29dee57-e30e-4248-b579-e0a77391339c", "metadata": { "tags": [] }, "outputs": [], "source": [ "manager.model.wte(token_ids).shape" ] }, { "cell_type": "code", "execution_count": null, "id": "5ef27833-e9a5-44eb-a235-7a63c1d273d9", "metadata": { "tags": [] }, "outputs": [], "source": [ "outputs.loss" ] }, { "cell_type": "code", "execution_count": null, "id": "6a6f1ffd-ca5c-4381-94f5-5d09285a4c93", "metadata": { "tags": [] }, "outputs": [], "source": [ "manager.model.h[9].original_module.attn.c_attn.weight.grad" ] }, { "cell_type": "code", "execution_count": null, "id": "09e588e0-f7bd-4b55-9201-207e6065da06", "metadata": { "tags": [] }, "outputs": [], "source": [ "(torch.tensor([0, 1, 0]) == 0).any()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "eb556578-7ec2-4d08-a215-029c677e4878", "metadata": { "tags": [] }, "outputs": [], "source": [ "manager.model.h[9].soft_prompt.sadcl_learned_embedding.grad" ] }, { "cell_type": "code", "execution_count": null, "id": "8c30d534-2edb-4ab9-bde9-44e4b83de259", "metadata": { "tags": [] }, "outputs": [], "source": [ "outputs.last_hidden_state.sum().backward()" ] }, { "cell_type": "code", "execution_count": null, "id": "c13de4da-54f1-4611-93cd-c2d821112d0c", "metadata": {}, "outputs": [], "source": [ "\n", "\n", "last_hidden_states = outputs.last_hidden_state\n", "inputs = tokenizer([\"Hello, my dog is cute\", \"bye\"])\n", "outputs = model(**inputs)\n", "\n", "last_hidden_states = outputs.last_hidden_state" ] }, { "cell_type": "code", "execution_count": null, "id": "057a79bc-efe1-4b73-b990-d4d6d445ff3e", "metadata": { "tags": [] }, "outputs": [], "source": [ "inputs" ] }, { "cell_type": "code", "execution_count": null, "id": "10235ef7-ed2c-47dd-ab5e-096072bc6cd0", "metadata": { "tags": [] }, "outputs": [], "source": [ "\n", "inputs = tokenize_dataset([\"Hello, my dog is cute\", \"bye\"])\n", "inputs" ] }, { "cell_type": "code", "execution_count": null, "id": "a8ca1b81-f0a6-40f5-853d-a154585b61b3", "metadata": { "tags": [] }, "outputs": [], "source": [ "tokenizer.eos_token" ] }, { "cell_type": "code", "execution_count": null, "id": "7d88682f-9e3c-4e80-84d0-448f7ed93bc4", "metadata": {}, "outputs": [], "source": [ "x = nn.Parameter(torch.arange(27).reshape(3, 3, 3).float())\n", "x" ] }, { "cell_type": "code", "execution_count": null, "id": "3115c7f4-94a3-4526-8373-ec524fe46c1c", "metadata": { "tags": [] }, "outputs": [], "source": [ "b = nn.Parameter(torch.tensor([7, 7, 7]).float())\n", "b" ] }, { "cell_type": "code", "execution_count": null, "id": "30ec47b9-1636-491b-85c6-24415933fb3d", "metadata": { "tags": [] }, "outputs": [], "source": [ "x[0, 0, :] = torch.tensor([7, 7, 7])" ] }, { "cell_type": "code", "execution_count": null, "id": "b0a87c7c-9740-412e-9ad3-a7f3f06e90bc", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:deep]", "language": "python", "name": "conda-env-deep-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" } }, "nbformat": 4, "nbformat_minor": 5 }