infer.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. # add python path of PadleDetection to sys.path
  20. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  21. sys.path.insert(0, parent_path)
  22. # ignore warning log
  23. import warnings
  24. warnings.filterwarnings('ignore')
  25. import glob
  26. import paddle
  27. from ppdet.core.workspace import load_config, merge_config
  28. from ppdet.engine import Trainer
  29. from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_version, check_config
  30. from ppdet.utils.cli import ArgsParser
  31. from ppdet.slim import build_slim_model
  32. from ppdet.utils.logger import setup_logger
  33. logger = setup_logger('train')
  34. def parse_args():
  35. parser = ArgsParser()
  36. parser.add_argument(
  37. "--infer_dir",
  38. type=str,
  39. default=None,
  40. help="Directory for images to perform inference on.")
  41. parser.add_argument(
  42. "--infer_img",
  43. type=str,
  44. default=None,
  45. help="Image path, has higher priority over --infer_dir")
  46. parser.add_argument(
  47. "--output_dir",
  48. type=str,
  49. default="output",
  50. help="Directory for storing the output visualization files.")
  51. parser.add_argument(
  52. "--draw_threshold",
  53. type=float,
  54. default=0.5,
  55. help="Threshold to reserve the result for visualization.")
  56. parser.add_argument(
  57. "--slim_config",
  58. default=None,
  59. type=str,
  60. help="Configuration file of slim method.")
  61. parser.add_argument(
  62. "--use_vdl",
  63. type=bool,
  64. default=False,
  65. help="Whether to record the data to VisualDL.")
  66. parser.add_argument(
  67. '--vdl_log_dir',
  68. type=str,
  69. default="vdl_log_dir/image",
  70. help='VisualDL logging directory for image.')
  71. parser.add_argument(
  72. "--save_results",
  73. type=bool,
  74. default=False,
  75. help="Whether to save inference results to output_dir.")
  76. args = parser.parse_args()
  77. return args
  78. def get_test_images(infer_dir, infer_img):
  79. """
  80. Get image path list in TEST mode
  81. """
  82. assert infer_img is not None or infer_dir is not None, \
  83. "--infer_img or --infer_dir should be set"
  84. assert infer_img is None or os.path.isfile(infer_img), \
  85. "{} is not a file".format(infer_img)
  86. assert infer_dir is None or os.path.isdir(infer_dir), \
  87. "{} is not a directory".format(infer_dir)
  88. # infer_img has a higher priority
  89. if infer_img and os.path.isfile(infer_img):
  90. return [infer_img]
  91. images = set()
  92. infer_dir = os.path.abspath(infer_dir)
  93. assert os.path.isdir(infer_dir), \
  94. "infer_dir {} is not a directory".format(infer_dir)
  95. exts = ['jpg', 'jpeg', 'png', 'bmp']
  96. exts += [ext.upper() for ext in exts]
  97. for ext in exts:
  98. images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
  99. images = list(images)
  100. assert len(images) > 0, "no image found in {}".format(infer_dir)
  101. logger.info("Found {} inference images in total.".format(len(images)))
  102. return images
  103. def run(FLAGS, cfg):
  104. # build trainer
  105. trainer = Trainer(cfg, mode='test')
  106. # load weights
  107. trainer.load_weights(cfg.weights)
  108. # get inference images
  109. images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
  110. # inference
  111. trainer.predict(
  112. images,
  113. draw_threshold=FLAGS.draw_threshold,
  114. output_dir=FLAGS.output_dir,
  115. save_results=FLAGS.save_results)
  116. def main():
  117. FLAGS = parse_args()
  118. cfg = load_config(FLAGS.config)
  119. cfg['use_vdl'] = FLAGS.use_vdl
  120. cfg['vdl_log_dir'] = FLAGS.vdl_log_dir
  121. merge_config(FLAGS.opt)
  122. # disable npu in config by default
  123. if 'use_npu' not in cfg:
  124. cfg.use_npu = False
  125. # disable xpu in config by default
  126. if 'use_xpu' not in cfg:
  127. cfg.use_xpu = False
  128. if cfg.use_gpu:
  129. place = paddle.set_device('gpu')
  130. elif cfg.use_npu:
  131. place = paddle.set_device('npu')
  132. elif cfg.use_xpu:
  133. place = paddle.set_device('xpu')
  134. else:
  135. place = paddle.set_device('cpu')
  136. if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu:
  137. cfg['norm_type'] = 'bn'
  138. if FLAGS.slim_config:
  139. cfg = build_slim_model(cfg, FLAGS.slim_config, mode='test')
  140. check_config(cfg)
  141. check_gpu(cfg.use_gpu)
  142. check_npu(cfg.use_npu)
  143. check_xpu(cfg.use_xpu)
  144. check_version()
  145. run(FLAGS, cfg)
  146. if __name__ == '__main__':
  147. main()