export_model.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. # add python path of PadleDetection to sys.path
  20. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  21. sys.path.insert(0, parent_path)
  22. # ignore warning log
  23. import warnings
  24. warnings.filterwarnings('ignore')
  25. import paddle
  26. from ppdet.core.workspace import load_config, merge_config
  27. from ppdet.utils.check import check_gpu, check_version, check_config
  28. from ppdet.utils.cli import ArgsParser
  29. from ppdet.engine import Trainer
  30. from ppdet.slim import build_slim_model
  31. from ppdet.utils.logger import setup_logger
  32. logger = setup_logger('export_model')
  33. def parse_args():
  34. parser = ArgsParser()
  35. parser.add_argument(
  36. "--output_dir",
  37. type=str,
  38. default="output_inference",
  39. help="Directory for storing the output model files.")
  40. parser.add_argument(
  41. "--export_serving_model",
  42. type=bool,
  43. default=False,
  44. help="Whether to export serving model or not.")
  45. parser.add_argument(
  46. "--slim_config",
  47. default=None,
  48. type=str,
  49. help="Configuration file of slim method.")
  50. args = parser.parse_args()
  51. return args
  52. def run(FLAGS, cfg):
  53. # build detector
  54. trainer = Trainer(cfg, mode='test')
  55. # load weights
  56. if cfg.architecture in ['DeepSORT', 'ByteTrack']:
  57. trainer.load_weights_sde(cfg.det_weights, cfg.reid_weights)
  58. else:
  59. trainer.load_weights(cfg.weights)
  60. # export model
  61. trainer.export(FLAGS.output_dir)
  62. if FLAGS.export_serving_model:
  63. from paddle_serving_client.io import inference_model_to_serving
  64. model_name = os.path.splitext(os.path.split(cfg.filename)[-1])[0]
  65. inference_model_to_serving(
  66. dirname="{}/{}".format(FLAGS.output_dir, model_name),
  67. serving_server="{}/{}/serving_server".format(FLAGS.output_dir,
  68. model_name),
  69. serving_client="{}/{}/serving_client".format(FLAGS.output_dir,
  70. model_name),
  71. model_filename="model.pdmodel",
  72. params_filename="model.pdiparams")
  73. def main():
  74. paddle.set_device("cpu")
  75. FLAGS = parse_args()
  76. cfg = load_config(FLAGS.config)
  77. # TODO: to be refined in the future
  78. if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn':
  79. FLAGS.opt['norm_type'] = 'bn'
  80. merge_config(FLAGS.opt)
  81. if FLAGS.slim_config:
  82. cfg = build_slim_model(cfg, FLAGS.slim_config, mode='test')
  83. # FIXME: Temporarily solve the priority problem of FLAGS.opt
  84. merge_config(FLAGS.opt)
  85. check_config(cfg)
  86. check_gpu(cfg.use_gpu)
  87. check_version()
  88. run(FLAGS, cfg)
  89. if __name__ == '__main__':
  90. main()