export_onnx.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from loguru import logger
  2. import torch
  3. from torch import nn
  4. from yolox.exp import get_exp
  5. from yolox.models.network_blocks import SiLU
  6. from yolox.utils import replace_module
  7. import argparse
  8. import os
  9. def make_parser():
  10. parser = argparse.ArgumentParser("YOLOX onnx deploy")
  11. parser.add_argument(
  12. "--output-name", type=str, default="bytetrack_s.onnx", help="output name of models"
  13. )
  14. parser.add_argument(
  15. "--input", default="images", type=str, help="input node name of onnx model"
  16. )
  17. parser.add_argument(
  18. "--output", default="output", type=str, help="output node name of onnx model"
  19. )
  20. parser.add_argument(
  21. "-o", "--opset", default=11, type=int, help="onnx opset version"
  22. )
  23. parser.add_argument("--no-onnxsim", action="store_true", help="use onnxsim or not")
  24. parser.add_argument(
  25. "-f",
  26. "--exp_file",
  27. default=None,
  28. type=str,
  29. help="expriment description file",
  30. )
  31. parser.add_argument("-expn", "--experiment-name", type=str, default=None)
  32. parser.add_argument("-n", "--name", type=str, default=None, help="model name")
  33. parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
  34. parser.add_argument(
  35. "opts",
  36. help="Modify config_files options using the command-line",
  37. default=None,
  38. nargs=argparse.REMAINDER,
  39. )
  40. return parser
  41. @logger.catch
  42. def main():
  43. args = make_parser().parse_args()
  44. logger.info("args value: {}".format(args))
  45. exp = get_exp(args.exp_file, args.name)
  46. exp.merge(args.opts)
  47. if not args.experiment_name:
  48. args.experiment_name = exp.exp_name
  49. model = exp.get_model()
  50. if args.ckpt is None:
  51. file_name = os.path.join(exp.output_dir, args.experiment_name)
  52. ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar")
  53. else:
  54. ckpt_file = args.ckpt
  55. # load the model state dict
  56. ckpt = torch.load(ckpt_file, map_location="cpu")
  57. model.eval()
  58. if "model" in ckpt:
  59. ckpt = ckpt["model"]
  60. model.load_state_dict(ckpt)
  61. model = replace_module(model, nn.SiLU, SiLU)
  62. model.head.decode_in_inference = False
  63. logger.info("loading checkpoint done.")
  64. dummy_input = torch.randn(1, 3, exp.test_size[0], exp.test_size[1])
  65. torch.onnx._export(
  66. model,
  67. dummy_input,
  68. args.output_name,
  69. input_names=[args.input],
  70. output_names=[args.output],
  71. opset_version=args.opset,
  72. )
  73. logger.info("generated onnx model named {}".format(args.output_name))
  74. if not args.no_onnxsim:
  75. import onnx
  76. from onnxsim import simplify
  77. # use onnxsimplify to reduce reduent model.
  78. onnx_model = onnx.load(args.output_name)
  79. model_simp, check = simplify(onnx_model)
  80. assert check, "Simplified ONNX model could not be validated"
  81. onnx.save(model_simp, args.output_name)
  82. logger.info("generated simplified onnx model named {}".format(args.output_name))
  83. if __name__ == "__main__":
  84. main()