callbacks.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. import datetime
  20. import six
  21. import copy
  22. import json
  23. import paddle
  24. import paddle.distributed as dist
  25. from ppdet.utils.checkpoint import save_model
  26. from ppdet.metrics import get_infer_results
  27. from ppdet.utils.logger import setup_logger
  28. logger = setup_logger('ppdet.engine')
  29. __all__ = [
  30. 'Callback', 'ComposeCallback', 'LogPrinter', 'Checkpointer',
  31. 'VisualDLWriter', 'SniperProposalsGenerator'
  32. ]
  33. class Callback(object):
  34. def __init__(self, model):
  35. self.model = model
  36. def on_step_begin(self, status):
  37. pass
  38. def on_step_end(self, status):
  39. pass
  40. def on_epoch_begin(self, status):
  41. pass
  42. def on_epoch_end(self, status):
  43. pass
  44. def on_train_begin(self, status):
  45. pass
  46. def on_train_end(self, status):
  47. pass
  48. class ComposeCallback(object):
  49. def __init__(self, callbacks):
  50. callbacks = [c for c in list(callbacks) if c is not None]
  51. for c in callbacks:
  52. assert isinstance(
  53. c, Callback), "callback should be subclass of Callback"
  54. self._callbacks = callbacks
  55. def on_step_begin(self, status):
  56. for c in self._callbacks:
  57. c.on_step_begin(status)
  58. def on_step_end(self, status):
  59. for c in self._callbacks:
  60. c.on_step_end(status)
  61. def on_epoch_begin(self, status):
  62. for c in self._callbacks:
  63. c.on_epoch_begin(status)
  64. def on_epoch_end(self, status):
  65. for c in self._callbacks:
  66. c.on_epoch_end(status)
  67. def on_train_begin(self, status):
  68. for c in self._callbacks:
  69. c.on_train_begin(status)
  70. def on_train_end(self, status):
  71. for c in self._callbacks:
  72. c.on_train_end(status)
  73. class LogPrinter(Callback):
  74. def __init__(self, model):
  75. super(LogPrinter, self).__init__(model)
  76. def on_step_end(self, status):
  77. if dist.get_world_size() < 2 or dist.get_rank() == 0:
  78. mode = status['mode']
  79. if mode == 'train':
  80. epoch_id = status['epoch_id']
  81. step_id = status['step_id']
  82. steps_per_epoch = status['steps_per_epoch']
  83. training_staus = status['training_staus']
  84. batch_time = status['batch_time']
  85. data_time = status['data_time']
  86. epoches = self.model.cfg.epoch
  87. batch_size = self.model.cfg['{}Reader'.format(mode.capitalize(
  88. ))]['batch_size']
  89. logs = training_staus.log()
  90. space_fmt = ':' + str(len(str(steps_per_epoch))) + 'd'
  91. if step_id % self.model.cfg.log_iter == 0:
  92. eta_steps = (epoches - epoch_id) * steps_per_epoch - step_id
  93. eta_sec = eta_steps * batch_time.global_avg
  94. eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
  95. ips = float(batch_size) / batch_time.avg
  96. fmt = ' '.join([
  97. 'Epoch: [{}]',
  98. '[{' + space_fmt + '}/{}]',
  99. 'learning_rate: {lr:.6f}',
  100. '{meters}',
  101. 'eta: {eta}',
  102. 'batch_cost: {btime}',
  103. 'data_cost: {dtime}',
  104. 'ips: {ips:.4f} images/s',
  105. ])
  106. fmt = fmt.format(
  107. epoch_id,
  108. step_id,
  109. steps_per_epoch,
  110. lr=status['learning_rate'],
  111. meters=logs,
  112. eta=eta_str,
  113. btime=str(batch_time),
  114. dtime=str(data_time),
  115. ips=ips)
  116. logger.info(fmt)
  117. if mode == 'eval':
  118. step_id = status['step_id']
  119. if step_id % 100 == 0:
  120. logger.info("Eval iter: {}".format(step_id))
  121. def on_epoch_end(self, status):
  122. if dist.get_world_size() < 2 or dist.get_rank() == 0:
  123. mode = status['mode']
  124. if mode == 'eval':
  125. sample_num = status['sample_num']
  126. cost_time = status['cost_time']
  127. logger.info('Total sample number: {}, averge FPS: {}'.format(
  128. sample_num, sample_num / cost_time))
  129. class Checkpointer(Callback):
  130. def __init__(self, model):
  131. super(Checkpointer, self).__init__(model)
  132. cfg = self.model.cfg
  133. self.best_ap = 0.
  134. self.save_dir = os.path.join(self.model.cfg.save_dir,
  135. self.model.cfg.filename)
  136. if hasattr(self.model.model, 'student_model'):
  137. self.weight = self.model.model.student_model
  138. else:
  139. self.weight = self.model.model
  140. def on_epoch_end(self, status):
  141. # Checkpointer only performed during training
  142. mode = status['mode']
  143. epoch_id = status['epoch_id']
  144. weight = None
  145. save_name = None
  146. if dist.get_world_size() < 2 or dist.get_rank() == 0:
  147. if mode == 'train':
  148. end_epoch = self.model.cfg.epoch
  149. if (
  150. epoch_id + 1
  151. ) % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1:
  152. save_name = str(
  153. epoch_id) if epoch_id != end_epoch - 1 else "model_final"
  154. weight = self.weight.state_dict()
  155. elif mode == 'eval':
  156. if 'save_best_model' in status and status['save_best_model']:
  157. for metric in self.model._metrics:
  158. map_res = metric.get_results()
  159. if 'bbox' in map_res:
  160. key = 'bbox'
  161. elif 'keypoint' in map_res:
  162. key = 'keypoint'
  163. else:
  164. key = 'mask'
  165. if key not in map_res:
  166. logger.warning("Evaluation results empty, this may be due to " \
  167. "training iterations being too few or not " \
  168. "loading the correct weights.")
  169. return
  170. if map_res[key][0] > self.best_ap:
  171. self.best_ap = map_res[key][0]
  172. save_name = 'best_model'
  173. weight = self.weight.state_dict()
  174. logger.info("Best test {} ap is {:0.3f}.".format(
  175. key, self.best_ap))
  176. if weight:
  177. if self.model.use_ema:
  178. # save model and ema_model
  179. save_model(
  180. status['weight'],
  181. self.model.optimizer,
  182. self.save_dir,
  183. save_name,
  184. epoch_id + 1,
  185. ema_model=weight)
  186. else:
  187. save_model(weight, self.model.optimizer, self.save_dir,
  188. save_name, epoch_id + 1)
  189. class WiferFaceEval(Callback):
  190. def __init__(self, model):
  191. super(WiferFaceEval, self).__init__(model)
  192. def on_epoch_begin(self, status):
  193. assert self.model.mode == 'eval', \
  194. "WiferFaceEval can only be set during evaluation"
  195. for metric in self.model._metrics:
  196. metric.update(self.model.model)
  197. sys.exit()
  198. class VisualDLWriter(Callback):
  199. """
  200. Use VisualDL to log data or image
  201. """
  202. def __init__(self, model):
  203. super(VisualDLWriter, self).__init__(model)
  204. assert six.PY3, "VisualDL requires Python >= 3.5"
  205. try:
  206. from visualdl import LogWriter
  207. except Exception as e:
  208. logger.error('visualdl not found, plaese install visualdl. '
  209. 'for example: `pip install visualdl`.')
  210. raise e
  211. self.vdl_writer = LogWriter(
  212. model.cfg.get('vdl_log_dir', 'vdl_log_dir/scalar'))
  213. self.vdl_loss_step = 0
  214. self.vdl_mAP_step = 0
  215. self.vdl_image_step = 0
  216. self.vdl_image_frame = 0
  217. def on_step_end(self, status):
  218. mode = status['mode']
  219. if dist.get_world_size() < 2 or dist.get_rank() == 0:
  220. if mode == 'train':
  221. training_staus = status['training_staus']
  222. for loss_name, loss_value in training_staus.get().items():
  223. self.vdl_writer.add_scalar(loss_name, loss_value,
  224. self.vdl_loss_step)
  225. self.vdl_loss_step += 1
  226. elif mode == 'test':
  227. ori_image = status['original_image']
  228. result_image = status['result_image']
  229. self.vdl_writer.add_image(
  230. "original/frame_{}".format(self.vdl_image_frame), ori_image,
  231. self.vdl_image_step)
  232. self.vdl_writer.add_image(
  233. "result/frame_{}".format(self.vdl_image_frame),
  234. result_image, self.vdl_image_step)
  235. self.vdl_image_step += 1
  236. # each frame can display ten pictures at most.
  237. if self.vdl_image_step % 10 == 0:
  238. self.vdl_image_step = 0
  239. self.vdl_image_frame += 1
  240. def on_epoch_end(self, status):
  241. mode = status['mode']
  242. if dist.get_world_size() < 2 or dist.get_rank() == 0:
  243. if mode == 'eval':
  244. for metric in self.model._metrics:
  245. for key, map_value in metric.get_results().items():
  246. self.vdl_writer.add_scalar("{}-mAP".format(key),
  247. map_value[0],
  248. self.vdl_mAP_step)
  249. self.vdl_mAP_step += 1
  250. class SniperProposalsGenerator(Callback):
  251. def __init__(self, model):
  252. super(SniperProposalsGenerator, self).__init__(model)
  253. ori_dataset = self.model.dataset
  254. self.dataset = self._create_new_dataset(ori_dataset)
  255. self.loader = self.model.loader
  256. self.cfg = self.model.cfg
  257. self.infer_model = self.model.model
  258. def _create_new_dataset(self, ori_dataset):
  259. dataset = copy.deepcopy(ori_dataset)
  260. # init anno_cropper
  261. dataset.init_anno_cropper()
  262. # generate infer roidbs
  263. ori_roidbs = dataset.get_ori_roidbs()
  264. roidbs = dataset.anno_cropper.crop_infer_anno_records(ori_roidbs)
  265. # set new roidbs
  266. dataset.set_roidbs(roidbs)
  267. return dataset
  268. def _eval_with_loader(self, loader):
  269. results = []
  270. with paddle.no_grad():
  271. self.infer_model.eval()
  272. for step_id, data in enumerate(loader):
  273. outs = self.infer_model(data)
  274. for key in ['im_shape', 'scale_factor', 'im_id']:
  275. outs[key] = data[key]
  276. for key, value in outs.items():
  277. if hasattr(value, 'numpy'):
  278. outs[key] = value.numpy()
  279. results.append(outs)
  280. return results
  281. def on_train_end(self, status):
  282. self.loader.dataset = self.dataset
  283. results = self._eval_with_loader(self.loader)
  284. results = self.dataset.anno_cropper.aggregate_chips_detections(results)
  285. # sniper
  286. proposals = []
  287. clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()}
  288. for outs in results:
  289. batch_res = get_infer_results(outs, clsid2catid)
  290. start = 0
  291. for i, im_id in enumerate(outs['im_id']):
  292. bbox_num = outs['bbox_num']
  293. end = start + bbox_num[i]
  294. bbox_res = batch_res['bbox'][start:end] \
  295. if 'bbox' in batch_res else None
  296. if bbox_res:
  297. proposals += bbox_res
  298. logger.info("save proposals in {}".format(self.cfg.proposals_path))
  299. with open(self.cfg.proposals_path, 'w') as f:
  300. json.dump(proposals, f)