You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 5.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from config import Config
  2. import os
  3. import random
  4. import numpy as np
  5. import torch
  6. import wandb
  7. import logging
  8. import transformers
  9. import warnings
  10. import subprocess
  11. from model import load_model, prepare_model, get_number_of_trainable_parameters, load_model_weights, get_number_of_parameters
  12. from data import load_dataset, load_dataloaders
  13. from train import Trainer, generate_evaluation_output, save_evaluation_output
  14. from utils import clean_hyperparameters, copy_model_weights
  15. warnings.filterwarnings("ignore", "Using a non-full backward hook when the forward contains multiple autograd Nodes")
  16. transformers.logging.set_verbosity_error()
  17. def set_seeds(seed: int):
  18. os.environ['PYTHONHASHSEED'] = str(seed)
  19. random.seed(seed)
  20. np.random.seed(seed)
  21. torch.manual_seed(seed)
  22. torch.cuda.manual_seed(seed)
  23. torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
  24. transformers.set_seed(seed)
  25. def run_metric_script(file_path):
  26. result = subprocess.run(["e2e/measure_scores.py", "-p", "e2e_ref.txt", file_path], stdout=subprocess.PIPE)
  27. output = result.stdout.decode('utf-8')
  28. lines = output.split('\n')
  29. return lines[-7:-2]
  30. def main(cfg):
  31. set_seeds(cfg.seed)
  32. model, tokenizer = load_model(cfg.model_name, cache_dir=cfg.model_cache_path)
  33. model = prepare_model(model, cfg)
  34. num_of_all_params = get_number_of_parameters(model)
  35. num_of_trainbale_params = get_number_of_trainable_parameters(model)
  36. percentage = round(100 * num_of_trainbale_params / num_of_all_params, 2)
  37. logging.info(f"Model loaded successfully and number of trainable params is: {num_of_trainbale_params} out of {num_of_all_params}")
  38. logging.info(f"Percentage of trainable parameters: {percentage} %")
  39. dataset = load_dataset(cfg.dataset, cfg.media_path, cfg.toy_example)
  40. cfg.train_data_size = len(dataset["train"])
  41. # dataset = tokenize_dataset(tokenizer, dataset, cfg.dataset, cfg.seq_length)
  42. train_loader, validation_loader = load_dataloaders(dataset, cfg.dataset, cfg.batch_size, cfg.virtual_batch_size, tokenizer, cfg.seq_length, cfg.dp)
  43. logging.info("Dataset loaded and tokenized")
  44. trainer = Trainer(cfg, model, train_loader)
  45. trainer.train_and_evaluate(cfg.epochs, train_loader, validation_loader)
  46. if cfg.two_step_training and cfg.dp:
  47. trainer.privacy_engine.save_checkpoint(path="temp.pth", module=model)
  48. model_two, _ = load_model(cfg.model_name, cache_dir=cfg.model_cache_path)
  49. model_two = prepare_model(model_two, cfg)
  50. copy_model_weights(model, model_two)
  51. del model
  52. model = model_two
  53. for a, b in model.named_parameters():
  54. if 'bias' in a and not 'adapter' in a:
  55. b.requires_grad = True
  56. else:
  57. b.requires_grad = False
  58. logging.info("New Model adjusted")
  59. num_of_all_params = get_number_of_parameters(model)
  60. num_of_trainbale_params = get_number_of_trainable_parameters(model)
  61. percentage = round(100 * num_of_trainbale_params / num_of_all_params, 2)
  62. logging.info(f"New Model loaded successfully and number of trainable params is: {num_of_trainbale_params} out of {num_of_all_params}")
  63. logging.info(f"Percentage of trainable parameters: {percentage} %")
  64. trainer_two = Trainer(cfg, model, train_loader, second_trainer=True)
  65. trainer_two.train_and_evaluate(cfg.epochs_two, train_loader, validation_loader)
  66. # evaluate model on test data
  67. model.eval()
  68. model = load_model_weights(model, cfg.peft_mode, f"{trainer.save_path}/{trainer.model_name}.pth")
  69. evaluation_output = generate_evaluation_output(model, tokenizer, dataset["test"], cfg.device, cfg.seq_length, cfg.beam_size)
  70. output_path = f"{cfg.media_path}generation_eval_outputs/{cfg.dataset}/{cfg.peft_mode}"
  71. if not os.path.exists(output_path):
  72. os.makedirs(output_path)
  73. output_name = cfg.run_name if cfg.run_name else "generation_output"
  74. save_evaluation_output(evaluation_output, f"{output_path}/{output_name}-v1.txt")
  75. evaluation_output = generate_evaluation_output(model, tokenizer, dataset["test"], cfg.device, cfg.seq_length, cfg.beam_size, do_sample=True)
  76. save_evaluation_output(evaluation_output, f"{output_path}/{output_name}-v2.txt")
  77. logging.info("Generation for test data saved")
  78. metrics = run_metric_script(f"{output_path}/{output_name}-v1.txt")
  79. logging.info("Metrics without sampling:")
  80. for metric in metrics:
  81. logging.info(metric)
  82. metrics = run_metric_script(f"{output_path}/{output_name}-v2.txt")
  83. logging.info("Metrics with sampling:")
  84. for metric in metrics:
  85. logging.info(metric)
  86. if cfg.use_wandb:
  87. wandb.finish()
  88. if __name__ == "__main__":
  89. cfg = Config().args
  90. log_path = f"logs/{cfg.dataset}/{cfg.peft_mode}/"
  91. if not os.path.exists(log_path):
  92. os.makedirs(log_path)
  93. log_file_name = f"{cfg.run_name}.log" if cfg.run_name else "logs.log"
  94. if cfg.use_wandb:
  95. wandb.login(key="YOUR_KEY")
  96. if cfg.run_name:
  97. wandb.init(config=cfg, project=f"{cfg.wandb_project_name}-{cfg.dataset}", name=cfg.run_name)
  98. else:
  99. wandb.init(config=cfg, project=f"{cfg.wandb_project_name}-{cfg.dataset}")
  100. log_file_name = wandb.run.name
  101. logging.basicConfig(filename=f"{log_path}{log_file_name}", level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True)
  102. logging.info("Start of the logging")
  103. hyperparameters = {key: value for key, value in vars(cfg).items()}
  104. hyperparameters = clean_hyperparameters(hyperparameters)
  105. hyperparameters_str = "\n".join([f"{key}: {value}" for key, value in hyperparameters.items()])
  106. logging.info("config:\n" + hyperparameters_str)
  107. main(cfg)