mixed_precision.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import print_function
  16. import six
  17. from paddle.fluid.framework import Parameter
  18. from paddle.fluid import layers
  19. from paddle.fluid import core
  20. from paddle.fluid import unique_name
  21. import paddle.fluid.layer_helper_base as lhb
  22. import paddle.fluid.optimizer as optim
  23. __all__ = [
  24. 'mixed_precision_global_state', 'mixed_precision_context',
  25. 'StaticLossScale', 'DynamicLossScale'
  26. ]
  27. _mixed_precision_global_state = None
  28. def mixed_precision_global_state():
  29. return _mixed_precision_global_state
  30. class LossScale(object):
  31. def __init__(self):
  32. super(LossScale, self).__init__()
  33. def get_loss_scale_var(self):
  34. return self.scale
  35. def increment(self):
  36. raise NotImplementedError()
  37. def decrement(self):
  38. raise NotImplementedError()
  39. class StaticLossScale(LossScale):
  40. """
  41. Static (fixed) loss scale manager.
  42. Args:
  43. init_loss_scale (float): initial loss scale value.
  44. Examples:
  45. .. code-block:: python
  46. from paddle import fluid
  47. from ppdet.experimental import (mixed_precision_context,
  48. StaticLossScale)
  49. with mixed_precision_context(StaticLossScale(8.), True) as ctx:
  50. # ...
  51. # scale loss
  52. loss_scale = ctx.get_loss_scale_var()
  53. """
  54. def __init__(self, init_loss_scale=1.):
  55. super(StaticLossScale, self).__init__()
  56. self.scale = layers.create_global_var(
  57. name=unique_name.generate("loss_scale"),
  58. shape=[1],
  59. value=init_loss_scale,
  60. dtype='float32',
  61. persistable=True)
  62. class DynamicLossScale(LossScale):
  63. """
  64. Dynamic loss scale manager. it works as follows:
  65. if gradients is valid for `increment_every` steps, loss scale values is
  66. increased by `factor`, otherwise loss scale values is decreased by `factor`
  67. Args:
  68. init_loss_scale (float): initial loss scale value.
  69. increment_every (int): minimum 'good' steps before loss scale increase.
  70. factor (float): increase/decrease loss scale by this much.
  71. Examples:
  72. .. code-block:: python
  73. from paddle import fluid
  74. from ppdet.experimental import (mixed_precision_context,
  75. DynamicLossScale)
  76. loss_scale = DynamicLossScale(8., 1000, 4.)
  77. with mixed_precision_context(loss_scale, True) as ctx:
  78. # ...
  79. # scale loss
  80. loss_scale = ctx.get_loss_scale_var()
  81. """
  82. def __init__(self, init_loss_scale=2**15, increment_every=2000, factor=2.):
  83. super(DynamicLossScale, self).__init__()
  84. self.scale = layers.create_global_var(
  85. name=unique_name.generate("loss_scale"),
  86. shape=[1],
  87. value=init_loss_scale,
  88. dtype='float32',
  89. persistable=True)
  90. self.good_steps = layers.create_global_var(
  91. name=unique_name.generate("good_steps"),
  92. shape=[1],
  93. value=0,
  94. dtype='int32',
  95. persistable=True)
  96. self.increment_every = layers.fill_constant(
  97. shape=[1], dtype='int32', value=increment_every)
  98. self.factor = factor
  99. def increment(self):
  100. enough_steps = layers.less_than(self.increment_every,
  101. self.good_steps + 1)
  102. def increment_step():
  103. layers.increment(self.good_steps)
  104. def maybe_update():
  105. new_scale = self.scale * self.factor
  106. scale_valid = layers.isfinite(new_scale)
  107. def update_scale_and_step():
  108. layers.assign(new_scale, self.scale)
  109. layers.assign(
  110. layers.zeros_like(self.good_steps), self.good_steps)
  111. layers.cond(scale_valid, update_scale_and_step)
  112. layers.cond(enough_steps, maybe_update, increment_step)
  113. def decrement(self):
  114. new_scale = self.scale / self.factor
  115. one = layers.fill_constant(shape=[1], dtype='float32', value=1.0)
  116. layers.assign(layers.elementwise_max(new_scale, one), self.scale)
  117. layers.assign(layers.zeros_like(self.good_steps), self.good_steps)
  118. class mixed_precision_context(object):
  119. """
  120. Context manager for mixed precision training.
  121. Args:
  122. loss_scale (float, str or obj): loss scale settings, can be:
  123. 1. an number: use fixed loss scale.
  124. 2. 'dynamic': use a default `DynamicLossScale`.
  125. 3. `DynamicLossScale` or `StaticLossScale` instance.
  126. enabled (bool): enable mixed precision training.
  127. Examples:
  128. .. code-block:: python
  129. from paddle import fluid
  130. from ppdet.experimental import mixed_precision_context
  131. with mixed_precision_context('dynamic', True) as ctx:
  132. # cast inputs to float16
  133. inputs = fluid.layers.cast(inputs, "float16")
  134. # build model here
  135. logits = model(inputs)
  136. # use float32 for softmax
  137. logits = fluid.layers.cast(logits, "float32")
  138. softmax = fluid.layers.softmax(logits)
  139. loss = fluid.layers.cross_entropy(input=softmax, label=label)
  140. avg_loss = fluid.layers.mean(loss)
  141. # scale loss
  142. loss_scale = ctx.get_loss_scale_var()
  143. avg_loss *= loss_scale
  144. optimizer = fluid.optimizer.Momentum(...)
  145. optimizer.minimize(avg_loss)
  146. """
  147. def __init__(self, loss_scale=1., enabled=True):
  148. super(mixed_precision_context, self).__init__()
  149. self.enabled = enabled
  150. if not enabled:
  151. return
  152. monkey_patch()
  153. if isinstance(loss_scale, six.integer_types + (float, )):
  154. self.loss_scale = StaticLossScale(loss_scale)
  155. elif loss_scale == 'dynamic':
  156. self.loss_scale = DynamicLossScale()
  157. else:
  158. assert isinstance(loss_scale, LossScale), \
  159. "Invalid loss scale argument"
  160. self.loss_scale = loss_scale
  161. @property
  162. def dynamic_scaling(self):
  163. return isinstance(self.loss_scale, DynamicLossScale)
  164. def __getattr__(self, attr):
  165. if attr in ['get_loss_scale_var', 'increment', 'decrement']:
  166. return getattr(self.loss_scale, attr)
  167. def __enter__(self):
  168. if not self.enabled:
  169. return
  170. global _mixed_precision_global_state
  171. _mixed_precision_global_state = self
  172. return mixed_precision_global_state()
  173. def __exit__(self, *args):
  174. if not self.enabled:
  175. return
  176. global _mixed_precision_global_state
  177. _mixed_precision_global_state = None
  178. return mixed_precision_global_state()
  179. def create_parameter(self,
  180. attr,
  181. shape,
  182. dtype,
  183. is_bias=False,
  184. default_initializer=None):
  185. mp_state = mixed_precision_global_state()
  186. is_half = (isinstance(dtype, str) and dtype == 'float16') \
  187. or (isinstance(dtype, core.VarDesc.VarType)
  188. and dtype == core.VarDesc.VarType.FP16)
  189. if is_half and mp_state is not None:
  190. dtype = 'float32'
  191. param = self._create_parameter(attr, shape, dtype, is_bias,
  192. default_initializer)
  193. if not is_half or mp_state is None:
  194. return param
  195. param16 = self.main_program.current_block().create_var(
  196. name=param.name + '.fp16',
  197. dtype='float16',
  198. type=param.type,
  199. persistable=False)
  200. self.append_op(
  201. type='cast',
  202. inputs={'X': [param]},
  203. outputs={'Out': [param16]},
  204. attrs={'in_dtype': param.dtype,
  205. 'out_dtype': param16.dtype})
  206. return param16
  207. def scale_gradient(block, context):
  208. state = mixed_precision_global_state()
  209. if state is None:
  210. return
  211. scale = state.get_loss_scale_var()
  212. op_desc = block.desc.op(block.desc.op_size() - 1)
  213. op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
  214. bwd_role = core.op_proto_and_checker_maker.OpRole.Backward
  215. for name in [n for n in op_desc.output_arg_names() if n in context]:
  216. fwd_var = block._var_recursive(context[name])
  217. if not isinstance(fwd_var, Parameter):
  218. continue # TODO verify all use cases
  219. scale_op_desc = block.desc.append_op()
  220. scale_op_desc.set_type("elementwise_div")
  221. scale_op_desc.set_input("X", [name])
  222. scale_op_desc.set_input("Y", [scale.name])
  223. scale_op_desc.set_output("Out", [name])
  224. scale_op_desc._set_attr("axis", -1)
  225. scale_op_desc._set_attr(op_role_attr_name, bwd_role)
  226. def update_loss_scale(grads):
  227. state = mixed_precision_global_state()
  228. if state is None or not state.dynamic_scaling:
  229. return
  230. per_grad_check = layers.stack([layers.reduce_sum(g) for g in grads])
  231. grad_valid = layers.isfinite(per_grad_check)
  232. layers.cond(grad_valid, lambda: state.increment(),
  233. lambda: state.decrement())
  234. return grad_valid
  235. def backward(self, loss, **kwargs):
  236. state = mixed_precision_global_state()
  237. callbacks = 'callbacks' in kwargs and kwargs['callbacks'] or None
  238. if callbacks is None:
  239. from paddle.fluid.clip import error_clip_callback
  240. callbacks = [error_clip_callback] # XXX what if gradient is zero?
  241. if state is not None:
  242. kwargs['callbacks'] = [scale_gradient] + callbacks
  243. else:
  244. kwargs['callbacks'] = callbacks
  245. param_grads = self._backward(loss, **kwargs)
  246. def zero_grad():
  247. for _, g in param_grads:
  248. layers.assign(layers.zeros_like(g), g)
  249. if state is not None:
  250. grad_valid = update_loss_scale(v for k, v in param_grads)
  251. if state.dynamic_scaling:
  252. layers.cond(grad_valid, None, zero_grad)
  253. return param_grads
  254. mixed_precision_patched = False
  255. # XXX this is a temporary measure, until thoroughly evaluated
  256. def monkey_patch():
  257. global mixed_precision_patched
  258. if mixed_precision_patched:
  259. return
  260. create_parameter_orig = lhb.LayerHelperBase.create_parameter
  261. lhb.LayerHelperBase.create_parameter = create_parameter
  262. lhb.LayerHelperBase._create_parameter = create_parameter_orig
  263. backward_orig = optim.Optimizer.backward
  264. optim.Optimizer.backward = backward
  265. optim.Optimizer._backward = backward_orig
  266. mixed_precision_patched = True