You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

06_emb_ae_res_mlp.py 6.6KB

3 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # In[1]:
  4. import numpy as np
  5. from tqdm import tqdm
  6. from sklearn.model_selection import train_test_split
  7. import torch
  8. import torch.nn as nn
  9. from transformers import T5Model
  10. # In[2]:
  11. # BOTTLENECK_SIZE = 128
  12. TRAIN_BATCH_SIZE = 8192
  13. VALID_BATCH_SIZE = 8192
  14. NOISE_SCALE = 1
  15. RANDOM_SEED = 42
  16. SEED_SHIFT = 0
  17. DROP_OUT = 0.2
  18. DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  19. # In[3]:
  20. def train_valid_test_split(total_range, random_seed=RANDOM_SEED):
  21. train, testvalid = train_test_split(total_range, random_state=random_seed, test_size=0.2)
  22. test, valid = train_test_split(testvalid, random_state=random_seed, test_size=0.5)
  23. return train, valid, test
  24. def custom_dataloader(words_ids, batch_size, emb_dim, random_seed=RANDOM_SEED+SEED_SHIFT):
  25. np_rng = np.random.default_rng(seed=random_seed)
  26. while True:
  27. word_ids = np_rng.choice(words_ids, size=(batch_size, 2))
  28. additive_noise = np_rng.normal(loc=0, scale=NOISE_SCALE, size=(batch_size, emb_dim))
  29. alpha = np_rng.uniform(size=(batch_size, 1))
  30. yield torch.from_numpy(word_ids), torch.Tensor(additive_noise), torch.Tensor(alpha)
  31. class FakeEpoch:
  32. def __init__(self, dataloader, each_epoch_size):
  33. self.dataloader_iter = iter(dataloader)
  34. self.each_epoch_size = each_epoch_size
  35. def __len__(self):
  36. return self.each_epoch_size
  37. def __iter__(self):
  38. for _ in range(self.each_epoch_size):
  39. yield next(self.dataloader_iter)
  40. # In[4]:
  41. def ez_freeze(module):
  42. for param in module.parameters():
  43. param.requires_grad = False
  44. class ResLinear(nn.Module):
  45. def __init__(self, in_dim, out_dim):
  46. super().__init__()
  47. self.linear1 = nn.Linear(in_dim, out_dim)
  48. self.linear2 = nn.Linear(out_dim, out_dim)
  49. def forward(self, x):
  50. out1 = nn.functional.relu(self.linear1(x))
  51. out2 = nn.functional.relu(self.linear2(out1))
  52. return out1 + out2
  53. def ez_mlp(linear_dims, last_layer_bias=False, drop_out=None):
  54. layers = []
  55. pairs_count = len(linear_dims) - 1
  56. for idx in range(pairs_count):
  57. in_dim, out_dim = linear_dims[idx], linear_dims[idx + 1]
  58. if idx == pairs_count - 1:
  59. layers.append(nn.Linear(in_dim, out_dim, bias=last_layer_bias))
  60. else:
  61. layers.append(ResLinear(in_dim, out_dim))
  62. if drop_out is not None:
  63. layers.append(nn.Dropout(drop_out))
  64. return nn.Sequential(*layers)
  65. def auto_encoder_model(linear_dims):
  66. return nn.Sequential(
  67. ez_mlp(linear_dims, last_layer_bias=False, drop_out=DROP_OUT),
  68. nn.LayerNorm(linear_dims[-1]),
  69. ez_mlp(list(reversed(linear_dims)), last_layer_bias=True)
  70. )
  71. class AutoEncoderModel(nn.Module):
  72. def __init__(self, pretrained_name, bottleneck_sizes):
  73. super().__init__()
  74. self.bottleneck_size = bottleneck_sizes
  75. model = T5Model.from_pretrained(pretrained_name)
  76. self.emb_layer = model.get_encoder().get_input_embeddings()
  77. ez_freeze(self.emb_layer)
  78. self.auto_encoder = auto_encoder_model([
  79. self.embedding_dim,
  80. *bottleneck_sizes
  81. ])
  82. self.loss_fn = nn.MSELoss()
  83. def forward(self, word_ids, additive_noise, alpha):
  84. # word_ids.shape = (batch_size, 2)
  85. # additive_noise.shape = (batch_size, embedding_dim)
  86. # alpha.shape = (batch_size, 1)
  87. word_embs = self.emb_layer(word_ids)
  88. # word_embs.shape = (batch_size, 2, embedding_dim)
  89. word_combs = word_embs[:, 0] * alpha + word_embs[:, 1] * (1 - alpha)
  90. # word_combs.shape = (batch_size, embedding_dim)
  91. y_hat = self.auto_encoder(word_combs + additive_noise)
  92. loss = self.loss_fn(word_combs, y_hat)
  93. return loss, y_hat
  94. @property
  95. def embedding_dim(self):
  96. return self.emb_layer.embedding_dim
  97. @property
  98. def num_embeddings(self):
  99. return self.emb_layer.num_embeddings
  100. # In[5]:
  101. model = AutoEncoderModel('google/t5-large-lm-adapt', bottleneck_sizes=[768, 512, 256, 128])
  102. print(model)
  103. # In[6]:
  104. train_ds, valid_ds, test_ds = train_valid_test_split(range(model.num_embeddings))
  105. train_loader = custom_dataloader(words_ids=train_ds, batch_size=TRAIN_BATCH_SIZE, emb_dim=model.embedding_dim)
  106. valid_loader = custom_dataloader(words_ids=valid_ds, batch_size=VALID_BATCH_SIZE, emb_dim=model.embedding_dim)
  107. # In[7]:
  108. train_loader = FakeEpoch(train_loader, 1000)
  109. valid_loader = FakeEpoch(valid_loader, 100)
  110. # In[8]:
  111. def _prefix_dict_keys(prefix, input_dict):
  112. return {f'{prefix}_{key}': val for key, val in input_dict.items()}
  113. def train_loop(model, loader, optimizer, use_tqdm=False):
  114. model.train()
  115. batch_losses = []
  116. if use_tqdm:
  117. loader = tqdm(loader, position=2, desc="Train Loop", leave=False)
  118. for row in loader:
  119. optimizer.zero_grad()
  120. out = model(*(item.to(DEVICE) for item in row))
  121. loss = out[0]
  122. batch_loss_value = loss.item()
  123. loss.backward()
  124. optimizer.step()
  125. batch_losses.append(batch_loss_value)
  126. loss_value = np.mean(batch_losses)
  127. return _prefix_dict_keys('train', {
  128. 'loss': loss_value
  129. })
  130. def valid_loop(model, loader, use_tqdm=False):
  131. model.eval()
  132. batch_losses = []
  133. all_true = []
  134. all_pred = []
  135. if use_tqdm:
  136. loader = tqdm(loader, position=2, desc="Valid Loop", leave=False)
  137. with torch.no_grad():
  138. for row in loader:
  139. out = model(*(item.to(DEVICE) for item in row))
  140. loss = out[0]
  141. batch_loss_value = loss.item()
  142. batch_losses.append(batch_loss_value)
  143. loss_value = np.mean(batch_losses)
  144. return_value = {
  145. 'loss': loss_value,
  146. }
  147. return _prefix_dict_keys('valid', return_value)
  148. # In[9]:
  149. model.to(DEVICE)
  150. # model.load_state_dict(torch.load('./ae_file/snap_72.pt'))
  151. optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) # was 0.001
  152. for epoch in tqdm(range(1000), position=1):
  153. epoch_results = {}
  154. epoch_results.update(
  155. train_loop(
  156. model=model,
  157. loader=train_loader,
  158. optimizer=optimizer,
  159. use_tqdm=True
  160. )
  161. )
  162. epoch_results.update(
  163. valid_loop(
  164. model=model,
  165. loader=valid_loader,
  166. use_tqdm=True
  167. )
  168. )
  169. torch.save(model.state_dict(), f'./ae_file4_res_mlp/snap_{epoch}.pt')
  170. print(epoch_results)
  171. # In[ ]: