@@ -0,0 +1,6 @@ | |||
datasets/ | |||
__pycache__ | |||
.ipynb_checkpoints | |||
wandb | |||
lab/ | |||
@@ -0,0 +1,115 @@ | |||
from typing import Optional | |||
import numpy as np | |||
from tqdm import tqdm | |||
import wandb | |||
import torch | |||
import torch.nn as nn | |||
from transformers import T5TokenizerFast, T5ForConditionalGeneration | |||
from _config import load_config | |||
from _utils import print_system_info, silent_logs | |||
from _datasets import AutoLoad, generate_dataloader | |||
from _mydelta import T5Wrapper, auto_freeze, EmbeddingWrapper | |||
from _trainer import train_loop, valid_loop, BestFinder | |||
configs = load_config('./config.yaml') | |||
RANDOM_SEED = configs.shared.random_seed | |||
WANDB_PROJECT_NAME = configs.shared.project_name | |||
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
USE_TQDM = configs.shared.use_tqdm | |||
def run_experminent(config): | |||
np.random.seed(RANDOM_SEED) | |||
# ______________________LOAD MODEL_____________________________ | |||
model = T5ForConditionalGeneration.from_pretrained(config.model_name) | |||
tokenizer = T5TokenizerFast.from_pretrained(config.model_name, model_max_length=2048) | |||
# ______________________MUTATE MODEL_____________________________ | |||
if config.peft_params is not None: | |||
peft_params = config.peft_params.to_dict() | |||
slected_tokens = torch.from_numpy( | |||
np.random.randint(0, tokenizer.vocab_size, size=(peft_params['n_tokens'],)) | |||
) | |||
peft_class = { | |||
't5_encoder': T5Wrapper, | |||
'encoder_emb': EmbeddingWrapper | |||
}[peft_params.pop('kind')] | |||
delta_module = peft_class.mutate( | |||
model=model, | |||
slected_tokens=slected_tokens, | |||
**peft_params | |||
) | |||
elif config.best_finder.save: | |||
raise NotImplementedError() | |||
freeze_notes = auto_freeze(model, config.hot_modules) | |||
# ______________________LOAD DATA_____________________________ | |||
data_loader = AutoLoad(tokenizer) | |||
dataset = data_loader.get_and_map(config.tasks[0]) | |||
train_loader, valid_loader = generate_dataloader(tokenizer, dataset['train'], dataset['valid'], config) | |||
# ______________________TRAIN_____________________________ | |||
wandb.init( | |||
name=config.wandb_name, | |||
project=WANDB_PROJECT_NAME, | |||
config=config.to_dict(), | |||
notes=freeze_notes | |||
) | |||
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) | |||
best_finder = BestFinder(config.best_finder.higher_better) | |||
model.to(DEVICE) | |||
epochs_range = range(config.num_epochs) | |||
if USE_TQDM: | |||
epochs_range = tqdm(epochs_range, position=1, desc="EPOCHS", leave=False) | |||
for epoch in epochs_range: | |||
epoch_results = {} | |||
epoch_results.update( | |||
train_loop( | |||
model=model, | |||
loader=train_loader, | |||
optimizer=optimizer, | |||
use_tqdm=USE_TQDM | |||
) | |||
) | |||
epoch_results.update( | |||
valid_loop( | |||
model=model, | |||
loader=valid_loader, | |||
use_tqdm=USE_TQDM | |||
) | |||
) | |||
if config.best_finder.save: | |||
if best_finder.is_better(epoch_results[config.best_finder.metric]): | |||
torch.save(delta_module.peft_state_dict(), './best.pt') | |||
wandb.log(epoch_results) | |||
wandb.finish() | |||
if __name__ == '__main__': | |||
print_system_info() | |||
silent_logs() | |||
run_configs = configs.run_configs | |||
if USE_TQDM: | |||
run_configs = tqdm(run_configs, position=0, desc="Experiment") | |||
for run_config in run_configs: | |||
run_experminent(run_config) |
@@ -0,0 +1,111 @@ | |||
from typing import Optional | |||
import numpy as np | |||
from tqdm import tqdm | |||
import wandb | |||
import torch | |||
import torch.nn as nn | |||
from transformers import T5TokenizerFast, T5ForConditionalGeneration | |||
from _config import load_config | |||
from _utils import print_system_info, silent_logs | |||
from _datasets import AutoLoad, generate_dataloader | |||
from _mydelta import T5Wrapper, auto_freeze, EmbeddingWrapper | |||
from _trainer import train_loop, valid_loop | |||
configs = load_config('./config.yaml') | |||
RANDOM_SEED = configs.shared.random_seed | |||
WANDB_PROJECT_NAME = configs.shared.project_name | |||
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
USE_TQDM = configs.shared.use_tqdm | |||
def run_experminent(config): | |||
np.random.seed(RANDOM_SEED) | |||
# ______________________LOAD MODEL_____________________________ | |||
model = T5ForConditionalGeneration.from_pretrained(config.model_name) | |||
tokenizer = T5TokenizerFast.from_pretrained(config.model_name, model_max_length=2048) | |||
# ______________________MUTATE MODEL_____________________________ | |||
if config.peft_params is not None: | |||
peft_params = config.peft_params.to_dict() | |||
slected_tokens = torch.from_numpy( | |||
np.random.randint(0, tokenizer.vocab_size, size=(peft_params['n_tokens'],)) | |||
) | |||
peft_class = { | |||
't5_encoder': T5Wrapper, | |||
'encoder_emb': EmbeddingWrapper | |||
}[peft_params.pop('kind')] | |||
delta_module = peft_class.mutate( | |||
model=model, | |||
slected_tokens=slected_tokens, | |||
**peft_params | |||
) | |||
loaded_weights = torch.load('./best.pt') | |||
loaded_weights.pop('sadcl_learned_embedding') | |||
delta_module.load_peft_state_dict(loaded_weights) | |||
freeze_notes = auto_freeze(model, config.hot_modules) | |||
# ______________________LOAD DATA_____________________________ | |||
data_loader = AutoLoad(tokenizer) | |||
dataset = data_loader.get_and_map(config.tasks[0]) | |||
train_loader, valid_loader = generate_dataloader(tokenizer, dataset['train'], dataset['valid'], config) | |||
# ______________________TRAIN_____________________________ | |||
wandb.init( | |||
name=config.wandb_name, | |||
project=WANDB_PROJECT_NAME, | |||
config=config.to_dict(), | |||
notes=freeze_notes | |||
) | |||
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) | |||
model.to(DEVICE) | |||
epochs_range = range(config.num_epochs) | |||
if USE_TQDM: | |||
epochs_range = tqdm(epochs_range, position=1, desc="EPOCHS", leave=False) | |||
for epoch in epochs_range: | |||
epoch_results = {} | |||
epoch_results.update( | |||
train_loop( | |||
model=model, | |||
loader=train_loader, | |||
optimizer=optimizer, | |||
use_tqdm=USE_TQDM | |||
) | |||
) | |||
epoch_results.update( | |||
valid_loop( | |||
model=model, | |||
loader=valid_loader, | |||
use_tqdm=USE_TQDM | |||
) | |||
) | |||
wandb.log(epoch_results) | |||
wandb.finish() | |||
if __name__ == '__main__': | |||
print_system_info() | |||
silent_logs() | |||
run_configs = configs.run_configs | |||
if USE_TQDM: | |||
run_configs = tqdm(run_configs, position=0, desc="Experiment") | |||
for run_config in run_configs: | |||
run_experminent(run_config) |
@@ -0,0 +1,105 @@ | |||
shared: | |||
project_name: continual_prompt_pretrained_mlp | |||
use_tqdm: true | |||
random_seed: 42 | |||
default: &default | |||
model_name: google/t5-large-lm-adapt | |||
wandb_name: null | |||
train_batch_size: 32 | |||
valid_batch_size: 32 | |||
num_epochs: 100 | |||
peft_params: null # no mutation | |||
hot_modules: null # fine-tune all | |||
balancify_train: false | |||
best_finder: | |||
save: true | |||
metric: valid_f1-score-ma | |||
higher_better: true | |||
tasks: | |||
- glue:cola | |||
run_configs: | |||
# - <<: *default | |||
# wandb_name: large_5t_mlp128 | |||
# learning_rate: 0.02 | |||
# hot_modules: | |||
# - sadcl_learned_embeddin | |||
# train_batch_size: 24 | |||
# valid_batch_size: 24 | |||
# peft_params: | |||
# kind: encoder_emb | |||
# n_tokens: 5 | |||
# mlp_emb: 128 | |||
# - <<: *default | |||
# wandb_name: large_10t_mlp128 | |||
# learning_rate: 0.02 | |||
# hot_modules: | |||
# - sadcl_learned_embeddin | |||
# train_batch_size: 24 | |||
# valid_batch_size: 24 | |||
# peft_params: | |||
# kind: encoder_emb | |||
# n_tokens: 10 | |||
# mlp_emb: 128 | |||
# - <<: *default | |||
# wandb_name: large_5t_mlp128_not_freeze | |||
# learning_rate: 0.02 | |||
# hot_modules: | |||
# - sadcl | |||
# train_batch_size: 24 | |||
# valid_batch_size: 24 | |||
# peft_params: | |||
# kind: encoder_emb | |||
# n_tokens: 5 | |||
# mlp_emb: 128 | |||
# - <<: *default | |||
# wandb_name: large_10t_mlp128_not_freeze | |||
# learning_rate: 0.02 | |||
# hot_modules: | |||
# - sadcl | |||
# train_batch_size: 24 | |||
# valid_batch_size: 24 | |||
# peft_params: | |||
# kind: encoder_emb | |||
# n_tokens: 10 | |||
# mlp_emb: 128 | |||
# - <<: *default | |||
# wandb_name: large_5t_mlp128_not_freeze_lowlr | |||
# learning_rate: 0.001 | |||
# hot_modules: | |||
# - sadcl | |||
# train_batch_size: 24 | |||
# valid_batch_size: 24 | |||
# peft_params: | |||
# kind: encoder_emb | |||
# n_tokens: 5 | |||
# mlp_emb: 128 | |||
# - <<: *default | |||
# wandb_name: large_10t_mlp128_not_freeze_lowlr | |||
# learning_rate: 0.001 | |||
# hot_modules: | |||
# - sadcl | |||
# train_batch_size: 24 | |||
# valid_batch_size: 24 | |||
# peft_params: | |||
# kind: encoder_emb | |||
# n_tokens: 10 | |||
# mlp_emb: 128 | |||
- <<: *default | |||
wandb_name: large_100t_mlp128_lr.02 | |||
learning_rate: 0.02 | |||
hot_modules: | |||
- sadcl_learned_embeddin | |||
train_batch_size: 24 | |||
valid_batch_size: 24 | |||
peft_params: | |||
kind: encoder_emb | |||
n_tokens: 100 | |||
mlp_emb: 128 |
@@ -0,0 +1,417 @@ | |||
{ | |||
"cells": [ | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 6, | |||
"id": "a50443d6-fe09-4905-b913-1be5f88c8c03", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"import numpy as np\n", | |||
"from tqdm import tqdm\n", | |||
"from sklearn.model_selection import train_test_split\n", | |||
"import torch\n", | |||
"import torch.nn as nn\n", | |||
"from transformers import T5Model" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 7, | |||
"id": "4e677034-dc27-4939-8ea2-71fcbb2da57d", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"np_rng = np.random.default_rng(seed=42)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 8, | |||
"id": "3d139e0a-b8e3-427b-a537-44bc0f14ba46", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"array([[ 0.09141512, -0.31199523],\n", | |||
" [ 0.22513536, 0.28216941],\n", | |||
" [-0.58531056, -0.39065385],\n", | |||
" [ 0.03835212, -0.09487278],\n", | |||
" [-0.00504035, -0.25591318],\n", | |||
" [ 0.26381939, 0.23333758],\n", | |||
" [ 0.01980921, 0.33817236],\n", | |||
" [ 0.1402528 , -0.25778774],\n", | |||
" [ 0.11062524, -0.28766478],\n", | |||
" [ 0.26353509, -0.01497777],\n", | |||
" [-0.05545871, -0.20427886],\n", | |||
" [ 0.3667624 , -0.04635884],\n", | |||
" [-0.12849835, -0.10564007],\n", | |||
" [ 0.15969276, 0.10963322],\n", | |||
" [ 0.12381978, 0.1292463 ],\n", | |||
" [ 0.64249428, -0.1219245 ],\n", | |||
" [-0.15367282, -0.24413182],\n", | |||
" [ 0.18479383, 0.33869169],\n", | |||
" [-0.03418424, -0.25204694],\n", | |||
" [-0.24734436, 0.19517784],\n", | |||
" [ 0.22297625, 0.16294628],\n", | |||
" [-0.19965291, 0.0696484 ],\n", | |||
" [ 0.03500574, 0.06560658],\n", | |||
" [ 0.26142863, 0.06707866],\n", | |||
" [ 0.20367407, 0.02027372],\n", | |||
" [ 0.08673582, 0.18938647],\n", | |||
" [-0.43714675, -0.09590136],\n", | |||
" [-0.1411118 , -0.19166335],\n", | |||
" [-0.08254268, 0.44848239],\n", | |||
" [-0.25974933, 0.29048351],\n", | |||
" [-0.50486093, -0.10046551],\n", | |||
" [ 0.04882592, 0.1758667 ]])" | |||
] | |||
}, | |||
"execution_count": 8, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
} | |||
], | |||
"source": [ | |||
"np_rng.normal(loc=0, scale=0.3, size=(32, 2))" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 9, | |||
"id": "544207bc-37fc-4376-9c63-bff44c72b32f", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"# BOTTLENECK_SIZE = 128\n", | |||
"TRAIN_BATCH_SIZE = 8192\n", | |||
"VALID_BATCH_SIZE = 8192\n", | |||
"RANDOM_SEED = 42\n", | |||
"\n", | |||
"DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 10, | |||
"id": "37d2d256-a348-402b-999d-1a4edce360c5", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"def train_valid_test_split(total_range, random_seed=RANDOM_SEED):\n", | |||
" train, testvalid = train_test_split(total_range, random_state=RANDOM_SEED, test_size=0.2)\n", | |||
" test, valid = train_test_split(testvalid, random_state=RANDOM_SEED, test_size=0.5)\n", | |||
" return train, valid, test\n", | |||
"\n", | |||
"def custom_dataloader(words_ids, batch_size, emb_dim, random_seed=RANDOM_SEED):\n", | |||
" np_rng = np.random.default_rng(seed=random_seed)\n", | |||
" while True:\n", | |||
" word_ids = np_rng.choice(words_ids, size=(batch_size, 2))\n", | |||
" additive_noise = np_rng.normal(loc=0, scale=0.1, size=(batch_size, emb_dim))\n", | |||
" alpha = np_rng.uniform(size=(batch_size, 1))\n", | |||
" yield torch.from_numpy(word_ids), torch.Tensor(additive_noise), torch.Tensor(alpha)\n", | |||
" \n", | |||
"class FakeEpoch:\n", | |||
" def __init__(self, dataloader, each_epoch_size):\n", | |||
" self.dataloader_iter = iter(dataloader)\n", | |||
" self.each_epoch_size = each_epoch_size\n", | |||
" \n", | |||
" def __len__(self):\n", | |||
" return self.each_epoch_size\n", | |||
" \n", | |||
" def __iter__(self):\n", | |||
" for _ in range(self.each_epoch_size):\n", | |||
" yield next(self.dataloader_iter)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 11, | |||
"id": "644ae479-3f9a-426a-bd0b-4ec7694bc675", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"\n", | |||
"def ez_freeze(module):\n", | |||
" for param in module.parameters():\n", | |||
" param.requires_grad = False\n", | |||
" \n", | |||
"def ez_mlp(linear_dims, last_layer_bias=False):\n", | |||
" layers = []\n", | |||
" pairs_count = len(linear_dims) - 1\n", | |||
" for idx in range(pairs_count):\n", | |||
" in_dim, out_dim = linear_dims[idx], linear_dims[idx + 1]\n", | |||
" if idx == pairs_count - 1:\n", | |||
" layers.append(nn.Linear(in_dim, out_dim, bias=last_layer_bias))\n", | |||
" else:\n", | |||
" layers.append(nn.Linear(in_dim, out_dim, bias=True))\n", | |||
" layers.append(nn.ReLU())\n", | |||
" return nn.Sequential(*layers)\n", | |||
"\n", | |||
"def auto_encoder_model(linear_dims):\n", | |||
" return nn.Sequential(\n", | |||
" ez_mlp(linear_dims, last_layer_bias=False),\n", | |||
" nn.LayerNorm(linear_dims[-1]),\n", | |||
" ez_mlp(list(reversed(linear_dims)), last_layer_bias=True)\n", | |||
" )\n", | |||
"\n", | |||
"class AutoEncoderModel(nn.Module):\n", | |||
" def __init__(self, pretrained_name, bottleneck_sizes):\n", | |||
" super().__init__()\n", | |||
" \n", | |||
" self.bottleneck_size = bottleneck_sizes\n", | |||
" \n", | |||
" model = T5Model.from_pretrained(pretrained_name)\n", | |||
" self.emb_layer = model.get_encoder().get_input_embeddings()\n", | |||
" ez_freeze(self.emb_layer)\n", | |||
" \n", | |||
" self.auto_encoder = auto_encoder_model([\n", | |||
" self.embedding_dim,\n", | |||
" *bottleneck_sizes\n", | |||
" ])\n", | |||
" \n", | |||
" self.loss_fn = nn.MSELoss()\n", | |||
" \n", | |||
" def forward(self, word_ids, additive_noise, alpha):\n", | |||
" # word_ids.shape = (batch_size, 2)\n", | |||
" # additive_noise.shape = (batch_size, embedding_dim)\n", | |||
" # alpha.shape = (batch_size, 1)\n", | |||
" \n", | |||
" word_embs = self.emb_layer(word_ids)\n", | |||
" # word_embs.shape = (batch_size, 2, embedding_dim)\n", | |||
" \n", | |||
" word_combs = word_embs[:, 0] * alpha + word_embs[:, 1] * (1 - alpha)\n", | |||
" # word_combs.shape = (batch_size, embedding_dim)\n", | |||
" \n", | |||
" y_hat = self.auto_encoder(word_combs + additive_noise)\n", | |||
" loss = self.loss_fn(word_combs, y_hat)\n", | |||
" return loss, y_hat\n", | |||
" \n", | |||
" @property\n", | |||
" def embedding_dim(self):\n", | |||
" return self.emb_layer.embedding_dim\n", | |||
" \n", | |||
" @property\n", | |||
" def num_embeddings(self):\n", | |||
" return self.emb_layer.num_embeddings " | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 12, | |||
"id": "aba28049-20bf-4ae6-9445-2f7c294686d8", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"model = AutoEncoderModel('google/t5-large-lm-adapt', bottleneck_sizes=[768, 768, 512, 512, 256, 256, 128, 128])" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 16, | |||
"id": "cac6bc39-ba12-4052-bd5f-8834f57cfa15", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"tensor(96.9082)" | |||
] | |||
}, | |||
"execution_count": 16, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
} | |||
], | |||
"source": [ | |||
"(model.emb_layer.weight**2).mean()" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 6, | |||
"id": "afe2efbf-e703-4c43-8f7b-a87d303ea89e", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"train_ds, valid_ds, test_ds = train_valid_test_split(range(model.num_embeddings))\n", | |||
"train_loader = custom_dataloader(words_ids=train_ds, batch_size=TRAIN_BATCH_SIZE, emb_dim=model.embedding_dim)\n", | |||
"valid_loader = custom_dataloader(words_ids=valid_ds, batch_size=VALID_BATCH_SIZE, emb_dim=model.embedding_dim)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 7, | |||
"id": "c24ccc1c-4cbe-4373-871e-9090dceb69a1", | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"train_loader = FakeEpoch(train_loader, 1000)\n", | |||
"valid_loader = FakeEpoch(valid_loader, 100)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 8, | |||
"id": "71936e43-d718-45ef-8115-7fc63999ebd9", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"def _prefix_dict_keys(prefix, input_dict):\n", | |||
" return {f'{prefix}_{key}': val for key, val in input_dict.items()}\n", | |||
"\n", | |||
"def train_loop(model, loader, optimizer, use_tqdm=False):\n", | |||
" model.train()\n", | |||
"\n", | |||
" batch_losses = []\n", | |||
" \n", | |||
" if use_tqdm:\n", | |||
" loader = tqdm(loader, position=2, desc=\"Train Loop\", leave=False)\n", | |||
" \n", | |||
" for row in loader:\n", | |||
" optimizer.zero_grad()\n", | |||
" \n", | |||
" out = model(*(item.to(DEVICE) for item in row))\n", | |||
" loss = out[0]\n", | |||
" \n", | |||
" batch_loss_value = loss.item()\n", | |||
" loss.backward()\n", | |||
" optimizer.step()\n", | |||
" \n", | |||
" batch_losses.append(batch_loss_value)\n", | |||
" \n", | |||
" loss_value = np.mean(batch_losses)\n", | |||
" return _prefix_dict_keys('train', {\n", | |||
" 'loss': loss_value\n", | |||
" })\n", | |||
"\n", | |||
"def valid_loop(model, loader, use_tqdm=False):\n", | |||
" model.eval()\n", | |||
"\n", | |||
" batch_losses = []\n", | |||
" \n", | |||
" all_true = []\n", | |||
" all_pred = []\n", | |||
" \n", | |||
" if use_tqdm:\n", | |||
" loader = tqdm(loader, position=2, desc=\"Valid Loop\", leave=False)\n", | |||
" \n", | |||
" with torch.no_grad():\n", | |||
" for row in loader:\n", | |||
" out = model(*(item.to(DEVICE) for item in row))\n", | |||
" loss = out[0]\n", | |||
" \n", | |||
" batch_loss_value = loss.item()\n", | |||
"\n", | |||
" batch_losses.append(batch_loss_value)\n", | |||
"\n", | |||
" loss_value = np.mean(batch_losses)\n", | |||
" \n", | |||
" return_value = {\n", | |||
" 'loss': loss_value,\n", | |||
" }\n", | |||
" \n", | |||
" return _prefix_dict_keys('valid', return_value)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 9, | |||
"id": "082b5384-827f-48b3-aa8e-40483668bbc0", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [ | |||
{ | |||
"ename": "KeyboardInterrupt", | |||
"evalue": "", | |||
"output_type": "error", | |||
"traceback": [ | |||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |||
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", | |||
"Cell \u001b[0;32mIn[9], line 8\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1000\u001b[39m):\n\u001b[1;32m 5\u001b[0m epoch_results \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m 7\u001b[0m epoch_results\u001b[38;5;241m.\u001b[39mupdate(\n\u001b[0;32m----> 8\u001b[0m \u001b[43mtrain_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43mloader\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_tqdm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\n\u001b[1;32m 13\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 14\u001b[0m )\n\u001b[1;32m 16\u001b[0m epoch_results\u001b[38;5;241m.\u001b[39mupdate(\n\u001b[1;32m 17\u001b[0m valid_loop(\n\u001b[1;32m 18\u001b[0m model\u001b[38;5;241m=\u001b[39mmodel,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 21\u001b[0m )\n\u001b[1;32m 22\u001b[0m )\n\u001b[1;32m 23\u001b[0m \u001b[38;5;28mprint\u001b[39m(epoch_results)\n", | |||
"Cell \u001b[0;32mIn[8], line 12\u001b[0m, in \u001b[0;36mtrain_loop\u001b[0;34m(model, loader, optimizer, use_tqdm)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_tqdm:\n\u001b[1;32m 10\u001b[0m loader \u001b[38;5;241m=\u001b[39m tqdm(loader, position\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m, desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTrain Loop\u001b[39m\u001b[38;5;124m\"\u001b[39m, leave\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m---> 12\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m row \u001b[38;5;129;01min\u001b[39;00m loader:\n\u001b[1;32m 13\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m 15\u001b[0m out \u001b[38;5;241m=\u001b[39m model(\u001b[38;5;241m*\u001b[39m(item\u001b[38;5;241m.\u001b[39mto(DEVICE) \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m row))\n", | |||
"Cell \u001b[0;32mIn[3], line 24\u001b[0m, in \u001b[0;36mFakeEpoch.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__iter__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 23\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39meach_epoch_size):\n\u001b[0;32m---> 24\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataloader_iter\u001b[49m\u001b[43m)\u001b[49m\n", | |||
"Cell \u001b[0;32mIn[3], line 10\u001b[0m, in \u001b[0;36mcustom_dataloader\u001b[0;34m(words_ids, batch_size, emb_dim, random_seed)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[1;32m 9\u001b[0m word_ids \u001b[38;5;241m=\u001b[39m np_rng\u001b[38;5;241m.\u001b[39mchoice(words_ids, size\u001b[38;5;241m=\u001b[39m(batch_size, \u001b[38;5;241m2\u001b[39m))\n\u001b[0;32m---> 10\u001b[0m additive_noise \u001b[38;5;241m=\u001b[39m \u001b[43mnp_rng\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnormal\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43memb_dim\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 11\u001b[0m alpha \u001b[38;5;241m=\u001b[39m np_rng\u001b[38;5;241m.\u001b[39muniform(size\u001b[38;5;241m=\u001b[39m(batch_size, \u001b[38;5;241m1\u001b[39m))\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mfrom_numpy(word_ids), torch\u001b[38;5;241m.\u001b[39mTensor(additive_noise), torch\u001b[38;5;241m.\u001b[39mTensor(alpha)\n", | |||
"\u001b[0;31mKeyboardInterrupt\u001b[0m: " | |||
] | |||
} | |||
], | |||
"source": [ | |||
"model.to(DEVICE)\n", | |||
"optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)\n", | |||
"\n", | |||
"for epoch in range(1000):\n", | |||
" epoch_results = {}\n", | |||
"\n", | |||
" epoch_results.update(\n", | |||
" train_loop(\n", | |||
" model=model,\n", | |||
" loader=train_loader,\n", | |||
" optimizer=optimizer,\n", | |||
" use_tqdm=False\n", | |||
" )\n", | |||
" )\n", | |||
"\n", | |||
" epoch_results.update(\n", | |||
" valid_loop(\n", | |||
" model=model,\n", | |||
" loader=valid_loader,\n", | |||
" use_tqdm=False\n", | |||
" )\n", | |||
" )\n", | |||
" print(epoch_results)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"id": "53425637-6146-41d2-b59e-4617ae1f8521", | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [] | |||
} | |||
], | |||
"metadata": { | |||
"kernelspec": { | |||
"display_name": "Python [conda env:deep]", | |||
"language": "python", | |||
"name": "conda-env-deep-py" | |||
}, | |||
"language_info": { | |||
"codemirror_mode": { | |||
"name": "ipython", | |||
"version": 3 | |||
}, | |||
"file_extension": ".py", | |||
"mimetype": "text/x-python", | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.10.11" | |||
} | |||
}, | |||
"nbformat": 4, | |||
"nbformat_minor": 5 | |||
} |
@@ -0,0 +1,244 @@ | |||
#!/usr/bin/env python | |||
# coding: utf-8 | |||
# In[1]: | |||
import numpy as np | |||
from tqdm import tqdm | |||
from sklearn.model_selection import train_test_split | |||
import torch | |||
import torch.nn as nn | |||
from transformers import T5Model | |||
# In[2]: | |||
# BOTTLENECK_SIZE = 128 | |||
TRAIN_BATCH_SIZE = 64 | |||
VALID_BATCH_SIZE = 64 | |||
NOISE_SCALE = 0.5 | |||
RANDOM_SEED = 42 | |||
SEED_SHIFT = 0 | |||
DROP_OUT = 0.5 | |||
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
# In[3]: | |||
def train_valid_test_split(total_range, random_seed=RANDOM_SEED): | |||
train, testvalid = train_test_split(total_range, random_state=random_seed, test_size=0.2) | |||
test, valid = train_test_split(testvalid, random_state=random_seed, test_size=0.5) | |||
return train, valid, test | |||
def custom_dataloader(words_ids, batch_size, emb_dim, random_seed=RANDOM_SEED+SEED_SHIFT): | |||
np_rng = np.random.default_rng(seed=random_seed) | |||
while True: | |||
word_ids = np_rng.choice(words_ids, size=(batch_size, 2)) | |||
additive_noise = np_rng.normal(loc=0, scale=NOISE_SCALE, size=(batch_size, emb_dim)) | |||
alpha = np_rng.uniform(size=(batch_size, 1)) | |||
yield torch.from_numpy(word_ids), torch.Tensor(additive_noise), torch.Tensor(alpha) | |||
class FakeEpoch: | |||
def __init__(self, dataloader, each_epoch_size): | |||
self.dataloader_iter = iter(dataloader) | |||
self.each_epoch_size = each_epoch_size | |||
def __len__(self): | |||
return self.each_epoch_size | |||
def __iter__(self): | |||
for _ in range(self.each_epoch_size): | |||
yield next(self.dataloader_iter) | |||
# In[4]: | |||
def ez_freeze(module): | |||
for param in module.parameters(): | |||
param.requires_grad = False | |||
def ez_mlp(linear_dims, last_layer_bias=False, drop_out=None): | |||
layers = [] | |||
pairs_count = len(linear_dims) - 1 | |||
for idx in range(pairs_count): | |||
in_dim, out_dim = linear_dims[idx], linear_dims[idx + 1] | |||
if idx == pairs_count - 1: | |||
layers.append(nn.Linear(in_dim, out_dim, bias=True)) | |||
else: | |||
layers.append(nn.Linear(in_dim, out_dim, bias=True)) | |||
layers.append(nn.ReLU()) | |||
if drop_out is not None: | |||
layers.append(nn.Dropout(drop_out)) | |||
return nn.Sequential(*layers) | |||
def auto_encoder_model(linear_dims): | |||
return nn.Sequential( | |||
ez_mlp(linear_dims, last_layer_bias=False, drop_out=DROP_OUT), | |||
nn.ReLU(), | |||
nn.Dropout(0.5), | |||
# nn.LayerNorm(linear_dims[-1]), | |||
ez_mlp(list(reversed(linear_dims)), last_layer_bias=True) | |||
) | |||
class AutoEncoderModel(nn.Module): | |||
def __init__(self, pretrained_name, bottleneck_sizes): | |||
super().__init__() | |||
self.bottleneck_size = bottleneck_sizes | |||
model = T5Model.from_pretrained(pretrained_name) | |||
self.emb_layer = model.get_encoder().get_input_embeddings() | |||
ez_freeze(self.emb_layer) | |||
self.auto_encoder = auto_encoder_model([ | |||
self.embedding_dim, | |||
*bottleneck_sizes | |||
]) | |||
self.loss_fn = nn.MSELoss() | |||
def forward(self, word_ids, additive_noise, alpha): | |||
# word_ids.shape = (batch_size, 2) | |||
# additive_noise.shape = (batch_size, embedding_dim) | |||
# alpha.shape = (batch_size, 1) | |||
word_embs = self.emb_layer(word_ids) | |||
# word_embs.shape = (batch_size, 2, embedding_dim) | |||
word_combs = word_embs[:, 0] * alpha + word_embs[:, 1] * (1 - alpha) | |||
# word_combs.shape = (batch_size, embedding_dim) | |||
y_hat = self.auto_encoder(word_combs + additive_noise) | |||
loss = self.loss_fn(word_combs, y_hat) | |||
return loss, y_hat | |||
@property | |||
def embedding_dim(self): | |||
return self.emb_layer.embedding_dim | |||
@property | |||
def num_embeddings(self): | |||
return self.emb_layer.num_embeddings | |||
# In[5]: | |||
model = AutoEncoderModel('google/t5-large-lm-adapt', bottleneck_sizes=[4096]) | |||
print(model) | |||
# In[6]: | |||
train_ds, valid_ds, test_ds = train_valid_test_split(range(model.num_embeddings)) | |||
train_loader = custom_dataloader(words_ids=train_ds, batch_size=TRAIN_BATCH_SIZE, emb_dim=model.embedding_dim) | |||
valid_loader = custom_dataloader(words_ids=valid_ds, batch_size=VALID_BATCH_SIZE, emb_dim=model.embedding_dim) | |||
# In[7]: | |||
train_loader = FakeEpoch(train_loader, 2000) | |||
valid_loader = FakeEpoch(valid_loader, 100) | |||
# In[8]: | |||
def _prefix_dict_keys(prefix, input_dict): | |||
return {f'{prefix}_{key}': val for key, val in input_dict.items()} | |||
def train_loop(model, loader, optimizer, use_tqdm=False): | |||
model.train() | |||
batch_losses = [] | |||
if use_tqdm: | |||
loader = tqdm(loader, position=2, desc="Train Loop", leave=False) | |||
for row in loader: | |||
optimizer.zero_grad() | |||
out = model(*(item.to(DEVICE) for item in row)) | |||
loss = out[0] | |||
batch_loss_value = loss.item() | |||
loss.backward() | |||
optimizer.step() | |||
batch_losses.append(batch_loss_value) | |||
loss_value = np.mean(batch_losses) | |||
return _prefix_dict_keys('train', { | |||
'loss': loss_value | |||
}) | |||
def valid_loop(model, loader, use_tqdm=False): | |||
model.eval() | |||
batch_losses = [] | |||
if use_tqdm: | |||
loader = tqdm(loader, position=2, desc="Valid Loop", leave=False) | |||
with torch.no_grad(): | |||
for row in loader: | |||
out = model(*(item.to(DEVICE) for item in row)) | |||
loss = out[0] | |||
batch_loss_value = loss.item() | |||
batch_losses.append(batch_loss_value) | |||
loss_value = np.mean(batch_losses) | |||
return_value = { | |||
'loss': loss_value, | |||
} | |||
return _prefix_dict_keys('valid', return_value) | |||
# In[9]: | |||
model.to(DEVICE) | |||
# model.load_state_dict(torch.load('./ae_file/snap_72.pt')) | |||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001) # was 0.001 | |||
for epoch in tqdm(range(1000), position=1): | |||
epoch_results = {} | |||
epoch_results.update( | |||
train_loop( | |||
model=model, | |||
loader=train_loader, | |||
optimizer=optimizer, | |||
use_tqdm=True | |||
) | |||
) | |||
epoch_results.update( | |||
valid_loop( | |||
model=model, | |||
loader=valid_loader, | |||
use_tqdm=True | |||
) | |||
) | |||
torch.save(model.state_dict(), f'/disks/ssd/ae_file4/snap_{epoch}.pt') | |||
print(epoch_results) | |||
# In[ ]: | |||
@@ -0,0 +1,254 @@ | |||
#!/usr/bin/env python | |||
# coding: utf-8 | |||
# In[1]: | |||
import numpy as np | |||
from tqdm import tqdm | |||
from sklearn.model_selection import train_test_split | |||
import torch | |||
import torch.nn as nn | |||
from transformers import T5Model | |||
# In[2]: | |||
# BOTTLENECK_SIZE = 128 | |||
TRAIN_BATCH_SIZE = 8192 | |||
VALID_BATCH_SIZE = 8192 | |||
NOISE_SCALE = 1 | |||
RANDOM_SEED = 42 | |||
SEED_SHIFT = 0 | |||
DROP_OUT = 0.2 | |||
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
# In[3]: | |||
def train_valid_test_split(total_range, random_seed=RANDOM_SEED): | |||
train, testvalid = train_test_split(total_range, random_state=random_seed, test_size=0.2) | |||
test, valid = train_test_split(testvalid, random_state=random_seed, test_size=0.5) | |||
return train, valid, test | |||
def custom_dataloader(words_ids, batch_size, emb_dim, random_seed=RANDOM_SEED+SEED_SHIFT): | |||
np_rng = np.random.default_rng(seed=random_seed) | |||
while True: | |||
word_ids = np_rng.choice(words_ids, size=(batch_size, 2)) | |||
additive_noise = np_rng.normal(loc=0, scale=NOISE_SCALE, size=(batch_size, emb_dim)) | |||
alpha = np_rng.uniform(size=(batch_size, 1)) | |||
yield torch.from_numpy(word_ids), torch.Tensor(additive_noise), torch.Tensor(alpha) | |||
class FakeEpoch: | |||
def __init__(self, dataloader, each_epoch_size): | |||
self.dataloader_iter = iter(dataloader) | |||
self.each_epoch_size = each_epoch_size | |||
def __len__(self): | |||
return self.each_epoch_size | |||
def __iter__(self): | |||
for _ in range(self.each_epoch_size): | |||
yield next(self.dataloader_iter) | |||
# In[4]: | |||
def ez_freeze(module): | |||
for param in module.parameters(): | |||
param.requires_grad = False | |||
class ResLinear(nn.Module): | |||
def __init__(self, in_dim, out_dim): | |||
super().__init__() | |||
self.linear1 = nn.Linear(in_dim, out_dim) | |||
self.linear2 = nn.Linear(out_dim, out_dim) | |||
def forward(self, x): | |||
out1 = nn.functional.relu(self.linear1(x)) | |||
out2 = nn.functional.relu(self.linear2(out1)) | |||
return out1 + out2 | |||
def ez_mlp(linear_dims, last_layer_bias=False, drop_out=None): | |||
layers = [] | |||
pairs_count = len(linear_dims) - 1 | |||
for idx in range(pairs_count): | |||
in_dim, out_dim = linear_dims[idx], linear_dims[idx + 1] | |||
if idx == pairs_count - 1: | |||
layers.append(nn.Linear(in_dim, out_dim, bias=last_layer_bias)) | |||
else: | |||
layers.append(ResLinear(in_dim, out_dim)) | |||
if drop_out is not None: | |||
layers.append(nn.Dropout(drop_out)) | |||
return nn.Sequential(*layers) | |||
def auto_encoder_model(linear_dims): | |||
return nn.Sequential( | |||
ez_mlp(linear_dims, last_layer_bias=False, drop_out=DROP_OUT), | |||
nn.LayerNorm(linear_dims[-1]), | |||
ez_mlp(list(reversed(linear_dims)), last_layer_bias=True) | |||
) | |||
class AutoEncoderModel(nn.Module): | |||
def __init__(self, pretrained_name, bottleneck_sizes): | |||
super().__init__() | |||
self.bottleneck_size = bottleneck_sizes | |||
model = T5Model.from_pretrained(pretrained_name) | |||
self.emb_layer = model.get_encoder().get_input_embeddings() | |||
ez_freeze(self.emb_layer) | |||
self.auto_encoder = auto_encoder_model([ | |||
self.embedding_dim, | |||
*bottleneck_sizes | |||
]) | |||
self.loss_fn = nn.MSELoss() | |||
def forward(self, word_ids, additive_noise, alpha): | |||
# word_ids.shape = (batch_size, 2) | |||
# additive_noise.shape = (batch_size, embedding_dim) | |||
# alpha.shape = (batch_size, 1) | |||
word_embs = self.emb_layer(word_ids) | |||
# word_embs.shape = (batch_size, 2, embedding_dim) | |||
word_combs = word_embs[:, 0] * alpha + word_embs[:, 1] * (1 - alpha) | |||
# word_combs.shape = (batch_size, embedding_dim) | |||
y_hat = self.auto_encoder(word_combs + additive_noise) | |||
loss = self.loss_fn(word_combs, y_hat) | |||
return loss, y_hat | |||
@property | |||
def embedding_dim(self): | |||
return self.emb_layer.embedding_dim | |||
@property | |||
def num_embeddings(self): | |||
return self.emb_layer.num_embeddings | |||
# In[5]: | |||
model = AutoEncoderModel('google/t5-large-lm-adapt', bottleneck_sizes=[768, 512, 256, 128]) | |||
print(model) | |||
# In[6]: | |||
train_ds, valid_ds, test_ds = train_valid_test_split(range(model.num_embeddings)) | |||
train_loader = custom_dataloader(words_ids=train_ds, batch_size=TRAIN_BATCH_SIZE, emb_dim=model.embedding_dim) | |||
valid_loader = custom_dataloader(words_ids=valid_ds, batch_size=VALID_BATCH_SIZE, emb_dim=model.embedding_dim) | |||
# In[7]: | |||
train_loader = FakeEpoch(train_loader, 1000) | |||
valid_loader = FakeEpoch(valid_loader, 100) | |||
# In[8]: | |||
def _prefix_dict_keys(prefix, input_dict): | |||
return {f'{prefix}_{key}': val for key, val in input_dict.items()} | |||
def train_loop(model, loader, optimizer, use_tqdm=False): | |||
model.train() | |||
batch_losses = [] | |||
if use_tqdm: | |||
loader = tqdm(loader, position=2, desc="Train Loop", leave=False) | |||
for row in loader: | |||
optimizer.zero_grad() | |||
out = model(*(item.to(DEVICE) for item in row)) | |||
loss = out[0] | |||
batch_loss_value = loss.item() | |||
loss.backward() | |||
optimizer.step() | |||
batch_losses.append(batch_loss_value) | |||
loss_value = np.mean(batch_losses) | |||
return _prefix_dict_keys('train', { | |||
'loss': loss_value | |||
}) | |||
def valid_loop(model, loader, use_tqdm=False): | |||
model.eval() | |||
batch_losses = [] | |||
all_true = [] | |||
all_pred = [] | |||
if use_tqdm: | |||
loader = tqdm(loader, position=2, desc="Valid Loop", leave=False) | |||
with torch.no_grad(): | |||
for row in loader: | |||
out = model(*(item.to(DEVICE) for item in row)) | |||
loss = out[0] | |||
batch_loss_value = loss.item() | |||
batch_losses.append(batch_loss_value) | |||
loss_value = np.mean(batch_losses) | |||
return_value = { | |||
'loss': loss_value, | |||
} | |||
return _prefix_dict_keys('valid', return_value) | |||
# In[9]: | |||
model.to(DEVICE) | |||
# model.load_state_dict(torch.load('./ae_file/snap_72.pt')) | |||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) # was 0.001 | |||
for epoch in tqdm(range(1000), position=1): | |||
epoch_results = {} | |||
epoch_results.update( | |||
train_loop( | |||
model=model, | |||
loader=train_loader, | |||
optimizer=optimizer, | |||
use_tqdm=True | |||
) | |||
) | |||
epoch_results.update( | |||
valid_loop( | |||
model=model, | |||
loader=valid_loader, | |||
use_tqdm=True | |||
) | |||
) | |||
torch.save(model.state_dict(), f'./ae_file4_res_mlp/snap_{epoch}.pt') | |||
print(epoch_results) | |||
# In[ ]: | |||
@@ -0,0 +1,88 @@ | |||
{ | |||
"cells": [ | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 1, | |||
"id": "4c6f353f-83e2-4780-9124-bf7f30e2a77d", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"from typing import Optional\n", | |||
"\n", | |||
"import numpy as np\n", | |||
"from tqdm import tqdm\n", | |||
"\n", | |||
"import wandb\n", | |||
"import torch\n", | |||
"import torch.nn as nn\n", | |||
"from transformers import T5TokenizerFast, T5ForConditionalGeneration\n", | |||
"\n", | |||
"from _config import load_config\n", | |||
"from _utils import print_system_info, silent_logs\n", | |||
"from _datasets import AutoLoad, generate_dataloader\n", | |||
"from _mydelta import T5Wrapper, auto_freeze, EmbeddingWrapper\n", | |||
"from _trainer import train_loop, valid_loop, BestFinder\n", | |||
"\n", | |||
"# configs = load_config('./config.yaml')\n", | |||
"\n", | |||
"# RANDOM_SEED = configs.shared.random_seed\n", | |||
"# WANDB_PROJECT_NAME = configs.shared.project_name\n", | |||
"# DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | |||
"# USE_TQDM = configs.shared.use_tqdm\n", | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 2, | |||
"id": "ead0c663-c9e4-4625-8f3b-11e53ca59920", | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"model = T5ForConditionalGeneration.from_pretrained('google/t5-large-lm-adapt')\n", | |||
"tokenizer = T5TokenizerFast.from_pretrained('google/t5-large-lm-adapt', model_max_length=2048)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"id": "e348f601-c713-49af-86e4-a40382c5a36f", | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"num_tokens = 100" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"id": "6d9a6602-f90d-440a-b11e-ddda2d36d2f7", | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [] | |||
} | |||
], | |||
"metadata": { | |||
"kernelspec": { | |||
"display_name": "Python [conda env:deep]", | |||
"language": "python", | |||
"name": "conda-env-deep-py" | |||
}, | |||
"language_info": { | |||
"codemirror_mode": { | |||
"name": "ipython", | |||
"version": 3 | |||
}, | |||
"file_extension": ".py", | |||
"mimetype": "text/x-python", | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.10.11" | |||
} | |||
}, | |||
"nbformat": 4, | |||
"nbformat_minor": 5 | |||
} |
@@ -0,0 +1,27 @@ | |||
from tqdm import tqdm | |||
import torch | |||
import os | |||
import sys | |||
sys.path.insert(1, os.path.join(sys.path[0], '..')) | |||
from _config import load_config | |||
from _utils import print_system_info, sp_encode | |||
from train_single import run_experminent | |||
if __name__ == '__main__': | |||
print_system_info() | |||
configs = load_config(sys.argv[1]) | |||
run_configs = tqdm(configs.run_configs, position=0, desc="Experiment") | |||
for run_config in run_configs: | |||
tasks = tqdm(run_config.tasks, position=1, desc="Task:", leave=False) | |||
for task_name in tasks: | |||
tasks.set_description(f'Task: {task_name}') | |||
torch.cuda.empty_cache() | |||
run_experminent(run_config, task_name) |
@@ -0,0 +1,47 @@ | |||
import numpy as np | |||
import torch | |||
import os | |||
import sys | |||
sys.path.insert(1, os.path.join(sys.path[0], '..')) | |||
from _utils import silent_logs, sp_decode | |||
from _datasets import AutoLoad | |||
from _trainer import auto_train | |||
from _mydelta import auto_mutate | |||
from _models import auto_model | |||
from _config import Config | |||
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
def run_experminent(config, task_name): | |||
silent_logs() | |||
np.random.seed(config.random_seed) | |||
# ______________________LOAD MODEL_____________________________ | |||
model, tokenizer = auto_model(config.model_name, AutoLoad.get_task_output(task_name)) | |||
# ______________________MUTATE MODEL_____________________________ | |||
n_prefix_token = 0 | |||
if config.peft_params is not None: | |||
n_prefix_token = config.peft_params.n_tokens | |||
delta_module = auto_mutate( | |||
model=model, | |||
tokenizer=tokenizer, | |||
peft_params=config.peft_params.to_dict(), | |||
remove_dropout=config.remove_dropout | |||
) | |||
# ______________________LOAD DATA_____________________________ | |||
autoload = AutoLoad(tokenizer, n_prefix_token=n_prefix_token) | |||
# ______________________TRAIN_____________________________ | |||
dataset = autoload.get_and_map(task_name) | |||
auto_train(model, tokenizer, dataset, config, device=DEVICE) | |||
if __name__ == '__main__': | |||
config_json = sp_decode(sys.argv[1]) | |||
config = Config(config_json, '') | |||
task_name = sp_decode(sys.argv[2]) | |||
run_experminent(config, task_name) |
@@ -0,0 +1,62 @@ | |||
shared: | |||
project_name: lowdim_prompts | |||
use_tqdm: true | |||
random_seed: 42 | |||
default: &default | |||
model_name: google/t5-large-lm-adapt | |||
wandb_name: null | |||
train_batch_size: 32 | |||
valid_batch_size: 32 | |||
num_epochs: 200 | |||
peft_params: null # no mutation | |||
hot_modules: null # fine-tune all | |||
balancify_train: false | |||
best_finder: | |||
save: true | |||
metric: valid_f1-score-ma | |||
higher_better: true | |||
tasks: | |||
- glue:cola | |||
run_configs: | |||
# - <<: *default | |||
# wandb_name: n_tokens100_n_comb_tokens512 | |||
# learning_rate: 0.01 | |||
# hot_modules: | |||
# - sadcl | |||
# peft_params: | |||
# kind: comb_prompt | |||
# n_tokens: 100 | |||
# n_comb_tokens: 512 | |||
# - <<: *default | |||
# wandb_name: n_tokens100_n_comb_tokens2048 | |||
# learning_rate: 0.01 | |||
# hot_modules: | |||
# - sadcl | |||
# peft_params: | |||
# kind: comb_prompt | |||
# n_tokens: 100 | |||
# n_comb_tokens: 2048 | |||
- <<: *default | |||
wandb_name: large_n_tokens100_64_256 | |||
learning_rate: 0.01 | |||
hot_modules: | |||
- sadcl | |||
peft_params: | |||
kind: lowdim_prompt | |||
n_tokens: 100 | |||
dims: | |||
- 64 | |||
- 256 | |||
- <<: *default | |||
wandb_name: large_n_tokens100_256_512 | |||
learning_rate: 0.01 | |||
hot_modules: | |||
- sadcl | |||
peft_params: | |||
kind: lowdim_prompt | |||
n_tokens: 100 | |||
dims: | |||
- 256 | |||
- 512 |
@@ -0,0 +1,116 @@ | |||
from typing import Optional | |||
import numpy as np | |||
from tqdm import tqdm | |||
import wandb | |||
import torch | |||
import torch.nn as nn | |||
from transformers import T5TokenizerFast, T5ForConditionalGeneration | |||
import os | |||
import sys | |||
sys.path.insert(1, os.path.join(sys.path[0], '..')) | |||
from _config import load_config | |||
from _utils import print_system_info, silent_logs | |||
from _datasets import AutoLoad, generate_dataloader | |||
from _mydelta import auto_freeze, LowdimEmbeddingWrapper | |||
from _trainer import train_loop, valid_loop, BestFinder | |||
configs = load_config('./config.yaml') | |||
RANDOM_SEED = configs.shared.random_seed | |||
WANDB_PROJECT_NAME = configs.shared.project_name | |||
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
USE_TQDM = configs.shared.use_tqdm | |||
def run_experminent(config): | |||
np.random.seed(RANDOM_SEED) | |||
# ______________________LOAD MODEL_____________________________ | |||
model = T5ForConditionalGeneration.from_pretrained(config.model_name) | |||
tokenizer = T5TokenizerFast.from_pretrained(config.model_name, model_max_length=2048) | |||
# ______________________MUTATE MODEL_____________________________ | |||
if config.peft_params is not None: | |||
peft_params = config.peft_params.to_dict() | |||
peft_class = { | |||
'lowdim_prompt': LowdimEmbeddingWrapper | |||
}[peft_params.pop('kind')] | |||
delta_module = peft_class.mutate( | |||
model=model, | |||
**peft_params | |||
) | |||
elif config.best_finder.save: | |||
raise NotImplementedError() | |||
freeze_notes = auto_freeze(model, config.hot_modules) | |||
# ______________________LOAD DATA_____________________________ | |||
data_loader = AutoLoad(tokenizer) | |||
dataset = data_loader.get_and_map(config.tasks[0]) | |||
train_loader, valid_loader = generate_dataloader(tokenizer, dataset['train'], dataset['valid'], config) | |||
# ______________________TRAIN_____________________________ | |||
print(delta_module) | |||
wandb.init( | |||
name=config.wandb_name, | |||
project=WANDB_PROJECT_NAME, | |||
config=config.to_dict(), | |||
notes=freeze_notes | |||
) | |||
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) | |||
best_finder = BestFinder(config.best_finder.higher_better) | |||
model.to(DEVICE) | |||
epochs_range = range(config.num_epochs) | |||
if USE_TQDM: | |||
epochs_range = tqdm(epochs_range, position=1, desc="EPOCHS", leave=False) | |||
for epoch in epochs_range: | |||
epoch_results = {} | |||
epoch_results.update( | |||
train_loop( | |||
model=model, | |||
loader=train_loader, | |||
optimizer=optimizer, | |||
use_tqdm=USE_TQDM | |||
) | |||
) | |||
epoch_results.update( | |||
valid_loop( | |||
model=model, | |||
loader=valid_loader, | |||
use_tqdm=USE_TQDM | |||
) | |||
) | |||
if config.best_finder.save: | |||
if best_finder.is_better(epoch_results[config.best_finder.metric]): | |||
torch.save(delta_module.peft_state_dict(), './best.pt') | |||
wandb.log(epoch_results) | |||
wandb.finish() | |||
if __name__ == '__main__': | |||
print_system_info() | |||
silent_logs() | |||
run_configs = configs.run_configs | |||
if USE_TQDM: | |||
run_configs = tqdm(run_configs, position=0, desc="Experiment") | |||
for run_config in run_configs: | |||
run_experminent(run_config) |
@@ -0,0 +1,219 @@ | |||
{ | |||
"cells": [ | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 1, | |||
"id": "e6ecf439-a0db-42e0-a6b9-f512198b0e0e", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"import torch" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 4, | |||
"id": "4bcc7c7e-711a-4cd9-b901-d6ff76938a75", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"best_path = '/home/msadraei/trained_final/iclr_resp_t5_small_glue-cola/10_attempt/best.pt'\n", | |||
"first_path = '/home/msadraei/trained_final/iclr_resp_t5_small_glue-cola/10_attempt/first.pt'" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 5, | |||
"id": "eaa4a300-1e6c-46f0-8f0d-16e9c71c2388", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"best = torch.load(best_path)\n", | |||
"first = torch.load(first_path)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 8, | |||
"id": "c5e0b6bb-3bde-4526-8a6a-5dac0a3b3cc3", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"sadcl_p_target\n", | |||
"tensor(42.7208, device='cuda:0')\n", | |||
"pretrained_tasks\n", | |||
"tensor(0., device='cuda:0')\n", | |||
"sadcl_attention_score.g_network.0.weight\n", | |||
"tensor(157.3032, device='cuda:0')\n", | |||
"sadcl_attention_score.g_network.2.weight\n", | |||
"tensor(154.6590, device='cuda:0')\n", | |||
"sadcl_attention_score.g_network.3.weight\n", | |||
"tensor(18.1127, device='cuda:0')\n", | |||
"sadcl_attention_score.g_network.3.bias\n", | |||
"tensor(19.0149, device='cuda:0')\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"for key in best.keys():\n", | |||
" print(key)\n", | |||
" v1 = first[key]\n", | |||
" v2 = best[key]\n", | |||
" print(torch.norm(v1 - v2))" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 13, | |||
"id": "42815cf2-b8bf-4219-a3fd-ebbe92fb5c32", | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"base_path = '/home/msadraei/trained_final/forward_transfer_test_t5_base_superglue-rte/10_combine_128_4tasks_new_impl_tie_50/100'\n", | |||
"last_path = f'{base_path}/last.pt'\n", | |||
"best_path = f'{base_path}/best.pt'\n", | |||
"first_path = f'{base_path}/first.pt'" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 14, | |||
"id": "880cb651-ddea-4564-93ab-c5f52e1f02dd", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"import torch\n", | |||
"last = torch.load(last_path)\n", | |||
"best = torch.load(best_path)\n", | |||
"first = torch.load(first_path)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 15, | |||
"id": "ee4b3287-203f-49b0-8b89-6070f9ff4062", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"import numpy as np\n", | |||
"def pretrained_coeff(state_dict):\n", | |||
" return np.stack([\n", | |||
" val.cpu().numpy()\n", | |||
" for key, val in state_dict.items()\n", | |||
" if 'sadcl_coeff_pretrained' in key\n", | |||
" ])" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 16, | |||
"id": "26518ecd-8cc1-4543-acaf-56637295bbe8", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"last_coeff = pretrained_coeff(best)\n", | |||
"best_coeff = pretrained_coeff(best)\n", | |||
"first_coeff = pretrained_coeff(first)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 17, | |||
"id": "5a850a65-724a-483d-abb3-b7de6118db31", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"array([[0.43, 0.42, 0.42, 0.42],\n", | |||
" [0.43, 0.42, 0.42, 0.42],\n", | |||
" [0.43, 0.42, 0.42, 0.42],\n", | |||
" [0.43, 0.42, 0.42, 0.42],\n", | |||
" [0.43, 0.42, 0.42, 0.42],\n", | |||
" [0.43, 0.42, 0.42, 0.42],\n", | |||
" [0.43, 0.42, 0.42, 0.42],\n", | |||
" [0.43, 0.42, 0.42, 0.42],\n", | |||
" [0.43, 0.42, 0.42, 0.42],\n", | |||
" [0.43, 0.42, 0.42, 0.42]], dtype=float32)" | |||
] | |||
}, | |||
"execution_count": 17, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
} | |||
], | |||
"source": [ | |||
"np.round(last_coeff/ 100 , 2)\n" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 65, | |||
"id": "7182b595-5bb3-4c06-88dc-1f50ed774500", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"tensor(34.9105)" | |||
] | |||
}, | |||
"execution_count": 65, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
} | |||
], | |||
"source": [ | |||
"torch.linalg.vector_norm(torch.Tensor(best_coeff[0]), ord=1)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"id": "9e2a2080-9450-4df2-b20e-4619e3f92c1b", | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [] | |||
} | |||
], | |||
"metadata": { | |||
"kernelspec": { | |||
"display_name": "Python [conda env:deep]", | |||
"language": "python", | |||
"name": "conda-env-deep-py" | |||
}, | |||
"language_info": { | |||
"codemirror_mode": { | |||
"name": "ipython", | |||
"version": 3 | |||
}, | |||
"file_extension": ".py", | |||
"mimetype": "text/x-python", | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.10.13" | |||
} | |||
}, | |||
"nbformat": 4, | |||
"nbformat_minor": 5 | |||
} |
@@ -0,0 +1,538 @@ | |||
{ | |||
"cells": [ | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 1, | |||
"id": "3526e83a-baa5-4278-81ce-e142e0a6d208", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"import sys\n", | |||
"from pathlib import Path\n", | |||
"sys.path.append(Path('./').absolute().parent.__str__())\n", | |||
"from _datasets import AutoLoad" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 48, | |||
"id": "5a0264f8-4b67-44e2-8aa9-468ae8b249b5", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"(12, 15)\n", | |||
"{'a': 'b'}\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"class Test():\n", | |||
" def __new__(cls, *args, **kwargs):\n", | |||
" print(args)\n", | |||
" print(kwargs)\n", | |||
"Test(12, 15, a='b')" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 10, | |||
"id": "f0d8ead2-cfa6-4044-8e7a-6b7146bea9cd", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"from transformers import T5TokenizerFast\n", | |||
"\n", | |||
"tokenizer = T5TokenizerFast.from_pretrained('google/t5-small-lm-adapt')\n", | |||
"tokenizer._is_seq2seq = True\n", | |||
"loader = AutoLoad(tokenizer=tokenizer)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 19, | |||
"id": "07c556fd-780d-4aee-a5e9-ad81a474d94b", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"['sentence1', 'sentence2']" | |||
] | |||
}, | |||
"execution_count": 19, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
} | |||
], | |||
"source": [ | |||
"loader.glue_helper.get_task_input('stsb')" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 11, | |||
"id": "04feb162-ef3f-42a8-ab00-23d3faea5209", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"application/vnd.jupyter.widget-view+json": { | |||
"model_id": "8165afbb7bcb474e80b9538b0c0c39da", | |||
"version_major": 2, | |||
"version_minor": 0 | |||
}, | |||
"text/plain": [ | |||
"Map: 0%| | 0/5749 [00:00<?, ? examples/s]" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"application/vnd.jupyter.widget-view+json": { | |||
"model_id": "95318c2e7b684eabb280fd34d014f1d3", | |||
"version_major": 2, | |||
"version_minor": 0 | |||
}, | |||
"text/plain": [ | |||
"Map: 0%| | 0/1500 [00:00<?, ? examples/s]" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"application/vnd.jupyter.widget-view+json": { | |||
"model_id": "0e47b3895f4d4f77920c8d82579ec683", | |||
"version_major": 2, | |||
"version_minor": 0 | |||
}, | |||
"text/plain": [ | |||
"Map: 0%| | 0/1500 [00:00<?, ? examples/s]" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
} | |||
], | |||
"source": [ | |||
"ds = loader.get_and_map('glue:stsb')" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 43, | |||
"id": "9dcf1e0c-e703-4e30-9dab-bfc54cde7d3f", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"application/vnd.jupyter.widget-view+json": { | |||
"model_id": "e703362287be445fa8f3949c592b1c26", | |||
"version_major": 2, | |||
"version_minor": 0 | |||
}, | |||
"text/plain": [ | |||
"Downloading data: 0%| | 0.00/51.8M [00:00<?, ?B/s]" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"application/vnd.jupyter.widget-view+json": { | |||
"model_id": "2d231baabf80401eacf8c400a811c5ac", | |||
"version_major": 2, | |||
"version_minor": 0 | |||
}, | |||
"text/plain": [ | |||
"Generating train split: 0%| | 0/100730 [00:00<?, ? examples/s]" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"application/vnd.jupyter.widget-view+json": { | |||
"model_id": "6c699b3fdf1e468e9ef8a442651d1f7c", | |||
"version_major": 2, | |||
"version_minor": 0 | |||
}, | |||
"text/plain": [ | |||
"Generating validation split: 0%| | 0/10000 [00:00<?, ? examples/s]" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"application/vnd.jupyter.widget-view+json": { | |||
"model_id": "91acd57830124beeb29c9869f3b67788", | |||
"version_major": 2, | |||
"version_minor": 0 | |||
}, | |||
"text/plain": [ | |||
"Generating test split: 0%| | 0/10000 [00:00<?, ? examples/s]" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
} | |||
], | |||
"source": [ | |||
"from datasets import load_dataset\n", | |||
"\n", | |||
"ds = load_dataset('super_glue', 'record')" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 46, | |||
"id": "c4d652d7-8237-4e5a-85e5-faf39a88eea5", | |||
"metadata": { | |||
"tags": [] | |||
}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"{'passage': \"For everyone who has ever thought about shooting their boss - metaphorically, o fcourse - this one is for you. An employee of a Texas armored car company got to do just that this week to 'demonstrate that they take client safety seriously'. And to further that demonstration, the CEO was sitting alone inside the Mercedes-Benz as 12 rounds from an AK-47 rained down upon the SUV. The company, Texas Armoring Corporation, has supplied protected vehicles to the Pope, celebrities like rapper T.I. and actor Steven Segal and oil executives in West Africa, according to My San Antonio. Texas Armoring Corp. & Jason Forston.\\n@highlight\\nTexas Armoring Corporation created a video to show the effectiveness of their armored\\n@highlight\\nCEO R. Trent Kimball sat in the drivers seat of a Mercedes-Benz SUV\\n@highlight\\nTotal of 12 rounds fired at the windscreen\\n@highlight\\nCompany known for working with celebrities, oil barons and even the Pope\",\n", | |||
" 'query': \"'When it comes to assuring our clients' safety, we take product testing extremely seriously,' @placeholder says in a video taken of the display.\",\n", | |||
" 'entities': ['Steven Segal',\n", | |||
" 'Texas Armoring Corp.',\n", | |||
" 'Trent Kimball',\n", | |||
" 'Texas Armoring Corporation',\n", | |||
" 'Texas',\n", | |||
" 'AK-47',\n", | |||
" 'Pope',\n", | |||
" 'Mercedes-Benz',\n", | |||
" 'San Antonio',\n", | |||
" 'West Africa',\n", | |||