You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 721B

123456789101112131415161718192021222324252627
  1. from tqdm import tqdm
  2. import torch
  3. import os
  4. import sys
  5. sys.path.insert(1, os.path.join(sys.path[0], '..'))
  6. from _config import load_config
  7. from _utils import print_system_info, sp_encode
  8. from train_single import run_experminent
  9. if __name__ == '__main__':
  10. print_system_info()
  11. configs = load_config(sys.argv[1])
  12. run_configs = tqdm(configs.run_configs, position=0, desc="Experiment")
  13. for run_config in run_configs:
  14. tasks = tqdm(run_config.tasks, position=1, desc="Task:", leave=False)
  15. for task_name in tasks:
  16. tasks.set_description(f'Task: {task_name}')
  17. torch.cuda.empty_cache()
  18. run_experminent(run_config, task_name)