{ "cells": [ { "cell_type": "code", "execution_count": 13, "id": "e86f79a2-61c6-4f4f-99c6-28fdf93453c6", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Python version is: 3.10.11\n", "Torch version is: 1.13.1+cu117\n", "Nvidia device is: NVIDIA GeForce RTX 4090\n", "Transformers version is: 4.32.1\n", "Adapterhub not found!!!\n" ] } ], "source": [ "import torch\n", "from transformers import GPT2ForSequenceClassification, GPT2TokenizerFast, GPT2Model\n", "from utils import print_system_info\n", "\n", "print_system_info()" ] }, { "cell_type": "code", "execution_count": 14, "id": "fbdbd4cc-c433-498f-af81-2d607ee4a8c5", "metadata": { "tags": [] }, "outputs": [], "source": [ "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "MODEL_NAME = 'gpt2'" ] }, { "cell_type": "code", "execution_count": 15, "id": "f3e17bdf-4fea-48cb-8e1d-0c87c26421b0", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_NAME, add_prefix_space=True, padding_side='left')\n", "tokenizer.pad_token = tokenizer.eos_token\n", "\n", "model = GPT2ForSequenceClassification.from_pretrained(MODEL_NAME, pad_token_id=tokenizer.pad_token_id)" ] }, { "cell_type": "code", "execution_count": 16, "id": "76375dd7-0d4f-41e5-9789-7f8c410c44a0", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "trainable params: 10,752 || all params: 124,450,560 || trainable%: 0.008639575426579036\n" ] } ], "source": [ "from peft import get_peft_model, PromptTuningConfig, PromptTuningInit, TaskType\n", "\n", "peft_config = PromptTuningConfig(\n", " task_type=TaskType.SEQ_CLS,\n", " prompt_tuning_init=PromptTuningInit.TEXT,\n", " num_virtual_tokens=10,\n", " prompt_tuning_init_text=\"sentiment or value or relation of the previous text is\",\n", " tokenizer_name_or_path=MODEL_NAME\n", ")\n", "\n", "peft_model = get_peft_model(model, peft_config)\n", "peft_model.print_trainable_parameters()" ] }, { "cell_type": "code", "execution_count": 17, "id": "92dff26c-e778-4f7a-8ee3-ef9ed017b50f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset glue (/home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "43b9b0397309423889316a45de0d6eba", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/8551 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/1043 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/1063 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from datasets import load_dataset\n", "dataset = load_dataset('glue', 'cola')\n", "dataset = dataset.map(lambda x: tokenizer(x['sentence']), batched=True)\n", "dataset.set_format(type='torch', columns=[\n", " 'input_ids', 'attention_mask', 'label' # 'token_type_ids',\n", "])" ] }, { "cell_type": "code", "execution_count": null, "id": "78e5481c-d16b-417a-9c52-f90a80b77b1f", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n", "GPT2ForSequenceClassification will not detect padding tokens in `inputs_embeds`. Results may be unexpected if using padding tokens in conjunction with `inputs_embeds.`\n" ] }, { "data": { "text/html": [ "\n", "
Epoch | \n", "Training Loss | \n", "Validation Loss | \n", "Accuracy | \n", "F1-score-1 | \n", "F1-score-ma | \n", "
---|---|---|---|---|---|
1 | \n", "No log | \n", "0.659724 | \n", "0.692234 | \n", "0.817717 | \n", "0.415012 | \n", "
2 | \n", "0.825400 | \n", "0.643644 | \n", "0.692234 | \n", "0.817717 | \n", "0.415012 | \n", "
3 | \n", "0.825400 | \n", "0.634887 | \n", "0.689358 | \n", "0.815490 | \n", "0.416836 | \n", "
4 | \n", "0.673500 | \n", "0.632855 | \n", "0.691275 | \n", "0.817253 | \n", "0.411713 | \n", "
5 | \n", "0.673500 | \n", "0.632050 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
6 | \n", "0.663300 | \n", "0.630056 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
7 | \n", "0.663300 | \n", "0.627022 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
8 | \n", "0.642900 | \n", "0.625471 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
9 | \n", "0.642900 | \n", "0.624368 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
10 | \n", "0.629500 | \n", "0.625301 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
11 | \n", "0.629500 | \n", "0.623609 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
12 | \n", "0.624800 | \n", "0.622582 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
13 | \n", "0.624800 | \n", "0.620383 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
14 | \n", "0.617300 | \n", "0.619082 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
15 | \n", "0.614200 | \n", "0.620724 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
16 | \n", "0.614200 | \n", "0.621190 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
17 | \n", "0.612000 | \n", "0.620213 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
18 | \n", "0.612000 | \n", "0.617329 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
19 | \n", "0.606200 | \n", "0.617191 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
20 | \n", "0.606200 | \n", "0.620259 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
21 | \n", "0.605500 | \n", "0.618004 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
22 | \n", "0.605500 | \n", "0.617333 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
23 | \n", "0.605100 | \n", "0.619251 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
24 | \n", "0.605100 | \n", "0.620442 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
25 | \n", "0.604600 | \n", "0.617622 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
26 | \n", "0.604600 | \n", "0.617024 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
27 | \n", "0.604100 | \n", "0.616203 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
28 | \n", "0.601800 | \n", "0.621876 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
29 | \n", "0.601800 | \n", "0.619959 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
30 | \n", "0.600700 | \n", "0.621087 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
31 | \n", "0.600700 | \n", "0.619117 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
32 | \n", "0.601400 | \n", "0.619615 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
33 | \n", "0.601400 | \n", "0.618296 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
34 | \n", "0.600400 | \n", "0.616326 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
35 | \n", "0.600400 | \n", "0.620853 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
36 | \n", "0.598400 | \n", "0.615045 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
37 | \n", "0.598400 | \n", "0.616010 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
38 | \n", "0.598600 | \n", "0.616971 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
39 | \n", "0.598600 | \n", "0.617972 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
40 | \n", "0.598900 | \n", "0.619221 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
41 | \n", "0.598900 | \n", "0.618343 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
42 | \n", "0.597800 | \n", "0.616757 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
43 | \n", "0.599200 | \n", "0.616346 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
44 | \n", "0.599200 | \n", "0.616001 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
45 | \n", "0.599300 | \n", "0.617817 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
46 | \n", "0.599300 | \n", "0.618765 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
47 | \n", "0.594900 | \n", "0.617520 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
48 | \n", "0.594900 | \n", "0.615305 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
49 | \n", "0.595300 | \n", "0.615376 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
50 | \n", "0.595300 | \n", "0.614802 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
51 | \n", "0.597500 | \n", "0.617168 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
52 | \n", "0.597500 | \n", "0.614939 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
53 | \n", "0.598300 | \n", "0.617452 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
54 | \n", "0.598300 | \n", "0.622289 | \n", "0.691275 | \n", "0.817460 | \n", "0.408730 | \n", "
55 | \n", "0.596200 | \n", "0.620482 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
56 | \n", "0.594900 | \n", "0.619634 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
57 | \n", "0.594900 | \n", "0.615826 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
58 | \n", "0.594400 | \n", "0.618395 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
59 | \n", "0.594400 | \n", "0.618274 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
60 | \n", "0.596600 | \n", "0.616338 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
61 | \n", "0.596600 | \n", "0.614465 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
62 | \n", "0.594100 | \n", "0.615096 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
63 | \n", "0.594100 | \n", "0.615145 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
64 | \n", "0.594600 | \n", "0.617211 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
65 | \n", "0.594600 | \n", "0.618484 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
66 | \n", "0.593700 | \n", "0.611892 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
67 | \n", "0.593700 | \n", "0.615543 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
68 | \n", "0.595800 | \n", "0.615577 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
69 | \n", "0.595800 | \n", "0.613829 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
70 | \n", "0.593700 | \n", "0.616256 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
71 | \n", "0.593600 | \n", "0.615421 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
72 | \n", "0.593600 | \n", "0.614953 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
73 | \n", "0.594600 | \n", "0.615790 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
74 | \n", "0.594600 | \n", "0.616779 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
75 | \n", "0.593200 | \n", "0.614842 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
76 | \n", "0.593200 | \n", "0.613461 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
77 | \n", "0.595000 | \n", "0.613352 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
78 | \n", "0.595000 | \n", "0.611748 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
79 | \n", "0.591900 | \n", "0.613381 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
80 | \n", "0.591900 | \n", "0.614556 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
81 | \n", "0.592300 | \n", "0.615140 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
82 | \n", "0.592300 | \n", "0.613348 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
83 | \n", "0.592400 | \n", "0.612780 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
84 | \n", "0.592900 | \n", "0.613358 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
85 | \n", "0.592900 | \n", "0.614089 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
86 | \n", "0.591400 | \n", "0.616400 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
87 | \n", "0.591400 | \n", "0.616220 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
88 | \n", "0.591500 | \n", "0.613058 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
89 | \n", "0.591500 | \n", "0.614206 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
90 | \n", "0.592600 | \n", "0.614179 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
91 | \n", "0.592600 | \n", "0.614048 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
92 | \n", "0.592700 | \n", "0.614766 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
93 | \n", "0.592700 | \n", "0.615107 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
94 | \n", "0.590500 | \n", "0.613957 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
95 | \n", "0.590500 | \n", "0.612804 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
96 | \n", "0.589600 | \n", "0.610184 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
97 | \n", "0.589600 | \n", "0.614211 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
98 | \n", "0.593300 | \n", "0.612123 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
99 | \n", "0.591200 | \n", "0.611470 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
100 | \n", "0.591200 | \n", "0.612707 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
101 | \n", "0.593300 | \n", "0.612491 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
102 | \n", "0.593300 | \n", "0.614426 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
103 | \n", "0.591700 | \n", "0.614888 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
104 | \n", "0.591700 | \n", "0.613895 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
105 | \n", "0.589600 | \n", "0.615002 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
106 | \n", "0.589600 | \n", "0.612817 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
107 | \n", "0.589800 | \n", "0.613109 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
108 | \n", "0.589800 | \n", "0.611671 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
109 | \n", "0.591900 | \n", "0.612564 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
110 | \n", "0.591900 | \n", "0.612609 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
111 | \n", "0.589900 | \n", "0.615618 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
112 | \n", "0.590200 | \n", "0.612373 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
113 | \n", "0.590200 | \n", "0.614096 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
114 | \n", "0.591000 | \n", "0.612907 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
115 | \n", "0.591000 | \n", "0.613617 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
116 | \n", "0.590200 | \n", "0.612939 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
117 | \n", "0.590200 | \n", "0.613484 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
118 | \n", "0.591000 | \n", "0.613280 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
119 | \n", "0.591000 | \n", "0.614021 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
120 | \n", "0.589800 | \n", "0.610681 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
121 | \n", "0.589800 | \n", "0.611854 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
122 | \n", "0.592400 | \n", "0.612633 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
123 | \n", "0.592400 | \n", "0.610687 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
124 | \n", "0.591700 | \n", "0.611704 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
125 | \n", "0.589300 | \n", "0.612493 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
126 | \n", "0.589300 | \n", "0.613478 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
127 | \n", "0.589400 | \n", "0.612765 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
128 | \n", "0.589400 | \n", "0.613636 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
129 | \n", "0.589000 | \n", "0.612400 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
130 | \n", "0.589000 | \n", "0.614144 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
131 | \n", "0.590200 | \n", "0.613008 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
132 | \n", "0.590200 | \n", "0.612329 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
133 | \n", "0.592500 | \n", "0.613047 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
134 | \n", "0.592500 | \n", "0.612284 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
135 | \n", "0.592000 | \n", "0.612038 | \n", "0.692234 | \n", "0.817924 | \n", "0.412058 | \n", "
"
],
"text/plain": [
"