prune.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. # add python path of PadleDetection to sys.path
  20. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 3)))
  21. if parent_path not in sys.path:
  22. sys.path.append(parent_path)
  23. import time
  24. import numpy as np
  25. import datetime
  26. from collections import deque
  27. from paddleslim.prune import Pruner
  28. from paddleslim.analysis import flops
  29. from paddle import fluid
  30. import logging
  31. FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
  32. logging.basicConfig(level=logging.INFO, format=FORMAT)
  33. logger = logging.getLogger(__name__)
  34. try:
  35. from ppdet.experimental import mixed_precision_context
  36. from ppdet.core.workspace import load_config, merge_config, create
  37. from ppdet.data.reader import create_reader
  38. from ppdet.utils import dist_utils
  39. from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results
  40. from ppdet.utils.stats import TrainingStats
  41. from ppdet.utils.cli import ArgsParser
  42. from ppdet.utils.check import check_gpu, check_version, check_config, enable_static_mode
  43. import ppdet.utils.checkpoint as checkpoint
  44. except ImportError as e:
  45. if sys.argv[0].find('static') >= 0:
  46. logger.error("Importing ppdet failed when running static model "
  47. "with error: {}\n"
  48. "please try:\n"
  49. "\t1. run static model under PaddleDetection/static "
  50. "directory\n"
  51. "\t2. run 'pip uninstall ppdet' to uninstall ppdet "
  52. "dynamic version firstly.".format(e))
  53. sys.exit(-1)
  54. else:
  55. raise e
  56. def main():
  57. env = os.environ
  58. FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env
  59. if FLAGS.dist:
  60. trainer_id = int(env['PADDLE_TRAINER_ID'])
  61. import random
  62. local_seed = (99 + trainer_id)
  63. random.seed(local_seed)
  64. np.random.seed(local_seed)
  65. cfg = load_config(FLAGS.config)
  66. merge_config(FLAGS.opt)
  67. check_config(cfg)
  68. # check if set use_gpu=True in paddlepaddle cpu version
  69. check_gpu(cfg.use_gpu)
  70. # check if paddlepaddle version is satisfied
  71. check_version()
  72. main_arch = cfg.architecture
  73. if cfg.use_gpu:
  74. devices_num = fluid.core.get_cuda_device_count()
  75. else:
  76. devices_num = int(os.environ.get('CPU_NUM', 1))
  77. if 'FLAGS_selected_gpus' in env:
  78. device_id = int(env['FLAGS_selected_gpus'])
  79. else:
  80. device_id = 0
  81. place = fluid.CUDAPlace(device_id) if cfg.use_gpu else fluid.CPUPlace()
  82. exe = fluid.Executor(place)
  83. lr_builder = create('LearningRate')
  84. optim_builder = create('OptimizerBuilder')
  85. # build program
  86. startup_prog = fluid.Program()
  87. train_prog = fluid.Program()
  88. with fluid.program_guard(train_prog, startup_prog):
  89. with fluid.unique_name.guard():
  90. model = create(main_arch)
  91. if FLAGS.fp16:
  92. assert (getattr(model.backbone, 'norm_type', None)
  93. != 'affine_channel'), \
  94. '--fp16 currently does not support affine channel, ' \
  95. ' please modify backbone settings to use batch norm'
  96. with mixed_precision_context(FLAGS.loss_scale, FLAGS.fp16) as ctx:
  97. inputs_def = cfg['TrainReader']['inputs_def']
  98. feed_vars, train_loader = model.build_inputs(**inputs_def)
  99. train_fetches = model.train(feed_vars)
  100. loss = train_fetches['loss']
  101. if FLAGS.fp16:
  102. loss *= ctx.get_loss_scale_var()
  103. lr = lr_builder()
  104. optimizer = optim_builder(lr)
  105. optimizer.minimize(loss)
  106. if FLAGS.fp16:
  107. loss /= ctx.get_loss_scale_var()
  108. # parse train fetches
  109. train_keys, train_values, _ = parse_fetches(train_fetches)
  110. train_values.append(lr)
  111. if FLAGS.print_params:
  112. param_delimit_str = '-' * 20 + "All parameters in current graph" + '-' * 20
  113. print(param_delimit_str)
  114. for block in train_prog.blocks:
  115. for param in block.all_parameters():
  116. print("parameter name: {}\tshape: {}".format(param.name,
  117. param.shape))
  118. print('-' * len(param_delimit_str))
  119. return
  120. if FLAGS.eval:
  121. eval_prog = fluid.Program()
  122. with fluid.program_guard(eval_prog, startup_prog):
  123. with fluid.unique_name.guard():
  124. model = create(main_arch)
  125. inputs_def = cfg['EvalReader']['inputs_def']
  126. feed_vars, eval_loader = model.build_inputs(**inputs_def)
  127. fetches = model.eval(feed_vars)
  128. eval_prog = eval_prog.clone(True)
  129. eval_reader = create_reader(cfg.EvalReader)
  130. # When iterable mode, set set_sample_list_generator(eval_reader, place)
  131. eval_loader.set_sample_list_generator(eval_reader)
  132. # parse eval fetches
  133. extra_keys = []
  134. if cfg.metric == 'COCO':
  135. extra_keys = ['im_info', 'im_id', 'im_shape']
  136. if cfg.metric == 'VOC':
  137. extra_keys = ['gt_bbox', 'gt_class', 'is_difficult']
  138. if cfg.metric == 'WIDERFACE':
  139. extra_keys = ['im_id', 'im_shape', 'gt_bbox']
  140. eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog,
  141. extra_keys)
  142. # compile program for multi-devices
  143. build_strategy = fluid.BuildStrategy()
  144. build_strategy.fuse_all_optimizer_ops = False
  145. build_strategy.fuse_elewise_add_act_ops = True
  146. # only enable sync_bn in multi GPU devices
  147. sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn'
  148. build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \
  149. and cfg.use_gpu
  150. exec_strategy = fluid.ExecutionStrategy()
  151. # iteration number when CompiledProgram tries to drop local execution scopes.
  152. # Set it to be 1 to save memory usages, so that unused variables in
  153. # local execution scopes can be deleted after each iteration.
  154. exec_strategy.num_iteration_per_drop_scope = 1
  155. if FLAGS.dist:
  156. dist_utils.prepare_for_multi_process(exe, build_strategy, startup_prog,
  157. train_prog)
  158. exec_strategy.num_threads = 1
  159. exe.run(startup_prog)
  160. fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'
  161. start_iter = 0
  162. if cfg.pretrain_weights:
  163. checkpoint.load_params(exe, train_prog, cfg.pretrain_weights)
  164. pruned_params = FLAGS.pruned_params
  165. assert FLAGS.pruned_params is not None, \
  166. "FLAGS.pruned_params is empty!!! Please set it by '--pruned_params' option."
  167. pruned_params = FLAGS.pruned_params.strip().split(",")
  168. logger.info("pruned params: {}".format(pruned_params))
  169. pruned_ratios = [float(n) for n in FLAGS.pruned_ratios.strip().split(",")]
  170. logger.info("pruned ratios: {}".format(pruned_ratios))
  171. assert len(pruned_params) == len(pruned_ratios), \
  172. "The length of pruned params and pruned ratios should be equal."
  173. assert (pruned_ratios > [0] * len(pruned_ratios) and
  174. pruned_ratios < [1] * len(pruned_ratios)
  175. ), "The elements of pruned ratios should be in range (0, 1)."
  176. assert FLAGS.prune_criterion in ['l1_norm', 'geometry_median'], \
  177. "unsupported prune criterion {}".format(FLAGS.prune_criterion)
  178. pruner = Pruner(criterion=FLAGS.prune_criterion)
  179. if FLAGS.eval:
  180. base_flops = flops(eval_prog)
  181. eval_prog = pruner.prune(
  182. eval_prog,
  183. fluid.global_scope(),
  184. params=pruned_params,
  185. ratios=pruned_ratios,
  186. place=place,
  187. only_graph=True)[0]
  188. pruned_flops = flops(eval_prog)
  189. logger.info("FLOPs -{}; total FLOPs: {}; pruned FLOPs: {}".format(
  190. float(base_flops - pruned_flops) / base_flops, base_flops,
  191. pruned_flops))
  192. compiled_eval_prog = fluid.CompiledProgram(eval_prog)
  193. train_prog = pruner.prune(
  194. train_prog,
  195. fluid.global_scope(),
  196. params=pruned_params,
  197. ratios=pruned_ratios,
  198. place=place,
  199. only_graph=False)[0]
  200. compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
  201. loss_name=loss.name,
  202. build_strategy=build_strategy,
  203. exec_strategy=exec_strategy)
  204. if FLAGS.resume_checkpoint:
  205. checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint)
  206. start_iter = checkpoint.global_step()
  207. train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) *
  208. devices_num, cfg)
  209. train_loader.set_sample_list_generator(train_reader, place)
  210. # whether output bbox is normalized in model output layer
  211. is_bbox_normalized = False
  212. if hasattr(model, 'is_bbox_normalized') and \
  213. callable(model.is_bbox_normalized):
  214. is_bbox_normalized = model.is_bbox_normalized()
  215. # if map_type not set, use default 11point, only use in VOC eval
  216. map_type = cfg.map_type if 'map_type' in cfg else '11point'
  217. train_stats = TrainingStats(cfg.log_iter, train_keys)
  218. train_loader.start()
  219. start_time = time.time()
  220. end_time = time.time()
  221. cfg_name = os.path.basename(FLAGS.config).split('.')[0]
  222. save_dir = os.path.join(cfg.save_dir, cfg_name)
  223. time_stat = deque(maxlen=cfg.log_iter)
  224. best_box_ap_list = [0.0, 0] #[map, iter]
  225. # use VisualDL to log data
  226. if FLAGS.use_vdl:
  227. from visualdl import LogWriter
  228. vdl_writer = LogWriter(FLAGS.vdl_log_dir)
  229. vdl_loss_step = 0
  230. vdl_mAP_step = 0
  231. if FLAGS.eval:
  232. resolution = None
  233. if 'Mask' in cfg.architecture:
  234. resolution = model.mask_head.resolution
  235. # evaluation
  236. results = eval_run(
  237. exe,
  238. compiled_eval_prog,
  239. eval_loader,
  240. eval_keys,
  241. eval_values,
  242. eval_cls,
  243. cfg,
  244. resolution=resolution)
  245. dataset = cfg['EvalReader']['dataset']
  246. box_ap_stats = eval_results(
  247. results,
  248. cfg.metric,
  249. cfg.num_classes,
  250. resolution,
  251. is_bbox_normalized,
  252. FLAGS.output_eval,
  253. map_type,
  254. dataset=dataset)
  255. for it in range(start_iter, cfg.max_iters):
  256. start_time = end_time
  257. end_time = time.time()
  258. time_stat.append(end_time - start_time)
  259. time_cost = np.mean(time_stat)
  260. eta_sec = (cfg.max_iters - it) * time_cost
  261. eta = str(datetime.timedelta(seconds=int(eta_sec)))
  262. outs = exe.run(compiled_train_prog, fetch_list=train_values)
  263. stats = {k: np.array(v).mean() for k, v in zip(train_keys, outs[:-1])}
  264. # use VisualDL to log loss
  265. if FLAGS.use_vdl:
  266. if it % cfg.log_iter == 0:
  267. for loss_name, loss_value in stats.items():
  268. vdl_writer.add_scalar(loss_name, loss_value, vdl_loss_step)
  269. vdl_loss_step += 1
  270. train_stats.update(stats)
  271. logs = train_stats.log()
  272. if it % cfg.log_iter == 0 and (not FLAGS.dist or trainer_id == 0):
  273. strs = 'iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'.format(
  274. it, np.mean(outs[-1]), logs, time_cost, eta)
  275. logger.info(strs)
  276. if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \
  277. and (not FLAGS.dist or trainer_id == 0):
  278. save_name = str(it) if it != cfg.max_iters - 1 else "model_final"
  279. checkpoint.save(exe, train_prog, os.path.join(save_dir, save_name))
  280. if FLAGS.eval:
  281. # evaluation
  282. resolution = None
  283. if 'Mask' in cfg.architecture:
  284. resolution = model.mask_head.resolution
  285. results = eval_run(
  286. exe,
  287. compiled_eval_prog,
  288. eval_loader,
  289. eval_keys,
  290. eval_values,
  291. eval_cls,
  292. cfg=cfg,
  293. resolution=resolution)
  294. box_ap_stats = eval_results(
  295. results,
  296. cfg.metric,
  297. cfg.num_classes,
  298. resolution,
  299. is_bbox_normalized,
  300. FLAGS.output_eval,
  301. map_type,
  302. dataset=dataset)
  303. # use VisualDL to log mAP
  304. if FLAGS.use_vdl:
  305. vdl_writer.add_scalar("mAP", box_ap_stats[0], vdl_mAP_step)
  306. vdl_mAP_step += 1
  307. if box_ap_stats[0] > best_box_ap_list[0]:
  308. best_box_ap_list[0] = box_ap_stats[0]
  309. best_box_ap_list[1] = it
  310. checkpoint.save(exe, train_prog,
  311. os.path.join(save_dir, "best_model"))
  312. logger.info("Best test box ap: {}, in iter: {}".format(
  313. best_box_ap_list[0], best_box_ap_list[1]))
  314. train_loader.reset()
  315. if __name__ == '__main__':
  316. enable_static_mode()
  317. parser = ArgsParser()
  318. parser.add_argument(
  319. "-r",
  320. "--resume_checkpoint",
  321. default=None,
  322. type=str,
  323. help="Checkpoint path for resuming training.")
  324. parser.add_argument(
  325. "--fp16",
  326. action='store_true',
  327. default=False,
  328. help="Enable mixed precision training.")
  329. parser.add_argument(
  330. "--loss_scale",
  331. default=8.,
  332. type=float,
  333. help="Mixed precision training loss scale.")
  334. parser.add_argument(
  335. "--eval",
  336. action='store_true',
  337. default=False,
  338. help="Whether to perform evaluation in train")
  339. parser.add_argument(
  340. "--output_eval",
  341. default=None,
  342. type=str,
  343. help="Evaluation directory, default is current directory.")
  344. parser.add_argument(
  345. "--use_vdl",
  346. type=bool,
  347. default=False,
  348. help="whether to record the data to VisualDL.")
  349. parser.add_argument(
  350. '--vdl_log_dir',
  351. type=str,
  352. default="vdl_log_dir/scalar",
  353. help='VisualDL logging directory for scalar.')
  354. parser.add_argument(
  355. "-p",
  356. "--pruned_params",
  357. default=None,
  358. type=str,
  359. help="The parameters to be pruned when calculating sensitivities.")
  360. parser.add_argument(
  361. "--pruned_ratios",
  362. default=None,
  363. type=str,
  364. help="The ratios pruned iteratively for each parameter when calculating sensitivities."
  365. )
  366. parser.add_argument(
  367. "-P",
  368. "--print_params",
  369. default=False,
  370. action='store_true',
  371. help="Whether to only print the parameters' names and shapes.")
  372. parser.add_argument(
  373. "--prune_criterion",
  374. default='l1_norm',
  375. type=str,
  376. help="criterion function type for channels sorting in pruning, can be set " \
  377. "as 'l1_norm' or 'geometry_median' currently, default 'l1_norm'")
  378. FLAGS = parser.parse_args()
  379. main()