|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211 |
- from transformers import get_linear_schedule_with_warmup
- import logging
- import torch
- import torch.nn.functional as F
- from torch.optim import AdamW
- from tqdm import tqdm
- from opacus import PrivacyEngine
- from opacus.utils.batch_memory_manager import BatchMemoryManager
-
-
- import wandb
- import math
- from model import save_model
- from torch.optim.lr_scheduler import StepLR
-
-
- class Trainer:
- def __init__(self, cfg, model, train_loader, checkpoint=None, second_trainer=False):
- if second_trainer:
- self.epochs = cfg.epochs_two
- self.lr = cfg.lr_two
- self.weight_decay = cfg.weight_decay_two
- else:
- self.epochs = cfg.epochs
- self.lr = cfg.lr
- self.weight_decay = cfg.weight_decay
- self.optimizer = AdamW(model.parameters(), lr=self.lr, weight_decay=self.weight_decay, eps=cfg.optimizer_eps)
- self.gradient_accumulation_steps = cfg.virtual_batch_size // cfg.batch_size
- total_steps = len(train_loader) * self.gradient_accumulation_steps * self.epochs
-
- if cfg.scheduler:
- if cfg.scheduler_type == "linear":
- warmup_steps = cfg.scheduler_warmup_steps if cfg.scheduler_warmup_steps else cfg.scheduler_warmup_ratio*total_steps
- self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
- elif cfg.scheduler_type == "steplr":
- self.scheduler = StepLR(self.optimizer, step_size=cfg.scheduler_step_size, gamma=cfg.scheduler_gamma)
-
- self.dp = cfg.dp
- self.model = model
- self.cfg = cfg
- self.save_path = f"{cfg.media_path}generation_saved_models/{cfg.dataset}/{cfg.peft_mode}"
- self.model_name = self.cfg.run_name if self.cfg.run_name else "best_model"
-
- if cfg.dp:
- self.model.train()
- self.privacy_engine = PrivacyEngine(
- accountant="rdp",
- )
- if checkpoint:
- self.privacy_engine.load_checkpoint(path=checkpoint, module=self.model)
- self.model, self.optimizer, _ = self.privacy_engine.make_private_with_epsilon(
- module=self.model,
- optimizer=self.optimizer,
- data_loader=train_loader,
- target_epsilon=cfg.epsilon,
- target_delta=cfg.delta,
- epochs=self.epochs,
- max_grad_norm=cfg.clipping_threshold,
- )
-
-
- def train_step(self, train_loader):
- train_loss = 0
-
- self.model.train()
- self.optimizer.zero_grad()
-
- if self.dp:
- with BatchMemoryManager(data_loader=train_loader, max_physical_batch_size=self.cfg.batch_size, optimizer=self.optimizer) as new_data_loader:
- for batch_number, batch in tqdm(enumerate(new_data_loader, 1), total=len(new_data_loader)):
- # Move batch tensors to the same device as the model
- batch = prepare_inputs(batch)
- batch = {k: v.to(self.cfg.device) for k, v in batch.items()}
- # Forward pass
- outputs = self.model(**batch)
- loss = outputs.loss
- loss.backward()
- train_loss += loss.item()
- self.optimizer.step()
- self.optimizer.zero_grad()
- if self.cfg.scheduler and self.cfg.scheduler_type == "linear":
- self.scheduler.step()
-
- if self.cfg.scheduler and self.cfg.scheduler_type == "steplr":
- self.scheduler.step()
- else:
- for batch_number, batch in tqdm(enumerate(new_data_loader, 1), total=len(new_data_loader)):
- # Move batch tensors to the same device as the model
- batch = prepare_inputs(batch)
- batch = {k: v.to(self.cfg.device) for k, v in batch.items()}
- # Forward pass
- outputs = self.model(**batch)
- loss = outputs.loss
- loss.backward()
- train_loss += loss.item()
- self.optimizer.step()
- self.optimizer.zero_grad()
- if self.cfg.scheduler and self.cfg.scheduler_type == "linear":
- self.scheduler.step()
-
- if self.cfg.scheduler and self.cfg.scheduler_type == "steplr":
- self.scheduler.step()
-
- return train_loss/len(train_loader)
-
- def evaluate_step(self, val_loader):
- # Evaluation loop
- val_loss = 0
- self.model.eval()
- with torch.no_grad():
- for batch in tqdm(val_loader):
- # Move batch tensors to the same device as the model
- batch = prepare_inputs(batch)
- batch = {k: v.to(self.cfg.device) for k, v in batch.items()}
-
- outputs = self.model(**batch)
- loss = compute_loss_per_input(outputs, batch)
-
- val_loss += loss.mean().item()
-
- return val_loss/len(val_loader)
-
- def train_and_evaluate(self, epochs, train_loader, val_loader):
- best_validation_loss = None
- best_epoch = 0
-
- wandb_log = []
-
- for epoch in range(epochs):
- log_data = {}
- train_loss = self.train_step(train_loader)
- log_data["train_loss"] = train_loss
- logging.info(f"Epoch {epoch+1} Training loss: {train_loss}")
- val_loss = self.evaluate_step(val_loader=val_loader)
- log_data["validation_loss"] = val_loss
- logging.info(f"Epoch {epoch+1} Validation loss: {val_loss}")
- if best_validation_loss is None or val_loss < best_validation_loss:
- best_validation_loss = val_loss
- best_epoch = epoch
- save_model(self.model, self.cfg.peft_mode, self.save_path, self.model_name)
- logging.info(f"Model improved and saved for epoch {epoch+1}")
-
- wandb_log.append(log_data)
-
- logging.info("Best results:")
- if self.cfg.dp:
- logging.info(self.privacy_engine.accountant.get_epsilon(delta=self.cfg.delta))
- logging.info(f"Best validatin loss: {best_validation_loss} for Epoch: {best_epoch+1}")
-
- if self.cfg.use_wandb:
- for i, epoch_data in enumerate(wandb_log):
- wandb.log(epoch_data)
-
-
- def prepare_inputs(batch):
- batch.pop('src_attn', None)
- batch.pop('tgt_attn', None)
- batch.pop('src', None)
- return batch
-
- def compute_loss_per_input(outputs, batch):
- logits = outputs.logits
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = batch["labels"][..., 1:].contiguous()
- seq_lens = (shift_labels != -100).sum(dim=1)
- loss = F.cross_entropy(shift_logits.permute(0, 2, 1), shift_labels, reduction="none")
- loss = loss.sum(dim=1) / seq_lens
- return loss
-
- def save_evaluation_output(outputs, path):
- with open(path, "w") as file:
- for strings in outputs:
- for string in strings:
- file.write(string + "\n")
- # file.write("\n")
- file.close()
-
- def generate_evaluation_output(model, tokenizer, data, device, max_length, beam_size=5, do_sample=False, num_return_sequences=1):
- generated_texts = []
-
- prev = None
-
- for entry in tqdm(data):
- if prev != entry["meaning_representation"]:
- prev = entry["meaning_representation"]
- prompt = f"{entry['meaning_representation']} {tokenizer.eos_token}"
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
- inputs = {key: val.to(device) for key, val in inputs.items()}
-
- with torch.no_grad():
- outputs = model.generate(**inputs,
- num_beams=beam_size,
- max_length=max_length,
- do_sample=do_sample,
- early_stopping=True,
- min_length=5,
- num_return_sequences=num_return_sequences,
- bad_words_ids = [[628], [198], [tokenizer.pad_token_id]],
- pad_token_id=tokenizer.eos_token_id,
- repetition_penalty=1,
- top_k=0,
- top_p=0.9)
-
- temp_generated_texts = []
- for output in outputs:
- generated_text = tokenizer.decode(output[len(inputs["input_ids"][0]):], skip_special_tokens=True)
- temp_generated_texts.append(generated_text.strip())
-
- generated_texts.append(temp_generated_texts)
- return generated_texts
-
|