train_multi_machine.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. # add python path of PadleDetection to sys.path
  20. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  21. if parent_path not in sys.path:
  22. sys.path.append(parent_path)
  23. import logging
  24. FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
  25. logging.basicConfig(level=logging.INFO, format=FORMAT)
  26. logger = logging.getLogger(__name__)
  27. import time
  28. import numpy as np
  29. import random
  30. import datetime
  31. import six
  32. from collections import deque
  33. from paddle.fluid import profiler
  34. from paddle import fluid
  35. from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
  36. from paddle.fluid.optimizer import ExponentialMovingAverage
  37. from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy # new line 1
  38. from paddle.fluid.incubate.fleet.base import role_maker # new line 2
  39. try:
  40. from ppdet.experimental import mixed_precision_context
  41. from ppdet.core.workspace import load_config, merge_config, create
  42. from ppdet.data.reader import create_reader
  43. from ppdet.utils import dist_utils
  44. from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results
  45. from ppdet.utils.stats import TrainingStats
  46. from ppdet.utils.cli import ArgsParser
  47. from ppdet.utils.check import check_gpu, check_version, check_config, enable_static_mode
  48. import ppdet.utils.checkpoint as checkpoint
  49. except ImportError as e:
  50. if sys.argv[0].find('static') >= 0:
  51. logger.error("Importing ppdet failed when running static model "
  52. "with error: {}\n"
  53. "please try:\n"
  54. "\t1. run static model under PaddleDetection/static "
  55. "directory\n"
  56. "\t2. run 'pip uninstall ppdet' to uninstall ppdet "
  57. "dynamic version firstly.".format(e))
  58. sys.exit(-1)
  59. else:
  60. raise e
  61. def main():
  62. role = role_maker.PaddleCloudRoleMaker(is_collective=True) # new line 3
  63. fleet.init(role) # new line 4
  64. env = os.environ
  65. num_trainers = int(env.get('PADDLE_TRAINERS_NUM', 0))
  66. assert num_trainers != 0, "multi-machine training process must be started using distributed.launch..."
  67. trainer_id = int(env.get("PADDLE_TRAINER_ID", 0))
  68. # set different seeds for different trainers
  69. random.seed(trainer_id)
  70. np.random.seed(trainer_id)
  71. if FLAGS.enable_ce:
  72. random.seed(0)
  73. np.random.seed(0)
  74. cfg = load_config(FLAGS.config)
  75. merge_config(FLAGS.opt)
  76. check_config(cfg)
  77. # check if set use_gpu=True in paddlepaddle cpu version
  78. check_gpu(cfg.use_gpu)
  79. # check if paddlepaddle version is satisfied
  80. check_version()
  81. save_only = getattr(cfg, 'save_prediction_only', False)
  82. if save_only:
  83. raise NotImplementedError('The config file only support prediction,'
  84. ' training stage is not implemented now')
  85. main_arch = cfg.architecture
  86. assert cfg.use_gpu == True, "GPU must be supported for multi-machine training..."
  87. devices_num = fluid.core.get_cuda_device_count()
  88. if 'FLAGS_selected_gpus' in env:
  89. device_id = int(env['FLAGS_selected_gpus'])
  90. else:
  91. device_id = 0
  92. place = fluid.CUDAPlace(device_id) if cfg.use_gpu else fluid.CPUPlace()
  93. exe = fluid.Executor(place)
  94. lr_builder = create('LearningRate')
  95. optim_builder = create('OptimizerBuilder')
  96. # build program
  97. startup_prog = fluid.Program()
  98. train_prog = fluid.Program()
  99. if FLAGS.enable_ce:
  100. startup_prog.random_seed = 1000
  101. train_prog.random_seed = 1000
  102. with fluid.program_guard(train_prog, startup_prog):
  103. with fluid.unique_name.guard():
  104. model = create(main_arch)
  105. if FLAGS.fp16:
  106. assert (getattr(model.backbone, 'norm_type', None)
  107. != 'affine_channel'), \
  108. '--fp16 currently does not support affine channel, ' \
  109. ' please modify backbone settings to use batch norm'
  110. with mixed_precision_context(FLAGS.loss_scale, FLAGS.fp16) as ctx:
  111. inputs_def = cfg['TrainReader']['inputs_def']
  112. feed_vars, train_loader = model.build_inputs(**inputs_def)
  113. train_fetches = model.train(feed_vars)
  114. loss = train_fetches['loss']
  115. if FLAGS.fp16:
  116. loss *= ctx.get_loss_scale_var()
  117. lr = lr_builder()
  118. optimizer = optim_builder(lr)
  119. dist_strategy = DistributedStrategy()
  120. sync_bn = getattr(model.backbone, 'norm_type',
  121. None) == 'sync_bn'
  122. dist_strategy.sync_batch_norm = sync_bn
  123. dist_strategy.nccl_comm_num = 1
  124. exec_strategy = fluid.ExecutionStrategy()
  125. exec_strategy.num_threads = 3
  126. exec_strategy.num_iteration_per_drop_scope = 30
  127. dist_strategy.exec_strategy = exec_strategy
  128. dist_strategy.fuse_all_reduce_ops = True
  129. optimizer = fleet.distributed_optimizer(
  130. optimizer, strategy=dist_strategy) # new line 5
  131. optimizer.minimize(loss)
  132. if FLAGS.fp16:
  133. loss /= ctx.get_loss_scale_var()
  134. if 'use_ema' in cfg and cfg['use_ema']:
  135. global_steps = _decay_step_counter()
  136. ema = ExponentialMovingAverage(
  137. cfg['ema_decay'], thres_steps=global_steps)
  138. ema.update()
  139. # parse train fetches
  140. train_keys, train_values, _ = parse_fetches(train_fetches)
  141. train_values.append(lr)
  142. if FLAGS.eval:
  143. eval_prog = fluid.Program()
  144. with fluid.program_guard(eval_prog, startup_prog):
  145. with fluid.unique_name.guard():
  146. model = create(main_arch)
  147. inputs_def = cfg['EvalReader']['inputs_def']
  148. feed_vars, eval_loader = model.build_inputs(**inputs_def)
  149. fetches = model.eval(feed_vars)
  150. eval_prog = eval_prog.clone(True)
  151. eval_reader = create_reader(cfg.EvalReader, devices_num=1)
  152. # When iterable mode, set set_sample_list_generator(eval_reader, place)
  153. eval_loader.set_sample_list_generator(eval_reader)
  154. # parse eval fetches
  155. extra_keys = []
  156. if cfg.metric == 'COCO':
  157. extra_keys = ['im_info', 'im_id', 'im_shape']
  158. if cfg.metric == 'VOC':
  159. extra_keys = ['gt_bbox', 'gt_class', 'is_difficult']
  160. if cfg.metric == 'WIDERFACE':
  161. extra_keys = ['im_id', 'im_shape', 'gt_bbox']
  162. eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog,
  163. extra_keys)
  164. exe.run(startup_prog)
  165. compiled_train_prog = fleet.main_program
  166. if FLAGS.eval:
  167. compiled_eval_prog = fluid.CompiledProgram(eval_prog)
  168. fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'
  169. ignore_params = cfg.finetune_exclude_pretrained_params \
  170. if 'finetune_exclude_pretrained_params' in cfg else []
  171. start_iter = 0
  172. if FLAGS.resume_checkpoint:
  173. checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint)
  174. start_iter = checkpoint.global_step()
  175. elif cfg.pretrain_weights and fuse_bn and not ignore_params:
  176. checkpoint.load_and_fusebn(exe, train_prog, cfg.pretrain_weights)
  177. elif cfg.pretrain_weights:
  178. checkpoint.load_params(
  179. exe, train_prog, cfg.pretrain_weights, ignore_params=ignore_params)
  180. train_reader = create_reader(
  181. cfg.TrainReader, (cfg.max_iters - start_iter) * devices_num,
  182. cfg,
  183. devices_num=devices_num)
  184. # When iterable mode, set set_sample_list_generator(train_reader, place)
  185. train_loader.set_sample_list_generator(train_reader)
  186. # whether output bbox is normalized in model output layer
  187. is_bbox_normalized = False
  188. if hasattr(model, 'is_bbox_normalized') and \
  189. callable(model.is_bbox_normalized):
  190. is_bbox_normalized = model.is_bbox_normalized()
  191. # if map_type not set, use default 11point, only use in VOC eval
  192. map_type = cfg.map_type if 'map_type' in cfg else '11point'
  193. train_stats = TrainingStats(cfg.log_iter, train_keys)
  194. train_loader.start()
  195. start_time = time.time()
  196. end_time = time.time()
  197. cfg_name = os.path.basename(FLAGS.config).split('.')[0]
  198. save_dir = os.path.join(cfg.save_dir, cfg_name)
  199. time_stat = deque(maxlen=cfg.log_iter)
  200. best_box_ap_list = [0.0, 0] #[map, iter]
  201. # use VisualDL to log data
  202. if FLAGS.use_vdl:
  203. assert six.PY3, "VisualDL requires Python >= 3.5"
  204. from visualdl import LogWriter
  205. vdl_writer = LogWriter(FLAGS.vdl_log_dir)
  206. vdl_loss_step = 0
  207. vdl_mAP_step = 0
  208. for it in range(start_iter, cfg.max_iters):
  209. start_time = end_time
  210. end_time = time.time()
  211. time_stat.append(end_time - start_time)
  212. time_cost = np.mean(time_stat)
  213. eta_sec = (cfg.max_iters - it) * time_cost
  214. eta = str(datetime.timedelta(seconds=int(eta_sec)))
  215. outs = exe.run(compiled_train_prog, fetch_list=train_values)
  216. stats = {k: np.array(v).mean() for k, v in zip(train_keys, outs[:-1])}
  217. # use vdl-paddle to log loss
  218. if FLAGS.use_vdl:
  219. if it % cfg.log_iter == 0:
  220. for loss_name, loss_value in stats.items():
  221. vdl_writer.add_scalar(loss_name, loss_value, vdl_loss_step)
  222. vdl_loss_step += 1
  223. train_stats.update(stats)
  224. logs = train_stats.log()
  225. if it % cfg.log_iter == 0 and trainer_id == 0:
  226. strs = 'iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'.format(
  227. it, np.mean(outs[-1]), logs, time_cost, eta)
  228. logger.info(strs)
  229. # NOTE : profiler tools, used for benchmark
  230. if FLAGS.is_profiler and it == 5:
  231. profiler.start_profiler("All")
  232. elif FLAGS.is_profiler and it == 10:
  233. profiler.stop_profiler("total", FLAGS.profiler_path)
  234. return
  235. if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \
  236. and trainer_id == 0:
  237. save_name = str(it) if it != cfg.max_iters - 1 else "model_final"
  238. if 'use_ema' in cfg and cfg['use_ema']:
  239. exe.run(ema.apply_program)
  240. checkpoint.save(exe, train_prog, os.path.join(save_dir, save_name))
  241. if FLAGS.eval:
  242. # evaluation
  243. resolution = None
  244. if 'Mask' in cfg.architecture:
  245. resolution = model.mask_head.resolution
  246. results = eval_run(
  247. exe,
  248. compiled_eval_prog,
  249. eval_loader,
  250. eval_keys,
  251. eval_values,
  252. eval_cls,
  253. cfg,
  254. resolution=resolution)
  255. box_ap_stats = eval_results(
  256. results, cfg.metric, cfg.num_classes, resolution,
  257. is_bbox_normalized, FLAGS.output_eval, map_type,
  258. cfg['EvalReader']['dataset'])
  259. # use vdl_paddle to log mAP
  260. if FLAGS.use_vdl:
  261. vdl_writer.add_scalar("mAP", box_ap_stats[0], vdl_mAP_step)
  262. vdl_mAP_step += 1
  263. if box_ap_stats[0] > best_box_ap_list[0]:
  264. best_box_ap_list[0] = box_ap_stats[0]
  265. best_box_ap_list[1] = it
  266. checkpoint.save(exe, train_prog,
  267. os.path.join(save_dir, "best_model"))
  268. logger.info("Best test box ap: {}, in iter: {}".format(
  269. best_box_ap_list[0], best_box_ap_list[1]))
  270. if 'use_ema' in cfg and cfg['use_ema']:
  271. exe.run(ema.restore_program)
  272. train_loader.reset()
  273. if __name__ == '__main__':
  274. enable_static_mode()
  275. parser = ArgsParser()
  276. parser.add_argument(
  277. "-r",
  278. "--resume_checkpoint",
  279. default=None,
  280. type=str,
  281. help="Checkpoint path for resuming training.")
  282. parser.add_argument(
  283. "--fp16",
  284. action='store_true',
  285. default=False,
  286. help="Enable mixed precision training.")
  287. parser.add_argument(
  288. "--loss_scale",
  289. default=8.,
  290. type=float,
  291. help="Mixed precision training loss scale.")
  292. parser.add_argument(
  293. "--eval",
  294. action='store_true',
  295. default=False,
  296. help="Whether to perform evaluation in train")
  297. parser.add_argument(
  298. "--output_eval",
  299. default=None,
  300. type=str,
  301. help="Evaluation directory, default is current directory.")
  302. parser.add_argument(
  303. "--use_vdl",
  304. type=bool,
  305. default=False,
  306. help="whether to record the data to VisualDL.")
  307. parser.add_argument(
  308. '--vdl_log_dir',
  309. type=str,
  310. default="vdl_log_dir/scalar",
  311. help='VisualDL logging directory for scalar.')
  312. parser.add_argument(
  313. "--enable_ce",
  314. type=bool,
  315. default=False,
  316. help="If set True, enable continuous evaluation job."
  317. "This flag is only used for internal test.")
  318. #NOTE:args for profiler tools, used for benchmark
  319. parser.add_argument(
  320. '--is_profiler',
  321. type=int,
  322. default=0,
  323. help='The switch of profiler tools. (used for benchmark)')
  324. parser.add_argument(
  325. '--profiler_path',
  326. type=str,
  327. default="./detection.profiler",
  328. help='The profiler output file path. (used for benchmark)')
  329. FLAGS = parser.parse_args()
  330. main()