123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333 |
- # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import print_function
- import six
- from paddle.fluid.framework import Parameter
- from paddle.fluid import layers
- from paddle.fluid import core
- from paddle.fluid import unique_name
- import paddle.fluid.layer_helper_base as lhb
- import paddle.fluid.optimizer as optim
- __all__ = [
- 'mixed_precision_global_state', 'mixed_precision_context',
- 'StaticLossScale', 'DynamicLossScale'
- ]
- _mixed_precision_global_state = None
- def mixed_precision_global_state():
- return _mixed_precision_global_state
- class LossScale(object):
- def __init__(self):
- super(LossScale, self).__init__()
- def get_loss_scale_var(self):
- return self.scale
- def increment(self):
- raise NotImplementedError()
- def decrement(self):
- raise NotImplementedError()
- class StaticLossScale(LossScale):
- """
- Static (fixed) loss scale manager.
- Args:
- init_loss_scale (float): initial loss scale value.
- Examples:
- .. code-block:: python
- from paddle import fluid
- from ppdet.experimental import (mixed_precision_context,
- StaticLossScale)
- with mixed_precision_context(StaticLossScale(8.), True) as ctx:
- # ...
- # scale loss
- loss_scale = ctx.get_loss_scale_var()
- """
- def __init__(self, init_loss_scale=1.):
- super(StaticLossScale, self).__init__()
- self.scale = layers.create_global_var(
- name=unique_name.generate("loss_scale"),
- shape=[1],
- value=init_loss_scale,
- dtype='float32',
- persistable=True)
- class DynamicLossScale(LossScale):
- """
- Dynamic loss scale manager. it works as follows:
- if gradients is valid for `increment_every` steps, loss scale values is
- increased by `factor`, otherwise loss scale values is decreased by `factor`
- Args:
- init_loss_scale (float): initial loss scale value.
- increment_every (int): minimum 'good' steps before loss scale increase.
- factor (float): increase/decrease loss scale by this much.
- Examples:
- .. code-block:: python
- from paddle import fluid
- from ppdet.experimental import (mixed_precision_context,
- DynamicLossScale)
- loss_scale = DynamicLossScale(8., 1000, 4.)
- with mixed_precision_context(loss_scale, True) as ctx:
- # ...
- # scale loss
- loss_scale = ctx.get_loss_scale_var()
- """
- def __init__(self, init_loss_scale=2**15, increment_every=2000, factor=2.):
- super(DynamicLossScale, self).__init__()
- self.scale = layers.create_global_var(
- name=unique_name.generate("loss_scale"),
- shape=[1],
- value=init_loss_scale,
- dtype='float32',
- persistable=True)
- self.good_steps = layers.create_global_var(
- name=unique_name.generate("good_steps"),
- shape=[1],
- value=0,
- dtype='int32',
- persistable=True)
- self.increment_every = layers.fill_constant(
- shape=[1], dtype='int32', value=increment_every)
- self.factor = factor
- def increment(self):
- enough_steps = layers.less_than(self.increment_every,
- self.good_steps + 1)
- def increment_step():
- layers.increment(self.good_steps)
- def maybe_update():
- new_scale = self.scale * self.factor
- scale_valid = layers.isfinite(new_scale)
- def update_scale_and_step():
- layers.assign(new_scale, self.scale)
- layers.assign(
- layers.zeros_like(self.good_steps), self.good_steps)
- layers.cond(scale_valid, update_scale_and_step)
- layers.cond(enough_steps, maybe_update, increment_step)
- def decrement(self):
- new_scale = self.scale / self.factor
- one = layers.fill_constant(shape=[1], dtype='float32', value=1.0)
- layers.assign(layers.elementwise_max(new_scale, one), self.scale)
- layers.assign(layers.zeros_like(self.good_steps), self.good_steps)
- class mixed_precision_context(object):
- """
- Context manager for mixed precision training.
- Args:
- loss_scale (float, str or obj): loss scale settings, can be:
- 1. an number: use fixed loss scale.
- 2. 'dynamic': use a default `DynamicLossScale`.
- 3. `DynamicLossScale` or `StaticLossScale` instance.
- enabled (bool): enable mixed precision training.
- Examples:
- .. code-block:: python
- from paddle import fluid
- from ppdet.experimental import mixed_precision_context
- with mixed_precision_context('dynamic', True) as ctx:
- # cast inputs to float16
- inputs = fluid.layers.cast(inputs, "float16")
- # build model here
- logits = model(inputs)
- # use float32 for softmax
- logits = fluid.layers.cast(logits, "float32")
- softmax = fluid.layers.softmax(logits)
- loss = fluid.layers.cross_entropy(input=softmax, label=label)
- avg_loss = fluid.layers.mean(loss)
- # scale loss
- loss_scale = ctx.get_loss_scale_var()
- avg_loss *= loss_scale
- optimizer = fluid.optimizer.Momentum(...)
- optimizer.minimize(avg_loss)
- """
- def __init__(self, loss_scale=1., enabled=True):
- super(mixed_precision_context, self).__init__()
- self.enabled = enabled
- if not enabled:
- return
- monkey_patch()
- if isinstance(loss_scale, six.integer_types + (float, )):
- self.loss_scale = StaticLossScale(loss_scale)
- elif loss_scale == 'dynamic':
- self.loss_scale = DynamicLossScale()
- else:
- assert isinstance(loss_scale, LossScale), \
- "Invalid loss scale argument"
- self.loss_scale = loss_scale
- @property
- def dynamic_scaling(self):
- return isinstance(self.loss_scale, DynamicLossScale)
- def __getattr__(self, attr):
- if attr in ['get_loss_scale_var', 'increment', 'decrement']:
- return getattr(self.loss_scale, attr)
- def __enter__(self):
- if not self.enabled:
- return
- global _mixed_precision_global_state
- _mixed_precision_global_state = self
- return mixed_precision_global_state()
- def __exit__(self, *args):
- if not self.enabled:
- return
- global _mixed_precision_global_state
- _mixed_precision_global_state = None
- return mixed_precision_global_state()
- def create_parameter(self,
- attr,
- shape,
- dtype,
- is_bias=False,
- default_initializer=None):
- mp_state = mixed_precision_global_state()
- is_half = (isinstance(dtype, str) and dtype == 'float16') \
- or (isinstance(dtype, core.VarDesc.VarType)
- and dtype == core.VarDesc.VarType.FP16)
- if is_half and mp_state is not None:
- dtype = 'float32'
- param = self._create_parameter(attr, shape, dtype, is_bias,
- default_initializer)
- if not is_half or mp_state is None:
- return param
- param16 = self.main_program.current_block().create_var(
- name=param.name + '.fp16',
- dtype='float16',
- type=param.type,
- persistable=False)
- self.append_op(
- type='cast',
- inputs={'X': [param]},
- outputs={'Out': [param16]},
- attrs={'in_dtype': param.dtype,
- 'out_dtype': param16.dtype})
- return param16
- def scale_gradient(block, context):
- state = mixed_precision_global_state()
- if state is None:
- return
- scale = state.get_loss_scale_var()
- op_desc = block.desc.op(block.desc.op_size() - 1)
- op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
- bwd_role = core.op_proto_and_checker_maker.OpRole.Backward
- for name in [n for n in op_desc.output_arg_names() if n in context]:
- fwd_var = block._var_recursive(context[name])
- if not isinstance(fwd_var, Parameter):
- continue # TODO verify all use cases
- scale_op_desc = block.desc.append_op()
- scale_op_desc.set_type("elementwise_div")
- scale_op_desc.set_input("X", [name])
- scale_op_desc.set_input("Y", [scale.name])
- scale_op_desc.set_output("Out", [name])
- scale_op_desc._set_attr("axis", -1)
- scale_op_desc._set_attr(op_role_attr_name, bwd_role)
- def update_loss_scale(grads):
- state = mixed_precision_global_state()
- if state is None or not state.dynamic_scaling:
- return
- per_grad_check = layers.stack([layers.reduce_sum(g) for g in grads])
- grad_valid = layers.isfinite(per_grad_check)
- layers.cond(grad_valid, lambda: state.increment(),
- lambda: state.decrement())
- return grad_valid
- def backward(self, loss, **kwargs):
- state = mixed_precision_global_state()
- callbacks = 'callbacks' in kwargs and kwargs['callbacks'] or None
- if callbacks is None:
- from paddle.fluid.clip import error_clip_callback
- callbacks = [error_clip_callback] # XXX what if gradient is zero?
- if state is not None:
- kwargs['callbacks'] = [scale_gradient] + callbacks
- else:
- kwargs['callbacks'] = callbacks
- param_grads = self._backward(loss, **kwargs)
- def zero_grad():
- for _, g in param_grads:
- layers.assign(layers.zeros_like(g), g)
- if state is not None:
- grad_valid = update_loss_scale(v for k, v in param_grads)
- if state.dynamic_scaling:
- layers.cond(grad_valid, None, zero_grad)
- return param_grads
- mixed_precision_patched = False
- # XXX this is a temporary measure, until thoroughly evaluated
- def monkey_patch():
- global mixed_precision_patched
- if mixed_precision_patched:
- return
- create_parameter_orig = lhb.LayerHelperBase.create_parameter
- lhb.LayerHelperBase.create_parameter = create_parameter
- lhb.LayerHelperBase._create_parameter = create_parameter_orig
- backward_orig = optim.Optimizer.backward
- optim.Optimizer.backward = backward
- optim.Optimizer._backward = backward_orig
- mixed_precision_patched = True
|