123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- from loguru import logger
- import torch
- from torch import nn
- from yolox.exp import get_exp
- from yolox.models.network_blocks import SiLU
- from yolox.utils import replace_module
- import argparse
- import os
- def make_parser():
- parser = argparse.ArgumentParser("YOLOX onnx deploy")
- parser.add_argument(
- "--output-name", type=str, default="bytetrack_s.onnx", help="output name of models"
- )
- parser.add_argument(
- "--input", default="images", type=str, help="input node name of onnx model"
- )
- parser.add_argument(
- "--output", default="output", type=str, help="output node name of onnx model"
- )
- parser.add_argument(
- "-o", "--opset", default=11, type=int, help="onnx opset version"
- )
- parser.add_argument("--no-onnxsim", action="store_true", help="use onnxsim or not")
- parser.add_argument(
- "-f",
- "--exp_file",
- default=None,
- type=str,
- help="expriment description file",
- )
- parser.add_argument("-expn", "--experiment-name", type=str, default=None)
- parser.add_argument("-n", "--name", type=str, default=None, help="model name")
- parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
- parser.add_argument(
- "opts",
- help="Modify config_files options using the command-line",
- default=None,
- nargs=argparse.REMAINDER,
- )
- return parser
- @logger.catch
- def main():
- args = make_parser().parse_args()
- logger.info("args value: {}".format(args))
- exp = get_exp(args.exp_file, args.name)
- exp.merge(args.opts)
- if not args.experiment_name:
- args.experiment_name = exp.exp_name
- model = exp.get_model()
- if args.ckpt is None:
- file_name = os.path.join(exp.output_dir, args.experiment_name)
- ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar")
- else:
- ckpt_file = args.ckpt
- # load the model state dict
- ckpt = torch.load(ckpt_file, map_location="cpu")
- model.eval()
- if "model" in ckpt:
- ckpt = ckpt["model"]
- model.load_state_dict(ckpt)
- model = replace_module(model, nn.SiLU, SiLU)
- model.head.decode_in_inference = False
- logger.info("loading checkpoint done.")
- dummy_input = torch.randn(1, 3, exp.test_size[0], exp.test_size[1])
- torch.onnx._export(
- model,
- dummy_input,
- args.output_name,
- input_names=[args.input],
- output_names=[args.output],
- opset_version=args.opset,
- )
- logger.info("generated onnx model named {}".format(args.output_name))
- if not args.no_onnxsim:
- import onnx
- from onnxsim import simplify
- # use onnxsimplify to reduce reduent model.
- onnx_model = onnx.load(args.output_name)
- model_simp, check = simplify(onnx_model)
- assert check, "Simplified ONNX model could not be validated"
- onnx.save(model_simp, args.output_name)
- logger.info("generated simplified onnx model named {}".format(args.output_name))
- if __name__ == "__main__":
- main()
|