export_utils.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import yaml
  19. from collections import OrderedDict
  20. import logging
  21. logger = logging.getLogger(__name__)
  22. import paddle.fluid as fluid
  23. __all__ = ['dump_infer_config', 'save_infer_model']
  24. # Global dictionary
  25. TRT_MIN_SUBGRAPH = {
  26. 'YOLO': 3,
  27. 'SSD': 3,
  28. 'RCNN': 40,
  29. 'RetinaNet': 40,
  30. 'S2ANet': 40,
  31. 'EfficientDet': 40,
  32. 'Face': 3,
  33. 'TTFNet': 3,
  34. 'FCOS': 33,
  35. 'SOLOv2': 60,
  36. }
  37. RESIZE_SCALE_SET = {
  38. 'RCNN',
  39. 'RetinaNet',
  40. 'S2ANet',
  41. 'FCOS',
  42. 'SOLOv2',
  43. }
  44. def parse_reader(reader_cfg, metric, arch):
  45. preprocess_list = []
  46. image_shape = reader_cfg['inputs_def'].get('image_shape', [3, None, None])
  47. has_shape_def = not None in image_shape
  48. dataset = reader_cfg['dataset']
  49. anno_file = dataset.get_anno()
  50. with_background = dataset.with_background
  51. use_default_label = dataset.use_default_label
  52. if metric == 'COCO':
  53. from ppdet.utils.coco_eval import get_category_info
  54. elif metric == "VOC":
  55. from ppdet.utils.voc_eval import get_category_info
  56. elif metric == "WIDERFACE":
  57. from ppdet.utils.widerface_eval_utils import get_category_info
  58. elif cfg.metric == 'OID':
  59. from ppdet.utils.oid_eval import get_category_info
  60. else:
  61. raise ValueError(
  62. "metric only supports COCO, VOC, WIDERFACE, but received {}".format(
  63. metric))
  64. clsid2catid, catid2name = get_category_info(anno_file, with_background,
  65. use_default_label)
  66. label_list = [str(cat) for cat in catid2name.values()]
  67. sample_transforms = reader_cfg['sample_transforms']
  68. for st in sample_transforms[1:]:
  69. method = st.__class__.__name__
  70. p = {'type': method.replace('Image', '')}
  71. params = st.__dict__
  72. params.pop('_id')
  73. if p['type'] == 'Resize' and has_shape_def:
  74. params['target_size'] = min(image_shape[
  75. 1:]) if arch in RESIZE_SCALE_SET else image_shape[1]
  76. params['max_size'] = max(image_shape[
  77. 1:]) if arch in RESIZE_SCALE_SET else 0
  78. params['image_shape'] = image_shape[1:]
  79. if 'target_dim' in params:
  80. params.pop('target_dim')
  81. if p['type'] == 'ResizeAndPad':
  82. assert has_shape_def, "missing input shape"
  83. p['type'] = 'Resize'
  84. p['target_size'] = params['target_dim']
  85. p['max_size'] = params['target_dim']
  86. p['interp'] = params['interp']
  87. p['image_shape'] = image_shape[1:]
  88. preprocess_list.append(p)
  89. continue
  90. p.update(params)
  91. preprocess_list.append(p)
  92. batch_transforms = reader_cfg.get('batch_transforms', None)
  93. if batch_transforms:
  94. methods = [bt.__class__.__name__ for bt in batch_transforms]
  95. for bt in batch_transforms:
  96. method = bt.__class__.__name__
  97. if method == 'PadBatch':
  98. preprocess_list.append({'type': 'PadStride'})
  99. params = bt.__dict__
  100. preprocess_list[-1].update({'stride': params['pad_to_stride']})
  101. break
  102. return with_background, preprocess_list, label_list
  103. def dump_infer_config(FLAGS, config):
  104. arch_state = 0
  105. cfg_name = os.path.basename(FLAGS.config).split('.')[0]
  106. save_dir = os.path.join(FLAGS.output_dir, cfg_name)
  107. if not os.path.exists(save_dir):
  108. os.makedirs(save_dir)
  109. from ppdet.core.config.yaml_helpers import setup_orderdict
  110. setup_orderdict()
  111. infer_cfg = OrderedDict({
  112. 'use_python_inference': False,
  113. 'mode': 'fluid',
  114. 'draw_threshold': 0.5,
  115. 'metric': config['metric']
  116. })
  117. infer_arch = config['architecture']
  118. for arch, min_subgraph_size in TRT_MIN_SUBGRAPH.items():
  119. if arch in infer_arch:
  120. infer_cfg['arch'] = arch
  121. infer_cfg['min_subgraph_size'] = min_subgraph_size
  122. arch_state = 1
  123. break
  124. if not arch_state:
  125. logger.error(
  126. 'Architecture: {} is not supported for exporting model now'.format(
  127. infer_arch))
  128. os._exit(0)
  129. # support land mark output
  130. if 'with_lmk' in config and config['with_lmk'] == True:
  131. infer_cfg['with_lmk'] = True
  132. if 'Mask' in config['architecture']:
  133. infer_cfg['mask_resolution'] = config['MaskHead']['resolution']
  134. infer_cfg['with_background'], infer_cfg['Preprocess'], infer_cfg[
  135. 'label_list'] = parse_reader(config['TestReader'], config['metric'],
  136. infer_cfg['arch'])
  137. yaml.dump(infer_cfg, open(os.path.join(save_dir, 'infer_cfg.yml'), 'w'))
  138. logger.info("Export inference config file to {}".format(
  139. os.path.join(save_dir, 'infer_cfg.yml')))
  140. def prune_feed_vars(feeded_var_names, target_vars, prog):
  141. """
  142. Filter out feed variables which are not in program,
  143. pruned feed variables are only used in post processing
  144. on model output, which are not used in program, such
  145. as im_id to identify image order, im_shape to clip bbox
  146. in image.
  147. """
  148. exist_var_names = []
  149. prog = prog.clone()
  150. prog = prog._prune(targets=target_vars)
  151. global_block = prog.global_block()
  152. for name in feeded_var_names:
  153. try:
  154. v = global_block.var(name)
  155. exist_var_names.append(str(v.name))
  156. except Exception:
  157. logger.info('save_inference_model pruned unused feed '
  158. 'variables {}'.format(name))
  159. pass
  160. return exist_var_names
  161. def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog):
  162. cfg_name = os.path.basename(FLAGS.config).split('.')[0]
  163. save_dir = os.path.join(FLAGS.output_dir, cfg_name)
  164. feed_var_names = [var.name for var in feed_vars.values()]
  165. fetch_list = sorted(test_fetches.items(), key=lambda i: i[0])
  166. target_vars = [var[1] for var in fetch_list]
  167. feed_var_names = prune_feed_vars(feed_var_names, target_vars, infer_prog)
  168. logger.info("Export inference model to {}, input: {}, output: "
  169. "{}...".format(save_dir, feed_var_names,
  170. [str(var.name) for var in target_vars]))
  171. fluid.io.save_inference_model(
  172. save_dir,
  173. feeded_var_names=feed_var_names,
  174. target_vars=target_vars,
  175. executor=exe,
  176. main_program=infer_prog,
  177. params_filename="__params__")