|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- from loguru import logger
-
- import tensorrt as trt
- import torch
- from torch2trt import torch2trt
-
- from yolox.exp import get_exp
-
- import argparse
- import os
- import shutil
-
-
- def make_parser():
- parser = argparse.ArgumentParser("YOLOX ncnn deploy")
- parser.add_argument("-expn", "--experiment-name", type=str, default=None)
- parser.add_argument("-n", "--name", type=str, default=None, help="model name")
-
- parser.add_argument(
- "-f",
- "--exp_file",
- default=None,
- type=str,
- help="pls input your expriment description file",
- )
- parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
- return parser
-
-
- @logger.catch
- def main():
- args = make_parser().parse_args()
- exp = get_exp(args.exp_file, args.name)
- if not args.experiment_name:
- args.experiment_name = exp.exp_name
-
- model = exp.get_model()
- file_name = os.path.join(exp.output_dir, args.experiment_name)
- os.makedirs(file_name, exist_ok=True)
- if args.ckpt is None:
- ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar")
- else:
- ckpt_file = args.ckpt
-
- ckpt = torch.load(ckpt_file, map_location="cpu")
- # load the model state dict
-
- model.load_state_dict(ckpt["model"])
- logger.info("loaded checkpoint done.")
- model.eval()
- model.cuda()
- model.head.decode_in_inference = False
- x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
- model_trt = torch2trt(
- model,
- [x],
- fp16_mode=True,
- log_level=trt.Logger.INFO,
- max_workspace_size=(1 << 32),
- )
- torch.save(model_trt.state_dict(), os.path.join(file_name, "model_trt.pth"))
- logger.info("Converted TensorRT model done.")
- engine_file = os.path.join(file_name, "model_trt.engine")
- engine_file_demo = os.path.join("deploy", "TensorRT", "cpp", "model_trt.engine")
- with open(engine_file, "wb") as f:
- f.write(model_trt.engine.serialize())
-
- shutil.copyfile(engine_file, engine_file_demo)
-
- logger.info("Converted TensorRT model engine file is saved for C++ inference.")
-
-
- if __name__ == "__main__":
- main()
|