123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113 |
- {
- "cells": [
- {
- "cell_type": "markdown",
- "id": "896de91a-4ab9-40f5-a3c1-914535b6e0a7",
- "metadata": {},
- "source": [
- "# intro"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "5f17aae6-73f5-4793-95a3-09147ea89e04",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Python version is: 3.10.11\n",
- "Scikit-learn version is: 1.2.2\n",
- "Torch version is: 1.13.1+cu117\n",
- "Nvidia device is: NVIDIA GeForce RTX 4090\n",
- "Transformers version is: 4.32.1\n",
- "Adapterhub not found!!!\n"
- ]
- }
- ],
- "source": [
- "from typing import Optional\n",
- "\n",
- "import numpy as np\n",
- "from tqdm.notebook import tqdm\n",
- "\n",
- "import torch\n",
- "import torch.nn as nn\n",
- "from transformers import T5TokenizerFast, T5ForConditionalGeneration\n",
- "\n",
- "from _utils import print_system_info, generate_dataloader\n",
- "from _datasets import AutoLoad\n",
- "from _mydelta import T5Wrapper, auto_freeze\n",
- "from _trainer import train_loop, valid_loop\n",
- "\n",
- "print_system_info()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "fb5ef784-fef0-4b7b-98e7-ec5d3575a9a8",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "from types import SimpleNamespace\n",
- "config = SimpleNamespace(\n",
- " model_name='google/t5-base-lm-adapt',\n",
- " n_tokens=30,\n",
- " n_layers=6,\n",
- " random_seed=42,\n",
- " task=['glue:cola'],\n",
- " hot_modules=['sadcl'],\n",
- " train_batch_size=32,\n",
- " valid_batch_size=32,\n",
- " balancify_sample=False,\n",
- " learning_rate=0.01,\n",
- " num_epochs=200\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "d3802d01-7c5a-4c11-beaf-f683a2fb9d80",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
- "\n",
- "np.random.seed(config.random_seed)\n",
- "slected_tokens = torch.from_numpy(np.random.randint(0, 32128, size=(config.n_tokens,)))"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "1e785d49-beca-4333-986e-b198bbaadf7d",
- "metadata": {},
- "source": [
- "# load model and date"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "afcc6244-978a-425a-9fa9-8b11dd0df8ba",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "model = T5ForConditionalGeneration.from_pretrained(config.model_name)\n",
- "tokenizer = T5TokenizerFast.from_pretrained(config.model_name, model_max_length=2048)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "id": "894a8474-e2e1-4f9d-b9ab-58d911808ec0",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "encoder.block.6.soft_prompt.sadcl_learned_embedding\n",
- "encoder.block.7.soft_prompt.sadcl_learned_embedding\n",
- "encoder.block.8.soft_prompt.sadcl_learned_embedding\n",
- "encoder.block.9.soft_prompt.sadcl_learned_embedding\n",
- "encoder.block.10.soft_prompt.sadcl_learned_embedding\n",
- "encoder.block.11.soft_prompt.sadcl_learned_embedding\n"
- ]
- }
- ],
- "source": [
- "delta_module = T5Wrapper.mutate(\n",
- " model=model,\n",
- " config=config,\n",
- " slected_tokens=slected_tokens\n",
- ")\n",
- "auto_freeze(model, config.hot_modules, verbose=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "id": "9453d3cc-c04c-4a27-83aa-eaac3e49c14e",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "shared.weight\n",
- "encoder.block.0.layer.0.SelfAttention.q.weight\n",
- "encoder.block.0.layer.0.SelfAttention.k.weight\n",
- "encoder.block.0.layer.0.SelfAttention.v.weight\n",
- "encoder.block.0.layer.0.SelfAttention.o.weight\n",
- "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight\n",
- "encoder.block.0.layer.0.layer_norm.weight\n",
- "encoder.block.0.layer.1.DenseReluDense.wi_0.weight\n",
- "encoder.block.0.layer.1.DenseReluDense.wi_1.weight\n",
- "encoder.block.0.layer.1.DenseReluDense.wo.weight\n",
- "encoder.block.0.layer.1.layer_norm.weight\n",
- "encoder.block.1.layer.0.SelfAttention.q.weight\n",
- "encoder.block.1.layer.0.SelfAttention.k.weight\n",
- "encoder.block.1.layer.0.SelfAttention.v.weight\n",
- "encoder.block.1.layer.0.SelfAttention.o.weight\n",
- "encoder.block.1.layer.0.layer_norm.weight\n",
- "encoder.block.1.layer.1.DenseReluDense.wi_0.weight\n",
- "encoder.block.1.layer.1.DenseReluDense.wi_1.weight\n",
- "encoder.block.1.layer.1.DenseReluDense.wo.weight\n",
- "encoder.block.1.layer.1.layer_norm.weight\n",
- "encoder.block.2.layer.0.SelfAttention.q.weight\n",
- "encoder.block.2.layer.0.SelfAttention.k.weight\n",
- "encoder.block.2.layer.0.SelfAttention.v.weight\n",
- "encoder.block.2.layer.0.SelfAttention.o.weight\n",
- "encoder.block.2.layer.0.layer_norm.weight\n",
- "encoder.block.2.layer.1.DenseReluDense.wi_0.weight\n",
- "encoder.block.2.layer.1.DenseReluDense.wi_1.weight\n",
- "encoder.block.2.layer.1.DenseReluDense.wo.weight\n",
- "encoder.block.2.layer.1.layer_norm.weight\n",
- "encoder.block.3.layer.0.SelfAttention.q.weight\n",
- "encoder.block.3.layer.0.SelfAttention.k.weight\n",
- "encoder.block.3.layer.0.SelfAttention.v.weight\n",
- "encoder.block.3.layer.0.SelfAttention.o.weight\n",
- "encoder.block.3.layer.0.layer_norm.weight\n",
- "encoder.block.3.layer.1.DenseReluDense.wi_0.weight\n",
- "encoder.block.3.layer.1.DenseReluDense.wi_1.weight\n",
- "encoder.block.3.layer.1.DenseReluDense.wo.weight\n",
- "encoder.block.3.layer.1.layer_norm.weight\n",
- "encoder.block.4.layer.0.SelfAttention.q.weight\n",
- "encoder.block.4.layer.0.SelfAttention.k.weight\n",
- "encoder.block.4.layer.0.SelfAttention.v.weight\n",
- "encoder.block.4.layer.0.SelfAttention.o.weight\n",
- "encoder.block.4.layer.0.layer_norm.weight\n",
- "encoder.block.4.layer.1.DenseReluDense.wi_0.weight\n",
- "encoder.block.4.layer.1.DenseReluDense.wi_1.weight\n",
- "encoder.block.4.layer.1.DenseReluDense.wo.weight\n",
- "encoder.block.4.layer.1.layer_norm.weight\n",
- "encoder.block.5.layer.0.SelfAttention.q.weight\n",
- "encoder.block.5.layer.0.SelfAttention.k.weight\n",
- "encoder.block.5.layer.0.SelfAttention.v.weight\n",
- "encoder.block.5.layer.0.SelfAttention.o.weight\n",
- "encoder.block.5.layer.0.layer_norm.weight\n",
- "encoder.block.5.layer.1.DenseReluDense.wi_0.weight\n",
- "encoder.block.5.layer.1.DenseReluDense.wi_1.weight\n",
- "encoder.block.5.layer.1.DenseReluDense.wo.weight\n",
- "encoder.block.5.layer.1.layer_norm.weight\n",
- "encoder.block.6.original_module.layer.0.SelfAttention.q.weight\n",
- "encoder.block.6.original_module.layer.0.SelfAttention.k.weight\n",
- "encoder.block.6.original_module.layer.0.SelfAttention.v.weight\n",
- "encoder.block.6.original_module.layer.0.SelfAttention.o.weight\n",
- "encoder.block.6.original_module.layer.0.layer_norm.weight\n",
- "encoder.block.6.original_module.layer.1.DenseReluDense.wi_0.weight\n",
- "encoder.block.6.original_module.layer.1.DenseReluDense.wi_1.weight\n",
- "encoder.block.6.original_module.layer.1.DenseReluDense.wo.weight\n",
- "encoder.block.6.original_module.layer.1.layer_norm.weight\n",
- "encoder.block.6.soft_prompt.sadcl_learned_embedding\n",
- "encoder.block.7.original_module.layer.0.SelfAttention.q.weight\n",
- "encoder.block.7.original_module.layer.0.SelfAttention.k.weight\n",
- "encoder.block.7.original_module.layer.0.SelfAttention.v.weight\n",
- "encoder.block.7.original_module.layer.0.SelfAttention.o.weight\n",
- "encoder.block.7.original_module.layer.0.layer_norm.weight\n",
- "encoder.block.7.original_module.layer.1.DenseReluDense.wi_0.weight\n",
- "encoder.block.7.original_module.layer.1.DenseReluDense.wi_1.weight\n",
- "encoder.block.7.original_module.layer.1.DenseReluDense.wo.weight\n",
- "encoder.block.7.original_module.layer.1.layer_norm.weight\n",
- "encoder.block.7.soft_prompt.sadcl_learned_embedding\n",
- "encoder.block.8.original_module.layer.0.SelfAttention.q.weight\n",
- "encoder.block.8.original_module.layer.0.SelfAttention.k.weight\n",
- "encoder.block.8.original_module.layer.0.SelfAttention.v.weight\n",
- "encoder.block.8.original_module.layer.0.SelfAttention.o.weight\n",
- "encoder.block.8.original_module.layer.0.layer_norm.weight\n",
- "encoder.block.8.original_module.layer.1.DenseReluDense.wi_0.weight\n",
- "encoder.block.8.original_module.layer.1.DenseReluDense.wi_1.weight\n",
- "encoder.block.8.original_module.layer.1.DenseReluDense.wo.weight\n",
- "encoder.block.8.original_module.layer.1.layer_norm.weight\n",
- "encoder.block.8.soft_prompt.sadcl_learned_embedding\n",
- "encoder.block.9.original_module.layer.0.SelfAttention.q.weight\n",
- "encoder.block.9.original_module.layer.0.SelfAttention.k.weight\n",
- "encoder.block.9.original_module.layer.0.SelfAttention.v.weight\n",
- "encoder.block.9.original_module.layer.0.SelfAttention.o.weight\n",
- "encoder.block.9.original_module.layer.0.layer_norm.weight\n",
- "encoder.block.9.original_module.layer.1.DenseReluDense.wi_0.weight\n",
- "encoder.block.9.original_module.layer.1.DenseReluDense.wi_1.weight\n",
- "encoder.block.9.original_module.layer.1.DenseReluDense.wo.weight\n",
- "encoder.block.9.original_module.layer.1.layer_norm.weight\n",
- "encoder.block.9.soft_prompt.sadcl_learned_embedding\n",
- "encoder.block.10.original_module.layer.0.SelfAttention.q.weight\n",
- "encoder.block.10.original_module.layer.0.SelfAttention.k.weight\n",
- "encoder.block.10.original_module.layer.0.SelfAttention.v.weight\n",
- "encoder.block.10.original_module.layer.0.SelfAttention.o.weight\n",
- "encoder.block.10.original_module.layer.0.layer_norm.weight\n",
- "encoder.block.10.original_module.layer.1.DenseReluDense.wi_0.weight\n",
- "encoder.block.10.original_module.layer.1.DenseReluDense.wi_1.weight\n",
- "encoder.block.10.original_module.layer.1.DenseReluDense.wo.weight\n",
- "encoder.block.10.original_module.layer.1.layer_norm.weight\n",
- "encoder.block.10.soft_prompt.sadcl_learned_embedding\n",
- "encoder.block.11.original_module.layer.0.SelfAttention.q.weight\n",
- "encoder.block.11.original_module.layer.0.SelfAttention.k.weight\n",
- "encoder.block.11.original_module.layer.0.SelfAttention.v.weight\n",
- "encoder.block.11.original_module.layer.0.SelfAttention.o.weight\n",
- "encoder.block.11.original_module.layer.0.layer_norm.weight\n",
- "encoder.block.11.original_module.layer.1.DenseReluDense.wi_0.weight\n",
- "encoder.block.11.original_module.layer.1.DenseReluDense.wi_1.weight\n",
- "encoder.block.11.original_module.layer.1.DenseReluDense.wo.weight\n",
- "encoder.block.11.original_module.layer.1.layer_norm.weight\n",
- "encoder.block.11.soft_prompt.sadcl_learned_embedding\n",
- "encoder.final_layer_norm.weight\n",
- "decoder.block.0.layer.0.SelfAttention.q.weight\n",
- "decoder.block.0.layer.0.SelfAttention.k.weight\n",
- "decoder.block.0.layer.0.SelfAttention.v.weight\n",
- "decoder.block.0.layer.0.SelfAttention.o.weight\n",
- "decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight\n",
- "decoder.block.0.layer.0.layer_norm.weight\n",
- "decoder.block.0.layer.1.EncDecAttention.q.weight\n",
- "decoder.block.0.layer.1.EncDecAttention.k.weight\n",
- "decoder.block.0.layer.1.EncDecAttention.v.weight\n",
- "decoder.block.0.layer.1.EncDecAttention.o.weight\n",
- "decoder.block.0.layer.1.layer_norm.weight\n",
- "decoder.block.0.layer.2.DenseReluDense.wi_0.weight\n",
- "decoder.block.0.layer.2.DenseReluDense.wi_1.weight\n",
- "decoder.block.0.layer.2.DenseReluDense.wo.weight\n",
- "decoder.block.0.layer.2.layer_norm.weight\n",
- "decoder.block.1.layer.0.SelfAttention.q.weight\n",
- "decoder.block.1.layer.0.SelfAttention.k.weight\n",
- "decoder.block.1.layer.0.SelfAttention.v.weight\n",
- "decoder.block.1.layer.0.SelfAttention.o.weight\n",
- "decoder.block.1.layer.0.layer_norm.weight\n",
- "decoder.block.1.layer.1.EncDecAttention.q.weight\n",
- "decoder.block.1.layer.1.EncDecAttention.k.weight\n",
- "decoder.block.1.layer.1.EncDecAttention.v.weight\n",
- "decoder.block.1.layer.1.EncDecAttention.o.weight\n",
- "decoder.block.1.layer.1.layer_norm.weight\n",
- "decoder.block.1.layer.2.DenseReluDense.wi_0.weight\n",
- "decoder.block.1.layer.2.DenseReluDense.wi_1.weight\n",
- "decoder.block.1.layer.2.DenseReluDense.wo.weight\n",
- "decoder.block.1.layer.2.layer_norm.weight\n",
- "decoder.block.2.layer.0.SelfAttention.q.weight\n",
- "decoder.block.2.layer.0.SelfAttention.k.weight\n",
- "decoder.block.2.layer.0.SelfAttention.v.weight\n",
- "decoder.block.2.layer.0.SelfAttention.o.weight\n",
- "decoder.block.2.layer.0.layer_norm.weight\n",
- "decoder.block.2.layer.1.EncDecAttention.q.weight\n",
- "decoder.block.2.layer.1.EncDecAttention.k.weight\n",
- "decoder.block.2.layer.1.EncDecAttention.v.weight\n",
- "decoder.block.2.layer.1.EncDecAttention.o.weight\n",
- "decoder.block.2.layer.1.layer_norm.weight\n",
- "decoder.block.2.layer.2.DenseReluDense.wi_0.weight\n",
- "decoder.block.2.layer.2.DenseReluDense.wi_1.weight\n",
- "decoder.block.2.layer.2.DenseReluDense.wo.weight\n",
- "decoder.block.2.layer.2.layer_norm.weight\n",
- "decoder.block.3.layer.0.SelfAttention.q.weight\n",
- "decoder.block.3.layer.0.SelfAttention.k.weight\n",
- "decoder.block.3.layer.0.SelfAttention.v.weight\n",
- "decoder.block.3.layer.0.SelfAttention.o.weight\n",
- "decoder.block.3.layer.0.layer_norm.weight\n",
- "decoder.block.3.layer.1.EncDecAttention.q.weight\n",
- "decoder.block.3.layer.1.EncDecAttention.k.weight\n",
- "decoder.block.3.layer.1.EncDecAttention.v.weight\n",
- "decoder.block.3.layer.1.EncDecAttention.o.weight\n",
- "decoder.block.3.layer.1.layer_norm.weight\n",
- "decoder.block.3.layer.2.DenseReluDense.wi_0.weight\n",
- "decoder.block.3.layer.2.DenseReluDense.wi_1.weight\n",
- "decoder.block.3.layer.2.DenseReluDense.wo.weight\n",
- "decoder.block.3.layer.2.layer_norm.weight\n",
- "decoder.block.4.layer.0.SelfAttention.q.weight\n",
- "decoder.block.4.layer.0.SelfAttention.k.weight\n",
- "decoder.block.4.layer.0.SelfAttention.v.weight\n",
- "decoder.block.4.layer.0.SelfAttention.o.weight\n",
- "decoder.block.4.layer.0.layer_norm.weight\n",
- "decoder.block.4.layer.1.EncDecAttention.q.weight\n",
- "decoder.block.4.layer.1.EncDecAttention.k.weight\n",
- "decoder.block.4.layer.1.EncDecAttention.v.weight\n",
- "decoder.block.4.layer.1.EncDecAttention.o.weight\n",
- "decoder.block.4.layer.1.layer_norm.weight\n",
- "decoder.block.4.layer.2.DenseReluDense.wi_0.weight\n",
- "decoder.block.4.layer.2.DenseReluDense.wi_1.weight\n",
- "decoder.block.4.layer.2.DenseReluDense.wo.weight\n",
- "decoder.block.4.layer.2.layer_norm.weight\n",
- "decoder.block.5.layer.0.SelfAttention.q.weight\n",
- "decoder.block.5.layer.0.SelfAttention.k.weight\n",
- "decoder.block.5.layer.0.SelfAttention.v.weight\n",
- "decoder.block.5.layer.0.SelfAttention.o.weight\n",
- "decoder.block.5.layer.0.layer_norm.weight\n",
- "decoder.block.5.layer.1.EncDecAttention.q.weight\n",
- "decoder.block.5.layer.1.EncDecAttention.k.weight\n",
- "decoder.block.5.layer.1.EncDecAttention.v.weight\n",
- "decoder.block.5.layer.1.EncDecAttention.o.weight\n",
- "decoder.block.5.layer.1.layer_norm.weight\n",
- "decoder.block.5.layer.2.DenseReluDense.wi_0.weight\n",
- "decoder.block.5.layer.2.DenseReluDense.wi_1.weight\n",
- "decoder.block.5.layer.2.DenseReluDense.wo.weight\n",
- "decoder.block.5.layer.2.layer_norm.weight\n",
- "decoder.block.6.layer.0.SelfAttention.q.weight\n",
- "decoder.block.6.layer.0.SelfAttention.k.weight\n",
- "decoder.block.6.layer.0.SelfAttention.v.weight\n",
- "decoder.block.6.layer.0.SelfAttention.o.weight\n",
- "decoder.block.6.layer.0.layer_norm.weight\n",
- "decoder.block.6.layer.1.EncDecAttention.q.weight\n",
- "decoder.block.6.layer.1.EncDecAttention.k.weight\n",
- "decoder.block.6.layer.1.EncDecAttention.v.weight\n",
- "decoder.block.6.layer.1.EncDecAttention.o.weight\n",
- "decoder.block.6.layer.1.layer_norm.weight\n",
- "decoder.block.6.layer.2.DenseReluDense.wi_0.weight\n",
- "decoder.block.6.layer.2.DenseReluDense.wi_1.weight\n",
- "decoder.block.6.layer.2.DenseReluDense.wo.weight\n",
- "decoder.block.6.layer.2.layer_norm.weight\n",
- "decoder.block.7.layer.0.SelfAttention.q.weight\n",
- "decoder.block.7.layer.0.SelfAttention.k.weight\n",
- "decoder.block.7.layer.0.SelfAttention.v.weight\n",
- "decoder.block.7.layer.0.SelfAttention.o.weight\n",
- "decoder.block.7.layer.0.layer_norm.weight\n",
- "decoder.block.7.layer.1.EncDecAttention.q.weight\n",
- "decoder.block.7.layer.1.EncDecAttention.k.weight\n",
- "decoder.block.7.layer.1.EncDecAttention.v.weight\n",
- "decoder.block.7.layer.1.EncDecAttention.o.weight\n",
- "decoder.block.7.layer.1.layer_norm.weight\n",
- "decoder.block.7.layer.2.DenseReluDense.wi_0.weight\n",
- "decoder.block.7.layer.2.DenseReluDense.wi_1.weight\n",
- "decoder.block.7.layer.2.DenseReluDense.wo.weight\n",
- "decoder.block.7.layer.2.layer_norm.weight\n",
- "decoder.block.8.layer.0.SelfAttention.q.weight\n",
- "decoder.block.8.layer.0.SelfAttention.k.weight\n",
- "decoder.block.8.layer.0.SelfAttention.v.weight\n",
- "decoder.block.8.layer.0.SelfAttention.o.weight\n",
- "decoder.block.8.layer.0.layer_norm.weight\n",
- "decoder.block.8.layer.1.EncDecAttention.q.weight\n",
- "decoder.block.8.layer.1.EncDecAttention.k.weight\n",
- "decoder.block.8.layer.1.EncDecAttention.v.weight\n",
- "decoder.block.8.layer.1.EncDecAttention.o.weight\n",
- "decoder.block.8.layer.1.layer_norm.weight\n",
- "decoder.block.8.layer.2.DenseReluDense.wi_0.weight\n",
- "decoder.block.8.layer.2.DenseReluDense.wi_1.weight\n",
- "decoder.block.8.layer.2.DenseReluDense.wo.weight\n",
- "decoder.block.8.layer.2.layer_norm.weight\n",
- "decoder.block.9.layer.0.SelfAttention.q.weight\n",
- "decoder.block.9.layer.0.SelfAttention.k.weight\n",
- "decoder.block.9.layer.0.SelfAttention.v.weight\n",
- "decoder.block.9.layer.0.SelfAttention.o.weight\n",
- "decoder.block.9.layer.0.layer_norm.weight\n",
- "decoder.block.9.layer.1.EncDecAttention.q.weight\n",
- "decoder.block.9.layer.1.EncDecAttention.k.weight\n",
- "decoder.block.9.layer.1.EncDecAttention.v.weight\n",
- "decoder.block.9.layer.1.EncDecAttention.o.weight\n",
- "decoder.block.9.layer.1.layer_norm.weight\n",
- "decoder.block.9.layer.2.DenseReluDense.wi_0.weight\n",
- "decoder.block.9.layer.2.DenseReluDense.wi_1.weight\n",
- "decoder.block.9.layer.2.DenseReluDense.wo.weight\n",
- "decoder.block.9.layer.2.layer_norm.weight\n",
- "decoder.block.10.layer.0.SelfAttention.q.weight\n",
- "decoder.block.10.layer.0.SelfAttention.k.weight\n",
- "decoder.block.10.layer.0.SelfAttention.v.weight\n",
- "decoder.block.10.layer.0.SelfAttention.o.weight\n",
- "decoder.block.10.layer.0.layer_norm.weight\n",
- "decoder.block.10.layer.1.EncDecAttention.q.weight\n",
- "decoder.block.10.layer.1.EncDecAttention.k.weight\n",
- "decoder.block.10.layer.1.EncDecAttention.v.weight\n",
- "decoder.block.10.layer.1.EncDecAttention.o.weight\n",
- "decoder.block.10.layer.1.layer_norm.weight\n",
- "decoder.block.10.layer.2.DenseReluDense.wi_0.weight\n",
- "decoder.block.10.layer.2.DenseReluDense.wi_1.weight\n",
- "decoder.block.10.layer.2.DenseReluDense.wo.weight\n",
- "decoder.block.10.layer.2.layer_norm.weight\n",
- "decoder.block.11.layer.0.SelfAttention.q.weight\n",
- "decoder.block.11.layer.0.SelfAttention.k.weight\n",
- "decoder.block.11.layer.0.SelfAttention.v.weight\n",
- "decoder.block.11.layer.0.SelfAttention.o.weight\n",
- "decoder.block.11.layer.0.layer_norm.weight\n",
- "decoder.block.11.layer.1.EncDecAttention.q.weight\n",
- "decoder.block.11.layer.1.EncDecAttention.k.weight\n",
- "decoder.block.11.layer.1.EncDecAttention.v.weight\n",
- "decoder.block.11.layer.1.EncDecAttention.o.weight\n",
- "decoder.block.11.layer.1.layer_norm.weight\n",
- "decoder.block.11.layer.2.DenseReluDense.wi_0.weight\n",
- "decoder.block.11.layer.2.DenseReluDense.wi_1.weight\n",
- "decoder.block.11.layer.2.DenseReluDense.wo.weight\n",
- "decoder.block.11.layer.2.layer_norm.weight\n",
- "decoder.final_layer_norm.weight\n",
- "lm_head.weight\n"
- ]
- }
- ],
- "source": [
- "for x, y in model.named_parameters():\n",
- " print(x)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "id": "4a34e1f1-1fc1-4577-a87d-efeac33894b1",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Found cached dataset glue (/home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "22d7491179634c75ab8a5c70e9e4188f",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- " 0%| | 0/3 [00:00<?, ?it/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Loading cached processed dataset at /home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-63df8ebe4567b55a.arrow\n",
- "Loading cached processed dataset at /home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-bb3872c77bcda3cd.arrow\n"
- ]
- }
- ],
- "source": [
- "data_loader = AutoLoad(tokenizer)\n",
- "dataset = data_loader.get_and_map(config.task[0])\n",
- "train_loader, valid_loader = generate_dataloader(tokenizer, dataset['train'], dataset['valid'], config)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "id": "cf5aea38-4866-4026-b6d4-a8e8b50153b0",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "# model(**next(iter(train_loader))).loss.backward()\n",
- "# for i in range(6, 12):\n",
- "# o = model.encoder.block[i].soft_prompt.sadcl_learned_embedding.grad.abs().sum().item()\n",
- "# print(i, o)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "6281dbae-3023-4e95-82c9-c9d818c37622",
- "metadata": {},
- "source": [
- "# train model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "id": "dd92aff9-e4cb-4b1b-aece-0a7eee27e0e4",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\n",
- "KeyboardInterrupt\n",
- "\n"
- ]
- }
- ],
- "source": [
- "import wandb\n",
- "wandb.init(\n",
- " # set the wandb project where this run will be logged\n",
- " project=\"my-awesome-project\",\n",
- " # track hyperparameters and run metadata\n",
- " config=config.__dict__\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "id": "74f04c24-2298-4152-abde-c1ee6a0ea739",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " 0%| | 0/268 [00:00<?, ?it/s]You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
- "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:08<00:00, 32.56it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{'train_loss': 8.963883996009827, 'valid_loss': 6.972279635342685, 'valid_accuracy': 0.0, 'valid_f1-score-1': 0.0, 'valid_f1-score-ma': 0.0}\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 34.00it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{'train_loss': 7.36324492141382, 'valid_loss': 5.521347826177424, 'valid_accuracy': 0.0, 'valid_f1-score-1': 0.0, 'valid_f1-score-ma': 0.0}\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 33.97it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{'train_loss': 6.192735992260833, 'valid_loss': 4.384567484711155, 'valid_accuracy': 0.0, 'valid_f1-score-1': 0.0, 'valid_f1-score-ma': 0.0}\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 33.89it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{'train_loss': 5.118913385405469, 'valid_loss': 3.335551644816543, 'valid_accuracy': 0.05465004793863854, 'valid_f1-score-1': 0.14091470951792337, 'valid_f1-score-ma': 0.009394313967861558}\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 33.93it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{'train_loss': 4.148920764674002, 'valid_loss': 2.2682720783985024, 'valid_accuracy': 0.174496644295302, 'valid_f1-score-1': 0.36804853387259856, 'valid_f1-score-ma': 0.02164991375721168}\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 33.94it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{'train_loss': 3.1643025679374808, 'valid_loss': 1.2784492048350247, 'valid_accuracy': 0.5148609779482263, 'valid_f1-score-1': 0.7208053691275169, 'valid_f1-score-ma': 0.05544656685596284}\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 33.87it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{'train_loss': 2.235221519843856, 'valid_loss': 0.6245457141688375, 'valid_accuracy': 0.6625119846596357, 'valid_f1-score-1': 0.7915151515151515, 'valid_f1-score-ma': 0.09733333333333333}\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 33.91it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{'train_loss': 1.5051592252592543, 'valid_loss': 0.4341738431742697, 'valid_accuracy': 0.6768935762224353, 'valid_f1-score-1': 0.8077147866744595, 'valid_f1-score-ma': 0.0810680363745307}\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 33.99it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{'train_loss': 1.0515665002723238, 'valid_loss': 0.3996329452052261, 'valid_accuracy': 0.6826462128475551, 'valid_f1-score-1': 0.8151116199198626, 'valid_f1-score-ma': 0.08151116199198626}\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 34.04it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{'train_loss': 0.8272433252032123, 'valid_loss': 0.3832718174565922, 'valid_accuracy': 0.6855225311601151, 'valid_f1-score-1': 0.8162100456621004, 'valid_f1-score-ma': 0.11660143509458577}\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 34.00it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{'train_loss': 0.7063313028705653, 'valid_loss': 0.36713372216080176, 'valid_accuracy': 0.6874400767018217, 'valid_f1-score-1': 0.8175598631698974, 'valid_f1-score-ma': 0.16351197263397949}\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 33.92it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{'train_loss': 0.6392747724234168, 'valid_loss': 0.36563855906327564, 'valid_accuracy': 0.6874400767018217, 'valid_f1-score-1': 0.8170940170940172, 'valid_f1-score-ma': 0.2042735042735043}\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [00:07<00:00, 34.00it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{'train_loss': 0.5930970842713741, 'valid_loss': 0.3603471038919507, 'valid_accuracy': 0.6874400767018217, 'valid_f1-score-1': 0.8170940170940172, 'valid_f1-score-ma': 0.2042735042735043}\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " 40%|█████████████████████████████████████▏ | 106/268 [00:03<00:04, 33.89it/s]\n",
- "\n",
- "KeyboardInterrupt\n",
- "\n"
- ]
- }
- ],
- "source": [
- "optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)\n",
- "\n",
- "model.to(DEVICE)\n",
- "\n",
- "for epoch in range(config.num_epochs):\n",
- " train_out = train_loop(model=model, loader=train_loader, optimizer=optimizer)\n",
- " valid_out = valid_loop(model=model, loader=valid_loader)\n",
- " wandb.log({\n",
- " **train_out,\n",
- " **valid_out\n",
- " })\n",
- " \n",
- "wandb.finish()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4368eb4c-fb7b-41bf-89e3-bf20cdfec967",
- "metadata": {},
- "outputs": [],
- "source": [
- "# pip uninstall bitsandbytes -y"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "0de2e02c-c5fb-4d13-81fd-5ecb53d42b6c",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "dataset['train'].set_format(columns=['label', 'labels'])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "687f7994-e875-4f2e-b151-89460bf78eea",
- "metadata": {},
- "outputs": [],
- "source": [
- "dataset['train'][0:100]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "5493b857-49e0-4963-b220-7f00422b7511",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "from datasets import load_dataset\n",
- "x = load_dataset(\"glue\", \"sst2\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "d055f9b6-294e-4f53-9941-21ecc040e92b",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "Counter(x['train']['label'])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "abf72943-aa98-4543-beec-942f6f601b89",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "g = x['train']\n",
- "l = g.features['label']"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "919fb5a0-fc9c-4fcc-9fe6-21cff8960b51",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "l.int2str(1)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b286f811-82bc-4f59-914f-e1c5cd5cd1ef",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "29780 / (29780 + 37569)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "id": "b581a87a-120d-4f7a-a8f4-b39c6a6d1843",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "from types import SimpleNamespace\n",
- "config = SimpleNamespace(\n",
- " model_name='google/t5-base-lm-adapt',\n",
- " peft_params={\n",
- " 'n_tokens': 30,\n",
- " 'n_layers': 6\n",
- " },\n",
- " random_seed=42,\n",
- " task=['glue:cola'],\n",
- " hot_modules=['sadcl'],\n",
- " train_batch_size=32,\n",
- " valid_batch_size=32,\n",
- " balancify_sample=False,\n",
- " learning_rate=0.01,\n",
- " num_epochs=50\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "id": "e7dbb2d9-d545-48e1-a0ac-6d79258a393b",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{\"model_name\": \"google/t5-base-lm-adapt\", \"peft_params\": {\"n_tokens\": 30, \"n_layers\": 6}, \"random_seed\": 42, \"task\": [\"glue:cola\"], \"hot_modules\": [\"sadcl\"], \"train_batch_size\": 32, \"valid_batch_size\": 32, \"balancify_sample\": false, \"learning_rate\": 0.01, \"num_epochs\": 50}\n"
- ]
- }
- ],
- "source": [
- "import json\n",
- "print(json.dumps(config.__dict__))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "d18551c3-68e8-4ee0-8936-65bdec51f4eb",
- "metadata": {},
- "outputs": [],
- "source": [
- "from transformers import T5TokenizerFast, T5ForConditionalGeneration\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "ddac3321-27f4-4b89-aab6-f91ae8bbc86a",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "tokenizer = T5TokenizerFast.from_pretrained(\"google/t5-large-lm-adapt\", model_max_length=2048)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "7456f182-5d9d-44da-b36d-2ec4052fbaf6",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "import numpy as np"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "c476e007-c4fe-4eb3-939c-25d0b0add711",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "array([23830, 2611, 19567, 10149, 20142, 6737, 26963, 6788, 3871,\n",
- " 28330, 724, 7406, 11474, 18399, 2289, 25511, 25299, 23308,\n",
- " 25412, 370, 32091, 28829, 6148, 29154, 30369, 12979, 8560,\n",
- " 6872, 23228, 8051, 19537, 3741, 22206, 20744, 17051, 27857,\n",
- " 3830, 15329, 21857, 8296, 10768, 7854, 5710, 5405, 27449,\n",
- " 11528, 8599, 12695, 15427, 23726, 389, 3231, 15270, 26906,\n",
- " 23085, 15113, 31792, 8766, 9814, 15904, 6320, 23716, 19682,\n",
- " 2690, 30766, 21262, 11415, 2523, 26538, 3647, 13971, 21655,\n",
- " 287, 19479, 28945, 25134, 17673, 9792, 17556, 31293, 25795,\n",
- " 2753, 8955, 21049, 28409, 24281, 3610, 26070, 2189, 25611,\n",
- " 9641, 23766, 29195, 779, 18660, 10731, 19732, 1664, 2176,\n",
- " 2254])"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "np.random.randint(0, tokenizer.vocab_size, size=(100,))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "e83e5c34-0860-4e2a-9b2a-016f37b35003",
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "0f032c7f-ddd4-4350-81c5-9296c8376d8c",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "w = torch.load('best.pt')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "id": "112c06a2-1295-4989-9ce6-5f6204e809ef",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([[ 0.5470, -0.8095, -1.4617, ..., 0.8100, -1.1746, 0.5768],\n",
- " [-0.9284, -0.6230, -2.4697, ..., 0.3947, -0.5427, -0.3088],\n",
- " [ 1.4407, 0.8760, 0.2499, ..., 0.1860, -0.3176, 2.0041],\n",
- " ...,\n",
- " [ 0.8714, 1.1013, -2.7711, ..., -0.2819, 0.7087, -0.6164],\n",
- " [ 0.8026, -0.7928, -0.8946, ..., -1.5204, 1.0164, -1.3527],\n",
- " [ 0.4650, -2.1778, 0.0213, ..., -1.1430, -2.3895, -0.0235]],\n",
- " device='cuda:0')"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "w.pop('sadcl_learned_embedding')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "id": "4055f80b-ea2f-44db-bf32-98ea3ffe9597",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "OrderedDict([('sadcl_mlp.0.weight',\n",
- " tensor([[ 0.1171, -0.7743, 0.5095, ..., -1.0615, 1.5754, 0.7036],\n",
- " [-0.2675, 0.0969, 0.0543, ..., 0.7276, -0.0671, 0.8296],\n",
- " [-0.2987, -0.0700, -1.0519, ..., 0.6090, 0.0193, 0.0410],\n",
- " ...,\n",
- " [-0.1463, -0.8924, 0.7947, ..., 0.2265, -0.6957, 0.5928],\n",
- " [-0.4365, -0.9251, -1.0378, ..., -0.8628, -0.5243, 0.0860],\n",
- " [ 0.4860, 0.0648, -0.9160, ..., -0.5342, 0.1072, -0.1397]],\n",
- " device='cuda:0')),\n",
- " ('sadcl_mlp.0.bias',\n",
- " tensor([-0.6311, -1.0433, -1.0390, -1.6997, -1.0766, -0.2802, -0.9433, -0.7127,\n",
- " 0.5315, -1.0400, -0.3756, -0.2602, -0.7607, 0.7578, -0.7066, -0.3561,\n",
- " -0.5580, -0.7671, -0.2557, -1.6528, -0.1438, -0.4875, -0.6291, -1.2763,\n",
- " -0.2484, -0.6396, -0.7225, -0.8314, -1.3913, -0.7696, 0.0864, -0.7268,\n",
- " -0.7812, -1.0606, -0.9011, 0.3322, 0.5159, -0.4453, -0.6409, 0.0714,\n",
- " -0.2788, -0.1620, -0.9408, 0.1440, -0.8897, -0.9288, -1.2605, -1.2384,\n",
- " -0.0090, -0.0661, -0.5203, -1.5729, -0.5143, -0.4943, -0.9472, -0.8107,\n",
- " -0.5748, -1.1438, -0.8919, -0.8606, -1.0831, -1.4380, -1.0802, 0.0522,\n",
- " 0.0785, -2.4277, -1.0447, -0.3124, 0.1173, -0.8195, -0.0623, -0.1913,\n",
- " -1.4551, -0.0732, -1.1574, -0.2217, -0.6697, -0.5846, -0.2473, -0.0144,\n",
- " -1.2317, -0.5024, -0.2301, 0.2265, -0.6478, -0.8726, -0.8367, -0.0312,\n",
- " -0.4783, -0.3132, -0.6115, -1.5002, -0.6820, -0.9731, -0.6438, -0.8716,\n",
- " -0.2628, -0.8308, -0.8588, 0.8616, -0.3398, 0.2025, -0.6247, -0.4494,\n",
- " -1.2737, -0.9406, -0.5297, -0.4886, -1.6481, -2.5021, -0.1344, -0.8274,\n",
- " -1.6135, -0.9598, -0.8659, -1.3385, -1.4567, -1.0869, -0.1999, -1.3751,\n",
- " -0.4536, -1.0839, -1.0037, -0.0429, -0.5243, -0.8836, -0.9716, -1.1037],\n",
- " device='cuda:0')),\n",
- " ('sadcl_mlp.2.weight',\n",
- " tensor([[ 0.3282, 0.2861, 0.4277, ..., -1.1185, 0.3197, 0.6003],\n",
- " [-0.9305, 0.1462, -0.4269, ..., -0.1129, -0.7909, -0.6872],\n",
- " [ 0.0067, -0.7521, -1.6837, ..., -0.2374, -0.2790, -0.9895],\n",
- " ...,\n",
- " [-1.8292, 0.9060, 1.3090, ..., -0.0273, -1.0552, -0.2187],\n",
- " [-0.3804, 0.0945, 0.0337, ..., -1.6941, 0.0693, -0.0288],\n",
- " [-1.3038, 0.2590, -0.2965, ..., 0.9425, -0.0090, -1.2449]],\n",
- " device='cuda:0')),\n",
- " ('sadcl_mlp.3.weight',\n",
- " tensor([2.4987, 2.5515, 2.0518, ..., 1.9383, 1.2583, 1.1634], device='cuda:0')),\n",
- " ('sadcl_mlp.3.bias',\n",
- " tensor([ 1.0355, -0.7098, -1.4075, ..., -0.3350, -0.7165, -0.9371],\n",
- " device='cuda:0'))])"
- ]
- },
- "execution_count": 9,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "w"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "fa50b148-1f34-41c9-b7f4-c26c3e4cbce6",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python [conda env:deep]",
- "language": "python",
- "name": "conda-env-deep-py"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.11"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
- }
|