Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 3.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from loguru import logger
  2. import torch
  3. import torch.backends.cudnn as cudnn
  4. from yolox.core import Trainer, launch, MetaTrainer
  5. from yolox.exp import get_exp
  6. import argparse
  7. import random
  8. import warnings
  9. def make_parser():
  10. parser = argparse.ArgumentParser("YOLOX train parser")
  11. parser.add_argument("-expn", "--experiment-name", type=str, default=None)
  12. parser.add_argument("-t", "--task", type=str, default="metamot")
  13. parser.add_argument("-n", "--name", type=str, default=None, help="model name")
  14. # distributed
  15. parser.add_argument(
  16. "--dist-backend", default="nccl", type=str, help="distributed backend"
  17. )
  18. parser.add_argument(
  19. "--dist-url",
  20. default=None,
  21. type=str,
  22. help="url used to set up distributed training",
  23. )
  24. parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size")
  25. parser.add_argument(
  26. "-d", "--devices", default=None, type=int, help="device for training"
  27. )
  28. parser.add_argument(
  29. "--local_rank", default=0, type=int, help="local rank for dist training"
  30. )
  31. parser.add_argument(
  32. "--adaptation_period", default=4, type=int, help="if 4, then adapts to one batch in four batches"
  33. )
  34. parser.add_argument(
  35. "-f",
  36. "--exp_file",
  37. default=None,
  38. type=str,
  39. help="plz input your expriment description file",
  40. )
  41. parser.add_argument(
  42. "--resume", default=False, action="store_true", help="resume training"
  43. )
  44. parser.add_argument("-c", "--ckpt", default=None, type=str, help="checkpoint file")
  45. parser.add_argument(
  46. "-e",
  47. "--start_epoch",
  48. default=None,
  49. type=int,
  50. help="resume training start epoch",
  51. )
  52. parser.add_argument(
  53. "--num_machines", default=1, type=int, help="num of node for training"
  54. )
  55. parser.add_argument(
  56. "--machine_rank", default=0, type=int, help="node rank for multi-node training"
  57. )
  58. parser.add_argument(
  59. "--fp16",
  60. dest="fp16",
  61. default=True,
  62. action="store_true",
  63. help="Adopting mix precision training.",
  64. )
  65. parser.add_argument(
  66. "-o",
  67. "--occupy",
  68. dest="occupy",
  69. default=False,
  70. action="store_true",
  71. help="occupy GPU memory first for training.",
  72. )
  73. parser.add_argument(
  74. "opts",
  75. help="Modify config options using the command-line",
  76. default=None,
  77. nargs=argparse.REMAINDER,
  78. )
  79. return parser
  80. @logger.catch
  81. def main(exp, args):
  82. if exp.seed is not None:
  83. random.seed(exp.seed)
  84. torch.manual_seed(exp.seed)
  85. cudnn.deterministic = True
  86. warnings.warn(
  87. "You have chosen to seed training. This will turn on the CUDNN deterministic setting, "
  88. "which can slow down your training considerably! You may see unexpected behavior "
  89. "when restarting from checkpoints."
  90. )
  91. # set environment variables for distributed training
  92. cudnn.benchmark = True
  93. if args.task == "metamot":
  94. trainer = MetaTrainer(exp,args)
  95. else:
  96. trainer = Trainer(exp, args)
  97. print('Trainer Created')
  98. trainer.train()
  99. if __name__ == "__main__":
  100. args = make_parser().parse_args()
  101. exp = get_exp(args.exp_file, args.name)
  102. exp.merge(args.opts)
  103. if not args.experiment_name:
  104. args.experiment_name = exp.exp_name
  105. num_gpu = torch.cuda.device_count() if args.devices is None else args.devices
  106. assert num_gpu <= torch.cuda.device_count()
  107. launch(
  108. main,
  109. num_gpu,
  110. args.num_machines,
  111. args.machine_rank,
  112. backend=args.dist_backend,
  113. dist_url=args.dist_url,
  114. args=(exp, args),
  115. )