#!/usr/bin/env python # coding: utf-8 # In[1]: import numpy as np from tqdm import tqdm from sklearn.model_selection import train_test_split import torch import torch.nn as nn from transformers import T5Model # In[2]: # BOTTLENECK_SIZE = 128 TRAIN_BATCH_SIZE = 64 VALID_BATCH_SIZE = 64 NOISE_SCALE = 0.5 RANDOM_SEED = 42 SEED_SHIFT = 0 DROP_OUT = 0.5 DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # In[3]: def train_valid_test_split(total_range, random_seed=RANDOM_SEED): train, testvalid = train_test_split(total_range, random_state=random_seed, test_size=0.2) test, valid = train_test_split(testvalid, random_state=random_seed, test_size=0.5) return train, valid, test def custom_dataloader(words_ids, batch_size, emb_dim, random_seed=RANDOM_SEED+SEED_SHIFT): np_rng = np.random.default_rng(seed=random_seed) while True: word_ids = np_rng.choice(words_ids, size=(batch_size, 2)) additive_noise = np_rng.normal(loc=0, scale=NOISE_SCALE, size=(batch_size, emb_dim)) alpha = np_rng.uniform(size=(batch_size, 1)) yield torch.from_numpy(word_ids), torch.Tensor(additive_noise), torch.Tensor(alpha) class FakeEpoch: def __init__(self, dataloader, each_epoch_size): self.dataloader_iter = iter(dataloader) self.each_epoch_size = each_epoch_size def __len__(self): return self.each_epoch_size def __iter__(self): for _ in range(self.each_epoch_size): yield next(self.dataloader_iter) # In[4]: def ez_freeze(module): for param in module.parameters(): param.requires_grad = False def ez_mlp(linear_dims, last_layer_bias=False, drop_out=None): layers = [] pairs_count = len(linear_dims) - 1 for idx in range(pairs_count): in_dim, out_dim = linear_dims[idx], linear_dims[idx + 1] if idx == pairs_count - 1: layers.append(nn.Linear(in_dim, out_dim, bias=True)) else: layers.append(nn.Linear(in_dim, out_dim, bias=True)) layers.append(nn.ReLU()) if drop_out is not None: layers.append(nn.Dropout(drop_out)) return nn.Sequential(*layers) def auto_encoder_model(linear_dims): return nn.Sequential( ez_mlp(linear_dims, last_layer_bias=False, drop_out=DROP_OUT), nn.ReLU(), nn.Dropout(0.5), # nn.LayerNorm(linear_dims[-1]), ez_mlp(list(reversed(linear_dims)), last_layer_bias=True) ) class AutoEncoderModel(nn.Module): def __init__(self, pretrained_name, bottleneck_sizes): super().__init__() self.bottleneck_size = bottleneck_sizes model = T5Model.from_pretrained(pretrained_name) self.emb_layer = model.get_encoder().get_input_embeddings() ez_freeze(self.emb_layer) self.auto_encoder = auto_encoder_model([ self.embedding_dim, *bottleneck_sizes ]) self.loss_fn = nn.MSELoss() def forward(self, word_ids, additive_noise, alpha): # word_ids.shape = (batch_size, 2) # additive_noise.shape = (batch_size, embedding_dim) # alpha.shape = (batch_size, 1) word_embs = self.emb_layer(word_ids) # word_embs.shape = (batch_size, 2, embedding_dim) word_combs = word_embs[:, 0] * alpha + word_embs[:, 1] * (1 - alpha) # word_combs.shape = (batch_size, embedding_dim) y_hat = self.auto_encoder(word_combs + additive_noise) loss = self.loss_fn(word_combs, y_hat) return loss, y_hat @property def embedding_dim(self): return self.emb_layer.embedding_dim @property def num_embeddings(self): return self.emb_layer.num_embeddings # In[5]: model = AutoEncoderModel('google/t5-large-lm-adapt', bottleneck_sizes=[4096]) print(model) # In[6]: train_ds, valid_ds, test_ds = train_valid_test_split(range(model.num_embeddings)) train_loader = custom_dataloader(words_ids=train_ds, batch_size=TRAIN_BATCH_SIZE, emb_dim=model.embedding_dim) valid_loader = custom_dataloader(words_ids=valid_ds, batch_size=VALID_BATCH_SIZE, emb_dim=model.embedding_dim) # In[7]: train_loader = FakeEpoch(train_loader, 2000) valid_loader = FakeEpoch(valid_loader, 100) # In[8]: def _prefix_dict_keys(prefix, input_dict): return {f'{prefix}_{key}': val for key, val in input_dict.items()} def train_loop(model, loader, optimizer, use_tqdm=False): model.train() batch_losses = [] if use_tqdm: loader = tqdm(loader, position=2, desc="Train Loop", leave=False) for row in loader: optimizer.zero_grad() out = model(*(item.to(DEVICE) for item in row)) loss = out[0] batch_loss_value = loss.item() loss.backward() optimizer.step() batch_losses.append(batch_loss_value) loss_value = np.mean(batch_losses) return _prefix_dict_keys('train', { 'loss': loss_value }) def valid_loop(model, loader, use_tqdm=False): model.eval() batch_losses = [] if use_tqdm: loader = tqdm(loader, position=2, desc="Valid Loop", leave=False) with torch.no_grad(): for row in loader: out = model(*(item.to(DEVICE) for item in row)) loss = out[0] batch_loss_value = loss.item() batch_losses.append(batch_loss_value) loss_value = np.mean(batch_losses) return_value = { 'loss': loss_value, } return _prefix_dict_keys('valid', return_value) # In[9]: model.to(DEVICE) # model.load_state_dict(torch.load('./ae_file/snap_72.pt')) optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001) # was 0.001 for epoch in tqdm(range(1000), position=1): epoch_results = {} epoch_results.update( train_loop( model=model, loader=train_loader, optimizer=optimizer, use_tqdm=True ) ) epoch_results.update( valid_loop( model=model, loader=valid_loader, use_tqdm=True ) ) torch.save(model.state_dict(), f'/disks/ssd/ae_file4/snap_{epoch}.pt') print(epoch_results) # In[ ]: