infer.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. # add python path of PadleDetection to sys.path
  20. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  21. if parent_path not in sys.path:
  22. sys.path.append(parent_path)
  23. import glob
  24. import numpy as np
  25. from PIL import Image
  26. from paddle import fluid
  27. from paddleslim.prune import Pruner
  28. from paddleslim.analysis import flops
  29. from ppdet.core.workspace import load_config, merge_config, create
  30. from ppdet.utils.eval_utils import parse_fetches
  31. from ppdet.utils.cli import ArgsParser
  32. from ppdet.utils.check import check_gpu, check_version, check_config, enable_static_mode
  33. from ppdet.utils.visualizer import visualize_results
  34. import ppdet.utils.checkpoint as checkpoint
  35. from ppdet.data.reader import create_reader
  36. import logging
  37. FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
  38. logging.basicConfig(level=logging.INFO, format=FORMAT)
  39. logger = logging.getLogger(__name__)
  40. def get_save_image_name(output_dir, image_path):
  41. """
  42. Get save image name from source image path.
  43. """
  44. if not os.path.exists(output_dir):
  45. os.makedirs(output_dir)
  46. image_name = os.path.split(image_path)[-1]
  47. name, ext = os.path.splitext(image_name)
  48. return os.path.join(output_dir, "{}".format(name)) + ext
  49. def get_test_images(infer_dir, infer_img):
  50. """
  51. Get image path list in TEST mode
  52. """
  53. assert infer_img is not None or infer_dir is not None, \
  54. "--infer_img or --infer_dir should be set"
  55. assert infer_img is None or os.path.isfile(infer_img), \
  56. "{} is not a file".format(infer_img)
  57. assert infer_dir is None or os.path.isdir(infer_dir), \
  58. "{} is not a directory".format(infer_dir)
  59. # infer_img has a higher priority
  60. if infer_img and os.path.isfile(infer_img):
  61. return [infer_img]
  62. images = set()
  63. infer_dir = os.path.abspath(infer_dir)
  64. assert os.path.isdir(infer_dir), \
  65. "infer_dir {} is not a directory".format(infer_dir)
  66. exts = ['jpg', 'jpeg', 'png', 'bmp']
  67. exts += [ext.upper() for ext in exts]
  68. for ext in exts:
  69. images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
  70. images = list(images)
  71. assert len(images) > 0, "no image found in {}".format(infer_dir)
  72. logger.info("Found {} inference images in total.".format(len(images)))
  73. return images
  74. def main():
  75. cfg = load_config(FLAGS.config)
  76. merge_config(FLAGS.opt)
  77. check_config(cfg)
  78. # check if set use_gpu=True in paddlepaddle cpu version
  79. check_gpu(cfg.use_gpu)
  80. # check if paddlepaddle version is satisfied
  81. check_version()
  82. main_arch = cfg.architecture
  83. dataset = cfg.TestReader['dataset']
  84. test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
  85. dataset.set_images(test_images)
  86. place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
  87. exe = fluid.Executor(place)
  88. model = create(main_arch)
  89. startup_prog = fluid.Program()
  90. infer_prog = fluid.Program()
  91. with fluid.program_guard(infer_prog, startup_prog):
  92. with fluid.unique_name.guard():
  93. inputs_def = cfg['TestReader']['inputs_def']
  94. inputs_def['iterable'] = True
  95. feed_vars, loader = model.build_inputs(**inputs_def)
  96. test_fetches = model.test(feed_vars)
  97. infer_prog = infer_prog.clone(True)
  98. pruned_params = FLAGS.pruned_params
  99. assert (
  100. FLAGS.pruned_params is not None
  101. ), "FLAGS.pruned_params is empty!!! Please set it by '--pruned_params' option."
  102. pruned_params = FLAGS.pruned_params.strip().split(",")
  103. logger.info("pruned params: {}".format(pruned_params))
  104. pruned_ratios = [float(n) for n in FLAGS.pruned_ratios.strip().split(",")]
  105. logger.info("pruned ratios: {}".format(pruned_ratios))
  106. assert (len(pruned_params) == len(pruned_ratios)
  107. ), "The length of pruned params and pruned ratios should be equal."
  108. assert (pruned_ratios > [0] * len(pruned_ratios) and
  109. pruned_ratios < [1] * len(pruned_ratios)
  110. ), "The elements of pruned ratios should be in range (0, 1)."
  111. base_flops = flops(infer_prog)
  112. pruner = Pruner()
  113. infer_prog, _, _ = pruner.prune(
  114. infer_prog,
  115. fluid.global_scope(),
  116. params=pruned_params,
  117. ratios=pruned_ratios,
  118. place=place,
  119. only_graph=True)
  120. pruned_flops = flops(infer_prog)
  121. logger.info("pruned FLOPS: {}".format(
  122. float(base_flops - pruned_flops) / base_flops))
  123. reader = create_reader(cfg.TestReader, devices_num=1)
  124. loader.set_sample_list_generator(reader, place)
  125. exe.run(startup_prog)
  126. if cfg.weights:
  127. checkpoint.load_checkpoint(exe, infer_prog, cfg.weights)
  128. # parse infer fetches
  129. assert cfg.metric in ['COCO', 'VOC', 'OID', 'WIDERFACE'], \
  130. "unknown metric type {}".format(cfg.metric)
  131. extra_keys = []
  132. if cfg['metric'] in ['COCO', 'OID']:
  133. extra_keys = ['im_info', 'im_id', 'im_shape']
  134. if cfg['metric'] == 'VOC' or cfg['metric'] == 'WIDERFACE':
  135. extra_keys = ['im_id', 'im_shape']
  136. keys, values, _ = parse_fetches(test_fetches, infer_prog, extra_keys)
  137. # parse dataset category
  138. if cfg.metric == 'COCO':
  139. from ppdet.utils.coco_eval import bbox2out, mask2out, get_category_info
  140. if cfg.metric == 'OID':
  141. from ppdet.utils.oid_eval import bbox2out, get_category_info
  142. if cfg.metric == "VOC":
  143. from ppdet.utils.voc_eval import bbox2out, get_category_info
  144. if cfg.metric == "WIDERFACE":
  145. from ppdet.utils.widerface_eval_utils import bbox2out, get_category_info
  146. anno_file = dataset.get_anno()
  147. with_background = dataset.with_background
  148. use_default_label = dataset.use_default_label
  149. clsid2catid, catid2name = get_category_info(anno_file, with_background,
  150. use_default_label)
  151. # whether output bbox is normalized in model output layer
  152. is_bbox_normalized = False
  153. if hasattr(model, 'is_bbox_normalized') and \
  154. callable(model.is_bbox_normalized):
  155. is_bbox_normalized = model.is_bbox_normalized()
  156. imid2path = dataset.get_imid2path()
  157. for iter_id, data in enumerate(loader()):
  158. outs = exe.run(infer_prog,
  159. feed=data,
  160. fetch_list=values,
  161. return_numpy=False)
  162. res = {
  163. k: (np.array(v), v.recursive_sequence_lengths())
  164. for k, v in zip(keys, outs)
  165. }
  166. logger.info('Infer iter {}'.format(iter_id))
  167. bbox_results = None
  168. mask_results = None
  169. if 'bbox' in res:
  170. bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized)
  171. if 'mask' in res:
  172. mask_results = mask2out([res], clsid2catid,
  173. model.mask_head.resolution)
  174. # visualize result
  175. im_ids = res['im_id'][0]
  176. for im_id in im_ids:
  177. image_path = imid2path[int(im_id)]
  178. image = Image.open(image_path).convert('RGB')
  179. image = visualize_results(image,
  180. int(im_id), catid2name,
  181. FLAGS.draw_threshold, bbox_results,
  182. mask_results)
  183. save_name = get_save_image_name(FLAGS.output_dir, image_path)
  184. logger.info("Detection bbox results save in {}".format(save_name))
  185. image.save(save_name, quality=95)
  186. if __name__ == '__main__':
  187. enable_static_mode()
  188. parser = ArgsParser()
  189. parser.add_argument(
  190. "--infer_dir",
  191. type=str,
  192. default=None,
  193. help="Directory for images to perform inference on.")
  194. parser.add_argument(
  195. "--infer_img",
  196. type=str,
  197. default=None,
  198. help="Image path, has higher priority over --infer_dir")
  199. parser.add_argument(
  200. "--output_dir",
  201. type=str,
  202. default="output",
  203. help="Directory for storing the output visualization files.")
  204. parser.add_argument(
  205. "--draw_threshold",
  206. type=float,
  207. default=0.5,
  208. help="Threshold to reserve the result for visualization.")
  209. parser.add_argument(
  210. "-p",
  211. "--pruned_params",
  212. default=None,
  213. type=str,
  214. help="The parameters to be pruned when calculating sensitivities.")
  215. parser.add_argument(
  216. "--pruned_ratios",
  217. default=None,
  218. type=str,
  219. help="The ratios pruned iteratively for each parameter when calculating sensitivities."
  220. )
  221. FLAGS = parser.parse_args()
  222. main()