3
0

quant.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from paddle.utils import try_import
  18. from ppdet.core.workspace import register, serializable
  19. from ppdet.utils.logger import setup_logger
  20. logger = setup_logger(__name__)
  21. @register
  22. @serializable
  23. class QAT(object):
  24. def __init__(self, quant_config, print_model):
  25. super(QAT, self).__init__()
  26. self.quant_config = quant_config
  27. self.print_model = print_model
  28. def __call__(self, model):
  29. paddleslim = try_import('paddleslim')
  30. self.quanter = paddleslim.dygraph.quant.QAT(config=self.quant_config)
  31. if self.print_model:
  32. logger.info("Model before quant:")
  33. logger.info(model)
  34. self.quanter.quantize(model)
  35. if self.print_model:
  36. logger.info("Quantized model:")
  37. logger.info(model)
  38. return model
  39. def save_quantized_model(self, layer, path, input_spec=None, **config):
  40. self.quanter.save_quantized_model(
  41. model=layer, path=path, input_spec=input_spec, **config)
  42. @register
  43. @serializable
  44. class PTQ(object):
  45. def __init__(self,
  46. ptq_config,
  47. quant_batch_num=10,
  48. output_dir='output_inference',
  49. fuse=True,
  50. fuse_list=None):
  51. super(PTQ, self).__init__()
  52. self.ptq_config = ptq_config
  53. self.quant_batch_num = quant_batch_num
  54. self.output_dir = output_dir
  55. self.fuse = fuse
  56. self.fuse_list = fuse_list
  57. def __call__(self, model):
  58. paddleslim = try_import('paddleslim')
  59. self.ptq = paddleslim.PTQ(**self.ptq_config)
  60. model.eval()
  61. quant_model = self.ptq.quantize(
  62. model, fuse=self.fuse, fuse_list=self.fuse_list)
  63. return quant_model
  64. def save_quantized_model(self,
  65. quant_model,
  66. quantize_model_path,
  67. input_spec=None):
  68. self.ptq.save_quantized_model(quant_model, quantize_model_path,
  69. input_spec)