123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417 |
- {
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 6,
- "id": "a50443d6-fe09-4905-b913-1be5f88c8c03",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "from tqdm import tqdm\n",
- "from sklearn.model_selection import train_test_split\n",
- "import torch\n",
- "import torch.nn as nn\n",
- "from transformers import T5Model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "id": "4e677034-dc27-4939-8ea2-71fcbb2da57d",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "np_rng = np.random.default_rng(seed=42)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "id": "3d139e0a-b8e3-427b-a537-44bc0f14ba46",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "array([[ 0.09141512, -0.31199523],\n",
- " [ 0.22513536, 0.28216941],\n",
- " [-0.58531056, -0.39065385],\n",
- " [ 0.03835212, -0.09487278],\n",
- " [-0.00504035, -0.25591318],\n",
- " [ 0.26381939, 0.23333758],\n",
- " [ 0.01980921, 0.33817236],\n",
- " [ 0.1402528 , -0.25778774],\n",
- " [ 0.11062524, -0.28766478],\n",
- " [ 0.26353509, -0.01497777],\n",
- " [-0.05545871, -0.20427886],\n",
- " [ 0.3667624 , -0.04635884],\n",
- " [-0.12849835, -0.10564007],\n",
- " [ 0.15969276, 0.10963322],\n",
- " [ 0.12381978, 0.1292463 ],\n",
- " [ 0.64249428, -0.1219245 ],\n",
- " [-0.15367282, -0.24413182],\n",
- " [ 0.18479383, 0.33869169],\n",
- " [-0.03418424, -0.25204694],\n",
- " [-0.24734436, 0.19517784],\n",
- " [ 0.22297625, 0.16294628],\n",
- " [-0.19965291, 0.0696484 ],\n",
- " [ 0.03500574, 0.06560658],\n",
- " [ 0.26142863, 0.06707866],\n",
- " [ 0.20367407, 0.02027372],\n",
- " [ 0.08673582, 0.18938647],\n",
- " [-0.43714675, -0.09590136],\n",
- " [-0.1411118 , -0.19166335],\n",
- " [-0.08254268, 0.44848239],\n",
- " [-0.25974933, 0.29048351],\n",
- " [-0.50486093, -0.10046551],\n",
- " [ 0.04882592, 0.1758667 ]])"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "np_rng.normal(loc=0, scale=0.3, size=(32, 2))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "id": "544207bc-37fc-4376-9c63-bff44c72b32f",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "# BOTTLENECK_SIZE = 128\n",
- "TRAIN_BATCH_SIZE = 8192\n",
- "VALID_BATCH_SIZE = 8192\n",
- "RANDOM_SEED = 42\n",
- "\n",
- "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "id": "37d2d256-a348-402b-999d-1a4edce360c5",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "def train_valid_test_split(total_range, random_seed=RANDOM_SEED):\n",
- " train, testvalid = train_test_split(total_range, random_state=RANDOM_SEED, test_size=0.2)\n",
- " test, valid = train_test_split(testvalid, random_state=RANDOM_SEED, test_size=0.5)\n",
- " return train, valid, test\n",
- "\n",
- "def custom_dataloader(words_ids, batch_size, emb_dim, random_seed=RANDOM_SEED):\n",
- " np_rng = np.random.default_rng(seed=random_seed)\n",
- " while True:\n",
- " word_ids = np_rng.choice(words_ids, size=(batch_size, 2))\n",
- " additive_noise = np_rng.normal(loc=0, scale=0.1, size=(batch_size, emb_dim))\n",
- " alpha = np_rng.uniform(size=(batch_size, 1))\n",
- " yield torch.from_numpy(word_ids), torch.Tensor(additive_noise), torch.Tensor(alpha)\n",
- " \n",
- "class FakeEpoch:\n",
- " def __init__(self, dataloader, each_epoch_size):\n",
- " self.dataloader_iter = iter(dataloader)\n",
- " self.each_epoch_size = each_epoch_size\n",
- " \n",
- " def __len__(self):\n",
- " return self.each_epoch_size\n",
- " \n",
- " def __iter__(self):\n",
- " for _ in range(self.each_epoch_size):\n",
- " yield next(self.dataloader_iter)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "id": "644ae479-3f9a-426a-bd0b-4ec7694bc675",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "\n",
- "def ez_freeze(module):\n",
- " for param in module.parameters():\n",
- " param.requires_grad = False\n",
- " \n",
- "def ez_mlp(linear_dims, last_layer_bias=False):\n",
- " layers = []\n",
- " pairs_count = len(linear_dims) - 1\n",
- " for idx in range(pairs_count):\n",
- " in_dim, out_dim = linear_dims[idx], linear_dims[idx + 1]\n",
- " if idx == pairs_count - 1:\n",
- " layers.append(nn.Linear(in_dim, out_dim, bias=last_layer_bias))\n",
- " else:\n",
- " layers.append(nn.Linear(in_dim, out_dim, bias=True))\n",
- " layers.append(nn.ReLU())\n",
- " return nn.Sequential(*layers)\n",
- "\n",
- "def auto_encoder_model(linear_dims):\n",
- " return nn.Sequential(\n",
- " ez_mlp(linear_dims, last_layer_bias=False),\n",
- " nn.LayerNorm(linear_dims[-1]),\n",
- " ez_mlp(list(reversed(linear_dims)), last_layer_bias=True)\n",
- " )\n",
- "\n",
- "class AutoEncoderModel(nn.Module):\n",
- " def __init__(self, pretrained_name, bottleneck_sizes):\n",
- " super().__init__()\n",
- " \n",
- " self.bottleneck_size = bottleneck_sizes\n",
- " \n",
- " model = T5Model.from_pretrained(pretrained_name)\n",
- " self.emb_layer = model.get_encoder().get_input_embeddings()\n",
- " ez_freeze(self.emb_layer)\n",
- " \n",
- " self.auto_encoder = auto_encoder_model([\n",
- " self.embedding_dim,\n",
- " *bottleneck_sizes\n",
- " ])\n",
- " \n",
- " self.loss_fn = nn.MSELoss()\n",
- " \n",
- " def forward(self, word_ids, additive_noise, alpha):\n",
- " # word_ids.shape = (batch_size, 2)\n",
- " # additive_noise.shape = (batch_size, embedding_dim)\n",
- " # alpha.shape = (batch_size, 1)\n",
- " \n",
- " word_embs = self.emb_layer(word_ids)\n",
- " # word_embs.shape = (batch_size, 2, embedding_dim)\n",
- " \n",
- " word_combs = word_embs[:, 0] * alpha + word_embs[:, 1] * (1 - alpha)\n",
- " # word_combs.shape = (batch_size, embedding_dim)\n",
- " \n",
- " y_hat = self.auto_encoder(word_combs + additive_noise)\n",
- " loss = self.loss_fn(word_combs, y_hat)\n",
- " return loss, y_hat\n",
- " \n",
- " @property\n",
- " def embedding_dim(self):\n",
- " return self.emb_layer.embedding_dim\n",
- " \n",
- " @property\n",
- " def num_embeddings(self):\n",
- " return self.emb_layer.num_embeddings "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "id": "aba28049-20bf-4ae6-9445-2f7c294686d8",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "model = AutoEncoderModel('google/t5-large-lm-adapt', bottleneck_sizes=[768, 768, 512, 512, 256, 256, 128, 128])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "id": "cac6bc39-ba12-4052-bd5f-8834f57cfa15",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor(96.9082)"
- ]
- },
- "execution_count": 16,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "(model.emb_layer.weight**2).mean()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "id": "afe2efbf-e703-4c43-8f7b-a87d303ea89e",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "train_ds, valid_ds, test_ds = train_valid_test_split(range(model.num_embeddings))\n",
- "train_loader = custom_dataloader(words_ids=train_ds, batch_size=TRAIN_BATCH_SIZE, emb_dim=model.embedding_dim)\n",
- "valid_loader = custom_dataloader(words_ids=valid_ds, batch_size=VALID_BATCH_SIZE, emb_dim=model.embedding_dim)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "id": "c24ccc1c-4cbe-4373-871e-9090dceb69a1",
- "metadata": {},
- "outputs": [],
- "source": [
- "train_loader = FakeEpoch(train_loader, 1000)\n",
- "valid_loader = FakeEpoch(valid_loader, 100)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "id": "71936e43-d718-45ef-8115-7fc63999ebd9",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "def _prefix_dict_keys(prefix, input_dict):\n",
- " return {f'{prefix}_{key}': val for key, val in input_dict.items()}\n",
- "\n",
- "def train_loop(model, loader, optimizer, use_tqdm=False):\n",
- " model.train()\n",
- "\n",
- " batch_losses = []\n",
- " \n",
- " if use_tqdm:\n",
- " loader = tqdm(loader, position=2, desc=\"Train Loop\", leave=False)\n",
- " \n",
- " for row in loader:\n",
- " optimizer.zero_grad()\n",
- " \n",
- " out = model(*(item.to(DEVICE) for item in row))\n",
- " loss = out[0]\n",
- " \n",
- " batch_loss_value = loss.item()\n",
- " loss.backward()\n",
- " optimizer.step()\n",
- " \n",
- " batch_losses.append(batch_loss_value)\n",
- " \n",
- " loss_value = np.mean(batch_losses)\n",
- " return _prefix_dict_keys('train', {\n",
- " 'loss': loss_value\n",
- " })\n",
- "\n",
- "def valid_loop(model, loader, use_tqdm=False):\n",
- " model.eval()\n",
- "\n",
- " batch_losses = []\n",
- " \n",
- " all_true = []\n",
- " all_pred = []\n",
- " \n",
- " if use_tqdm:\n",
- " loader = tqdm(loader, position=2, desc=\"Valid Loop\", leave=False)\n",
- " \n",
- " with torch.no_grad():\n",
- " for row in loader:\n",
- " out = model(*(item.to(DEVICE) for item in row))\n",
- " loss = out[0]\n",
- " \n",
- " batch_loss_value = loss.item()\n",
- "\n",
- " batch_losses.append(batch_loss_value)\n",
- "\n",
- " loss_value = np.mean(batch_losses)\n",
- " \n",
- " return_value = {\n",
- " 'loss': loss_value,\n",
- " }\n",
- " \n",
- " return _prefix_dict_keys('valid', return_value)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "id": "082b5384-827f-48b3-aa8e-40483668bbc0",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "ename": "KeyboardInterrupt",
- "evalue": "",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
- "Cell \u001b[0;32mIn[9], line 8\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1000\u001b[39m):\n\u001b[1;32m 5\u001b[0m epoch_results \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m 7\u001b[0m epoch_results\u001b[38;5;241m.\u001b[39mupdate(\n\u001b[0;32m----> 8\u001b[0m \u001b[43mtrain_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43mloader\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_tqdm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\n\u001b[1;32m 13\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 14\u001b[0m )\n\u001b[1;32m 16\u001b[0m epoch_results\u001b[38;5;241m.\u001b[39mupdate(\n\u001b[1;32m 17\u001b[0m valid_loop(\n\u001b[1;32m 18\u001b[0m model\u001b[38;5;241m=\u001b[39mmodel,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 21\u001b[0m )\n\u001b[1;32m 22\u001b[0m )\n\u001b[1;32m 23\u001b[0m \u001b[38;5;28mprint\u001b[39m(epoch_results)\n",
- "Cell \u001b[0;32mIn[8], line 12\u001b[0m, in \u001b[0;36mtrain_loop\u001b[0;34m(model, loader, optimizer, use_tqdm)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_tqdm:\n\u001b[1;32m 10\u001b[0m loader \u001b[38;5;241m=\u001b[39m tqdm(loader, position\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m, desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTrain Loop\u001b[39m\u001b[38;5;124m\"\u001b[39m, leave\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m---> 12\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m row \u001b[38;5;129;01min\u001b[39;00m loader:\n\u001b[1;32m 13\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m 15\u001b[0m out \u001b[38;5;241m=\u001b[39m model(\u001b[38;5;241m*\u001b[39m(item\u001b[38;5;241m.\u001b[39mto(DEVICE) \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m row))\n",
- "Cell \u001b[0;32mIn[3], line 24\u001b[0m, in \u001b[0;36mFakeEpoch.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__iter__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 23\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39meach_epoch_size):\n\u001b[0;32m---> 24\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataloader_iter\u001b[49m\u001b[43m)\u001b[49m\n",
- "Cell \u001b[0;32mIn[3], line 10\u001b[0m, in \u001b[0;36mcustom_dataloader\u001b[0;34m(words_ids, batch_size, emb_dim, random_seed)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[1;32m 9\u001b[0m word_ids \u001b[38;5;241m=\u001b[39m np_rng\u001b[38;5;241m.\u001b[39mchoice(words_ids, size\u001b[38;5;241m=\u001b[39m(batch_size, \u001b[38;5;241m2\u001b[39m))\n\u001b[0;32m---> 10\u001b[0m additive_noise \u001b[38;5;241m=\u001b[39m \u001b[43mnp_rng\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnormal\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43memb_dim\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 11\u001b[0m alpha \u001b[38;5;241m=\u001b[39m np_rng\u001b[38;5;241m.\u001b[39muniform(size\u001b[38;5;241m=\u001b[39m(batch_size, \u001b[38;5;241m1\u001b[39m))\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mfrom_numpy(word_ids), torch\u001b[38;5;241m.\u001b[39mTensor(additive_noise), torch\u001b[38;5;241m.\u001b[39mTensor(alpha)\n",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
- ]
- }
- ],
- "source": [
- "model.to(DEVICE)\n",
- "optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)\n",
- "\n",
- "for epoch in range(1000):\n",
- " epoch_results = {}\n",
- "\n",
- " epoch_results.update(\n",
- " train_loop(\n",
- " model=model,\n",
- " loader=train_loader,\n",
- " optimizer=optimizer,\n",
- " use_tqdm=False\n",
- " )\n",
- " )\n",
- "\n",
- " epoch_results.update(\n",
- " valid_loop(\n",
- " model=model,\n",
- " loader=valid_loader,\n",
- " use_tqdm=False\n",
- " )\n",
- " )\n",
- " print(epoch_results)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "53425637-6146-41d2-b59e-4617ae1f8521",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python [conda env:deep]",
- "language": "python",
- "name": "conda-env-deep-py"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.11"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
- }
|