post_quant.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. # add python path of PadleDetection to sys.path
  20. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  21. sys.path.insert(0, parent_path)
  22. # ignore warning log
  23. import warnings
  24. warnings.filterwarnings('ignore')
  25. import paddle
  26. from ppdet.core.workspace import load_config, merge_config
  27. from ppdet.utils.check import check_gpu, check_version, check_config
  28. from ppdet.utils.cli import ArgsParser
  29. from ppdet.engine import Trainer
  30. from ppdet.slim import build_slim_model
  31. from ppdet.utils.logger import setup_logger
  32. logger = setup_logger('post_quant')
  33. def parse_args():
  34. parser = ArgsParser()
  35. parser.add_argument(
  36. "--output_dir",
  37. type=str,
  38. default="output_inference",
  39. help="Directory for storing the output model files.")
  40. parser.add_argument(
  41. "--slim_config",
  42. default=None,
  43. type=str,
  44. help="Configuration file of slim method.")
  45. args = parser.parse_args()
  46. return args
  47. def run(FLAGS, cfg):
  48. # build detector
  49. trainer = Trainer(cfg, mode='eval')
  50. # load weights
  51. if cfg.architecture in ['DeepSORT']:
  52. if cfg.det_weights != 'None':
  53. trainer.load_weights_sde(cfg.det_weights, cfg.reid_weights)
  54. else:
  55. trainer.load_weights_sde(None, cfg.reid_weights)
  56. else:
  57. trainer.load_weights(cfg.weights)
  58. # post quant model
  59. trainer.post_quant(FLAGS.output_dir)
  60. def main():
  61. FLAGS = parse_args()
  62. cfg = load_config(FLAGS.config)
  63. # TODO: to be refined in the future
  64. if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn':
  65. FLAGS.opt['norm_type'] = 'bn'
  66. merge_config(FLAGS.opt)
  67. if FLAGS.slim_config:
  68. cfg = build_slim_model(cfg, FLAGS.slim_config, mode='test')
  69. # FIXME: Temporarily solve the priority problem of FLAGS.opt
  70. merge_config(FLAGS.opt)
  71. check_config(cfg)
  72. check_gpu(cfg.use_gpu)
  73. check_version()
  74. run(FLAGS, cfg)
  75. if __name__ == '__main__':
  76. main()