3
0

train.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. # add python path of PadleDetection to sys.path
  20. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  21. if parent_path not in sys.path:
  22. sys.path.append(parent_path)
  23. import time
  24. import numpy as np
  25. import random
  26. import datetime
  27. import six
  28. from collections import deque
  29. from paddle.fluid import profiler
  30. from paddle import fluid
  31. from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
  32. from paddle.fluid.optimizer import ExponentialMovingAverage
  33. import logging
  34. FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
  35. logging.basicConfig(level=logging.INFO, format=FORMAT)
  36. logger = logging.getLogger(__name__)
  37. try:
  38. from ppdet.experimental import mixed_precision_context
  39. from ppdet.core.workspace import load_config, merge_config, create
  40. from ppdet.data.reader import create_reader
  41. from ppdet.utils import dist_utils
  42. from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results
  43. from ppdet.utils.stats import TrainingStats
  44. from ppdet.utils.cli import ArgsParser
  45. from ppdet.utils.check import check_gpu, check_xpu, check_npu, check_version, check_config, enable_static_mode
  46. import ppdet.utils.checkpoint as checkpoint
  47. except ImportError as e:
  48. if sys.argv[0].find('static') >= 0:
  49. logger.error("Importing ppdet failed when running static model "
  50. "with error: {}\n"
  51. "please try:\n"
  52. "\t1. run static model under PaddleDetection/static "
  53. "directory\n"
  54. "\t2. run 'pip uninstall ppdet' to uninstall ppdet "
  55. "dynamic version firstly.".format(e))
  56. sys.exit(-1)
  57. else:
  58. raise e
  59. def main():
  60. env = os.environ
  61. FLAGS.dist = 'PADDLE_TRAINER_ID' in env \
  62. and 'PADDLE_TRAINERS_NUM' in env \
  63. and int(env['PADDLE_TRAINERS_NUM']) > 1
  64. num_trainers = int(env.get('PADDLE_TRAINERS_NUM', 1))
  65. if FLAGS.dist:
  66. trainer_id = int(env['PADDLE_TRAINER_ID'])
  67. local_seed = (99 + trainer_id)
  68. random.seed(local_seed)
  69. np.random.seed(local_seed)
  70. if FLAGS.enable_ce:
  71. random.seed(0)
  72. np.random.seed(0)
  73. cfg = load_config(FLAGS.config)
  74. merge_config(FLAGS.opt)
  75. check_config(cfg)
  76. # check if set use_gpu=True in paddlepaddle cpu version
  77. check_gpu(cfg.use_gpu)
  78. # disable npu in config by default and check use_npu
  79. if 'use_npu' not in cfg:
  80. cfg.use_npu = False
  81. check_npu(cfg.use_npu)
  82. use_xpu = False
  83. if hasattr(cfg, 'use_xpu'):
  84. check_xpu(cfg.use_xpu)
  85. use_xpu = cfg.use_xpu
  86. # check if paddlepaddle version is satisfied
  87. check_version()
  88. assert not (use_xpu and cfg.use_gpu), \
  89. 'Can not run on both XPU and GPU'
  90. assert not (cfg.use_npu and cfg.use_gpu), \
  91. 'Can not run on both NPU and GPU'
  92. save_only = getattr(cfg, 'save_prediction_only', False)
  93. if save_only:
  94. raise NotImplementedError('The config file only support prediction,'
  95. ' training stage is not implemented now')
  96. main_arch = cfg.architecture
  97. if cfg.use_gpu:
  98. devices_num = fluid.core.get_cuda_device_count()
  99. elif cfg.use_npu:
  100. devices_num = fluid.core.get_npu_device_count()
  101. elif use_xpu:
  102. # ToDo(qingshu): XPU only support single card now
  103. devices_num = 1
  104. else:
  105. devices_num = int(os.environ.get('CPU_NUM', 1))
  106. if cfg.use_gpu and 'FLAGS_selected_gpus' in env:
  107. device_id = int(env['FLAGS_selected_gpus'])
  108. elif cfg.use_npu and 'FLAGS_selected_npus' in env:
  109. device_id = int(env['FLAGS_selected_npus'])
  110. elif use_xpu and 'FLAGS_selected_xpus' in env:
  111. device_id = int(env['FLAGS_selected_xpus'])
  112. else:
  113. device_id = 0
  114. if cfg.use_gpu:
  115. place = fluid.CUDAPlace(device_id)
  116. elif cfg.use_npu:
  117. place = fluid.NPUPlace(device_id)
  118. elif use_xpu:
  119. place = fluid.XPUPlace(device_id)
  120. else:
  121. place = fluid.CPUPlace()
  122. exe = fluid.Executor(place)
  123. lr_builder = create('LearningRate')
  124. optim_builder = create('OptimizerBuilder')
  125. # build program
  126. startup_prog = fluid.Program()
  127. train_prog = fluid.Program()
  128. if FLAGS.enable_ce:
  129. startup_prog.random_seed = 1000
  130. train_prog.random_seed = 1000
  131. with fluid.program_guard(train_prog, startup_prog):
  132. with fluid.unique_name.guard():
  133. model = create(main_arch)
  134. if FLAGS.fp16:
  135. assert (getattr(model.backbone, 'norm_type', None)
  136. != 'affine_channel'), \
  137. '--fp16 currently does not support affine channel, ' \
  138. ' please modify backbone settings to use batch norm'
  139. with mixed_precision_context(FLAGS.loss_scale, FLAGS.fp16) as ctx:
  140. inputs_def = cfg['TrainReader']['inputs_def']
  141. feed_vars, train_loader = model.build_inputs(**inputs_def)
  142. train_fetches = model.train(feed_vars)
  143. loss = train_fetches['loss']
  144. if FLAGS.fp16:
  145. loss *= ctx.get_loss_scale_var()
  146. lr = lr_builder()
  147. optimizer = optim_builder(lr)
  148. optimizer.minimize(loss)
  149. if FLAGS.fp16:
  150. loss /= ctx.get_loss_scale_var()
  151. if 'use_ema' in cfg and cfg['use_ema']:
  152. global_steps = _decay_step_counter()
  153. ema = ExponentialMovingAverage(
  154. cfg['ema_decay'], thres_steps=global_steps)
  155. ema.update()
  156. # parse train fetches
  157. train_keys, train_values, _ = parse_fetches(train_fetches)
  158. train_values.append(lr)
  159. if FLAGS.eval:
  160. eval_prog = fluid.Program()
  161. with fluid.program_guard(eval_prog, startup_prog):
  162. with fluid.unique_name.guard():
  163. model = create(main_arch)
  164. inputs_def = cfg['EvalReader']['inputs_def']
  165. feed_vars, eval_loader = model.build_inputs(**inputs_def)
  166. fetches = model.eval(feed_vars)
  167. eval_prog = eval_prog.clone(True)
  168. eval_reader = create_reader(cfg.EvalReader, devices_num=1)
  169. # When iterable mode, set set_sample_list_generator(eval_reader, place)
  170. eval_loader.set_sample_list_generator(eval_reader)
  171. # parse eval fetches
  172. extra_keys = []
  173. if cfg.metric == 'COCO':
  174. extra_keys = ['im_info', 'im_id', 'im_shape']
  175. if cfg.metric == 'VOC':
  176. extra_keys = ['gt_bbox', 'gt_class', 'is_difficult']
  177. if cfg.metric == 'WIDERFACE':
  178. extra_keys = ['im_id', 'im_shape', 'gt_bbox']
  179. eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog,
  180. extra_keys)
  181. # compile program for multi-devices
  182. build_strategy = fluid.BuildStrategy()
  183. build_strategy.fuse_all_optimizer_ops = False
  184. # only enable sync_bn in multi GPU devices
  185. sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn'
  186. build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \
  187. and cfg.use_gpu
  188. exec_strategy = fluid.ExecutionStrategy()
  189. # iteration number when CompiledProgram tries to drop local execution scopes.
  190. # Set it to be 1 to save memory usages, so that unused variables in
  191. # local execution scopes can be deleted after each iteration.
  192. exec_strategy.num_iteration_per_drop_scope = 1
  193. if FLAGS.dist:
  194. dist_utils.prepare_for_multi_process(exe, build_strategy, startup_prog,
  195. train_prog)
  196. exec_strategy.num_threads = 1
  197. exe.run(startup_prog)
  198. compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
  199. loss_name=loss.name,
  200. build_strategy=build_strategy,
  201. exec_strategy=exec_strategy)
  202. if use_xpu or cfg.use_npu:
  203. compiled_train_prog = train_prog
  204. if FLAGS.eval:
  205. compiled_eval_prog = fluid.CompiledProgram(eval_prog)
  206. if use_xpu or cfg.use_npu:
  207. compiled_eval_prog = eval_prog
  208. fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'
  209. ignore_params = cfg.finetune_exclude_pretrained_params \
  210. if 'finetune_exclude_pretrained_params' in cfg else []
  211. start_iter = 0
  212. if FLAGS.resume_checkpoint:
  213. checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint)
  214. start_iter = checkpoint.global_step()
  215. elif cfg.pretrain_weights and fuse_bn and not ignore_params:
  216. checkpoint.load_and_fusebn(exe, train_prog, cfg.pretrain_weights)
  217. elif cfg.pretrain_weights:
  218. checkpoint.load_params(
  219. exe, train_prog, cfg.pretrain_weights, ignore_params=ignore_params)
  220. train_reader = create_reader(
  221. cfg.TrainReader, (cfg.max_iters - start_iter) * devices_num,
  222. cfg,
  223. devices_num=devices_num,
  224. num_trainers=num_trainers)
  225. # When iterable mode, set set_sample_list_generator(train_reader, place)
  226. train_loader.set_sample_list_generator(train_reader)
  227. # whether output bbox is normalized in model output layer
  228. is_bbox_normalized = False
  229. if hasattr(model, 'is_bbox_normalized') and \
  230. callable(model.is_bbox_normalized):
  231. is_bbox_normalized = model.is_bbox_normalized()
  232. # if map_type not set, use default 11point, only use in VOC eval
  233. map_type = cfg.map_type if 'map_type' in cfg else '11point'
  234. train_stats = TrainingStats(cfg.log_iter, train_keys)
  235. train_loader.start()
  236. start_time = time.time()
  237. end_time = time.time()
  238. cfg_name = os.path.basename(FLAGS.config).split('.')[0]
  239. save_dir = os.path.join(cfg.save_dir, cfg_name)
  240. time_stat = deque(maxlen=cfg.log_iter)
  241. best_box_ap_list = [0.0, 0] #[map, iter]
  242. # use VisualDL to log data
  243. if FLAGS.use_vdl:
  244. assert six.PY3, "VisualDL requires Python >= 3.5"
  245. from visualdl import LogWriter
  246. vdl_writer = LogWriter(FLAGS.vdl_log_dir)
  247. vdl_loss_step = 0
  248. vdl_mAP_step = 0
  249. for it in range(start_iter, cfg.max_iters):
  250. start_time = end_time
  251. end_time = time.time()
  252. time_stat.append(end_time - start_time)
  253. time_cost = np.mean(time_stat)
  254. eta_sec = (cfg.max_iters - it) * time_cost
  255. eta = str(datetime.timedelta(seconds=int(eta_sec)))
  256. outs = exe.run(compiled_train_prog, fetch_list=train_values)
  257. stats = {k: np.array(v).mean() for k, v in zip(train_keys, outs[:-1])}
  258. # use vdl-paddle to log loss
  259. if FLAGS.use_vdl:
  260. if it % cfg.log_iter == 0:
  261. for loss_name, loss_value in stats.items():
  262. vdl_writer.add_scalar(loss_name, loss_value, vdl_loss_step)
  263. vdl_loss_step += 1
  264. train_stats.update(stats)
  265. logs = train_stats.log()
  266. if it % cfg.log_iter == 0 and (not FLAGS.dist or trainer_id == 0):
  267. ips = float(cfg['TrainReader']['batch_size']) / time_cost
  268. strs = 'iter: {}, lr: {:.6f}, {}, eta: {}, batch_cost: {:.5f} sec, ips: {:.5f} images/sec'.format(
  269. it, np.mean(outs[-1]), logs, eta, time_cost, ips)
  270. logger.info(strs)
  271. # NOTE : profiler tools, used for benchmark
  272. if FLAGS.is_profiler and it == 5:
  273. profiler.start_profiler("All")
  274. elif FLAGS.is_profiler and it == 10:
  275. profiler.stop_profiler("total", FLAGS.profiler_path)
  276. return
  277. if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \
  278. and (not FLAGS.dist or trainer_id == 0):
  279. save_name = str(it) if it != cfg.max_iters - 1 else "model_final"
  280. if 'use_ema' in cfg and cfg['use_ema']:
  281. exe.run(ema.apply_program)
  282. checkpoint.save(exe, train_prog, os.path.join(save_dir, save_name))
  283. if FLAGS.eval:
  284. # evaluation
  285. resolution = None
  286. if 'Mask' in cfg.architecture:
  287. resolution = model.mask_head.resolution
  288. results = eval_run(
  289. exe,
  290. compiled_eval_prog,
  291. eval_loader,
  292. eval_keys,
  293. eval_values,
  294. eval_cls,
  295. cfg,
  296. resolution=resolution)
  297. box_ap_stats = eval_results(
  298. results, cfg.metric, cfg.num_classes, resolution,
  299. is_bbox_normalized, FLAGS.output_eval, map_type,
  300. cfg['EvalReader']['dataset'])
  301. # use vdl_paddle to log mAP
  302. if FLAGS.use_vdl:
  303. vdl_writer.add_scalar("mAP", box_ap_stats[0], vdl_mAP_step)
  304. vdl_mAP_step += 1
  305. if box_ap_stats[0] > best_box_ap_list[0]:
  306. best_box_ap_list[0] = box_ap_stats[0]
  307. best_box_ap_list[1] = it
  308. checkpoint.save(exe, train_prog,
  309. os.path.join(save_dir, "best_model"))
  310. logger.info("Best test box ap: {}, in iter: {}".format(
  311. best_box_ap_list[0], best_box_ap_list[1]))
  312. if 'use_ema' in cfg and cfg['use_ema']:
  313. exe.run(ema.restore_program)
  314. train_loader.reset()
  315. if __name__ == '__main__':
  316. enable_static_mode()
  317. parser = ArgsParser()
  318. parser.add_argument(
  319. "-r",
  320. "--resume_checkpoint",
  321. default=None,
  322. type=str,
  323. help="Checkpoint path for resuming training.")
  324. parser.add_argument(
  325. "--fp16",
  326. action='store_true',
  327. default=False,
  328. help="Enable mixed precision training.")
  329. parser.add_argument(
  330. "--loss_scale",
  331. default=8.,
  332. type=float,
  333. help="Mixed precision training loss scale.")
  334. parser.add_argument(
  335. "--eval",
  336. action='store_true',
  337. default=False,
  338. help="Whether to perform evaluation in train")
  339. parser.add_argument(
  340. "--output_eval",
  341. default=None,
  342. type=str,
  343. help="Evaluation directory, default is current directory.")
  344. parser.add_argument(
  345. "--use_vdl",
  346. type=bool,
  347. default=False,
  348. help="whether to record the data to VisualDL.")
  349. parser.add_argument(
  350. '--vdl_log_dir',
  351. type=str,
  352. default="vdl_log_dir/scalar",
  353. help='VisualDL logging directory for scalar.')
  354. parser.add_argument(
  355. "--enable_ce",
  356. type=bool,
  357. default=False,
  358. help="If set True, enable continuous evaluation job."
  359. "This flag is only used for internal test.")
  360. #NOTE:args for profiler tools, used for benchmark
  361. parser.add_argument(
  362. '--is_profiler',
  363. type=int,
  364. default=0,
  365. help='The switch of profiler tools. (used for benchmark)')
  366. parser.add_argument(
  367. '--profiler_path',
  368. type=str,
  369. default="./detection.profiler",
  370. help='The profiler output file path. (used for benchmark)')
  371. FLAGS = parser.parse_args()
  372. main()