{ "cells": [ { "cell_type": "code", "execution_count": 6, "id": "a50443d6-fe09-4905-b913-1be5f88c8c03", "metadata": { "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "from tqdm import tqdm\n", "from sklearn.model_selection import train_test_split\n", "import torch\n", "import torch.nn as nn\n", "from transformers import T5Model" ] }, { "cell_type": "code", "execution_count": 7, "id": "4e677034-dc27-4939-8ea2-71fcbb2da57d", "metadata": { "tags": [] }, "outputs": [], "source": [ "np_rng = np.random.default_rng(seed=42)" ] }, { "cell_type": "code", "execution_count": 8, "id": "3d139e0a-b8e3-427b-a537-44bc0f14ba46", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "array([[ 0.09141512, -0.31199523],\n", " [ 0.22513536, 0.28216941],\n", " [-0.58531056, -0.39065385],\n", " [ 0.03835212, -0.09487278],\n", " [-0.00504035, -0.25591318],\n", " [ 0.26381939, 0.23333758],\n", " [ 0.01980921, 0.33817236],\n", " [ 0.1402528 , -0.25778774],\n", " [ 0.11062524, -0.28766478],\n", " [ 0.26353509, -0.01497777],\n", " [-0.05545871, -0.20427886],\n", " [ 0.3667624 , -0.04635884],\n", " [-0.12849835, -0.10564007],\n", " [ 0.15969276, 0.10963322],\n", " [ 0.12381978, 0.1292463 ],\n", " [ 0.64249428, -0.1219245 ],\n", " [-0.15367282, -0.24413182],\n", " [ 0.18479383, 0.33869169],\n", " [-0.03418424, -0.25204694],\n", " [-0.24734436, 0.19517784],\n", " [ 0.22297625, 0.16294628],\n", " [-0.19965291, 0.0696484 ],\n", " [ 0.03500574, 0.06560658],\n", " [ 0.26142863, 0.06707866],\n", " [ 0.20367407, 0.02027372],\n", " [ 0.08673582, 0.18938647],\n", " [-0.43714675, -0.09590136],\n", " [-0.1411118 , -0.19166335],\n", " [-0.08254268, 0.44848239],\n", " [-0.25974933, 0.29048351],\n", " [-0.50486093, -0.10046551],\n", " [ 0.04882592, 0.1758667 ]])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np_rng.normal(loc=0, scale=0.3, size=(32, 2))" ] }, { "cell_type": "code", "execution_count": 9, "id": "544207bc-37fc-4376-9c63-bff44c72b32f", "metadata": { "tags": [] }, "outputs": [], "source": [ "# BOTTLENECK_SIZE = 128\n", "TRAIN_BATCH_SIZE = 8192\n", "VALID_BATCH_SIZE = 8192\n", "RANDOM_SEED = 42\n", "\n", "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "37d2d256-a348-402b-999d-1a4edce360c5", "metadata": { "tags": [] }, "outputs": [], "source": [ "def train_valid_test_split(total_range, random_seed=RANDOM_SEED):\n", " train, testvalid = train_test_split(total_range, random_state=RANDOM_SEED, test_size=0.2)\n", " test, valid = train_test_split(testvalid, random_state=RANDOM_SEED, test_size=0.5)\n", " return train, valid, test\n", "\n", "def custom_dataloader(words_ids, batch_size, emb_dim, random_seed=RANDOM_SEED):\n", " np_rng = np.random.default_rng(seed=random_seed)\n", " while True:\n", " word_ids = np_rng.choice(words_ids, size=(batch_size, 2))\n", " additive_noise = np_rng.normal(loc=0, scale=0.1, size=(batch_size, emb_dim))\n", " alpha = np_rng.uniform(size=(batch_size, 1))\n", " yield torch.from_numpy(word_ids), torch.Tensor(additive_noise), torch.Tensor(alpha)\n", " \n", "class FakeEpoch:\n", " def __init__(self, dataloader, each_epoch_size):\n", " self.dataloader_iter = iter(dataloader)\n", " self.each_epoch_size = each_epoch_size\n", " \n", " def __len__(self):\n", " return self.each_epoch_size\n", " \n", " def __iter__(self):\n", " for _ in range(self.each_epoch_size):\n", " yield next(self.dataloader_iter)" ] }, { "cell_type": "code", "execution_count": 11, "id": "644ae479-3f9a-426a-bd0b-4ec7694bc675", "metadata": { "tags": [] }, "outputs": [], "source": [ "\n", "def ez_freeze(module):\n", " for param in module.parameters():\n", " param.requires_grad = False\n", " \n", "def ez_mlp(linear_dims, last_layer_bias=False):\n", " layers = []\n", " pairs_count = len(linear_dims) - 1\n", " for idx in range(pairs_count):\n", " in_dim, out_dim = linear_dims[idx], linear_dims[idx + 1]\n", " if idx == pairs_count - 1:\n", " layers.append(nn.Linear(in_dim, out_dim, bias=last_layer_bias))\n", " else:\n", " layers.append(nn.Linear(in_dim, out_dim, bias=True))\n", " layers.append(nn.ReLU())\n", " return nn.Sequential(*layers)\n", "\n", "def auto_encoder_model(linear_dims):\n", " return nn.Sequential(\n", " ez_mlp(linear_dims, last_layer_bias=False),\n", " nn.LayerNorm(linear_dims[-1]),\n", " ez_mlp(list(reversed(linear_dims)), last_layer_bias=True)\n", " )\n", "\n", "class AutoEncoderModel(nn.Module):\n", " def __init__(self, pretrained_name, bottleneck_sizes):\n", " super().__init__()\n", " \n", " self.bottleneck_size = bottleneck_sizes\n", " \n", " model = T5Model.from_pretrained(pretrained_name)\n", " self.emb_layer = model.get_encoder().get_input_embeddings()\n", " ez_freeze(self.emb_layer)\n", " \n", " self.auto_encoder = auto_encoder_model([\n", " self.embedding_dim,\n", " *bottleneck_sizes\n", " ])\n", " \n", " self.loss_fn = nn.MSELoss()\n", " \n", " def forward(self, word_ids, additive_noise, alpha):\n", " # word_ids.shape = (batch_size, 2)\n", " # additive_noise.shape = (batch_size, embedding_dim)\n", " # alpha.shape = (batch_size, 1)\n", " \n", " word_embs = self.emb_layer(word_ids)\n", " # word_embs.shape = (batch_size, 2, embedding_dim)\n", " \n", " word_combs = word_embs[:, 0] * alpha + word_embs[:, 1] * (1 - alpha)\n", " # word_combs.shape = (batch_size, embedding_dim)\n", " \n", " y_hat = self.auto_encoder(word_combs + additive_noise)\n", " loss = self.loss_fn(word_combs, y_hat)\n", " return loss, y_hat\n", " \n", " @property\n", " def embedding_dim(self):\n", " return self.emb_layer.embedding_dim\n", " \n", " @property\n", " def num_embeddings(self):\n", " return self.emb_layer.num_embeddings " ] }, { "cell_type": "code", "execution_count": 12, "id": "aba28049-20bf-4ae6-9445-2f7c294686d8", "metadata": { "tags": [] }, "outputs": [], "source": [ "model = AutoEncoderModel('google/t5-large-lm-adapt', bottleneck_sizes=[768, 768, 512, 512, 256, 256, 128, 128])" ] }, { "cell_type": "code", "execution_count": 16, "id": "cac6bc39-ba12-4052-bd5f-8834f57cfa15", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "tensor(96.9082)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(model.emb_layer.weight**2).mean()" ] }, { "cell_type": "code", "execution_count": 6, "id": "afe2efbf-e703-4c43-8f7b-a87d303ea89e", "metadata": { "tags": [] }, "outputs": [], "source": [ "train_ds, valid_ds, test_ds = train_valid_test_split(range(model.num_embeddings))\n", "train_loader = custom_dataloader(words_ids=train_ds, batch_size=TRAIN_BATCH_SIZE, emb_dim=model.embedding_dim)\n", "valid_loader = custom_dataloader(words_ids=valid_ds, batch_size=VALID_BATCH_SIZE, emb_dim=model.embedding_dim)" ] }, { "cell_type": "code", "execution_count": 7, "id": "c24ccc1c-4cbe-4373-871e-9090dceb69a1", "metadata": {}, "outputs": [], "source": [ "train_loader = FakeEpoch(train_loader, 1000)\n", "valid_loader = FakeEpoch(valid_loader, 100)" ] }, { "cell_type": "code", "execution_count": 8, "id": "71936e43-d718-45ef-8115-7fc63999ebd9", "metadata": { "tags": [] }, "outputs": [], "source": [ "def _prefix_dict_keys(prefix, input_dict):\n", " return {f'{prefix}_{key}': val for key, val in input_dict.items()}\n", "\n", "def train_loop(model, loader, optimizer, use_tqdm=False):\n", " model.train()\n", "\n", " batch_losses = []\n", " \n", " if use_tqdm:\n", " loader = tqdm(loader, position=2, desc=\"Train Loop\", leave=False)\n", " \n", " for row in loader:\n", " optimizer.zero_grad()\n", " \n", " out = model(*(item.to(DEVICE) for item in row))\n", " loss = out[0]\n", " \n", " batch_loss_value = loss.item()\n", " loss.backward()\n", " optimizer.step()\n", " \n", " batch_losses.append(batch_loss_value)\n", " \n", " loss_value = np.mean(batch_losses)\n", " return _prefix_dict_keys('train', {\n", " 'loss': loss_value\n", " })\n", "\n", "def valid_loop(model, loader, use_tqdm=False):\n", " model.eval()\n", "\n", " batch_losses = []\n", " \n", " all_true = []\n", " all_pred = []\n", " \n", " if use_tqdm:\n", " loader = tqdm(loader, position=2, desc=\"Valid Loop\", leave=False)\n", " \n", " with torch.no_grad():\n", " for row in loader:\n", " out = model(*(item.to(DEVICE) for item in row))\n", " loss = out[0]\n", " \n", " batch_loss_value = loss.item()\n", "\n", " batch_losses.append(batch_loss_value)\n", "\n", " loss_value = np.mean(batch_losses)\n", " \n", " return_value = {\n", " 'loss': loss_value,\n", " }\n", " \n", " return _prefix_dict_keys('valid', return_value)" ] }, { "cell_type": "code", "execution_count": 9, "id": "082b5384-827f-48b3-aa8e-40483668bbc0", "metadata": { "tags": [] }, "outputs": [ { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[9], line 8\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1000\u001b[39m):\n\u001b[1;32m 5\u001b[0m epoch_results \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m 7\u001b[0m epoch_results\u001b[38;5;241m.\u001b[39mupdate(\n\u001b[0;32m----> 8\u001b[0m \u001b[43mtrain_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43mloader\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_tqdm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\n\u001b[1;32m 13\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 14\u001b[0m )\n\u001b[1;32m 16\u001b[0m epoch_results\u001b[38;5;241m.\u001b[39mupdate(\n\u001b[1;32m 17\u001b[0m valid_loop(\n\u001b[1;32m 18\u001b[0m model\u001b[38;5;241m=\u001b[39mmodel,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 21\u001b[0m )\n\u001b[1;32m 22\u001b[0m )\n\u001b[1;32m 23\u001b[0m \u001b[38;5;28mprint\u001b[39m(epoch_results)\n", "Cell \u001b[0;32mIn[8], line 12\u001b[0m, in \u001b[0;36mtrain_loop\u001b[0;34m(model, loader, optimizer, use_tqdm)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_tqdm:\n\u001b[1;32m 10\u001b[0m loader \u001b[38;5;241m=\u001b[39m tqdm(loader, position\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m, desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTrain Loop\u001b[39m\u001b[38;5;124m\"\u001b[39m, leave\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m---> 12\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m row \u001b[38;5;129;01min\u001b[39;00m loader:\n\u001b[1;32m 13\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m 15\u001b[0m out \u001b[38;5;241m=\u001b[39m model(\u001b[38;5;241m*\u001b[39m(item\u001b[38;5;241m.\u001b[39mto(DEVICE) \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m row))\n", "Cell \u001b[0;32mIn[3], line 24\u001b[0m, in \u001b[0;36mFakeEpoch.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__iter__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 23\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39meach_epoch_size):\n\u001b[0;32m---> 24\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataloader_iter\u001b[49m\u001b[43m)\u001b[49m\n", "Cell \u001b[0;32mIn[3], line 10\u001b[0m, in \u001b[0;36mcustom_dataloader\u001b[0;34m(words_ids, batch_size, emb_dim, random_seed)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[1;32m 9\u001b[0m word_ids \u001b[38;5;241m=\u001b[39m np_rng\u001b[38;5;241m.\u001b[39mchoice(words_ids, size\u001b[38;5;241m=\u001b[39m(batch_size, \u001b[38;5;241m2\u001b[39m))\n\u001b[0;32m---> 10\u001b[0m additive_noise \u001b[38;5;241m=\u001b[39m \u001b[43mnp_rng\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnormal\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43memb_dim\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 11\u001b[0m alpha \u001b[38;5;241m=\u001b[39m np_rng\u001b[38;5;241m.\u001b[39muniform(size\u001b[38;5;241m=\u001b[39m(batch_size, \u001b[38;5;241m1\u001b[39m))\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mfrom_numpy(word_ids), torch\u001b[38;5;241m.\u001b[39mTensor(additive_noise), torch\u001b[38;5;241m.\u001b[39mTensor(alpha)\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "model.to(DEVICE)\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)\n", "\n", "for epoch in range(1000):\n", " epoch_results = {}\n", "\n", " epoch_results.update(\n", " train_loop(\n", " model=model,\n", " loader=train_loader,\n", " optimizer=optimizer,\n", " use_tqdm=False\n", " )\n", " )\n", "\n", " epoch_results.update(\n", " valid_loop(\n", " model=model,\n", " loader=valid_loader,\n", " use_tqdm=False\n", " )\n", " )\n", " print(epoch_results)" ] }, { "cell_type": "code", "execution_count": null, "id": "53425637-6146-41d2-b59e-4617ae1f8521", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:deep]", "language": "python", "name": "conda-env-deep-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" } }, "nbformat": 4, "nbformat_minor": 5 }