You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

config.py 3.5KB

2 weeks ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import argparse
  2. import torch
  3. from media import media_path
  4. class Config:
  5. def __init__(self):
  6. self.parser = argparse.ArgumentParser()
  7. self.add_arguments()
  8. self.args = self.parse()
  9. self.post_process()
  10. def parse(self):
  11. return self.parser.parse_args()
  12. def add_arguments(self):
  13. self.parser.add_argument('--device', type=int, default=0, help='Device number to use for training')
  14. self.parser.add_argument('--gpu_count', type=int, default=1, help='Number of GPUs available')
  15. self.parser.add_argument('--seed', type=int, default=1234, help='Set seed for reproducability')
  16. self.parser.add_argument('--batch_size', type=int, default=16, help='batch size for training ')
  17. self.parser.add_argument('--virtual_batch_size', type=int, default=16, help='batch size for updating model parameters')
  18. self.parser.add_argument('--epochs', type=int, default=5, help='Number of epochs for training')
  19. self.parser.add_argument('--lr', type=float, default='2e-3', help='Learning rate')
  20. self.parser.add_argument('--weight_decay', type=float, default=0.1, help='Weight decay for optimizer')
  21. self.parser.add_argument('--optimizer_eps', type=float, default=1e-8, help='optimizer eps')
  22. self.parser.add_argument("--scheduler", type=int, default=1, help="Uses scheduler if 1")
  23. self.parser.add_argument('--scheduler_warmup_ratio', type=float, default=0.06, help='Scheduler warmup ratio * total steps = warmup steps')
  24. self.parser.add_argument('--max_length', type=int, default=128, help='Max length for tokenization')
  25. self.parser.add_argument('--peft_mode', type=str, default='lora', choices=['lora', 'bitfit', 'full', 'lorabitfit'], help='PEFT mode for fine-tuning')
  26. self.parser.add_argument('--rank', type=int, default=8, help='Rank for lora')
  27. self.parser.add_argument('--alpha', type=int, default=16, help='Alpha for lora')
  28. self.parser.add_argument('--dataset', type=str, default='sst2', choices=['sst2', 'mnli', 'qqp', 'qnli'], help='Dataset name')
  29. self.parser.add_argument('--toy_example', type=int, default=0, help='if 1, the first 1024 data from train dataset will be used for training')
  30. self.parser.add_argument("--dp", type=int, default=0, help="Fine-tune using differential privacy if 1")
  31. self.parser.add_argument("--epsilon", type=int, default=3, help="Epsilon in privacy budget")
  32. self.parser.add_argument("--delta", type=float, default=1e-5, help="Delta in privacy budget")
  33. self.parser.add_argument('--clipping_mode', type=str, default='default', choices=['default', 'ghost'], help='Clipping mode for DP fine-tuning')
  34. self.parser.add_argument("--clipping_threshold", type=float, default=0.1, help="Max grad norm")
  35. self.parser.add_argument("--use_wandb", type=int, default=0, help="Uses wandb if 1")
  36. self.parser.add_argument("--wandb_project_name", type=str, default="Project-DP", help="Wandb project name")
  37. self.parser.add_argument("--run_name", type=str, default=None, help="run name")
  38. self.parser.add_argument("--two_step_training", type=int, default=0, help="if 1, first finetunes lora then bitfit")
  39. def post_process(self):
  40. assert self.args.virtual_batch_size % self.args.batch_size == 0, "virtual_batch_size should be devisible by batch_size"
  41. self.args.device = torch.device(f'cuda:{self.args.device}' if torch.cuda.is_available() else "cpu")
  42. self.args.media_path = media_path