eval.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. # add python path of PadleDetection to sys.path
  20. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  21. sys.path.insert(0, parent_path)
  22. # ignore warning log
  23. import warnings
  24. warnings.filterwarnings('ignore')
  25. import paddle
  26. from ppdet.core.workspace import load_config, merge_config
  27. from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_version, check_config
  28. from ppdet.utils.cli import ArgsParser
  29. from ppdet.engine import Trainer, init_parallel_env
  30. from ppdet.metrics.coco_utils import json_eval_results
  31. from ppdet.slim import build_slim_model
  32. from ppdet.utils.logger import setup_logger
  33. logger = setup_logger('eval')
  34. def parse_args():
  35. parser = ArgsParser()
  36. parser.add_argument(
  37. "--output_eval",
  38. default=None,
  39. type=str,
  40. help="Evaluation directory, default is current directory.")
  41. parser.add_argument(
  42. '--json_eval',
  43. action='store_true',
  44. default=False,
  45. help='Whether to re eval with already exists bbox.json or mask.json')
  46. parser.add_argument(
  47. "--slim_config",
  48. default=None,
  49. type=str,
  50. help="Configuration file of slim method.")
  51. # TODO: bias should be unified
  52. parser.add_argument(
  53. "--bias",
  54. action="store_true",
  55. help="whether add bias or not while getting w and h")
  56. parser.add_argument(
  57. "--classwise",
  58. action="store_true",
  59. help="whether per-category AP and draw P-R Curve or not.")
  60. parser.add_argument(
  61. '--save_prediction_only',
  62. action='store_true',
  63. default=False,
  64. help='Whether to save the evaluation results only')
  65. args = parser.parse_args()
  66. return args
  67. def run(FLAGS, cfg):
  68. if FLAGS.json_eval:
  69. logger.info(
  70. "In json_eval mode, PaddleDetection will evaluate json files in "
  71. "output_eval directly. And proposal.json, bbox.json and mask.json "
  72. "will be detected by default.")
  73. json_eval_results(
  74. cfg.metric,
  75. json_directory=FLAGS.output_eval,
  76. dataset=cfg['EvalDataset'])
  77. return
  78. # init parallel environment if nranks > 1
  79. init_parallel_env()
  80. # build trainer
  81. trainer = Trainer(cfg, mode='eval')
  82. # load weights
  83. trainer.load_weights(cfg.weights)
  84. # training
  85. trainer.evaluate()
  86. def main():
  87. FLAGS = parse_args()
  88. cfg = load_config(FLAGS.config)
  89. # TODO: bias should be unified
  90. cfg['bias'] = 1 if FLAGS.bias else 0
  91. cfg['classwise'] = True if FLAGS.classwise else False
  92. cfg['output_eval'] = FLAGS.output_eval
  93. cfg['save_prediction_only'] = FLAGS.save_prediction_only
  94. merge_config(FLAGS.opt)
  95. # disable npu in config by default
  96. if 'use_npu' not in cfg:
  97. cfg.use_npu = False
  98. # disable xpu in config by default
  99. if 'use_xpu' not in cfg:
  100. cfg.use_xpu = False
  101. if cfg.use_gpu:
  102. place = paddle.set_device('gpu')
  103. elif cfg.use_npu:
  104. place = paddle.set_device('npu')
  105. elif cfg.use_xpu:
  106. place = paddle.set_device('xpu')
  107. else:
  108. place = paddle.set_device('cpu')
  109. if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu:
  110. cfg['norm_type'] = 'bn'
  111. if FLAGS.slim_config:
  112. cfg = build_slim_model(cfg, FLAGS.slim_config, mode='eval')
  113. check_config(cfg)
  114. check_gpu(cfg.use_gpu)
  115. check_npu(cfg.use_npu)
  116. check_xpu(cfg.use_xpu)
  117. check_version()
  118. run(FLAGS, cfg)
  119. if __name__ == '__main__':
  120. main()