|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- #!/usr/bin/env python
- # coding: utf-8
-
- # In[1]:
-
-
- import numpy as np
- from tqdm import tqdm
- from sklearn.model_selection import train_test_split
- import torch
- import torch.nn as nn
- from transformers import T5Model
-
-
- # In[2]:
-
-
- # BOTTLENECK_SIZE = 128
- TRAIN_BATCH_SIZE = 8192
- VALID_BATCH_SIZE = 8192
- NOISE_SCALE = 1
- RANDOM_SEED = 42
- SEED_SHIFT = 0
- DROP_OUT = 0.2
-
- DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-
-
- # In[3]:
-
-
- def train_valid_test_split(total_range, random_seed=RANDOM_SEED):
- train, testvalid = train_test_split(total_range, random_state=random_seed, test_size=0.2)
- test, valid = train_test_split(testvalid, random_state=random_seed, test_size=0.5)
- return train, valid, test
-
- def custom_dataloader(words_ids, batch_size, emb_dim, random_seed=RANDOM_SEED+SEED_SHIFT):
- np_rng = np.random.default_rng(seed=random_seed)
- while True:
- word_ids = np_rng.choice(words_ids, size=(batch_size, 2))
- additive_noise = np_rng.normal(loc=0, scale=NOISE_SCALE, size=(batch_size, emb_dim))
- alpha = np_rng.uniform(size=(batch_size, 1))
- yield torch.from_numpy(word_ids), torch.Tensor(additive_noise), torch.Tensor(alpha)
-
- class FakeEpoch:
- def __init__(self, dataloader, each_epoch_size):
- self.dataloader_iter = iter(dataloader)
- self.each_epoch_size = each_epoch_size
-
- def __len__(self):
- return self.each_epoch_size
-
- def __iter__(self):
- for _ in range(self.each_epoch_size):
- yield next(self.dataloader_iter)
-
-
- # In[4]:
-
-
- def ez_freeze(module):
- for param in module.parameters():
- param.requires_grad = False
-
- class ResLinear(nn.Module):
- def __init__(self, in_dim, out_dim):
- super().__init__()
- self.linear1 = nn.Linear(in_dim, out_dim)
- self.linear2 = nn.Linear(out_dim, out_dim)
-
- def forward(self, x):
- out1 = nn.functional.relu(self.linear1(x))
- out2 = nn.functional.relu(self.linear2(out1))
- return out1 + out2
-
- def ez_mlp(linear_dims, last_layer_bias=False, drop_out=None):
- layers = []
- pairs_count = len(linear_dims) - 1
- for idx in range(pairs_count):
- in_dim, out_dim = linear_dims[idx], linear_dims[idx + 1]
- if idx == pairs_count - 1:
- layers.append(nn.Linear(in_dim, out_dim, bias=last_layer_bias))
- else:
- layers.append(ResLinear(in_dim, out_dim))
- if drop_out is not None:
- layers.append(nn.Dropout(drop_out))
-
- return nn.Sequential(*layers)
-
- def auto_encoder_model(linear_dims):
- return nn.Sequential(
- ez_mlp(linear_dims, last_layer_bias=False, drop_out=DROP_OUT),
- nn.LayerNorm(linear_dims[-1]),
- ez_mlp(list(reversed(linear_dims)), last_layer_bias=True)
- )
-
- class AutoEncoderModel(nn.Module):
- def __init__(self, pretrained_name, bottleneck_sizes):
- super().__init__()
-
- self.bottleneck_size = bottleneck_sizes
-
- model = T5Model.from_pretrained(pretrained_name)
- self.emb_layer = model.get_encoder().get_input_embeddings()
- ez_freeze(self.emb_layer)
-
- self.auto_encoder = auto_encoder_model([
- self.embedding_dim,
- *bottleneck_sizes
- ])
-
- self.loss_fn = nn.MSELoss()
-
- def forward(self, word_ids, additive_noise, alpha):
- # word_ids.shape = (batch_size, 2)
- # additive_noise.shape = (batch_size, embedding_dim)
- # alpha.shape = (batch_size, 1)
-
- word_embs = self.emb_layer(word_ids)
- # word_embs.shape = (batch_size, 2, embedding_dim)
-
- word_combs = word_embs[:, 0] * alpha + word_embs[:, 1] * (1 - alpha)
- # word_combs.shape = (batch_size, embedding_dim)
-
- y_hat = self.auto_encoder(word_combs + additive_noise)
- loss = self.loss_fn(word_combs, y_hat)
- return loss, y_hat
-
- @property
- def embedding_dim(self):
- return self.emb_layer.embedding_dim
-
- @property
- def num_embeddings(self):
- return self.emb_layer.num_embeddings
-
-
- # In[5]:
-
-
- model = AutoEncoderModel('google/t5-large-lm-adapt', bottleneck_sizes=[768, 512, 256, 128])
- print(model)
-
- # In[6]:
-
-
- train_ds, valid_ds, test_ds = train_valid_test_split(range(model.num_embeddings))
- train_loader = custom_dataloader(words_ids=train_ds, batch_size=TRAIN_BATCH_SIZE, emb_dim=model.embedding_dim)
- valid_loader = custom_dataloader(words_ids=valid_ds, batch_size=VALID_BATCH_SIZE, emb_dim=model.embedding_dim)
-
-
- # In[7]:
-
-
- train_loader = FakeEpoch(train_loader, 1000)
- valid_loader = FakeEpoch(valid_loader, 100)
-
-
- # In[8]:
-
-
- def _prefix_dict_keys(prefix, input_dict):
- return {f'{prefix}_{key}': val for key, val in input_dict.items()}
-
- def train_loop(model, loader, optimizer, use_tqdm=False):
- model.train()
-
- batch_losses = []
-
- if use_tqdm:
- loader = tqdm(loader, position=2, desc="Train Loop", leave=False)
-
- for row in loader:
- optimizer.zero_grad()
-
- out = model(*(item.to(DEVICE) for item in row))
- loss = out[0]
-
- batch_loss_value = loss.item()
- loss.backward()
- optimizer.step()
-
- batch_losses.append(batch_loss_value)
-
- loss_value = np.mean(batch_losses)
- return _prefix_dict_keys('train', {
- 'loss': loss_value
- })
-
- def valid_loop(model, loader, use_tqdm=False):
- model.eval()
-
- batch_losses = []
-
- all_true = []
- all_pred = []
-
- if use_tqdm:
- loader = tqdm(loader, position=2, desc="Valid Loop", leave=False)
-
- with torch.no_grad():
- for row in loader:
- out = model(*(item.to(DEVICE) for item in row))
- loss = out[0]
-
- batch_loss_value = loss.item()
-
- batch_losses.append(batch_loss_value)
-
- loss_value = np.mean(batch_losses)
-
- return_value = {
- 'loss': loss_value,
- }
-
- return _prefix_dict_keys('valid', return_value)
-
-
- # In[9]:
-
-
- model.to(DEVICE)
-
- # model.load_state_dict(torch.load('./ae_file/snap_72.pt'))
-
- optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) # was 0.001
-
- for epoch in tqdm(range(1000), position=1):
- epoch_results = {}
-
- epoch_results.update(
- train_loop(
- model=model,
- loader=train_loader,
- optimizer=optimizer,
- use_tqdm=True
- )
- )
-
- epoch_results.update(
- valid_loop(
- model=model,
- loader=valid_loader,
- use_tqdm=True
- )
- )
- torch.save(model.state_dict(), f'./ae_file4_res_mlp/snap_{epoch}.pt')
- print(epoch_results)
-
-
- # In[ ]:
-
-
-
|