3
0

configure.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import print_function
  15. import os
  16. import sys
  17. from argparse import ArgumentParser, RawDescriptionHelpFormatter
  18. # add python path of PadleDetection to sys.path
  19. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  20. if parent_path not in sys.path:
  21. sys.path.append(parent_path)
  22. import logging
  23. FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
  24. logging.basicConfig(level=logging.INFO, format=FORMAT)
  25. logger = logging.getLogger(__name__)
  26. import yaml
  27. try:
  28. from ppdet.core.workspace import get_registered_modules, load_config, dump_value
  29. from ppdet.utils.cli import ColorTTY, print_total_cfg
  30. except ImportError as e:
  31. if sys.argv[0].find('static') >= 0:
  32. logger.error("Importing ppdet failed when running static model "
  33. "with error: {}\n"
  34. "please try:\n"
  35. "\t1. run static model under PaddleDetection/static "
  36. "directory\n"
  37. "\t2. run 'pip uninstall ppdet' to uninstall ppdet "
  38. "dynamic version firstly.".format(e))
  39. sys.exit(-1)
  40. else:
  41. raise e
  42. color_tty = ColorTTY()
  43. try:
  44. from docstring_parser import parse as doc_parse
  45. except Exception:
  46. message = "docstring_parser is not installed, " \
  47. + "argument description is not available"
  48. print(color_tty.yellow(message))
  49. try:
  50. from typeguard import check_type
  51. except Exception:
  52. message = "typeguard is not installed," \
  53. + "type checking is not available"
  54. print(color_tty.yellow(message))
  55. MISC_CONFIG = {
  56. "architecture": "<value>",
  57. "max_iters": "<value>",
  58. "train_feed": "<value>",
  59. "eval_feed": "<value>",
  60. "test_feed": "<value>",
  61. "pretrain_weights": "<value>",
  62. "save_dir": "<value>",
  63. "weights": "<value>",
  64. "metric": "<value>",
  65. "map_type": "11point",
  66. "snapshot_iter": 10000,
  67. "log_iter": 20,
  68. "use_gpu": True,
  69. "finetune_exclude_pretrained_params": "<value>",
  70. }
  71. def dump_config(module, minimal=False):
  72. args = module.schema.values()
  73. if minimal:
  74. args = [arg for arg in args if not arg.has_default()]
  75. return yaml.dump(
  76. {
  77. module.name: {
  78. arg.name: arg.default if arg.has_default() else "<value>"
  79. for arg in args
  80. }
  81. },
  82. default_flow_style=False,
  83. default_style='')
  84. def list_modules(**kwargs):
  85. target_category = kwargs['category']
  86. module_schema = get_registered_modules()
  87. module_by_category = {}
  88. for schema in module_schema.values():
  89. category = schema.category
  90. if target_category is not None and schema.category != target_category:
  91. continue
  92. if category not in module_by_category:
  93. module_by_category[category] = [schema]
  94. else:
  95. module_by_category[category].append(schema)
  96. for cat, modules in module_by_category.items():
  97. print("Available modules in the category '{}':".format(cat))
  98. print("")
  99. max_len = max([len(mod.name) for mod in modules])
  100. for mod in modules:
  101. print(
  102. color_tty.green(mod.name.ljust(max_len)),
  103. mod.doc.split('\n')[0])
  104. print("")
  105. def help_module(**kwargs):
  106. schema = get_registered_modules()[kwargs['module']]
  107. doc = schema.doc is None and "Not documented" or "{}".format(schema.doc)
  108. func_args = {arg.name: arg.doc for arg in schema.schema.values()}
  109. max_len = max([len(k) for k in func_args.keys()])
  110. opts = "\n".join([
  111. "{} {}".format(color_tty.green(k.ljust(max_len)), v)
  112. for k, v in func_args.items()
  113. ])
  114. template = dump_config(schema)
  115. print("{}\n\n{}\n\n{}\n\n{}\n\n{}\n\n{}\n{}\n".format(
  116. color_tty.bold(color_tty.blue("MODULE DESCRIPTION:")),
  117. doc,
  118. color_tty.bold(color_tty.blue("MODULE OPTIONS:")),
  119. opts,
  120. color_tty.bold(color_tty.blue("CONFIGURATION TEMPLATE:")),
  121. template,
  122. color_tty.bold(color_tty.blue("COMMAND LINE OPTIONS:")), ))
  123. for arg in schema.schema.values():
  124. print("--opt {}.{}={}".format(schema.name, arg.name,
  125. dump_value(arg.default)
  126. if arg.has_default() else "<value>"))
  127. def generate_config(**kwargs):
  128. minimal = kwargs['minimal']
  129. modules = kwargs['modules']
  130. module_schema = get_registered_modules()
  131. visited = []
  132. schema = []
  133. def walk(m):
  134. if m in visited:
  135. return
  136. s = module_schema[m]
  137. schema.append(s)
  138. visited.append(m)
  139. for mod in modules:
  140. walk(mod)
  141. # XXX try to be smart about when to add header,
  142. # if any "architecture" module, is included, head will be added as well
  143. if any([getattr(m, 'category', None) == 'architecture' for m in schema]):
  144. # XXX for ordered printing
  145. header = ""
  146. for k, v in MISC_CONFIG.items():
  147. header += yaml.dump(
  148. {
  149. k: v
  150. }, default_flow_style=False, default_style='')
  151. print(header)
  152. for s in schema:
  153. print(dump_config(s, minimal))
  154. # FIXME this is pretty hackish, maybe implement a custom YAML printer?
  155. def analyze_config(**kwargs):
  156. config = load_config(kwargs['file'])
  157. print_total_cfg(config)
  158. if __name__ == '__main__':
  159. argv = sys.argv[1:]
  160. parser = ArgumentParser(formatter_class=RawDescriptionHelpFormatter)
  161. subparsers = parser.add_subparsers(help='Supported Commands')
  162. list_parser = subparsers.add_parser("list", help="list available modules")
  163. help_parser = subparsers.add_parser(
  164. "help", help="show detail options for module")
  165. generate_parser = subparsers.add_parser(
  166. "generate", help="generate configuration template")
  167. analyze_parser = subparsers.add_parser(
  168. "analyze", help="analyze configuration file")
  169. list_parser.set_defaults(func=list_modules)
  170. help_parser.set_defaults(func=help_module)
  171. generate_parser.set_defaults(func=generate_config)
  172. analyze_parser.set_defaults(func=analyze_config)
  173. list_group = list_parser.add_mutually_exclusive_group()
  174. list_group.add_argument(
  175. "-c",
  176. "--category",
  177. type=str,
  178. default=None,
  179. help="list modules for <category>")
  180. help_parser.add_argument(
  181. "module",
  182. help="module to show info for",
  183. choices=list(get_registered_modules().keys()))
  184. generate_parser.add_argument(
  185. "modules",
  186. nargs='+',
  187. help="include these module in generated configuration template",
  188. choices=list(get_registered_modules().keys()))
  189. generate_group = generate_parser.add_mutually_exclusive_group()
  190. generate_group.add_argument(
  191. "--minimal", action='store_true', help="only include required options")
  192. generate_group.add_argument(
  193. "--full",
  194. action='store_false',
  195. dest='minimal',
  196. help="include all options")
  197. analyze_parser.add_argument("file", help="configuration file to analyze")
  198. if len(sys.argv) < 2:
  199. parser.print_help()
  200. sys.exit(1)
  201. args = parser.parse_args(argv)
  202. if hasattr(args, 'func'):
  203. args.func(**vars(args))