3
0

optimizer.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import math
  18. import logging
  19. from paddle import fluid
  20. import paddle.fluid.optimizer as optimizer
  21. import paddle.fluid.regularizer as regularizer
  22. from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
  23. from paddle.fluid.layers.ops import cos
  24. from ppdet.core.workspace import register, serializable
  25. __all__ = ['LearningRate', 'OptimizerBuilder']
  26. logger = logging.getLogger(__name__)
  27. @serializable
  28. class PiecewiseDecay(object):
  29. """
  30. Multi step learning rate decay
  31. Args:
  32. gamma (float | list): decay factor
  33. milestones (list): steps at which to decay learning rate
  34. """
  35. def __init__(self, gamma=[0.1, 0.1], milestones=[60000, 80000],
  36. values=None):
  37. super(PiecewiseDecay, self).__init__()
  38. if type(gamma) is not list:
  39. self.gamma = []
  40. for i in range(len(milestones)):
  41. self.gamma.append(gamma / 10**i)
  42. else:
  43. self.gamma = gamma
  44. self.milestones = milestones
  45. self.values = values
  46. def __call__(self, base_lr=None, learning_rate=None):
  47. if self.values is not None:
  48. return fluid.layers.piecewise_decay(self.milestones, self.values)
  49. assert base_lr is not None, "either base LR or values should be provided"
  50. values = [base_lr]
  51. for g in self.gamma:
  52. new_lr = base_lr * g
  53. values.append(new_lr)
  54. return fluid.layers.piecewise_decay(self.milestones, values)
  55. @serializable
  56. class PolynomialDecay(object):
  57. """
  58. Applies polynomial decay to the initial learning rate.
  59. Args:
  60. max_iter (int): The learning rate decay steps.
  61. end_lr (float): End learning rate.
  62. power (float): Polynomial attenuation coefficient
  63. """
  64. def __init__(self, max_iter=180000, end_lr=0.0001, power=1.0):
  65. super(PolynomialDecay).__init__()
  66. self.max_iter = max_iter
  67. self.end_lr = end_lr
  68. self.power = power
  69. def __call__(self, base_lr=None, learning_rate=None):
  70. assert base_lr is not None, "either base LR or values should be provided"
  71. lr = fluid.layers.polynomial_decay(base_lr, self.max_iter, self.end_lr,
  72. self.power)
  73. return lr
  74. @serializable
  75. class ExponentialDecay(object):
  76. """
  77. Applies exponential decay to the learning rate.
  78. Args:
  79. max_iter (int): The learning rate decay steps.
  80. decay_rate (float): The learning rate decay rate.
  81. """
  82. def __init__(self, max_iter, decay_rate):
  83. super(ExponentialDecay).__init__()
  84. self.max_iter = max_iter
  85. self.decay_rate = decay_rate
  86. def __call__(self, base_lr=None, learning_rate=None):
  87. assert base_lr is not None, "either base LR or values should be provided"
  88. lr = fluid.layers.exponential_decay(base_lr, self.max_iter,
  89. self.decay_rate)
  90. return lr
  91. @serializable
  92. class CosineDecay(object):
  93. """
  94. Cosine learning rate decay
  95. Args:
  96. max_iters (float): max iterations for the training process.
  97. if you commbine cosine decay with warmup, it is recommended that
  98. the max_iter is much larger than the warmup iter
  99. """
  100. def __init__(self, max_iters=180000):
  101. self.max_iters = max_iters
  102. def __call__(self, base_lr=None, learning_rate=None):
  103. assert base_lr is not None, "either base LR or values should be provided"
  104. lr = fluid.layers.cosine_decay(base_lr, 1, self.max_iters)
  105. return lr
  106. @serializable
  107. class CosineDecayWithSkip(object):
  108. """
  109. Cosine decay, with explicit support for warm up
  110. Args:
  111. total_steps (int): total steps over which to apply the decay
  112. skip_steps (int): skip some steps at the beginning, e.g., warm up
  113. """
  114. def __init__(self, total_steps, skip_steps=None):
  115. super(CosineDecayWithSkip, self).__init__()
  116. assert (not skip_steps or skip_steps > 0), \
  117. "skip steps must be greater than zero"
  118. assert total_steps > 0, "total step must be greater than zero"
  119. assert (not skip_steps or skip_steps < total_steps), \
  120. "skip steps must be smaller than total steps"
  121. self.total_steps = total_steps
  122. self.skip_steps = skip_steps
  123. def __call__(self, base_lr=None, learning_rate=None):
  124. steps = _decay_step_counter()
  125. total = self.total_steps
  126. if self.skip_steps is not None:
  127. total -= self.skip_steps
  128. lr = fluid.layers.tensor.create_global_var(
  129. shape=[1],
  130. value=base_lr,
  131. dtype='float32',
  132. persistable=True,
  133. name="learning_rate")
  134. def decay():
  135. cos_lr = base_lr * .5 * (cos(steps * (math.pi / total)) + 1)
  136. fluid.layers.tensor.assign(input=cos_lr, output=lr)
  137. if self.skip_steps is None:
  138. decay()
  139. else:
  140. skipped = steps >= self.skip_steps
  141. fluid.layers.cond(skipped, decay)
  142. return lr
  143. @serializable
  144. class LinearWarmup(object):
  145. """
  146. Warm up learning rate linearly
  147. Args:
  148. steps (int): warm up steps
  149. start_factor (float): initial learning rate factor
  150. """
  151. def __init__(self, steps=500, start_factor=1. / 3):
  152. super(LinearWarmup, self).__init__()
  153. self.steps = steps
  154. self.start_factor = start_factor
  155. def __call__(self, base_lr, learning_rate):
  156. start_lr = base_lr * self.start_factor
  157. return fluid.layers.linear_lr_warmup(
  158. learning_rate=learning_rate,
  159. warmup_steps=self.steps,
  160. start_lr=start_lr,
  161. end_lr=base_lr)
  162. @register
  163. class LearningRate(object):
  164. """
  165. Learning Rate configuration
  166. Args:
  167. base_lr (float): base learning rate
  168. schedulers (list): learning rate schedulers
  169. """
  170. __category__ = 'optim'
  171. def __init__(self,
  172. base_lr=0.01,
  173. schedulers=[PiecewiseDecay(), LinearWarmup()]):
  174. super(LearningRate, self).__init__()
  175. self.base_lr = base_lr
  176. self.schedulers = schedulers
  177. def __call__(self):
  178. lr = None
  179. for sched in self.schedulers:
  180. lr = sched(self.base_lr, lr)
  181. return lr
  182. @register
  183. class OptimizerBuilder():
  184. """
  185. Build optimizer handles
  186. Args:
  187. regularizer (object): an `Regularizer` instance
  188. optimizer (object): an `Optimizer` instance
  189. """
  190. __category__ = 'optim'
  191. def __init__(self,
  192. clip_grad_by_norm=None,
  193. regularizer={'type': 'L2',
  194. 'factor': .0001},
  195. optimizer={'type': 'Momentum',
  196. 'momentum': .9}):
  197. self.clip_grad_by_norm = clip_grad_by_norm
  198. self.regularizer = regularizer
  199. self.optimizer = optimizer
  200. def __call__(self, learning_rate):
  201. if self.clip_grad_by_norm is not None:
  202. fluid.clip.set_gradient_clip(
  203. clip=fluid.clip.GradientClipByGlobalNorm(
  204. clip_norm=self.clip_grad_by_norm))
  205. if self.regularizer:
  206. reg_type = self.regularizer['type'] + 'Decay'
  207. reg_factor = self.regularizer['factor']
  208. regularization = getattr(regularizer, reg_type)(reg_factor)
  209. else:
  210. regularization = None
  211. optim_args = self.optimizer.copy()
  212. optim_type = optim_args['type']
  213. del optim_args['type']
  214. op = getattr(optimizer, optim_type)
  215. return op(learning_rate=learning_rate,
  216. regularization=regularization,
  217. **optim_args)