from loguru import logger import torch from torch import nn from yolox.exp import get_exp from yolox.models.network_blocks import SiLU from yolox.utils import replace_module import argparse import os def make_parser(): parser = argparse.ArgumentParser("YOLOX onnx deploy") parser.add_argument( "--output-name", type=str, default="bytetrack_s.onnx", help="output name of models" ) parser.add_argument( "--input", default="images", type=str, help="input node name of onnx model" ) parser.add_argument( "--output", default="output", type=str, help="output node name of onnx model" ) parser.add_argument( "-o", "--opset", default=11, type=int, help="onnx opset version" ) parser.add_argument("--no-onnxsim", action="store_true", help="use onnxsim or not") parser.add_argument( "-f", "--exp_file", default=None, type=str, help="expriment description file", ) parser.add_argument("-expn", "--experiment-name", type=str, default=None) parser.add_argument("-n", "--name", type=str, default=None, help="model name") parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path") parser.add_argument( "opts", help="Modify config_files options using the command-line", default=None, nargs=argparse.REMAINDER, ) return parser @logger.catch def main(): args = make_parser().parse_args() logger.info("args value: {}".format(args)) exp = get_exp(args.exp_file, args.name) exp.merge(args.opts) if not args.experiment_name: args.experiment_name = exp.exp_name model = exp.get_model() if args.ckpt is None: file_name = os.path.join(exp.output_dir, args.experiment_name) ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar") else: ckpt_file = args.ckpt # load the model state dict ckpt = torch.load(ckpt_file, map_location="cpu") model.eval() if "model" in ckpt: ckpt = ckpt["model"] model.load_state_dict(ckpt) model = replace_module(model, nn.SiLU, SiLU) model.head.decode_in_inference = False logger.info("loading checkpoint done.") dummy_input = torch.randn(1, 3, exp.test_size[0], exp.test_size[1]) torch.onnx._export( model, dummy_input, args.output_name, input_names=[args.input], output_names=[args.output], opset_version=args.opset, ) logger.info("generated onnx model named {}".format(args.output_name)) if not args.no_onnxsim: import onnx from onnxsim import simplify # use onnxsimplify to reduce reduent model. onnx_model = onnx.load(args.output_name) model_simp, check = simplify(onnx_model) assert check, "Simplified ONNX model could not be validated" onnx.save(model_simp, args.output_name) logger.info("generated simplified onnx model named {}".format(args.output_name)) if __name__ == "__main__": main()