check.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import sys
  18. import paddle
  19. import paddle.fluid as fluid
  20. import logging
  21. import six
  22. import paddle.version as fluid_version
  23. logger = logging.getLogger(__name__)
  24. __all__ = [
  25. 'check_gpu',
  26. 'check_xpu',
  27. 'check_npu',
  28. 'check_version',
  29. 'check_config',
  30. 'check_py_func',
  31. ]
  32. def check_xpu(use_xpu):
  33. """
  34. Log error and exit when set use_xpu=true in paddlepaddle
  35. cpu/gpu version.
  36. """
  37. err = "Config use_xpu cannot be set as true while you are " \
  38. "using paddlepaddle cpu/gpu version ! \nPlease try: \n" \
  39. "\t1. Install paddlepaddle-xpu to run model on XPU \n" \
  40. "\t2. Set use_xpu as false in config file to run " \
  41. "model on CPU/GPU"
  42. try:
  43. if use_xpu and not fluid.is_compiled_with_xpu():
  44. logger.error(err)
  45. sys.exit(1)
  46. except Exception as e:
  47. pass
  48. def check_npu(use_npu):
  49. """
  50. Log error and exit when set use_npu=true in paddlepaddle
  51. cpu/gpu/xpu version.
  52. """
  53. err = "Config use_npu cannot be set as true while you are " \
  54. "using paddlepaddle cpu/gpu/xpu version ! \nPlease try: \n" \
  55. "\t1. Install paddlepaddle-npu to run model on NPU \n" \
  56. "\t2. Set use_npu as false in config file to run " \
  57. "model on CPU/GPU/XPU"
  58. try:
  59. if use_npu and not fluid.is_compiled_with_npu():
  60. logger.error(err)
  61. sys.exit(1)
  62. except Exception as e:
  63. pass
  64. def check_gpu(use_gpu):
  65. """
  66. Log error and exit when set use_gpu=true in paddlepaddle
  67. cpu version.
  68. """
  69. err = "Config use_gpu cannot be set as true while you are " \
  70. "using paddlepaddle cpu version ! \nPlease try: \n" \
  71. "\t1. Install paddlepaddle-gpu to run model on GPU \n" \
  72. "\t2. Set use_gpu as false in config file to run " \
  73. "model on CPU"
  74. try:
  75. if use_gpu and not fluid.is_compiled_with_cuda():
  76. logger.error(err)
  77. sys.exit(1)
  78. except Exception as e:
  79. pass
  80. def check_version(version='1.7.0'):
  81. """
  82. Log error and exit when the installed version of paddlepaddle is
  83. not satisfied.
  84. """
  85. err = "PaddlePaddle version {} or higher is required, " \
  86. "or a suitable develop version is satisfied as well. \n" \
  87. "Please make sure the version is good with your code.".format(version)
  88. version_installed = [
  89. fluid_version.major, fluid_version.minor, fluid_version.patch,
  90. fluid_version.rc
  91. ]
  92. if version_installed == ['0', '0', '0', '0']:
  93. return
  94. version_split = version.split('.')
  95. length = min(len(version_installed), len(version_split))
  96. for i in six.moves.range(length):
  97. if version_installed[i] > version_split[i]:
  98. return
  99. if len(version_installed[i]) == 1 and len(version_split[i]) > 1:
  100. return
  101. if version_installed[i] < version_split[i]:
  102. raise Exception(err)
  103. def check_config(cfg):
  104. """
  105. Check the correctness of the configuration file. Log error and exit
  106. when Config is not compliant.
  107. """
  108. err = "'{}' not specified in config file. Please set it in config file."
  109. check_list = ['architecture', 'num_classes']
  110. try:
  111. for var in check_list:
  112. if not var in cfg:
  113. logger.error(err.format(var))
  114. sys.exit(1)
  115. except Exception as e:
  116. pass
  117. if 'log_iter' not in cfg:
  118. cfg.log_iter = 20
  119. train_dataset = cfg['TrainReader']['dataset']
  120. eval_dataset = cfg['EvalReader']['dataset']
  121. test_dataset = cfg['TestReader']['dataset']
  122. assert train_dataset.with_background == eval_dataset.with_background, \
  123. "'with_background' of TrainReader is not equal to EvalReader."
  124. assert train_dataset.with_background == test_dataset.with_background, \
  125. "'with_background' of TrainReader is not equal to TestReader."
  126. actual_num_classes = int(cfg.num_classes) - int(
  127. train_dataset.with_background)
  128. logger.debug("The 'num_classes'(number of classes) you set is {}, " \
  129. "and 'with_background' in 'dataset' sets {}.\n" \
  130. "So please note the actual number of categories is {}."
  131. .format(cfg.num_classes, train_dataset.with_background,
  132. actual_num_classes))
  133. return cfg
  134. def check_py_func(program):
  135. for block in program.blocks:
  136. for op in block.ops:
  137. if op.type == 'py_func':
  138. input_arg = op.input_arg_names
  139. output_arg = op.output_arg_names
  140. err = "The program contains py_func with input: {}, "\
  141. "output: {}. It is not supported in Paddle inference "\
  142. "engine. please replace it by paddle ops. For example, "\
  143. "if you use MultiClassSoftNMS, better to replace it by "\
  144. "MultiClassNMS.".format(input_arg, output_arg)
  145. raise Exception(err)
  146. def enable_static_mode():
  147. try:
  148. paddle.enable_static()
  149. except:
  150. pass