You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

06_emb_ae.py 6.2KB

3 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # In[1]:
  4. import numpy as np
  5. from tqdm import tqdm
  6. from sklearn.model_selection import train_test_split
  7. import torch
  8. import torch.nn as nn
  9. from transformers import T5Model
  10. # In[2]:
  11. # BOTTLENECK_SIZE = 128
  12. TRAIN_BATCH_SIZE = 64
  13. VALID_BATCH_SIZE = 64
  14. NOISE_SCALE = 0.5
  15. RANDOM_SEED = 42
  16. SEED_SHIFT = 0
  17. DROP_OUT = 0.5
  18. DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  19. # In[3]:
  20. def train_valid_test_split(total_range, random_seed=RANDOM_SEED):
  21. train, testvalid = train_test_split(total_range, random_state=random_seed, test_size=0.2)
  22. test, valid = train_test_split(testvalid, random_state=random_seed, test_size=0.5)
  23. return train, valid, test
  24. def custom_dataloader(words_ids, batch_size, emb_dim, random_seed=RANDOM_SEED+SEED_SHIFT):
  25. np_rng = np.random.default_rng(seed=random_seed)
  26. while True:
  27. word_ids = np_rng.choice(words_ids, size=(batch_size, 2))
  28. additive_noise = np_rng.normal(loc=0, scale=NOISE_SCALE, size=(batch_size, emb_dim))
  29. alpha = np_rng.uniform(size=(batch_size, 1))
  30. yield torch.from_numpy(word_ids), torch.Tensor(additive_noise), torch.Tensor(alpha)
  31. class FakeEpoch:
  32. def __init__(self, dataloader, each_epoch_size):
  33. self.dataloader_iter = iter(dataloader)
  34. self.each_epoch_size = each_epoch_size
  35. def __len__(self):
  36. return self.each_epoch_size
  37. def __iter__(self):
  38. for _ in range(self.each_epoch_size):
  39. yield next(self.dataloader_iter)
  40. # In[4]:
  41. def ez_freeze(module):
  42. for param in module.parameters():
  43. param.requires_grad = False
  44. def ez_mlp(linear_dims, last_layer_bias=False, drop_out=None):
  45. layers = []
  46. pairs_count = len(linear_dims) - 1
  47. for idx in range(pairs_count):
  48. in_dim, out_dim = linear_dims[idx], linear_dims[idx + 1]
  49. if idx == pairs_count - 1:
  50. layers.append(nn.Linear(in_dim, out_dim, bias=True))
  51. else:
  52. layers.append(nn.Linear(in_dim, out_dim, bias=True))
  53. layers.append(nn.ReLU())
  54. if drop_out is not None:
  55. layers.append(nn.Dropout(drop_out))
  56. return nn.Sequential(*layers)
  57. def auto_encoder_model(linear_dims):
  58. return nn.Sequential(
  59. ez_mlp(linear_dims, last_layer_bias=False, drop_out=DROP_OUT),
  60. nn.ReLU(),
  61. nn.Dropout(0.5),
  62. # nn.LayerNorm(linear_dims[-1]),
  63. ez_mlp(list(reversed(linear_dims)), last_layer_bias=True)
  64. )
  65. class AutoEncoderModel(nn.Module):
  66. def __init__(self, pretrained_name, bottleneck_sizes):
  67. super().__init__()
  68. self.bottleneck_size = bottleneck_sizes
  69. model = T5Model.from_pretrained(pretrained_name)
  70. self.emb_layer = model.get_encoder().get_input_embeddings()
  71. ez_freeze(self.emb_layer)
  72. self.auto_encoder = auto_encoder_model([
  73. self.embedding_dim,
  74. *bottleneck_sizes
  75. ])
  76. self.loss_fn = nn.MSELoss()
  77. def forward(self, word_ids, additive_noise, alpha):
  78. # word_ids.shape = (batch_size, 2)
  79. # additive_noise.shape = (batch_size, embedding_dim)
  80. # alpha.shape = (batch_size, 1)
  81. word_embs = self.emb_layer(word_ids)
  82. # word_embs.shape = (batch_size, 2, embedding_dim)
  83. word_combs = word_embs[:, 0] * alpha + word_embs[:, 1] * (1 - alpha)
  84. # word_combs.shape = (batch_size, embedding_dim)
  85. y_hat = self.auto_encoder(word_combs + additive_noise)
  86. loss = self.loss_fn(word_combs, y_hat)
  87. return loss, y_hat
  88. @property
  89. def embedding_dim(self):
  90. return self.emb_layer.embedding_dim
  91. @property
  92. def num_embeddings(self):
  93. return self.emb_layer.num_embeddings
  94. # In[5]:
  95. model = AutoEncoderModel('google/t5-large-lm-adapt', bottleneck_sizes=[4096])
  96. print(model)
  97. # In[6]:
  98. train_ds, valid_ds, test_ds = train_valid_test_split(range(model.num_embeddings))
  99. train_loader = custom_dataloader(words_ids=train_ds, batch_size=TRAIN_BATCH_SIZE, emb_dim=model.embedding_dim)
  100. valid_loader = custom_dataloader(words_ids=valid_ds, batch_size=VALID_BATCH_SIZE, emb_dim=model.embedding_dim)
  101. # In[7]:
  102. train_loader = FakeEpoch(train_loader, 2000)
  103. valid_loader = FakeEpoch(valid_loader, 100)
  104. # In[8]:
  105. def _prefix_dict_keys(prefix, input_dict):
  106. return {f'{prefix}_{key}': val for key, val in input_dict.items()}
  107. def train_loop(model, loader, optimizer, use_tqdm=False):
  108. model.train()
  109. batch_losses = []
  110. if use_tqdm:
  111. loader = tqdm(loader, position=2, desc="Train Loop", leave=False)
  112. for row in loader:
  113. optimizer.zero_grad()
  114. out = model(*(item.to(DEVICE) for item in row))
  115. loss = out[0]
  116. batch_loss_value = loss.item()
  117. loss.backward()
  118. optimizer.step()
  119. batch_losses.append(batch_loss_value)
  120. loss_value = np.mean(batch_losses)
  121. return _prefix_dict_keys('train', {
  122. 'loss': loss_value
  123. })
  124. def valid_loop(model, loader, use_tqdm=False):
  125. model.eval()
  126. batch_losses = []
  127. if use_tqdm:
  128. loader = tqdm(loader, position=2, desc="Valid Loop", leave=False)
  129. with torch.no_grad():
  130. for row in loader:
  131. out = model(*(item.to(DEVICE) for item in row))
  132. loss = out[0]
  133. batch_loss_value = loss.item()
  134. batch_losses.append(batch_loss_value)
  135. loss_value = np.mean(batch_losses)
  136. return_value = {
  137. 'loss': loss_value,
  138. }
  139. return _prefix_dict_keys('valid', return_value)
  140. # In[9]:
  141. model.to(DEVICE)
  142. # model.load_state_dict(torch.load('./ae_file/snap_72.pt'))
  143. optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001) # was 0.001
  144. for epoch in tqdm(range(1000), position=1):
  145. epoch_results = {}
  146. epoch_results.update(
  147. train_loop(
  148. model=model,
  149. loader=train_loader,
  150. optimizer=optimizer,
  151. use_tqdm=True
  152. )
  153. )
  154. epoch_results.update(
  155. valid_loop(
  156. model=model,
  157. loader=valid_loader,
  158. use_tqdm=True
  159. )
  160. )
  161. torch.save(model.state_dict(), f'/disks/ssd/ae_file4/snap_{epoch}.pt')
  162. print(epoch_results)
  163. # In[ ]: