You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 8.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. from transformers import get_linear_schedule_with_warmup
  2. import logging
  3. import torch
  4. import torch.nn.functional as F
  5. from torch.optim import AdamW
  6. from tqdm import tqdm
  7. from opacus import PrivacyEngine
  8. from opacus.utils.batch_memory_manager import BatchMemoryManager
  9. import wandb
  10. import math
  11. from model import save_model
  12. from torch.optim.lr_scheduler import StepLR
  13. class Trainer:
  14. def __init__(self, cfg, model, train_loader, checkpoint=None, second_trainer=False):
  15. if second_trainer:
  16. self.epochs = cfg.epochs_two
  17. self.lr = cfg.lr_two
  18. self.weight_decay = cfg.weight_decay_two
  19. else:
  20. self.epochs = cfg.epochs
  21. self.lr = cfg.lr
  22. self.weight_decay = cfg.weight_decay
  23. self.optimizer = AdamW(model.parameters(), lr=self.lr, weight_decay=self.weight_decay, eps=cfg.optimizer_eps)
  24. self.gradient_accumulation_steps = cfg.virtual_batch_size // cfg.batch_size
  25. total_steps = len(train_loader) * self.gradient_accumulation_steps * self.epochs
  26. if cfg.scheduler:
  27. if cfg.scheduler_type == "linear":
  28. warmup_steps = cfg.scheduler_warmup_steps if cfg.scheduler_warmup_steps else cfg.scheduler_warmup_ratio*total_steps
  29. self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
  30. elif cfg.scheduler_type == "steplr":
  31. self.scheduler = StepLR(self.optimizer, step_size=cfg.scheduler_step_size, gamma=cfg.scheduler_gamma)
  32. self.dp = cfg.dp
  33. self.model = model
  34. self.cfg = cfg
  35. self.save_path = f"{cfg.media_path}generation_saved_models/{cfg.dataset}/{cfg.peft_mode}"
  36. self.model_name = self.cfg.run_name if self.cfg.run_name else "best_model"
  37. if cfg.dp:
  38. self.model.train()
  39. self.privacy_engine = PrivacyEngine(
  40. accountant="rdp",
  41. )
  42. if checkpoint:
  43. self.privacy_engine.load_checkpoint(path=checkpoint, module=self.model)
  44. self.model, self.optimizer, _ = self.privacy_engine.make_private_with_epsilon(
  45. module=self.model,
  46. optimizer=self.optimizer,
  47. data_loader=train_loader,
  48. target_epsilon=cfg.epsilon,
  49. target_delta=cfg.delta,
  50. epochs=self.epochs,
  51. max_grad_norm=cfg.clipping_threshold,
  52. )
  53. def train_step(self, train_loader):
  54. train_loss = 0
  55. self.model.train()
  56. self.optimizer.zero_grad()
  57. if self.dp:
  58. with BatchMemoryManager(data_loader=train_loader, max_physical_batch_size=self.cfg.batch_size, optimizer=self.optimizer) as new_data_loader:
  59. for batch_number, batch in tqdm(enumerate(new_data_loader, 1), total=len(new_data_loader)):
  60. # Move batch tensors to the same device as the model
  61. batch = prepare_inputs(batch)
  62. batch = {k: v.to(self.cfg.device) for k, v in batch.items()}
  63. # Forward pass
  64. outputs = self.model(**batch)
  65. loss = outputs.loss
  66. loss.backward()
  67. train_loss += loss.item()
  68. self.optimizer.step()
  69. self.optimizer.zero_grad()
  70. if self.cfg.scheduler and self.cfg.scheduler_type == "linear":
  71. self.scheduler.step()
  72. if self.cfg.scheduler and self.cfg.scheduler_type == "steplr":
  73. self.scheduler.step()
  74. else:
  75. for batch_number, batch in tqdm(enumerate(new_data_loader, 1), total=len(new_data_loader)):
  76. # Move batch tensors to the same device as the model
  77. batch = prepare_inputs(batch)
  78. batch = {k: v.to(self.cfg.device) for k, v in batch.items()}
  79. # Forward pass
  80. outputs = self.model(**batch)
  81. loss = outputs.loss
  82. loss.backward()
  83. train_loss += loss.item()
  84. self.optimizer.step()
  85. self.optimizer.zero_grad()
  86. if self.cfg.scheduler and self.cfg.scheduler_type == "linear":
  87. self.scheduler.step()
  88. if self.cfg.scheduler and self.cfg.scheduler_type == "steplr":
  89. self.scheduler.step()
  90. return train_loss/len(train_loader)
  91. def evaluate_step(self, val_loader):
  92. # Evaluation loop
  93. val_loss = 0
  94. self.model.eval()
  95. with torch.no_grad():
  96. for batch in tqdm(val_loader):
  97. # Move batch tensors to the same device as the model
  98. batch = prepare_inputs(batch)
  99. batch = {k: v.to(self.cfg.device) for k, v in batch.items()}
  100. outputs = self.model(**batch)
  101. loss = compute_loss_per_input(outputs, batch)
  102. val_loss += loss.mean().item()
  103. return val_loss/len(val_loader)
  104. def train_and_evaluate(self, epochs, train_loader, val_loader):
  105. best_validation_loss = None
  106. best_epoch = 0
  107. wandb_log = []
  108. for epoch in range(epochs):
  109. log_data = {}
  110. train_loss = self.train_step(train_loader)
  111. log_data["train_loss"] = train_loss
  112. logging.info(f"Epoch {epoch+1} Training loss: {train_loss}")
  113. val_loss = self.evaluate_step(val_loader=val_loader)
  114. log_data["validation_loss"] = val_loss
  115. logging.info(f"Epoch {epoch+1} Validation loss: {val_loss}")
  116. if best_validation_loss is None or val_loss < best_validation_loss:
  117. best_validation_loss = val_loss
  118. best_epoch = epoch
  119. save_model(self.model, self.cfg.peft_mode, self.save_path, self.model_name)
  120. logging.info(f"Model improved and saved for epoch {epoch+1}")
  121. wandb_log.append(log_data)
  122. logging.info("Best results:")
  123. if self.cfg.dp:
  124. logging.info(self.privacy_engine.accountant.get_epsilon(delta=self.cfg.delta))
  125. logging.info(f"Best validatin loss: {best_validation_loss} for Epoch: {best_epoch+1}")
  126. if self.cfg.use_wandb:
  127. for i, epoch_data in enumerate(wandb_log):
  128. wandb.log(epoch_data)
  129. def prepare_inputs(batch):
  130. batch.pop('src_attn', None)
  131. batch.pop('tgt_attn', None)
  132. batch.pop('src', None)
  133. return batch
  134. def compute_loss_per_input(outputs, batch):
  135. logits = outputs.logits
  136. shift_logits = logits[..., :-1, :].contiguous()
  137. shift_labels = batch["labels"][..., 1:].contiguous()
  138. seq_lens = (shift_labels != -100).sum(dim=1)
  139. loss = F.cross_entropy(shift_logits.permute(0, 2, 1), shift_labels, reduction="none")
  140. loss = loss.sum(dim=1) / seq_lens
  141. return loss
  142. def save_evaluation_output(outputs, path):
  143. with open(path, "w") as file:
  144. for strings in outputs:
  145. for string in strings:
  146. file.write(string + "\n")
  147. # file.write("\n")
  148. file.close()
  149. def generate_evaluation_output(model, tokenizer, data, device, max_length, beam_size=5, do_sample=False, num_return_sequences=1):
  150. generated_texts = []
  151. prev = None
  152. for entry in tqdm(data):
  153. if prev != entry["meaning_representation"]:
  154. prev = entry["meaning_representation"]
  155. prompt = f"{entry['meaning_representation']} {tokenizer.eos_token}"
  156. inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
  157. inputs = {key: val.to(device) for key, val in inputs.items()}
  158. with torch.no_grad():
  159. outputs = model.generate(**inputs,
  160. num_beams=beam_size,
  161. max_length=max_length,
  162. do_sample=do_sample,
  163. early_stopping=True,
  164. min_length=5,
  165. num_return_sequences=num_return_sequences,
  166. bad_words_ids = [[628], [198], [tokenizer.pad_token_id]],
  167. pad_token_id=tokenizer.eos_token_id,
  168. repetition_penalty=1,
  169. top_k=0,
  170. top_p=0.9)
  171. temp_generated_texts = []
  172. for output in outputs:
  173. generated_text = tokenizer.decode(output[len(inputs["input_ids"][0]):], skip_special_tokens=True)
  174. temp_generated_texts.append(generated_text.strip())
  175. generated_texts.append(temp_generated_texts)
  176. return generated_texts