distill.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. # add python path of PadleDetection to sys.path
  20. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 3)))
  21. if parent_path not in sys.path:
  22. sys.path.append(parent_path)
  23. import logging
  24. FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
  25. logging.basicConfig(level=logging.INFO, format=FORMAT)
  26. logger = logging.getLogger(__name__)
  27. from collections import OrderedDict
  28. from paddleslim.dist.single_distiller import merge, l2_loss
  29. from paddle import fluid
  30. try:
  31. from ppdet.core.workspace import load_config, merge_config, create
  32. from ppdet.data.reader import create_reader
  33. from ppdet.utils.eval_utils import parse_fetches, eval_results, eval_run
  34. from ppdet.utils.stats import TrainingStats
  35. from ppdet.utils.cli import ArgsParser
  36. from ppdet.utils.check import check_gpu, check_version, check_config, enable_static_mode
  37. import ppdet.utils.checkpoint as checkpoint
  38. except ImportError as e:
  39. if sys.argv[0].find('static') >= 0:
  40. logger.error("Importing ppdet failed when running static model "
  41. "with error: {}\n"
  42. "please try:\n"
  43. "\t1. run static model under PaddleDetection/static "
  44. "directory\n"
  45. "\t2. run 'pip uninstall ppdet' to uninstall ppdet "
  46. "dynamic version firstly.".format(e))
  47. sys.exit(-1)
  48. else:
  49. raise e
  50. def l2_distill(pairs, weight):
  51. """
  52. Add l2 distillation losses composed of multi pairs of feature maps,
  53. each pair of feature maps is the input of teacher and student's
  54. yolov3_loss respectively
  55. """
  56. loss = []
  57. for pair in pairs:
  58. loss.append(l2_loss(pair[0], pair[1]))
  59. loss = fluid.layers.sum(loss)
  60. weighted_loss = loss * weight
  61. return weighted_loss
  62. def split_distill(split_output_names, weight):
  63. """
  64. Add fine grained distillation losses.
  65. Each loss is composed by distill_reg_loss, distill_cls_loss and
  66. distill_obj_loss
  67. """
  68. student_var = []
  69. for name in split_output_names:
  70. student_var.append(fluid.default_main_program().global_block().var(
  71. name))
  72. s_x0, s_y0, s_w0, s_h0, s_obj0, s_cls0 = student_var[0:6]
  73. s_x1, s_y1, s_w1, s_h1, s_obj1, s_cls1 = student_var[6:12]
  74. s_x2, s_y2, s_w2, s_h2, s_obj2, s_cls2 = student_var[12:18]
  75. teacher_var = []
  76. for name in split_output_names:
  77. teacher_var.append(fluid.default_main_program().global_block().var(
  78. 'teacher_' + name))
  79. t_x0, t_y0, t_w0, t_h0, t_obj0, t_cls0 = teacher_var[0:6]
  80. t_x1, t_y1, t_w1, t_h1, t_obj1, t_cls1 = teacher_var[6:12]
  81. t_x2, t_y2, t_w2, t_h2, t_obj2, t_cls2 = teacher_var[12:18]
  82. def obj_weighted_reg(sx, sy, sw, sh, tx, ty, tw, th, tobj):
  83. loss_x = fluid.layers.sigmoid_cross_entropy_with_logits(
  84. sx, fluid.layers.sigmoid(tx))
  85. loss_y = fluid.layers.sigmoid_cross_entropy_with_logits(
  86. sy, fluid.layers.sigmoid(ty))
  87. loss_w = fluid.layers.abs(sw - tw)
  88. loss_h = fluid.layers.abs(sh - th)
  89. loss = fluid.layers.sum([loss_x, loss_y, loss_w, loss_h])
  90. weighted_loss = fluid.layers.reduce_mean(loss *
  91. fluid.layers.sigmoid(tobj))
  92. return weighted_loss
  93. def obj_weighted_cls(scls, tcls, tobj):
  94. loss = fluid.layers.sigmoid_cross_entropy_with_logits(
  95. scls, fluid.layers.sigmoid(tcls))
  96. weighted_loss = fluid.layers.reduce_mean(
  97. fluid.layers.elementwise_mul(
  98. loss, fluid.layers.sigmoid(tobj), axis=0))
  99. return weighted_loss
  100. def obj_loss(sobj, tobj):
  101. obj_mask = fluid.layers.cast(tobj > 0., dtype="float32")
  102. obj_mask.stop_gradient = True
  103. loss = fluid.layers.reduce_mean(
  104. fluid.layers.sigmoid_cross_entropy_with_logits(sobj, obj_mask))
  105. return loss
  106. distill_reg_loss0 = obj_weighted_reg(s_x0, s_y0, s_w0, s_h0, t_x0, t_y0,
  107. t_w0, t_h0, t_obj0)
  108. distill_reg_loss1 = obj_weighted_reg(s_x1, s_y1, s_w1, s_h1, t_x1, t_y1,
  109. t_w1, t_h1, t_obj1)
  110. distill_reg_loss2 = obj_weighted_reg(s_x2, s_y2, s_w2, s_h2, t_x2, t_y2,
  111. t_w2, t_h2, t_obj2)
  112. distill_reg_loss = fluid.layers.sum(
  113. [distill_reg_loss0, distill_reg_loss1, distill_reg_loss2])
  114. distill_cls_loss0 = obj_weighted_cls(s_cls0, t_cls0, t_obj0)
  115. distill_cls_loss1 = obj_weighted_cls(s_cls1, t_cls1, t_obj1)
  116. distill_cls_loss2 = obj_weighted_cls(s_cls2, t_cls2, t_obj2)
  117. distill_cls_loss = fluid.layers.sum(
  118. [distill_cls_loss0, distill_cls_loss1, distill_cls_loss2])
  119. distill_obj_loss0 = obj_loss(s_obj0, t_obj0)
  120. distill_obj_loss1 = obj_loss(s_obj1, t_obj1)
  121. distill_obj_loss2 = obj_loss(s_obj2, t_obj2)
  122. distill_obj_loss = fluid.layers.sum(
  123. [distill_obj_loss0, distill_obj_loss1, distill_obj_loss2])
  124. loss = (distill_reg_loss + distill_cls_loss + distill_obj_loss) * weight
  125. return loss
  126. def main():
  127. env = os.environ
  128. cfg = load_config(FLAGS.config)
  129. merge_config(FLAGS.opt)
  130. check_config(cfg)
  131. # check if set use_gpu=True in paddlepaddle cpu version
  132. check_gpu(cfg.use_gpu)
  133. check_version()
  134. main_arch = cfg.architecture
  135. if cfg.use_gpu:
  136. devices_num = fluid.core.get_cuda_device_count()
  137. else:
  138. devices_num = int(os.environ.get('CPU_NUM', 1))
  139. if 'FLAGS_selected_gpus' in env:
  140. device_id = int(env['FLAGS_selected_gpus'])
  141. else:
  142. device_id = 0
  143. place = fluid.CUDAPlace(device_id) if cfg.use_gpu else fluid.CPUPlace()
  144. exe = fluid.Executor(place)
  145. # build program
  146. model = create(main_arch)
  147. inputs_def = cfg['TrainReader']['inputs_def']
  148. train_feed_vars, train_loader = model.build_inputs(**inputs_def)
  149. train_fetches = model.train(train_feed_vars)
  150. loss = train_fetches['loss']
  151. start_iter = 0
  152. train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) *
  153. devices_num, cfg)
  154. # When iterable mode, set set_sample_list_generator(train_reader, place)
  155. train_loader.set_sample_list_generator(train_reader)
  156. # get all student variables
  157. student_vars = []
  158. for v in fluid.default_main_program().list_vars():
  159. try:
  160. student_vars.append((v.name, v.shape))
  161. except:
  162. pass
  163. # uncomment the following lines to print all student variables
  164. # print("="*50 + "student_model_vars" + "="*50)
  165. # print(student_vars)
  166. eval_prog = fluid.Program()
  167. with fluid.program_guard(eval_prog, fluid.default_startup_program()):
  168. with fluid.unique_name.guard():
  169. model = create(main_arch)
  170. inputs_def = cfg['EvalReader']['inputs_def']
  171. test_feed_vars, eval_loader = model.build_inputs(**inputs_def)
  172. fetches = model.eval(test_feed_vars)
  173. eval_prog = eval_prog.clone(True)
  174. eval_reader = create_reader(cfg.EvalReader)
  175. # When iterable mode, set set_sample_list_generator(eval_reader, place)
  176. eval_loader.set_sample_list_generator(eval_reader)
  177. # parse eval fetches
  178. extra_keys = []
  179. if cfg.metric == 'COCO':
  180. extra_keys = ['im_info', 'im_id', 'im_shape']
  181. if cfg.metric == 'VOC':
  182. extra_keys = ['gt_bbox', 'gt_class', 'is_difficult']
  183. eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog,
  184. extra_keys)
  185. teacher_cfg = load_config(FLAGS.teacher_config)
  186. merge_config(FLAGS.opt)
  187. teacher_arch = teacher_cfg.architecture
  188. teacher_program = fluid.Program()
  189. teacher_startup_program = fluid.Program()
  190. with fluid.program_guard(teacher_program, teacher_startup_program):
  191. with fluid.unique_name.guard():
  192. teacher_feed_vars = OrderedDict()
  193. for name, var in train_feed_vars.items():
  194. teacher_feed_vars[name] = teacher_program.global_block(
  195. )._clone_variable(
  196. var, force_persistable=False)
  197. model = create(teacher_arch)
  198. train_fetches = model.train(teacher_feed_vars)
  199. teacher_loss = train_fetches['loss']
  200. # get all teacher variables
  201. teacher_vars = []
  202. for v in teacher_program.list_vars():
  203. try:
  204. teacher_vars.append((v.name, v.shape))
  205. except:
  206. pass
  207. # uncomment the following lines to print all teacher variables
  208. # print("="*50 + "teacher_model_vars" + "="*50)
  209. # print(teacher_vars)
  210. exe.run(teacher_startup_program)
  211. assert FLAGS.teacher_pretrained, "teacher_pretrained should be set"
  212. checkpoint.load_params(exe, teacher_program, FLAGS.teacher_pretrained)
  213. teacher_program = teacher_program.clone(for_test=True)
  214. cfg = load_config(FLAGS.config)
  215. merge_config(FLAGS.opt)
  216. data_name_map = {
  217. 'target0': 'target0',
  218. 'target1': 'target1',
  219. 'target2': 'target2',
  220. 'image': 'image',
  221. 'gt_bbox': 'gt_bbox',
  222. 'gt_class': 'gt_class',
  223. 'gt_score': 'gt_score'
  224. }
  225. merge(teacher_program, fluid.default_main_program(), data_name_map, place)
  226. yolo_output_names = [
  227. 'strided_slice_0.tmp_0', 'strided_slice_1.tmp_0',
  228. 'strided_slice_2.tmp_0', 'strided_slice_3.tmp_0',
  229. 'strided_slice_4.tmp_0', 'transpose_0.tmp_0', 'strided_slice_5.tmp_0',
  230. 'strided_slice_6.tmp_0', 'strided_slice_7.tmp_0',
  231. 'strided_slice_8.tmp_0', 'strided_slice_9.tmp_0', 'transpose_2.tmp_0',
  232. 'strided_slice_10.tmp_0', 'strided_slice_11.tmp_0',
  233. 'strided_slice_12.tmp_0', 'strided_slice_13.tmp_0',
  234. 'strided_slice_14.tmp_0', 'transpose_4.tmp_0'
  235. ]
  236. distill_pairs = [['teacher_conv2d_6.tmp_1', 'conv2d_20.tmp_1'],
  237. ['teacher_conv2d_14.tmp_1', 'conv2d_28.tmp_1'],
  238. ['teacher_conv2d_22.tmp_1', 'conv2d_36.tmp_1']]
  239. distill_loss = l2_distill(
  240. distill_pairs, 100) if not cfg.use_fine_grained_loss else split_distill(
  241. yolo_output_names, 1000)
  242. loss = distill_loss + loss
  243. lr_builder = create('LearningRate')
  244. optim_builder = create('OptimizerBuilder')
  245. lr = lr_builder()
  246. opt = optim_builder(lr)
  247. opt.minimize(loss)
  248. exe.run(fluid.default_startup_program())
  249. fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'
  250. ignore_params = cfg.finetune_exclude_pretrained_params \
  251. if 'finetune_exclude_pretrained_params' in cfg else []
  252. if FLAGS.resume_checkpoint:
  253. checkpoint.load_checkpoint(exe,
  254. fluid.default_main_program(),
  255. FLAGS.resume_checkpoint)
  256. start_iter = checkpoint.global_step()
  257. elif cfg.pretrain_weights and fuse_bn and not ignore_params:
  258. checkpoint.load_and_fusebn(exe,
  259. fluid.default_main_program(),
  260. cfg.pretrain_weights)
  261. elif cfg.pretrain_weights:
  262. checkpoint.load_params(
  263. exe,
  264. fluid.default_main_program(),
  265. cfg.pretrain_weights,
  266. ignore_params=ignore_params)
  267. build_strategy = fluid.BuildStrategy()
  268. build_strategy.fuse_all_reduce_ops = False
  269. build_strategy.fuse_all_optimizer_ops = False
  270. # only enable sync_bn in multi GPU devices
  271. sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn'
  272. build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \
  273. and cfg.use_gpu
  274. exec_strategy = fluid.ExecutionStrategy()
  275. # iteration number when CompiledProgram tries to drop local execution scopes.
  276. # Set it to be 1 to save memory usages, so that unused variables in
  277. # local execution scopes can be deleted after each iteration.
  278. exec_strategy.num_iteration_per_drop_scope = 1
  279. parallel_main = fluid.CompiledProgram(fluid.default_main_program(
  280. )).with_data_parallel(
  281. loss_name=loss.name,
  282. build_strategy=build_strategy,
  283. exec_strategy=exec_strategy)
  284. compiled_eval_prog = fluid.CompiledProgram(eval_prog)
  285. # whether output bbox is normalized in model output layer
  286. is_bbox_normalized = False
  287. if hasattr(model, 'is_bbox_normalized') and \
  288. callable(model.is_bbox_normalized):
  289. is_bbox_normalized = model.is_bbox_normalized()
  290. map_type = cfg.map_type if 'map_type' in cfg else '11point'
  291. best_box_ap_list = [0.0, 0] #[map, iter]
  292. cfg_name = os.path.basename(FLAGS.config).split('.')[0]
  293. save_dir = os.path.join(cfg.save_dir, cfg_name)
  294. train_loader.start()
  295. for step_id in range(start_iter, cfg.max_iters):
  296. teacher_loss_np, distill_loss_np, loss_np, lr_np = exe.run(
  297. parallel_main,
  298. fetch_list=[
  299. 'teacher_' + teacher_loss.name, distill_loss.name, loss.name,
  300. lr.name
  301. ])
  302. if step_id % cfg.log_iter == 0:
  303. logger.info(
  304. "step {} lr {:.6f}, loss {:.6f}, distill_loss {:.6f}, teacher_loss {:.6f}".
  305. format(step_id, lr_np[0], loss_np[0], distill_loss_np[0],
  306. teacher_loss_np[0]))
  307. if step_id % cfg.snapshot_iter == 0 and step_id != 0 or step_id == cfg.max_iters - 1:
  308. save_name = str(
  309. step_id) if step_id != cfg.max_iters - 1 else "model_final"
  310. checkpoint.save(exe,
  311. fluid.default_main_program(),
  312. os.path.join(save_dir, save_name))
  313. if FLAGS.save_inference:
  314. feeded_var_names = ['image', 'im_size']
  315. targets = list(fetches.values())
  316. fluid.io.save_inference_model(save_dir + '/infer',
  317. feeded_var_names, targets, exe,
  318. eval_prog)
  319. # eval
  320. results = eval_run(exe, compiled_eval_prog, eval_loader, eval_keys,
  321. eval_values, eval_cls, cfg)
  322. resolution = None
  323. box_ap_stats = eval_results(results, cfg.metric, cfg.num_classes,
  324. resolution, is_bbox_normalized,
  325. FLAGS.output_eval, map_type,
  326. cfg['EvalReader']['dataset'])
  327. if box_ap_stats[0] > best_box_ap_list[0]:
  328. best_box_ap_list[0] = box_ap_stats[0]
  329. best_box_ap_list[1] = step_id
  330. checkpoint.save(exe,
  331. fluid.default_main_program(),
  332. os.path.join(save_dir, "best_model"))
  333. if FLAGS.save_inference:
  334. feeded_var_names = ['image', 'im_size']
  335. targets = list(fetches.values())
  336. fluid.io.save_inference_model(save_dir + '/infer',
  337. feeded_var_names, targets,
  338. exe, eval_prog)
  339. logger.info("Best test box ap: {}, in step: {}".format(
  340. best_box_ap_list[0], best_box_ap_list[1]))
  341. train_loader.reset()
  342. if __name__ == '__main__':
  343. enable_static_mode()
  344. parser = ArgsParser()
  345. parser.add_argument(
  346. "-r",
  347. "--resume_checkpoint",
  348. default=None,
  349. type=str,
  350. help="Checkpoint path for resuming training.")
  351. parser.add_argument(
  352. "-t",
  353. "--teacher_config",
  354. default=None,
  355. type=str,
  356. help="Config file of teacher architecture.")
  357. parser.add_argument(
  358. "--teacher_pretrained",
  359. default=None,
  360. type=str,
  361. help="Whether to use pretrained model.")
  362. parser.add_argument(
  363. "--output_eval",
  364. default=None,
  365. type=str,
  366. help="Evaluation directory, default is current directory.")
  367. parser.add_argument(
  368. "--save_inference",
  369. default=False,
  370. type=bool,
  371. help="Whether to save inference model.")
  372. FLAGS = parser.parse_args()
  373. main()