train.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. # add python path of PadleDetection to sys.path
  20. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  21. sys.path.insert(0, parent_path)
  22. # ignore warning log
  23. import warnings
  24. warnings.filterwarnings('ignore')
  25. import paddle
  26. from ppdet.core.workspace import load_config, merge_config
  27. from ppdet.engine import Trainer, init_parallel_env, set_random_seed, init_fleet_env
  28. from ppdet.slim import build_slim_model
  29. import ppdet.utils.cli as cli
  30. import ppdet.utils.check as check
  31. from ppdet.utils.logger import setup_logger
  32. logger = setup_logger('train')
  33. def parse_args():
  34. parser = cli.ArgsParser()
  35. parser.add_argument(
  36. "--eval",
  37. action='store_true',
  38. default=False,
  39. help="Whether to perform evaluation in train")
  40. parser.add_argument(
  41. "-r", "--resume", default=None, help="weights path for resume")
  42. parser.add_argument(
  43. "--slim_config",
  44. default=None,
  45. type=str,
  46. help="Configuration file of slim method.")
  47. parser.add_argument(
  48. "--enable_ce",
  49. type=bool,
  50. default=False,
  51. help="If set True, enable continuous evaluation job."
  52. "This flag is only used for internal test.")
  53. parser.add_argument(
  54. "--amp",
  55. action='store_true',
  56. default=False,
  57. help="Enable auto mixed precision training.")
  58. parser.add_argument(
  59. "--fleet", action='store_true', default=False, help="Use fleet or not")
  60. parser.add_argument(
  61. "--use_vdl",
  62. type=bool,
  63. default=False,
  64. help="whether to record the data to VisualDL.")
  65. parser.add_argument(
  66. '--vdl_log_dir',
  67. type=str,
  68. default="vdl_log_dir/scalar",
  69. help='VisualDL logging directory for scalar.')
  70. parser.add_argument(
  71. '--save_prediction_only',
  72. action='store_true',
  73. default=False,
  74. help='Whether to save the evaluation results only')
  75. parser.add_argument(
  76. '--profiler_options',
  77. type=str,
  78. default=None,
  79. help="The option of profiler, which should be in "
  80. "format \"key1=value1;key2=value2;key3=value3\"."
  81. "please see ppdet/utils/profiler.py for detail.")
  82. parser.add_argument(
  83. '--save_proposals',
  84. action='store_true',
  85. default=False,
  86. help='Whether to save the train proposals')
  87. parser.add_argument(
  88. '--proposals_path',
  89. type=str,
  90. default="sniper/proposals.json",
  91. help='Train proposals directory')
  92. args = parser.parse_args()
  93. return args
  94. def run(FLAGS, cfg):
  95. # init fleet environment
  96. if cfg.fleet:
  97. init_fleet_env(cfg.get('find_unused_parameters', False))
  98. else:
  99. # init parallel environment if nranks > 1
  100. init_parallel_env()
  101. if FLAGS.enable_ce:
  102. set_random_seed(0)
  103. # build trainer
  104. trainer = Trainer(cfg, mode='train')
  105. # load weights
  106. if FLAGS.resume is not None:
  107. trainer.resume_weights(FLAGS.resume)
  108. elif 'pretrain_weights' in cfg and cfg.pretrain_weights:
  109. trainer.load_weights(cfg.pretrain_weights)
  110. # training
  111. trainer.train(FLAGS.eval)
  112. def main():
  113. FLAGS = parse_args()
  114. cfg = load_config(FLAGS.config)
  115. cfg['amp'] = FLAGS.amp
  116. cfg['fleet'] = FLAGS.fleet
  117. cfg['use_vdl'] = FLAGS.use_vdl
  118. cfg['vdl_log_dir'] = FLAGS.vdl_log_dir
  119. cfg['save_prediction_only'] = FLAGS.save_prediction_only
  120. cfg['profiler_options'] = FLAGS.profiler_options
  121. cfg['save_proposals'] = FLAGS.save_proposals
  122. cfg['proposals_path'] = FLAGS.proposals_path
  123. merge_config(FLAGS.opt)
  124. # disable npu in config by default
  125. if 'use_npu' not in cfg:
  126. cfg.use_npu = False
  127. # disable xpu in config by default
  128. if 'use_xpu' not in cfg:
  129. cfg.use_xpu = False
  130. if cfg.use_gpu:
  131. place = paddle.set_device('gpu')
  132. elif cfg.use_npu:
  133. place = paddle.set_device('npu')
  134. elif cfg.use_xpu:
  135. place = paddle.set_device('xpu')
  136. else:
  137. place = paddle.set_device('cpu')
  138. if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu:
  139. cfg['norm_type'] = 'bn'
  140. if FLAGS.slim_config:
  141. cfg = build_slim_model(cfg, FLAGS.slim_config)
  142. # FIXME: Temporarily solve the priority problem of FLAGS.opt
  143. merge_config(FLAGS.opt)
  144. check.check_config(cfg)
  145. check.check_gpu(cfg.use_gpu)
  146. check.check_npu(cfg.use_npu)
  147. check.check_version()
  148. run(FLAGS, cfg)
  149. if __name__ == "__main__":
  150. main()