__init__.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from . import prune
  15. from . import quant
  16. from . import distill
  17. from . import unstructured_prune
  18. from .prune import *
  19. from .quant import *
  20. from .distill import *
  21. from .unstructured_prune import *
  22. from .ofa import *
  23. import yaml
  24. from ppdet.core.workspace import load_config
  25. from ppdet.utils.checkpoint import load_pretrain_weight
  26. def build_slim_model(cfg, slim_cfg, mode='train'):
  27. with open(slim_cfg) as f:
  28. slim_load_cfg = yaml.load(f, Loader=yaml.Loader)
  29. if mode != 'train' and slim_load_cfg['slim'] == 'Distill':
  30. return cfg
  31. if slim_load_cfg['slim'] == 'Distill':
  32. model = DistillModel(cfg, slim_cfg)
  33. cfg['model'] = model
  34. cfg['slim_type'] = cfg.slim
  35. elif slim_load_cfg['slim'] == 'OFA':
  36. load_config(slim_cfg)
  37. model = create(cfg.architecture)
  38. load_pretrain_weight(model, cfg.weights)
  39. slim = create(cfg.slim)
  40. cfg['slim'] = slim
  41. cfg['model'] = slim(model, model.state_dict())
  42. cfg['slim_type'] = cfg.slim
  43. elif slim_load_cfg['slim'] == 'DistillPrune':
  44. if mode == 'train':
  45. model = DistillModel(cfg, slim_cfg)
  46. pruner = create(cfg.pruner)
  47. pruner(model.student_model)
  48. else:
  49. model = create(cfg.architecture)
  50. weights = cfg.weights
  51. load_config(slim_cfg)
  52. pruner = create(cfg.pruner)
  53. model = pruner(model)
  54. load_pretrain_weight(model, weights)
  55. cfg['model'] = model
  56. cfg['slim_type'] = cfg.slim
  57. elif slim_load_cfg['slim'] == 'PTQ':
  58. model = create(cfg.architecture)
  59. load_config(slim_cfg)
  60. load_pretrain_weight(model, cfg.weights)
  61. slim = create(cfg.slim)
  62. cfg['slim'] = slim
  63. cfg['model'] = slim(model)
  64. cfg['slim_type'] = cfg.slim
  65. elif slim_load_cfg['slim'] == 'UnstructuredPruner':
  66. load_config(slim_cfg)
  67. slim = create(cfg.slim)
  68. cfg['slim_type'] = cfg.slim
  69. cfg['slim'] = slim
  70. cfg['unstructured_prune'] = True
  71. else:
  72. load_config(slim_cfg)
  73. model = create(cfg.architecture)
  74. if mode == 'train':
  75. load_pretrain_weight(model, cfg.pretrain_weights)
  76. slim = create(cfg.slim)
  77. cfg['slim_type'] = cfg.slim
  78. # TODO: fix quant export model in framework.
  79. if mode == 'test' and slim_load_cfg['slim'] == 'QAT':
  80. slim.quant_config['activation_preprocess_type'] = None
  81. cfg['model'] = slim(model)
  82. cfg['slim'] = slim
  83. if mode != 'train':
  84. load_pretrain_weight(cfg['model'], cfg.weights)
  85. return cfg