infer.py 11 KB


  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. # add python path of PadleDetection to sys.path
  20. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  21. if parent_path not in sys.path:
  22. sys.path.append(parent_path)
  23. import glob
  24. import numpy as np
  25. import six
  26. from PIL import Image, ImageOps
  27. from paddle import fluid
  28. import logging
  29. FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
  30. logging.basicConfig(level=logging.INFO, format=FORMAT)
  31. logger = logging.getLogger(__name__)
  32. try:
  33. from ppdet.core.workspace import load_config, merge_config, create
  34. from ppdet.utils.eval_utils import parse_fetches
  35. from ppdet.utils.cli import ArgsParser
  36. from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_version, check_config, enable_static_mode
  37. from ppdet.utils.visualizer import visualize_results
  38. import ppdet.utils.checkpoint as checkpoint
  39. from ppdet.data.reader import create_reader
  40. except ImportError as e:
  41. if sys.argv[0].find('static') >= 0:
  42. logger.error("Importing ppdet failed when running static model "
  43. "with error: {}\n"
  44. "please try:\n"
  45. "\t1. run static model under PaddleDetection/static "
  46. "directory\n"
  47. "\t2. run 'pip uninstall ppdet' to uninstall ppdet "
  48. "dynamic version firstly.".format(e))
  49. sys.exit(-1)
  50. else:
  51. raise e
  52. def get_save_image_name(output_dir, image_path):
  53. """
  54. Get save image name from source image path.
  55. """
  56. if not os.path.exists(output_dir):
  57. os.makedirs(output_dir)
  58. image_name = os.path.split(image_path)[-1]
  59. name, ext = os.path.splitext(image_name)
  60. return os.path.join(output_dir, "{}".format(name)) + ext
  61. def get_test_images(infer_dir, infer_img):
  62. """
  63. Get image path list in TEST mode
  64. """
  65. assert infer_img is not None or infer_dir is not None, \
  66. "--infer_img or --infer_dir should be set"
  67. assert infer_img is None or os.path.isfile(infer_img), \
  68. "{} is not a file".format(infer_img)
  69. assert infer_dir is None or os.path.isdir(infer_dir), \
  70. "{} is not a directory".format(infer_dir)
  71. # infer_img has a higher priority
  72. if infer_img and os.path.isfile(infer_img):
  73. return [infer_img]
  74. images = set()
  75. infer_dir = os.path.abspath(infer_dir)
  76. assert os.path.isdir(infer_dir), \
  77. "infer_dir {} is not a directory".format(infer_dir)
  78. exts = ['jpg', 'jpeg', 'png', 'bmp']
  79. exts += [ext.upper() for ext in exts]
  80. for ext in exts:
  81. images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
  82. images = list(images)
  83. assert len(images) > 0, "no image found in {}".format(infer_dir)
  84. logger.info("Found {} inference images in total.".format(len(images)))
  85. return images
  86. def main():
  87. env = os.environ
  88. cfg = load_config(FLAGS.config)
  89. merge_config(FLAGS.opt)
  90. check_config(cfg)
  91. # check if set use_gpu=True in paddlepaddle cpu version
  92. check_gpu(cfg.use_gpu)
  93. # disable npu in config by default and check use_npu
  94. if 'use_npu' not in cfg:
  95. cfg.use_npu = False
  96. check_npu(cfg.use_npu)
  97. # disable xpu in config by default and check use_xpu
  98. if 'use_xpu' not in cfg:
  99. cfg.use_xpu = False
  100. check_xpu(cfg.use_xpu)
  101. # check if paddlepaddle version is satisfied
  102. check_version()
  103. main_arch = cfg.architecture
  104. dataset = cfg.TestReader['dataset']
  105. test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
  106. dataset.set_images(test_images)
  107. if cfg.use_gpu and 'FLAGS_selected_gpus' in env:
  108. device_id = int(env['FLAGS_selected_gpus'])
  109. elif cfg.use_npu and 'FLAGS_selected_npus' in env:
  110. device_id = int(env['FLAGS_selected_npus'])
  111. elif cfg.use_xpu and 'FLAGS_selected_xpus' in env:
  112. device_id = int(env['FLAGS_selected_xpus'])
  113. else:
  114. device_id = 0
  115. # define executor
  116. if cfg.use_gpu:
  117. place = fluid.CUDAPlace(device_id)
  118. elif cfg.use_npu:
  119. place = fluid.NPUPlace(device_id)
  120. elif cfg.use_xpu:
  121. place = fluid.XPUPlace(device_id)
  122. else:
  123. place = fluid.CPUPlace()
  124. exe = fluid.Executor(place)
  125. model = create(main_arch)
  126. startup_prog = fluid.Program()
  127. infer_prog = fluid.Program()
  128. with fluid.program_guard(infer_prog, startup_prog):
  129. with fluid.unique_name.guard():
  130. inputs_def = cfg['TestReader']['inputs_def']
  131. inputs_def['iterable'] = True
  132. feed_vars, loader = model.build_inputs(**inputs_def)
  133. test_fetches = model.test(feed_vars)
  134. infer_prog = infer_prog.clone(True)
  135. reader = create_reader(cfg.TestReader, devices_num=1)
  136. loader.set_sample_list_generator(reader, place)
  137. exe.run(startup_prog)
  138. if cfg.weights:
  139. checkpoint.load_params(exe, infer_prog, cfg.weights)
  140. # parse infer fetches
  141. assert cfg.metric in ['COCO', 'VOC', 'OID', 'WIDERFACE'], \
  142. "unknown metric type {}".format(cfg.metric)
  143. extra_keys = []
  144. if cfg['metric'] in ['COCO', 'OID']:
  145. extra_keys = ['im_info', 'im_id', 'im_shape']
  146. if cfg['metric'] == 'VOC' or cfg['metric'] == 'WIDERFACE':
  147. extra_keys = ['im_id', 'im_shape']
  148. keys, values, _ = parse_fetches(test_fetches, infer_prog, extra_keys)
  149. # parse dataset category
  150. if cfg.metric == 'COCO':
  151. from ppdet.utils.coco_eval import bbox2out, mask2out, segm2out, get_category_info
  152. if cfg.metric == 'OID':
  153. from ppdet.utils.oid_eval import bbox2out, get_category_info
  154. if cfg.metric == "VOC":
  155. from ppdet.utils.voc_eval import bbox2out, get_category_info
  156. if cfg.metric == "WIDERFACE":
  157. from ppdet.utils.widerface_eval_utils import bbox2out, lmk2out, get_category_info
  158. anno_file = dataset.get_anno()
  159. with_background = dataset.with_background
  160. use_default_label = dataset.use_default_label
  161. clsid2catid, catid2name = get_category_info(anno_file, with_background,
  162. use_default_label)
  163. # whether output bbox is normalized in model output layer
  164. is_bbox_normalized = False
  165. if hasattr(model, 'is_bbox_normalized') and \
  166. callable(model.is_bbox_normalized):
  167. is_bbox_normalized = model.is_bbox_normalized()
  168. # use VisualDL to log image
  169. if FLAGS.use_vdl:
  170. assert six.PY3, "VisualDL requires Python >= 3.5"
  171. from visualdl import LogWriter
  172. vdl_writer = LogWriter(FLAGS.vdl_log_dir)
  173. vdl_image_step = 0
  174. vdl_image_frame = 0 # each frame can display ten pictures at most.
  175. imid2path = dataset.get_imid2path()
  176. for iter_id, data in enumerate(loader()):
  177. outs = exe.run(infer_prog,
  178. feed=data,
  179. fetch_list=values,
  180. return_numpy=False)
  181. res = {
  182. k: (np.array(v), v.recursive_sequence_lengths())
  183. for k, v in zip(keys, outs)
  184. }
  185. logger.info('Infer iter {}'.format(iter_id))
  186. if 'TTFNet' in cfg.architecture:
  187. res['bbox'][1].append([len(res['bbox'][0])])
  188. if 'CornerNet' in cfg.architecture:
  189. from ppdet.utils.post_process import corner_post_process
  190. post_config = getattr(cfg, 'PostProcess', None)
  191. corner_post_process(res, post_config, cfg.num_classes)
  192. bbox_results = None
  193. mask_results = None
  194. segm_results = None
  195. lmk_results = None
  196. if 'bbox' in res:
  197. bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized)
  198. if 'mask' in res:
  199. mask_results = mask2out([res], clsid2catid,
  200. model.mask_head.resolution)
  201. if 'segm' in res:
  202. segm_results = segm2out([res], clsid2catid)
  203. if 'landmark' in res:
  204. lmk_results = lmk2out([res], is_bbox_normalized)
  205. # visualize result
  206. im_ids = res['im_id'][0]
  207. for im_id in im_ids:
  208. image_path = imid2path[int(im_id)]
  209. image = Image.open(image_path).convert('RGB')
  210. image = ImageOps.exif_transpose(image)
  211. # use VisualDL to log original image
  212. if FLAGS.use_vdl:
  213. original_image_np = np.array(image)
  214. vdl_writer.add_image(
  215. "original/frame_{}".format(vdl_image_frame),
  216. original_image_np, vdl_image_step)
  217. image = visualize_results(image,
  218. int(im_id), catid2name,
  219. FLAGS.draw_threshold, bbox_results,
  220. mask_results, segm_results, lmk_results)
  221. # use VisualDL to log image with bbox
  222. if FLAGS.use_vdl:
  223. infer_image_np = np.array(image)
  224. vdl_writer.add_image("bbox/frame_{}".format(vdl_image_frame),
  225. infer_image_np, vdl_image_step)
  226. vdl_image_step += 1
  227. if vdl_image_step % 10 == 0:
  228. vdl_image_step = 0
  229. vdl_image_frame += 1
  230. save_name = get_save_image_name(FLAGS.output_dir, image_path)
  231. logger.info("Detection bbox results save in {}".format(save_name))
  232. image.save(save_name, quality=95)
  233. if __name__ == '__main__':
  234. enable_static_mode()
  235. parser = ArgsParser()
  236. parser.add_argument(
  237. "--infer_dir",
  238. type=str,
  239. default=None,
  240. help="Directory for images to perform inference on.")
  241. parser.add_argument(
  242. "--infer_img",
  243. type=str,
  244. default=None,
  245. help="Image path, has higher priority over --infer_dir")
  246. parser.add_argument(
  247. "--output_dir",
  248. type=str,
  249. default="output",
  250. help="Directory for storing the output visualization files.")
  251. parser.add_argument(
  252. "--draw_threshold",
  253. type=float,
  254. default=0.5,
  255. help="Threshold to reserve the result for visualization.")
  256. parser.add_argument(
  257. "--use_vdl",
  258. type=bool,
  259. default=False,
  260. help="whether to record the data to VisualDL.")
  261. parser.add_argument(
  262. '--vdl_log_dir',
  263. type=str,
  264. default="vdl_log_dir/image",
  265. help='VisualDL logging directory for image.')
  266. FLAGS = parser.parse_args()
  267. main()