3
0

prune.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle.utils import try_import
  19. from ppdet.core.workspace import register, serializable
  20. from ppdet.utils.logger import setup_logger
  21. logger = setup_logger(__name__)
  22. def print_prune_params(model):
  23. model_dict = model.state_dict()
  24. for key in model_dict.keys():
  25. weight_name = model_dict[key].name
  26. logger.info('Parameter name: {}, shape: {}'.format(
  27. weight_name, model_dict[key].shape))
  28. @register
  29. @serializable
  30. class Pruner(object):
  31. def __init__(self,
  32. criterion,
  33. pruned_params,
  34. pruned_ratios,
  35. print_params=False):
  36. super(Pruner, self).__init__()
  37. assert criterion in ['l1_norm', 'fpgm'], \
  38. "unsupported prune criterion: {}".format(criterion)
  39. self.criterion = criterion
  40. self.pruned_params = pruned_params
  41. self.pruned_ratios = pruned_ratios
  42. self.print_params = print_params
  43. def __call__(self, model):
  44. # FIXME: adapt to network graph when Training and inference are
  45. # inconsistent, now only supports prune inference network graph.
  46. model.eval()
  47. paddleslim = try_import('paddleslim')
  48. from paddleslim.analysis import dygraph_flops as flops
  49. input_spec = [{
  50. "image": paddle.ones(
  51. shape=[1, 3, 640, 640], dtype='float32'),
  52. "im_shape": paddle.full(
  53. [1, 2], 640, dtype='float32'),
  54. "scale_factor": paddle.ones(
  55. shape=[1, 2], dtype='float32')
  56. }]
  57. if self.print_params:
  58. print_prune_params(model)
  59. ori_flops = flops(model, input_spec) / (1000**3)
  60. logger.info("FLOPs before pruning: {}GFLOPs".format(ori_flops))
  61. if self.criterion == 'fpgm':
  62. pruner = paddleslim.dygraph.FPGMFilterPruner(model, input_spec)
  63. elif self.criterion == 'l1_norm':
  64. pruner = paddleslim.dygraph.L1NormFilterPruner(model, input_spec)
  65. logger.info("pruned params: {}".format(self.pruned_params))
  66. pruned_ratios = [float(n) for n in self.pruned_ratios]
  67. ratios = {}
  68. for i, param in enumerate(self.pruned_params):
  69. ratios[param] = pruned_ratios[i]
  70. pruner.prune_vars(ratios, [0])
  71. pruned_flops = flops(model, input_spec) / (1000**3)
  72. logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format(
  73. pruned_flops, (ori_flops - pruned_flops) / ori_flops))
  74. return model