|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444 |
- {
- "cells": [
- {
- "cell_type": "markdown",
- "id": "19c25879-e13a-4f5e-8b5a-67d6bb77c3f6",
- "metadata": {
- "tags": []
- },
- "source": [
- "# Intro"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "ca485005-54c1-4126-8c1e-53ca633b7f26",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Python version is: 3.10.11\n",
- "Torch version is: 1.13.1+cu117\n",
- "Nvidia device is: NVIDIA GeForce RTX 4090\n",
- "Transformers version is: 4.32.1\n",
- "Adapterhub not found!!!\n"
- ]
- }
- ],
- "source": [
- "from transformers import GPT2TokenizerFast, GPT2Model, DataCollatorWithPadding\n",
- "from transformers.modeling_outputs import SequenceClassifierOutputWithPast\n",
- "import torch\n",
- "import torch.nn as nn\n",
- "from utils import print_system_info\n",
- "from typing import Literal, Optional, List, Dict, Callable\n",
- "from types import SimpleNamespace\n",
- "from dataclasses import dataclass\n",
- "\n",
- "print_system_info()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "931ebd25-5e5a-4fdf-b2db-92d4ccf7f88e",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
- "MODEL_NAME = 'gpt2'\n",
- "NAMESPACE = 'sadcl'\n",
- "\n",
- "INIT_TEXT = \"sentiment or value or relation of the previous text is\"\n",
- "N_LAST_LAYERS = 10\n",
- "\n",
- "N_TOKENS = 5"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "e879d47e-0b67-452a-91c2-f36383efbed8",
- "metadata": {},
- "source": [
- "# Class"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "7c61fcde-f9e7-4d30-b989-511010e6298b",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "def initialize_embedding(\n",
- " emb_dim: int,\n",
- " n_tokens: int, \n",
- " random_range: float,\n",
- " initialize_from: Optional[torch.Tensor]\n",
- "):\n",
- " if initialize_from is None:\n",
- " return torch.FloatTensor(n_tokens, emb_dim).uniform_(-random_range, random_range)\n",
- "\n",
- " assert initialize_from.shape[0] >= n_tokens\n",
- " assert initialize_from.shape[1] == emb_dim\n",
- " return initialize_from[:n_tokens, :].detach().clone()\n",
- "\n",
- "class SoftEmbedding(nn.Module):\n",
- " def __init__(\n",
- " self,\n",
- " emb_dim: int,\n",
- " n_tokens: int,\n",
- " first_layer_flag: bool = False,\n",
- " random_range: float = 0.1,\n",
- " initialize_from: Optional[torch.Tensor] = None\n",
- " ):\n",
- " super().__init__()\n",
- " \n",
- " self.emb_dim = emb_dim\n",
- " self.n_tokens = n_tokens\n",
- " self.first_layer_flag = first_layer_flag\n",
- " \n",
- " self.sadcl_learned_embedding = nn.parameter.Parameter(\n",
- " initialize_embedding(\n",
- " emb_dim,\n",
- " n_tokens,\n",
- " random_range,\n",
- " initialize_from\n",
- " )\n",
- " )\n",
- " # self.sadcl_mlp = nn.Sequential(\n",
- " # nn.Linear(emb_dim, 24, bias=False),\n",
- " # nn.ReLU(),\n",
- " # nn.Linear(24, 768, bias=False)\n",
- " # )\n",
- "\n",
- " assert self.sadcl_learned_embedding.shape == (n_tokens, emb_dim)\n",
- " \n",
- " def forward(self, input_embedding, attention_mask, sequnce_lengths):\n",
- " # input_embedding.shape = (batch_size, num_of_input_tokens+n_tokens, emb_dim)\n",
- " # output_embedding = []\n",
- " \n",
- " learned_embedding = self.sadcl_learned_embedding# + self.sadcl_mlp(self.sadcl_learned_embedding)\n",
- " \n",
- " batch_size = input_embedding.size(0)\n",
- " learned_embedding = learned_embedding.repeat(batch_size, 1, 1) # (batch_size, n_tokens, emb_dim)\n",
- " \n",
- " attention_mask_shift = torch.zeros((batch_size, 1, 1, self.n_tokens), device=attention_mask.device)\n",
- " attention_mask = torch.cat([attention_mask_shift, attention_mask[:, :, :, :-self.n_tokens]], dim=-1)\n",
- " if self.first_layer_flag:\n",
- " output_embedding = torch.cat([learned_embedding, input_embedding[:, :-self.n_tokens]], dim=1)\n",
- " else:\n",
- " output_embedding = torch.cat([learned_embedding, input_embedding[:, self.n_tokens:]], dim=1)\n",
- " # print(attention_mask == 0)\n",
- " return output_embedding, attention_mask\n",
- " \n",
- " def get_weights(self):\n",
- " return self.sadcl_learned_embedding.detach().clone()\n",
- "\n",
- "\n",
- "class GPT2ModuleWrapper(nn.Module):\n",
- " def __init__(\n",
- " self,\n",
- " module,\n",
- " emb_dim:int,\n",
- " n_tokens:int,\n",
- " get_sequnce_lengths:int,\n",
- " first_layer_flag:bool,\n",
- " initialize_from:Optional[torch.Tensor] = None\n",
- " ):\n",
- " super().__init__()\n",
- " self.original_module = module\n",
- " self.soft_prompt = SoftEmbedding(\n",
- " emb_dim=emb_dim,\n",
- " n_tokens=n_tokens,\n",
- " first_layer_flag=first_layer_flag,\n",
- " initialize_from=initialize_from\n",
- " )\n",
- " self.get_sequnce_lengths = get_sequnce_lengths\n",
- " \n",
- " \n",
- " def forward(self, hidden_states, *args, **kwargs):\n",
- " output_embedding, attention_mask = self.soft_prompt(\n",
- " hidden_states,\n",
- " kwargs['attention_mask'],\n",
- " self.get_sequnce_lengths()\n",
- " )\n",
- " kwargs['attention_mask'] = attention_mask\n",
- " return self.original_module(output_embedding, *args, **kwargs)\n",
- "\n",
- "class GPT2Injector:\n",
- " def __init__(self):\n",
- " self.sequnce_lengths = None\n",
- " \n",
- " def get_sequnce_lengths(self):\n",
- " return self.sequnce_lengths\n",
- " \n",
- " def _mutate_model_forward(self, model):\n",
- " old_forward = model.forward\n",
- " pad_token_id = model.config.pad_token_id\n",
- " def new_forward(*args, **kwargs):\n",
- " input_ids = kwargs['input_ids']\n",
- " self.sequnce_lengths = (\n",
- " torch.eq(input_ids, pad_token_id).long().argmax(-1) - 1\n",
- " ).detach().cpu().tolist()\n",
- " return old_forward(*args, **kwargs)\n",
- " model.forward = new_forward\n",
- " \n",
- " def _reverse_mutate_model_forward(self, model):\n",
- " orig_class = type(model)\n",
- " model.forward = orig_class.forward.__get__(model, orig_class)\n",
- " \n",
- " def mutate(self, model, n_layers, n_tokens, init_prompts):\n",
- " self._mutate_model_forward(model)\n",
- " module_list = manager.model.h\n",
- " start = len(module_list) - n_layers\n",
- " for idx in range(start, len(module_list)):\n",
- " module_list[idx] = GPT2ModuleWrapper(\n",
- " module=module_list[idx],\n",
- " emb_dim=model.embed_dim,\n",
- " n_tokens=n_tokens,\n",
- " get_sequnce_lengths=self.get_sequnce_lengths,\n",
- " first_layer_flag=(idx == start),\n",
- " initialize_from=init_prompts[idx][0]\n",
- " )\n",
- " return module_list[start:]\n",
- " \n",
- " def reverse_mutate(self, model):\n",
- " self._reverse_mutate_model_forward(model)\n",
- " module_list = model.h\n",
- " for idx in range(len(module_list)):\n",
- " if type(module_list[idx]) is GPT2ModuleWrapper:\n",
- " module_list[idx] = module_list[idx].original_module\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "f215af71-8f06-4466-a1cc-bf27b1193627",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "class MixHeadModel(nn.Module):\n",
- " def __init__(self, model, head):\n",
- " super().__init__()\n",
- " self.model = model\n",
- " self.sadcl_head = head\n",
- " \n",
- " def forward(self, *args, **kwargs):\n",
- " labels = kwargs.pop('labels', None)\n",
- " transformer_outputs = self.model(*args, **kwargs)\n",
- " out = self.sadcl_head(\n",
- " transformer_outputs=transformer_outputs,\n",
- " labels=labels\n",
- " )\n",
- " return out"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "id": "cea800ea-d538-4aab-8aca-41feaba49b7d",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "class GPT2ClassificationHead(nn.Module):\n",
- " def __init__(\n",
- " self,\n",
- " emb_dim: int,\n",
- " n_labels: int,\n",
- " get_sequnce_lengths: Callable[[], List[int]],\n",
- " n_tokens: int,\n",
- " init_range: float,\n",
- " bias=True\n",
- " ):\n",
- " super().__init__()\n",
- " \n",
- " self.get_sequnce_lengths = get_sequnce_lengths\n",
- " self.n_labels = n_labels\n",
- " self.n_tokens = n_tokens\n",
- " self.loss_func = nn.CrossEntropyLoss()\n",
- " \n",
- " self.score = nn.Linear(emb_dim, n_labels, bias) # Bias is false in huggingface implementation\n",
- " \n",
- " self._init_weights(init_range)\n",
- " \n",
- " def _init_weights(self, init_range):\n",
- " self.score.weight.data.normal_(mean=0.0, std=init_range)\n",
- " if self.score.bias is not None:\n",
- " self.score.bias.data.zero_()\n",
- " \n",
- " def forward(self, transformer_outputs, labels=None):\n",
- " last_text_token_per_batch = self.get_sequnce_lengths()\n",
- " last_prompt_token_per_batch = [\n",
- " seqlen + self.n_tokens for seqlen in last_text_token_per_batch\n",
- " ]\n",
- " last_hidden_state = transformer_outputs.last_hidden_state\n",
- " batch_size = last_hidden_state.size(0)\n",
- " \n",
- " # last_text_token = last_hidden_state[range(batch_size), last_text_token_per_batch]\n",
- " last_prompt_token = last_hidden_state[range(batch_size), last_prompt_token_per_batch]\n",
- " logits = self.score(last_prompt_token)\n",
- " \n",
- " loss = None\n",
- " if labels is not None:\n",
- " loss = self.loss_func(logits.view(-1, self.n_labels), labels.view(-1))\n",
- " \n",
- " return SequenceClassifierOutputWithPast(\n",
- " loss=loss,\n",
- " logits=logits,\n",
- " past_key_values=transformer_outputs.past_key_values,\n",
- " hidden_states=transformer_outputs.hidden_states,\n",
- " attentions=transformer_outputs.attentions,\n",
- " )"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "id": "577784eb-ab61-424d-a633-7b030d6d06d3",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "@dataclass\n",
- "class PEFTConfig:\n",
- " name: str\n",
- " kind: Literal['regression', 'classification', 'generation']\n",
- " n_labels: Optional[int] # only for classification\n",
- " @classmethod\n",
- " def classification(cls, name: str, n_labels: int):\n",
- " return cls(name=name, n_labels=n_labels, kind='classification')\n",
- "\n",
- "class GPT2LLL:\n",
- " def __init__(\n",
- " self,\n",
- " n_tokens=N_TOKENS,\n",
- " n_last_layers=N_LAST_LAYERS,\n",
- " model_name=MODEL_NAME,\n",
- " device=DEVICE,\n",
- " init_text=INIT_TEXT\n",
- " ):\n",
- " self.n_tokens = n_tokens\n",
- " self.n_last_layers = n_last_layers\n",
- " self.model_name = model_name\n",
- " self.device = device\n",
- " \n",
- " self.pefts = {}\n",
- " \n",
- " self.tokenizer = GPT2TokenizerFast.from_pretrained(model_name, add_prefix_space=True)\n",
- " self.tokenizer.pad_token = self.tokenizer.eos_token\n",
- " \n",
- " self.model = GPT2Model.from_pretrained(model_name, pad_token_id=self.tokenizer.pad_token_id)\n",
- " self.model.to(device);\n",
- " \n",
- " init_tokens = self.tokenizer(init_text, return_tensors='pt').to(device)\n",
- " with torch.no_grad():\n",
- " self.init_prompts = self.model(**init_tokens, output_hidden_states=True).hidden_states\n",
- " \n",
- " self.current_peft_name = None\n",
- " self.current_mix_model = None\n",
- " \n",
- " @property\n",
- " def current_peft(self):\n",
- " if self.current_peft_name is None:\n",
- " return None\n",
- " return self.pefts[self.current_peft_name]\n",
- " \n",
- " def generate_tokenizer_map(self):\n",
- " n_tokens = self.n_tokens\n",
- " tokenizer = self.tokenizer\n",
- " def return_function(rows):\n",
- " outputs_dict = tokenizer(rows)\n",
- " for row in outputs_dict['input_ids']:\n",
- " row.extend([tokenizer.pad_token_id] * n_tokens)\n",
- " for row in outputs_dict['attention_mask']:\n",
- " row.extend([0] * n_tokens)\n",
- " return outputs_dict\n",
- " return return_function\n",
- " \n",
- " def activate_peft(self, name):\n",
- " self.current_peft_name = name\n",
- " \n",
- " self.current_peft.injector.mutate(\n",
- " model=self.model,\n",
- " n_layers=self.n_last_layers,\n",
- " n_tokens=self.n_tokens,\n",
- " init_prompts=self.init_prompts\n",
- " )\n",
- " self.current_mix_model = MixHeadModel(\n",
- " head=self.current_peft.head,\n",
- " model=self.model\n",
- " )\n",
- " \n",
- " def auto_freeze(self):\n",
- " print(\"Unfreezed params are:\")\n",
- " for param_name, weights in self.current_mix_model.named_parameters():\n",
- " if NAMESPACE in param_name:\n",
- " weights.requires_grad = True\n",
- " print(\"- \" + param_name)\n",
- " else:\n",
- " weights.requires_grad = False\n",
- " \n",
- " def add_peft(self, config: PEFTConfig):\n",
- " assert config.name not in self.pefts\n",
- " injector = GPT2Injector()\n",
- " head = GPT2ClassificationHead(\n",
- " emb_dim=self.model.embed_dim,\n",
- " n_labels=config.n_labels,\n",
- " get_sequnce_lengths=injector.get_sequnce_lengths,\n",
- " n_tokens=self.n_tokens,\n",
- " init_range=self.model.config.initializer_range,\n",
- " bias=False\n",
- " )\n",
- " head.to(self.device)\n",
- " self.pefts[config.name] = SimpleNamespace(\n",
- " head=head,\n",
- " injector=injector\n",
- " )"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "8fcfcb04-6513-4321-917f-d13c2dba886e",
- "metadata": {
- "tags": []
- },
- "source": [
- "# Train"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "003fb992-fb75-4655-b60a-284ef0dcf4eb",
- "metadata": {
- "tags": []
- },
- "source": [
- "## Prepare"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "id": "cfe42619-bb12-430e-9359-5ee2d2e40bdc",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Unfreezed params are:\n",
- "- model.h.2.soft_prompt.sadcl_learned_embedding\n",
- "- model.h.3.soft_prompt.sadcl_learned_embedding\n",
- "- model.h.4.soft_prompt.sadcl_learned_embedding\n",
- "- model.h.5.soft_prompt.sadcl_learned_embedding\n",
- "- model.h.6.soft_prompt.sadcl_learned_embedding\n",
- "- model.h.7.soft_prompt.sadcl_learned_embedding\n",
- "- model.h.8.soft_prompt.sadcl_learned_embedding\n",
- "- model.h.9.soft_prompt.sadcl_learned_embedding\n",
- "- model.h.10.soft_prompt.sadcl_learned_embedding\n",
- "- model.h.11.soft_prompt.sadcl_learned_embedding\n",
- "- sadcl_head.score.weight\n"
- ]
- }
- ],
- "source": [
- "peft_name = 'peft1'\n",
- "\n",
- "manager = GPT2LLL()\n",
- "manager.add_peft(PEFTConfig.classification(name=peft_name, n_labels=2))\n",
- "manager.activate_peft(peft_name)\n",
- "manager.auto_freeze()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "id": "072bf63b-de2f-4c05-a6a2-6fde3bb5aa6d",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "from config import load_config\n",
- "config = load_config('config.yaml')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "id": "3b3827aa-a61c-4e34-9e67-86768bd8b446",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Found cached dataset glue (/home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "ac4726d36f6241c59be6dbeee759fce2",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- " 0%| | 0/3 [00:00<?, ?it/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Loading cached processed dataset at /home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-f7a02c6d65621ecd.arrow\n",
- "Loading cached processed dataset at /home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-c36341ab82d2d37d.arrow\n",
- "Loading cached processed dataset at /home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9f7663dac81ea13b.arrow\n"
- ]
- }
- ],
- "source": [
- "from datasets import load_dataset\n",
- "dataset = load_dataset('glue', 'cola')\n",
- "tokenizer_map = manager.generate_tokenizer_map()\n",
- "dataset = dataset.map(lambda x: tokenizer_map(x['sentence']), batched=True)\n",
- "dataset.set_format(type='torch', columns=[\n",
- " 'input_ids', 'attention_mask', 'label' # 'token_type_ids',\n",
- "])"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "5331fedd-e6ec-4387-a1ec-55488d144f45",
- "metadata": {},
- "source": [
- "## Training"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "id": "e13c9012-089f-45c1-baea-4f0850ccfbaa",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- " <div>\n",
- " \n",
- " <progress value='14609' max='42880' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
- " [14609/42880 04:58 < 09:37, 48.96 it/s, Epoch 54.51/160]\n",
- " </div>\n",
- " <table border=\"1\" class=\"dataframe\">\n",
- " <thead>\n",
- " <tr style=\"text-align: left;\">\n",
- " <th>Epoch</th>\n",
- " <th>Training Loss</th>\n",
- " <th>Validation Loss</th>\n",
- " <th>Accuracy</th>\n",
- " <th>F1-score-1</th>\n",
- " <th>F1-score-ma</th>\n",
- " </tr>\n",
- " </thead>\n",
- " <tbody>\n",
- " <tr>\n",
- " <td>1</td>\n",
- " <td>No log</td>\n",
- " <td>0.617917</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>2</td>\n",
- " <td>0.618200</td>\n",
- " <td>0.620259</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>3</td>\n",
- " <td>0.618200</td>\n",
- " <td>0.612236</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>4</td>\n",
- " <td>0.616500</td>\n",
- " <td>0.613789</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>5</td>\n",
- " <td>0.616500</td>\n",
- " <td>0.615989</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>6</td>\n",
- " <td>0.612800</td>\n",
- " <td>0.614961</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>7</td>\n",
- " <td>0.612800</td>\n",
- " <td>0.612622</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>8</td>\n",
- " <td>0.611300</td>\n",
- " <td>0.613691</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>9</td>\n",
- " <td>0.611300</td>\n",
- " <td>0.613889</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>10</td>\n",
- " <td>0.609400</td>\n",
- " <td>0.616157</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>11</td>\n",
- " <td>0.609400</td>\n",
- " <td>0.614404</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>12</td>\n",
- " <td>0.609700</td>\n",
- " <td>0.614005</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>13</td>\n",
- " <td>0.609700</td>\n",
- " <td>0.611722</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>14</td>\n",
- " <td>0.607100</td>\n",
- " <td>0.609891</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817717</td>\n",
- " <td>0.415012</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>15</td>\n",
- " <td>0.606600</td>\n",
- " <td>0.612338</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>16</td>\n",
- " <td>0.606600</td>\n",
- " <td>0.614802</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>17</td>\n",
- " <td>0.604600</td>\n",
- " <td>0.614289</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>18</td>\n",
- " <td>0.604600</td>\n",
- " <td>0.610662</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817717</td>\n",
- " <td>0.415012</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>19</td>\n",
- " <td>0.603600</td>\n",
- " <td>0.610867</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817717</td>\n",
- " <td>0.415012</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>20</td>\n",
- " <td>0.603600</td>\n",
- " <td>0.615460</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>21</td>\n",
- " <td>0.602600</td>\n",
- " <td>0.612030</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817717</td>\n",
- " <td>0.415012</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>22</td>\n",
- " <td>0.602600</td>\n",
- " <td>0.611254</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817717</td>\n",
- " <td>0.415012</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>23</td>\n",
- " <td>0.601900</td>\n",
- " <td>0.612736</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>24</td>\n",
- " <td>0.601900</td>\n",
- " <td>0.613839</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>25</td>\n",
- " <td>0.604800</td>\n",
- " <td>0.612303</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>26</td>\n",
- " <td>0.604800</td>\n",
- " <td>0.612139</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>27</td>\n",
- " <td>0.603400</td>\n",
- " <td>0.612106</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>28</td>\n",
- " <td>0.602300</td>\n",
- " <td>0.614560</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>29</td>\n",
- " <td>0.602300</td>\n",
- " <td>0.613581</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>30</td>\n",
- " <td>0.602800</td>\n",
- " <td>0.615965</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>31</td>\n",
- " <td>0.602800</td>\n",
- " <td>0.613715</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817717</td>\n",
- " <td>0.415012</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>32</td>\n",
- " <td>0.601400</td>\n",
- " <td>0.613545</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817717</td>\n",
- " <td>0.415012</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>33</td>\n",
- " <td>0.601400</td>\n",
- " <td>0.612631</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817717</td>\n",
- " <td>0.415012</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>34</td>\n",
- " <td>0.601400</td>\n",
- " <td>0.611881</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817717</td>\n",
- " <td>0.415012</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>35</td>\n",
- " <td>0.601400</td>\n",
- " <td>0.614503</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>36</td>\n",
- " <td>0.600700</td>\n",
- " <td>0.610912</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817717</td>\n",
- " <td>0.415012</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>37</td>\n",
- " <td>0.600700</td>\n",
- " <td>0.611916</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817717</td>\n",
- " <td>0.415012</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>38</td>\n",
- " <td>0.600800</td>\n",
- " <td>0.611409</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817717</td>\n",
- " <td>0.415012</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>39</td>\n",
- " <td>0.600800</td>\n",
- " <td>0.613652</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817717</td>\n",
- " <td>0.415012</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>40</td>\n",
- " <td>0.600600</td>\n",
- " <td>0.612413</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817717</td>\n",
- " <td>0.415012</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>41</td>\n",
- " <td>0.600600</td>\n",
- " <td>0.613673</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>42</td>\n",
- " <td>0.600400</td>\n",
- " <td>0.611154</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817717</td>\n",
- " <td>0.415012</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>43</td>\n",
- " <td>0.600000</td>\n",
- " <td>0.611216</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817717</td>\n",
- " <td>0.415012</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>44</td>\n",
- " <td>0.600000</td>\n",
- " <td>0.610118</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817717</td>\n",
- " <td>0.415012</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>45</td>\n",
- " <td>0.601900</td>\n",
- " <td>0.611573</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817717</td>\n",
- " <td>0.415012</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>46</td>\n",
- " <td>0.601900</td>\n",
- " <td>0.613571</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817924</td>\n",
- " <td>0.412058</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>47</td>\n",
- " <td>0.598700</td>\n",
- " <td>0.611853</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>48</td>\n",
- " <td>0.598700</td>\n",
- " <td>0.611213</td>\n",
- " <td>0.691275</td>\n",
- " <td>0.817253</td>\n",
- " <td>0.411713</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>49</td>\n",
- " <td>0.597600</td>\n",
- " <td>0.611855</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817924</td>\n",
- " <td>0.412058</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>50</td>\n",
- " <td>0.597600</td>\n",
- " <td>0.611871</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817924</td>\n",
- " <td>0.412058</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>51</td>\n",
- " <td>0.600100</td>\n",
- " <td>0.612086</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817924</td>\n",
- " <td>0.412058</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>52</td>\n",
- " <td>0.600100</td>\n",
- " <td>0.610666</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817924</td>\n",
- " <td>0.412058</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>53</td>\n",
- " <td>0.599600</td>\n",
- " <td>0.613406</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817924</td>\n",
- " <td>0.412058</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>54</td>\n",
- " <td>0.599600</td>\n",
- " <td>0.617041</td>\n",
- " <td>0.692234</td>\n",
- " <td>0.817924</td>\n",
- " <td>0.412058</td>\n",
- " </tr>\n",
- " </tbody>\n",
- "</table><p>"
- ],
- "text/plain": [
- "<IPython.core.display.HTML object>"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "ename": "KeyboardInterrupt",
- "evalue": "",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
- "Cell \u001b[0;32mIn[15], line 43\u001b[0m\n\u001b[1;32m 34\u001b[0m trainer \u001b[38;5;241m=\u001b[39m Trainer(\n\u001b[1;32m 35\u001b[0m model\u001b[38;5;241m=\u001b[39mmanager\u001b[38;5;241m.\u001b[39mcurrent_mix_model, \u001b[38;5;66;03m# manager.current_mix_model\u001b[39;00m\n\u001b[1;32m 36\u001b[0m args\u001b[38;5;241m=\u001b[39mtraining_args,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 40\u001b[0m compute_metrics\u001b[38;5;241m=\u001b[39mcompute_metrics\n\u001b[1;32m 41\u001b[0m )\n\u001b[1;32m 42\u001b[0m \u001b[38;5;66;03m# trainer.label_names = ['labels']\u001b[39;00m\n\u001b[0;32m---> 43\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpast_key_values\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/anaconda3/envs/deep/lib/python3.10/site-packages/transformers/trainer.py:1555\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1553\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m 1554\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1555\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1556\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1557\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1558\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1559\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1560\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/anaconda3/envs/deep/lib/python3.10/site-packages/transformers/trainer.py:1837\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 1834\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_begin(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[1;32m 1836\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 1837\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1839\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 1840\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 1841\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_tpu_available()\n\u001b[1;32m 1842\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 1843\u001b[0m ):\n\u001b[1;32m 1844\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 1845\u001b[0m tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
- "File \u001b[0;32m~/anaconda3/envs/deep/lib/python3.10/site-packages/transformers/trainer.py:2693\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 2691\u001b[0m scaled_loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[1;32m 2692\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 2693\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maccelerator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2695\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\u001b[38;5;241m.\u001b[39mdetach() \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mgradient_accumulation_steps\n",
- "File \u001b[0;32m~/anaconda3/envs/deep/lib/python3.10/site-packages/accelerate/accelerator.py:1923\u001b[0m, in \u001b[0;36mAccelerator.backward\u001b[0;34m(self, loss, **kwargs)\u001b[0m\n\u001b[1;32m 1921\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscaler\u001b[38;5;241m.\u001b[39mscale(loss)\u001b[38;5;241m.\u001b[39mbackward(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1922\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1923\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/anaconda3/envs/deep/lib/python3.10/site-packages/torch/_tensor.py:488\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 478\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 480\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 481\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 486\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 487\u001b[0m )\n\u001b[0;32m--> 488\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 489\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 490\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/anaconda3/envs/deep/lib/python3.10/site-packages/torch/autograd/__init__.py:197\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 192\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 194\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 195\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 196\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 197\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 198\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 199\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
- ]
- }
- ],
- "source": [
- "from transformers import TrainingArguments, Trainer, DataCollatorWithPadding\n",
- "from sklearn.metrics import classification_report\n",
- "\n",
- "\n",
- "def compute_metrics(pred):\n",
- " true_labels = pred.label_ids.ravel()\n",
- " pred_labels = pred.predictions.argmax(-1).ravel()\n",
- " report = classification_report(true_labels, pred_labels, output_dict=True)\n",
- " return {\n",
- " 'accuracy': report['accuracy'],\n",
- " 'f1-score-1': report['1']['f1-score'],\n",
- " 'f1-score-ma': report['macro avg']['f1-score']\n",
- " }\n",
- "\n",
- "col_fn = DataCollatorWithPadding(\n",
- " manager.tokenizer, return_tensors='pt', padding='longest'\n",
- ")\n",
- "\n",
- "training_args = TrainingArguments(\n",
- " evaluation_strategy=\"epoch\",\n",
- " save_strategy=\"epoch\",\n",
- " # The next 2 lines are important to ensure the dataset labels are properly passed to the model\n",
- " remove_unused_columns=False,\n",
- " label_names=['labels'],\n",
- " **{\n",
- " 'output_dir': '/disks/part4/trash',\n",
- " 'num_train_epochs': 160,\n",
- " 'learning_rate': 0.00001,\n",
- " 'per_device_train_batch_size': 32,\n",
- " 'per_device_eval_batch_size': 32\n",
- " }\n",
- ")\n",
- "\n",
- "trainer = Trainer(\n",
- " model=manager.current_mix_model, # manager.current_mix_model\n",
- " args=training_args,\n",
- " train_dataset=dataset['train'],\n",
- " eval_dataset=dataset['validation'],\n",
- " data_collator=col_fn,\n",
- " compute_metrics=compute_metrics\n",
- ")\n",
- "# trainer.label_names = ['labels']\n",
- "trainer.train(ignore_keys_for_eval=[\"past_key_values\"])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "id": "60f9a209-1a05-4f6c-b450-773f788b93a0",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "0.6912751677852349"
- ]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "import numpy as np\n",
- "np.mean(dataset['validation']['label'].numpy())"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "8806b4a8-2f34-4e42-b0cd-d6b53d2b465b",
- "metadata": {},
- "source": [
- "# debug"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "559e0faa-f179-4b0a-b4b3-246260fc9056",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "inputs = col_fn(dataset['validation'][0:50]).to(DEVICE)\n",
- "outputs = manager.current_mix_model(**inputs)\n",
- "outputs.loss.backward()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "43cf9e69-b9a8-4059-a8e5-cab21635388c",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "for i in range(6, 12):\n",
- " o = manager.current_mix_model.model.h[i].soft_prompt.sadcl_learned_embedding.grad.abs().sum().item()\n",
- " print(i, o)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "ff6980bd-52c0-497b-b272-0be58044ee2f",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "manager.current_mix_model.sadcl_head.score.weight.grad"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "dd30fcbb-8385-4cf0-b2ce-26fa961e26c9",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "raise Exception()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "0a788c20-5f98-447e-970a-4e96a4694976",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "from transformers import GPT2ForSequenceClassification\n",
- "\n",
- "mtest = GPT2ForSequenceClassification.from_pretrained('gpt2', pad_token_id=manager.tokenizer.pad_token_id)\n",
- "mtest.to(DEVICE)\n",
- "\n",
- "training_args = TrainingArguments(\n",
- " evaluation_strategy=\"epoch\",\n",
- " save_strategy=\"epoch\",\n",
- " # The next 2 lines are important to ensure the dataset labels are properly passed to the model\n",
- " remove_unused_columns=False,\n",
- " label_names=['labels'],\n",
- " **\n",
- " {\n",
- " 'output_dir': '/home/mohalisad/Developer/Thesis/cp3',\n",
- " 'num_train_epochs': 80,\n",
- " 'learning_rate': 0.00001,\n",
- " 'per_device_train_batch_size': 32,\n",
- " 'per_device_eval_batch_size': 32\n",
- " }\n",
- ")\n",
- "\n",
- "trainer = Trainer(\n",
- " model=mtest, # manager.current_mix_model\n",
- " args=training_args,\n",
- " train_dataset=dataset['train'],\n",
- " eval_dataset=dataset['validation'],\n",
- " data_collator=col_fn,\n",
- " compute_metrics=compute_metrics\n",
- ")\n",
- "# trainer.label_names = ['labels']\n",
- "trainer.train()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "ebb6c1f3-104d-4185-a6db-1aade4a4c9c9",
- "metadata": {},
- "source": [
- "# Trash"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "f135d55d-b0cf-4b2b-bc4a-209ee70ca88b",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "def map_inputs(str_list):\n",
- " tokens = manager.generate_tokenizer_map()(str_list)\n",
- " col_fn = DataCollatorWithPadding(manager.tokenizer)\n",
- " return col_fn(tokens).to(DEVICE)\n",
- " \n",
- "inputs = map_inputs([\"Hello, my dog is cute\", \"bye\", \"why are\"])\n",
- "label = torch.tensor([0, 1, 1], device=DEVICE)\n",
- "outputs = manager.current_mix_model(label=label, **inputs)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "1a6d396c-59b4-43f8-9961-620ae96df172",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "token_ids = manager.tokenizer(INIT_TEXT, return_tensors='pt')['input_ids'].to(DEVICE)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "abb508ae-8fa0-47bf-a988-41147b739685",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "token_ids"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "e29dee57-e30e-4248-b579-e0a77391339c",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "manager.model.wte(token_ids).shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "5ef27833-e9a5-44eb-a235-7a63c1d273d9",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "outputs.loss"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "6a6f1ffd-ca5c-4381-94f5-5d09285a4c93",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "manager.model.h[9].original_module.attn.c_attn.weight.grad"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "09e588e0-f7bd-4b55-9201-207e6065da06",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "(torch.tensor([0, 1, 0]) == 0).any()\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "eb556578-7ec2-4d08-a215-029c677e4878",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "manager.model.h[9].soft_prompt.sadcl_learned_embedding.grad"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "8c30d534-2edb-4ab9-bde9-44e4b83de259",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "outputs.last_hidden_state.sum().backward()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "c13de4da-54f1-4611-93cd-c2d821112d0c",
- "metadata": {},
- "outputs": [],
- "source": [
- "\n",
- "\n",
- "last_hidden_states = outputs.last_hidden_state\n",
- "inputs = tokenizer([\"Hello, my dog is cute\", \"bye\"])\n",
- "outputs = model(**inputs)\n",
- "\n",
- "last_hidden_states = outputs.last_hidden_state"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "057a79bc-efe1-4b73-b990-d4d6d445ff3e",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "inputs"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "10235ef7-ed2c-47dd-ab5e-096072bc6cd0",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "\n",
- "inputs = tokenize_dataset([\"Hello, my dog is cute\", \"bye\"])\n",
- "inputs"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "a8ca1b81-f0a6-40f5-853d-a154585b61b3",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "tokenizer.eos_token"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "7d88682f-9e3c-4e80-84d0-448f7ed93bc4",
- "metadata": {},
- "outputs": [],
- "source": [
- "x = nn.Parameter(torch.arange(27).reshape(3, 3, 3).float())\n",
- "x"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "3115c7f4-94a3-4526-8373-ec524fe46c1c",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "b = nn.Parameter(torch.tensor([7, 7, 7]).float())\n",
- "b"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "30ec47b9-1636-491b-85c6-24415933fb3d",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "x[0, 0, :] = torch.tensor([7, 7, 7])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b0a87c7c-9740-412e-9ad3-a7f3f06e90bc",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python [conda env:deep]",
- "language": "python",
- "name": "conda-env-deep-py"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.11"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
- }
|