checkpoint.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from __future__ import unicode_literals
  18. import errno
  19. import os
  20. import shutil
  21. import tempfile
  22. import time
  23. import numpy as np
  24. import re
  25. import paddle.fluid as fluid
  26. from .download import get_weights_path
  27. import logging
  28. logger = logging.getLogger(__name__)
  29. __all__ = [
  30. 'load_checkpoint',
  31. 'load_and_fusebn',
  32. 'load_params',
  33. 'save',
  34. ]
  35. def is_url(path):
  36. """
  37. Whether path is URL.
  38. Args:
  39. path (string): URL string or not.
  40. """
  41. return path.startswith('http://') or path.startswith('https://')
  42. def _get_weight_path(path):
  43. env = os.environ
  44. if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
  45. trainer_id = int(env['PADDLE_TRAINER_ID'])
  46. num_trainers = int(env['PADDLE_TRAINERS_NUM'])
  47. if num_trainers <= 1:
  48. path = get_weights_path(path)
  49. else:
  50. from ppdet.utils.download import map_path, WEIGHTS_HOME
  51. weight_path = map_path(path, WEIGHTS_HOME)
  52. lock_path = weight_path + '.lock'
  53. if not os.path.exists(weight_path):
  54. try:
  55. os.makedirs(os.path.dirname(weight_path))
  56. except OSError as e:
  57. if e.errno != errno.EEXIST:
  58. raise
  59. with open(lock_path, 'w'): # touch
  60. os.utime(lock_path, None)
  61. if trainer_id == 0:
  62. get_weights_path(path)
  63. os.remove(lock_path)
  64. else:
  65. while os.path.exists(lock_path):
  66. time.sleep(1)
  67. path = weight_path
  68. else:
  69. path = get_weights_path(path)
  70. return path
  71. def _load_state(path):
  72. if os.path.exists(path + '.pdopt'):
  73. # XXX another hack to ignore the optimizer state
  74. tmp = tempfile.mkdtemp()
  75. dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))
  76. shutil.copy(path + '.pdparams', dst + '.pdparams')
  77. state = fluid.io.load_program_state(dst)
  78. shutil.rmtree(tmp)
  79. else:
  80. state = fluid.io.load_program_state(path)
  81. return state
  82. def _strip_postfix(path):
  83. path, ext = os.path.splitext(path)
  84. assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \
  85. "Unknown postfix {} from weights".format(ext)
  86. return path
  87. def load_params(exe, prog, path, ignore_params=[]):
  88. """
  89. Load model from the given path.
  90. Args:
  91. exe (fluid.Executor): The fluid.Executor object.
  92. prog (fluid.Program): load weight to which Program object.
  93. path (string): URL string or loca model path.
  94. ignore_params (list): ignore variable to load when finetuning.
  95. It can be specified by finetune_exclude_pretrained_params
  96. and the usage can refer to docs/advanced_tutorials/TRANSFER_LEARNING.md
  97. """
  98. if is_url(path):
  99. path = _get_weight_path(path)
  100. path = _strip_postfix(path)
  101. if not (os.path.isdir(path) or os.path.isfile(path) or
  102. os.path.exists(path + '.pdparams')):
  103. raise ValueError("Model pretrain path {} does not "
  104. "exists.".format(path))
  105. logger.debug('Loading parameters from {}...'.format(path))
  106. ignore_set = set()
  107. state = _load_state(path)
  108. # ignore the parameter which mismatch the shape
  109. # between the model and pretrain weight.
  110. all_var_shape = {}
  111. for block in prog.blocks:
  112. for param in block.all_parameters():
  113. all_var_shape[param.name] = param.shape
  114. ignore_set.update([
  115. name for name, shape in all_var_shape.items()
  116. if name in state and shape != state[name].shape
  117. ])
  118. if ignore_params:
  119. all_var_names = [var.name for var in prog.list_vars()]
  120. ignore_list = filter(
  121. lambda var: any([re.match(name, var) for name in ignore_params]),
  122. all_var_names)
  123. ignore_set.update(list(ignore_list))
  124. if len(ignore_set) > 0:
  125. for k in ignore_set:
  126. if k in state:
  127. logger.warning('variable {} not used'.format(k))
  128. del state[k]
  129. fluid.io.set_program_state(prog, state)
  130. def load_checkpoint(exe, prog, path):
  131. """
  132. Load model from the given path.
  133. Args:
  134. exe (fluid.Executor): The fluid.Executor object.
  135. prog (fluid.Program): load weight to which Program object.
  136. path (string): URL string or loca model path.
  137. """
  138. if is_url(path):
  139. path = _get_weight_path(path)
  140. path = _strip_postfix(path)
  141. if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
  142. raise ValueError("Model pretrain path {} does not "
  143. "exists.".format(path))
  144. fluid.load(prog, path, executor=exe)
  145. def global_step(scope=None):
  146. """
  147. Load global step in scope.
  148. Args:
  149. scope (fluid.Scope): load global step from which scope. If None,
  150. from default global_scope().
  151. Returns:
  152. global step: int.
  153. """
  154. if scope is None:
  155. scope = fluid.global_scope()
  156. v = scope.find_var('@LR_DECAY_COUNTER@')
  157. step = np.array(v.get_tensor())[0] if v else 0
  158. return step
  159. def save(exe, prog, path):
  160. """
  161. Load model from the given path.
  162. Args:
  163. exe (fluid.Executor): The fluid.Executor object.
  164. prog (fluid.Program): save weight from which Program object.
  165. path (string): the path to save model.
  166. """
  167. if os.path.isdir(path):
  168. shutil.rmtree(path)
  169. logger.info('Save model to {}.'.format(path))
  170. fluid.save(prog, path)
  171. def load_and_fusebn(exe, prog, path):
  172. """
  173. Fuse params of batch norm to scale and bias.
  174. Args:
  175. exe (fluid.Executor): The fluid.Executor object.
  176. prog (fluid.Program): save weight from which Program object.
  177. path (string): the path to save model.
  178. """
  179. logger.debug('Load model and fuse batch norm if have from {}...'.format(
  180. path))
  181. if is_url(path):
  182. path = _get_weight_path(path)
  183. if not os.path.exists(path):
  184. raise ValueError("Model path {} does not exists.".format(path))
  185. # Since the program uses affine-channel, there is no running mean and var
  186. # in the program, here append running mean and var.
  187. # NOTE, the params of batch norm should be like:
  188. # x_scale
  189. # x_offset
  190. # x_mean
  191. # x_variance
  192. # x is any prefix
  193. mean_variances = set()
  194. bn_vars = []
  195. state = _load_state(path)
  196. def check_mean_and_bias(prefix):
  197. m = prefix + 'mean'
  198. v = prefix + 'variance'
  199. return v in state and m in state
  200. has_mean_bias = True
  201. with fluid.program_guard(prog, fluid.Program()):
  202. for block in prog.blocks:
  203. ops = list(block.ops)
  204. if not has_mean_bias:
  205. break
  206. for op in ops:
  207. if op.type == 'affine_channel':
  208. # remove 'scale' as prefix
  209. scale_name = op.input('Scale')[0] # _scale
  210. bias_name = op.input('Bias')[0] # _offset
  211. prefix = scale_name[:-5]
  212. mean_name = prefix + 'mean'
  213. variance_name = prefix + 'variance'
  214. if not check_mean_and_bias(prefix):
  215. has_mean_bias = False
  216. break
  217. bias = block.var(bias_name)
  218. mean_vb = block.create_var(
  219. name=mean_name,
  220. type=bias.type,
  221. shape=bias.shape,
  222. dtype=bias.dtype)
  223. variance_vb = block.create_var(
  224. name=variance_name,
  225. type=bias.type,
  226. shape=bias.shape,
  227. dtype=bias.dtype)
  228. mean_variances.add(mean_vb)
  229. mean_variances.add(variance_vb)
  230. bn_vars.append(
  231. [scale_name, bias_name, mean_name, variance_name])
  232. if not has_mean_bias:
  233. fluid.io.set_program_state(prog, state)
  234. logger.warning(
  235. "There is no paramters of batch norm in model {}. "
  236. "Skip to fuse batch norm. And load paramters done.".format(path))
  237. return
  238. fluid.load(prog, path, exe)
  239. eps = 1e-5
  240. for names in bn_vars:
  241. scale_name, bias_name, mean_name, var_name = names
  242. scale = fluid.global_scope().find_var(scale_name).get_tensor()
  243. bias = fluid.global_scope().find_var(bias_name).get_tensor()
  244. mean = fluid.global_scope().find_var(mean_name).get_tensor()
  245. var = fluid.global_scope().find_var(var_name).get_tensor()
  246. scale_arr = np.array(scale)
  247. bias_arr = np.array(bias)
  248. mean_arr = np.array(mean)
  249. var_arr = np.array(var)
  250. bn_std = np.sqrt(np.add(var_arr, eps))
  251. new_scale = np.float32(np.divide(scale_arr, bn_std))
  252. new_bias = bias_arr - mean_arr * new_scale
  253. # fuse to scale and bias in affine_channel
  254. scale.set(new_scale, exe.place)
  255. bias.set(new_bias, exe.place)