You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 4.1KB

1 week ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. from config import Config
  2. from src.model import prepare_model
  3. from src.data import prepare_data
  4. from src.train import Trainer
  5. import os
  6. import random
  7. import numpy as np
  8. import torch
  9. import wandb
  10. import logging
  11. import transformers
  12. import warnings
  13. warnings.filterwarnings("ignore", "Using a non-full backward hook when the forward contains multiple autograd Nodes ")
  14. transformers.logging.set_verbosity_error()
  15. def set_seeds(seed: int):
  16. os.environ['PYTHONHASHSEED'] = str(seed)
  17. random.seed(seed)
  18. np.random.seed(seed)
  19. torch.manual_seed(seed)
  20. torch.cuda.manual_seed(seed)
  21. torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
  22. transformers.set_seed(seed)
  23. def copy_model_weights(model1, model2):
  24. model1.eval()
  25. model2.eval()
  26. params1 = model1.parameters()
  27. params2 = model2.parameters()
  28. with torch.no_grad():
  29. for param1, param2 in zip(params1, params2):
  30. param2.data.copy_(param1.data)
  31. # Returns number of trainbale parameters of the model
  32. def get_number_of_trainable_parameters(model):
  33. return sum(p.numel() for p in model.parameters() if p.requires_grad)
  34. # Returns number of parameters of the model
  35. def get_number_of_parameters(model):
  36. return sum(p.numel() for p in model.parameters())
  37. def main(cfg):
  38. set_seeds(cfg.seed)
  39. model, tokenizer = prepare_model(cfg)
  40. num_of_all_params = get_number_of_parameters(model)
  41. num_of_trainbale_params = get_number_of_trainable_parameters(model)
  42. percentage = round(100 * num_of_trainbale_params / num_of_all_params, 2)
  43. logging.info(f"New Model loaded successfully and number of trainable params is: {num_of_trainbale_params} out of {num_of_all_params}")
  44. logging.info(f"Percentage of trainable parameters: {percentage} %")
  45. train_loader, val_loader_one, val_loader_two = prepare_data(cfg, tokenizer)
  46. logging.info("Data is ready")
  47. trainer = Trainer(cfg, model, train_loader)
  48. trainer.train_and_evaluate(cfg.epochs, train_loader, val_loader_one, val_loader_two)
  49. if cfg.two_step_training:
  50. if cfg.dp:
  51. trainer.privacy_engine.save_checkpoint(path="temp.pth", module=model)
  52. model_two, _ = prepare_model(cfg)
  53. copy_model_weights(model, model_two)
  54. del model
  55. model = model_two
  56. for a, b in model.roberta.named_parameters():
  57. if 'bias' in a:
  58. b.requires_grad = True
  59. else:
  60. b.requires_grad = False
  61. logging.info("New Model adjusted")
  62. num_of_all_params = get_number_of_parameters(model)
  63. num_of_trainbale_params = get_number_of_trainable_parameters(model)
  64. percentage = round(100 * num_of_trainbale_params / num_of_all_params, 2)
  65. logging.info(f"New Model loaded successfully and number of trainable params is: {num_of_trainbale_params} out of {num_of_all_params}")
  66. logging.info(f"Percentage of trainable parameters: {percentage} %")
  67. trainer_two = Trainer(cfg, model, train_loader, checkpoint="temp.pth")
  68. trainer_two.train_and_evaluate(cfg.epochs, train_loader, val_loader_one, val_loader_two)
  69. if cfg.use_wandb:
  70. wandb.finish()
  71. if __name__ == "__main__":
  72. cfg = Config().args
  73. log_path = "logs/"
  74. if not os.path.exists(log_path):
  75. os.makedirs(log_path)
  76. log_file_name = f"{cfg.run_name}.log" if cfg.run_name else "logs.log"
  77. if cfg.use_wandb:
  78. wandb.login(key="YOUR_KEY")
  79. if cfg.run_name:
  80. wandb.init(config=cfg, project=f"{cfg.wandb_project_name}-{cfg.dataset}", name=cfg.run_name)
  81. else:
  82. wandb.init(config=cfg, project=f"{cfg.wandb_project_name}-{cfg.dataset}")
  83. log_file_name = wandb.run.name
  84. logging.basicConfig(filename=f"{log_path}{log_file_name}", level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True)
  85. logging.info("Start of the logging")
  86. hyperparameters = {key: value for key, value in vars(cfg).items()}
  87. hyperparameters_str = "\n".join([f"{key}: {value}" for key, value in hyperparameters.items()])
  88. logging.info("config:\n" + hyperparameters_str)
  89. main(cfg)