Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

trt.py 2.1KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from loguru import logger
  2. import tensorrt as trt
  3. import torch
  4. from torch2trt import torch2trt
  5. from yolox.exp import get_exp
  6. import argparse
  7. import os
  8. import shutil
  9. def make_parser():
  10. parser = argparse.ArgumentParser("YOLOX ncnn deploy")
  11. parser.add_argument("-expn", "--experiment-name", type=str, default=None)
  12. parser.add_argument("-n", "--name", type=str, default=None, help="model name")
  13. parser.add_argument(
  14. "-f",
  15. "--exp_file",
  16. default=None,
  17. type=str,
  18. help="pls input your expriment description file",
  19. )
  20. parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
  21. return parser
  22. @logger.catch
  23. def main():
  24. args = make_parser().parse_args()
  25. exp = get_exp(args.exp_file, args.name)
  26. if not args.experiment_name:
  27. args.experiment_name = exp.exp_name
  28. model = exp.get_model()
  29. file_name = os.path.join(exp.output_dir, args.experiment_name)
  30. os.makedirs(file_name, exist_ok=True)
  31. if args.ckpt is None:
  32. ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar")
  33. else:
  34. ckpt_file = args.ckpt
  35. ckpt = torch.load(ckpt_file, map_location="cpu")
  36. # load the model state dict
  37. model.load_state_dict(ckpt["model"])
  38. logger.info("loaded checkpoint done.")
  39. model.eval()
  40. model.cuda()
  41. model.head.decode_in_inference = False
  42. x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
  43. model_trt = torch2trt(
  44. model,
  45. [x],
  46. fp16_mode=True,
  47. log_level=trt.Logger.INFO,
  48. max_workspace_size=(1 << 32),
  49. )
  50. torch.save(model_trt.state_dict(), os.path.join(file_name, "model_trt.pth"))
  51. logger.info("Converted TensorRT model done.")
  52. engine_file = os.path.join(file_name, "model_trt.engine")
  53. engine_file_demo = os.path.join("deploy", "TensorRT", "cpp", "model_trt.engine")
  54. with open(engine_file, "wb") as f:
  55. f.write(model_trt.engine.serialize())
  56. shutil.copyfile(engine_file, engine_file_demo)
  57. logger.info("Converted TensorRT model engine file is saved for C++ inference.")
  58. if __name__ == "__main__":
  59. main()