optimizer.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import sys
  18. import math
  19. import weakref
  20. import paddle
  21. import paddle.nn as nn
  22. import paddle.optimizer as optimizer
  23. import paddle.regularizer as regularizer
  24. from ppdet.core.workspace import register, serializable
  25. import copy
  26. __all__ = ['LearningRate', 'OptimizerBuilder']
  27. from ppdet.utils.logger import setup_logger
  28. logger = setup_logger(__name__)
  29. @serializable
  30. class CosineDecay(object):
  31. """
  32. Cosine learning rate decay
  33. Args:
  34. max_epochs (int): max epochs for the training process.
  35. if you commbine cosine decay with warmup, it is recommended that
  36. the max_iters is much larger than the warmup iter
  37. use_warmup (bool): whether to use warmup. Default: True.
  38. min_lr_ratio (float): minimum learning rate ratio. Default: 0.
  39. last_plateau_epochs (int): use minimum learning rate in
  40. the last few epochs. Default: 0.
  41. """
  42. def __init__(self,
  43. max_epochs=1000,
  44. use_warmup=True,
  45. min_lr_ratio=0.,
  46. last_plateau_epochs=0):
  47. self.max_epochs = max_epochs
  48. self.use_warmup = use_warmup
  49. self.min_lr_ratio = min_lr_ratio
  50. self.last_plateau_epochs = last_plateau_epochs
  51. def __call__(self,
  52. base_lr=None,
  53. boundary=None,
  54. value=None,
  55. step_per_epoch=None):
  56. assert base_lr is not None, "either base LR or values should be provided"
  57. max_iters = self.max_epochs * int(step_per_epoch)
  58. last_plateau_iters = self.last_plateau_epochs * int(step_per_epoch)
  59. min_lr = base_lr * self.min_lr_ratio
  60. if boundary is not None and value is not None and self.use_warmup:
  61. # use warmup
  62. warmup_iters = len(boundary)
  63. for i in range(int(boundary[-1]), max_iters):
  64. boundary.append(i)
  65. if i < max_iters - last_plateau_iters:
  66. decayed_lr = min_lr + (base_lr - min_lr) * 0.5 * (math.cos(
  67. (i - warmup_iters) * math.pi /
  68. (max_iters - warmup_iters - last_plateau_iters)) + 1)
  69. value.append(decayed_lr)
  70. else:
  71. value.append(min_lr)
  72. return optimizer.lr.PiecewiseDecay(boundary, value)
  73. elif last_plateau_iters > 0:
  74. # not use warmup, but set `last_plateau_epochs` > 0
  75. boundary = []
  76. value = []
  77. for i in range(max_iters):
  78. if i < max_iters - last_plateau_iters:
  79. decayed_lr = min_lr + (base_lr - min_lr) * 0.5 * (math.cos(
  80. i * math.pi / (max_iters - last_plateau_iters)) + 1)
  81. value.append(decayed_lr)
  82. else:
  83. value.append(min_lr)
  84. if i > 0:
  85. boundary.append(i)
  86. return optimizer.lr.PiecewiseDecay(boundary, value)
  87. return optimizer.lr.CosineAnnealingDecay(
  88. base_lr, T_max=max_iters, eta_min=min_lr)
  89. @serializable
  90. class PiecewiseDecay(object):
  91. """
  92. Multi step learning rate decay
  93. Args:
  94. gamma (float | list): decay factor
  95. milestones (list): steps at which to decay learning rate
  96. """
  97. def __init__(self,
  98. gamma=[0.1, 0.01],
  99. milestones=[8, 11],
  100. values=None,
  101. use_warmup=True):
  102. super(PiecewiseDecay, self).__init__()
  103. if type(gamma) is not list:
  104. self.gamma = []
  105. for i in range(len(milestones)):
  106. self.gamma.append(gamma / 10**i)
  107. else:
  108. self.gamma = gamma
  109. self.milestones = milestones
  110. self.values = values
  111. self.use_warmup = use_warmup
  112. def __call__(self,
  113. base_lr=None,
  114. boundary=None,
  115. value=None,
  116. step_per_epoch=None):
  117. if boundary is not None and self.use_warmup:
  118. boundary.extend([int(step_per_epoch) * i for i in self.milestones])
  119. else:
  120. # do not use LinearWarmup
  121. boundary = [int(step_per_epoch) * i for i in self.milestones]
  122. value = [base_lr] # during step[0, boundary[0]] is base_lr
  123. # self.values is setted directly in config
  124. if self.values is not None:
  125. assert len(self.milestones) + 1 == len(self.values)
  126. return optimizer.lr.PiecewiseDecay(boundary, self.values)
  127. # value is computed by self.gamma
  128. value = value if value is not None else [base_lr]
  129. for i in self.gamma:
  130. value.append(base_lr * i)
  131. return optimizer.lr.PiecewiseDecay(boundary, value)
  132. @serializable
  133. class LinearWarmup(object):
  134. """
  135. Warm up learning rate linearly
  136. Args:
  137. steps (int): warm up steps
  138. start_factor (float): initial learning rate factor
  139. epochs (int|None): use epochs as warm up steps, the priority
  140. of `epochs` is higher than `steps`. Default: None.
  141. """
  142. def __init__(self, steps=500, start_factor=1. / 3, epochs=None):
  143. super(LinearWarmup, self).__init__()
  144. self.steps = steps
  145. self.start_factor = start_factor
  146. self.epochs = epochs
  147. def __call__(self, base_lr, step_per_epoch):
  148. boundary = []
  149. value = []
  150. warmup_steps = self.epochs * step_per_epoch \
  151. if self.epochs is not None else self.steps
  152. for i in range(warmup_steps + 1):
  153. if warmup_steps > 0:
  154. alpha = i / warmup_steps
  155. factor = self.start_factor * (1 - alpha) + alpha
  156. lr = base_lr * factor
  157. value.append(lr)
  158. if i > 0:
  159. boundary.append(i)
  160. return boundary, value
  161. @serializable
  162. class BurninWarmup(object):
  163. """
  164. Warm up learning rate in burnin mode
  165. Args:
  166. steps (int): warm up steps
  167. """
  168. def __init__(self, steps=1000):
  169. super(BurninWarmup, self).__init__()
  170. self.steps = steps
  171. def __call__(self, base_lr, step_per_epoch):
  172. boundary = []
  173. value = []
  174. burnin = min(self.steps, step_per_epoch)
  175. for i in range(burnin + 1):
  176. factor = (i * 1.0 / burnin)**4
  177. lr = base_lr * factor
  178. value.append(lr)
  179. if i > 0:
  180. boundary.append(i)
  181. return boundary, value
  182. @serializable
  183. class ExpWarmup(object):
  184. """
  185. Warm up learning rate in exponential mode
  186. Args:
  187. steps (int): warm up steps.
  188. epochs (int|None): use epochs as warm up steps, the priority
  189. of `epochs` is higher than `steps`. Default: None.
  190. """
  191. def __init__(self, steps=5, epochs=None):
  192. super(ExpWarmup, self).__init__()
  193. self.steps = steps
  194. self.epochs = epochs
  195. def __call__(self, base_lr, step_per_epoch):
  196. boundary = []
  197. value = []
  198. warmup_steps = self.epochs * step_per_epoch if self.epochs is not None else self.steps
  199. for i in range(warmup_steps + 1):
  200. factor = (i / float(warmup_steps))**2
  201. value.append(base_lr * factor)
  202. if i > 0:
  203. boundary.append(i)
  204. return boundary, value
  205. @register
  206. class LearningRate(object):
  207. """
  208. Learning Rate configuration
  209. Args:
  210. base_lr (float): base learning rate
  211. schedulers (list): learning rate schedulers
  212. """
  213. __category__ = 'optim'
  214. def __init__(self,
  215. base_lr=0.01,
  216. schedulers=[PiecewiseDecay(), LinearWarmup()]):
  217. super(LearningRate, self).__init__()
  218. self.base_lr = base_lr
  219. self.schedulers = []
  220. schedulers = copy.deepcopy(schedulers)
  221. for sched in schedulers:
  222. if isinstance(sched, dict):
  223. # support dict sched instantiate
  224. module = sys.modules[__name__]
  225. type = sched.pop("name")
  226. scheduler = getattr(module, type)(**sched)
  227. self.schedulers.append(scheduler)
  228. else:
  229. self.schedulers.append(sched)
  230. def __call__(self, step_per_epoch):
  231. assert len(self.schedulers) >= 1
  232. if not self.schedulers[0].use_warmup:
  233. return self.schedulers[0](base_lr=self.base_lr,
  234. step_per_epoch=step_per_epoch)
  235. # TODO: split warmup & decay
  236. # warmup
  237. boundary, value = self.schedulers[1](self.base_lr, step_per_epoch)
  238. # decay
  239. decay_lr = self.schedulers[0](self.base_lr, boundary, value,
  240. step_per_epoch)
  241. return decay_lr
  242. @register
  243. class OptimizerBuilder():
  244. """
  245. Build optimizer handles
  246. Args:
  247. regularizer (object): an `Regularizer` instance
  248. optimizer (object): an `Optimizer` instance
  249. """
  250. __category__ = 'optim'
  251. def __init__(self,
  252. clip_grad_by_norm=None,
  253. regularizer={'type': 'L2',
  254. 'factor': .0001},
  255. optimizer={'type': 'Momentum',
  256. 'momentum': .9}):
  257. self.clip_grad_by_norm = clip_grad_by_norm
  258. self.regularizer = regularizer
  259. self.optimizer = optimizer
  260. def __call__(self, learning_rate, model=None):
  261. if self.clip_grad_by_norm is not None:
  262. grad_clip = nn.ClipGradByGlobalNorm(
  263. clip_norm=self.clip_grad_by_norm)
  264. else:
  265. grad_clip = None
  266. if self.regularizer and self.regularizer != 'None':
  267. reg_type = self.regularizer['type'] + 'Decay'
  268. reg_factor = self.regularizer['factor']
  269. regularization = getattr(regularizer, reg_type)(reg_factor)
  270. else:
  271. regularization = None
  272. optim_args = self.optimizer.copy()
  273. optim_type = optim_args['type']
  274. del optim_args['type']
  275. if optim_type != 'AdamW':
  276. optim_args['weight_decay'] = regularization
  277. op = getattr(optimizer, optim_type)
  278. if 'param_groups' in optim_args:
  279. assert isinstance(optim_args['param_groups'], list), ''
  280. param_groups = optim_args.pop('param_groups')
  281. params, visited = [], []
  282. for group in param_groups:
  283. assert isinstance(group,
  284. dict) and 'params' in group and isinstance(
  285. group['params'], list), ''
  286. _params = {
  287. n: p
  288. for n, p in model.named_parameters()
  289. if any([k in n for k in group['params']])
  290. }
  291. _group = group.copy()
  292. _group.update({'params': list(_params.values())})
  293. params.append(_group)
  294. visited.extend(list(_params.keys()))
  295. ext_params = [
  296. p for n, p in model.named_parameters() if n not in visited
  297. ]
  298. if len(ext_params) < len(model.parameters()):
  299. params.append({'params': ext_params})
  300. elif len(ext_params) > len(model.parameters()):
  301. raise RuntimeError
  302. else:
  303. params = model.parameters()
  304. return op(learning_rate=learning_rate,
  305. parameters=params,
  306. grad_clip=grad_clip,
  307. **optim_args)
  308. class ModelEMA(object):
  309. """
  310. Exponential Weighted Average for Deep Neutal Networks
  311. Args:
  312. model (nn.Layer): Detector of model.
  313. decay (int): The decay used for updating ema parameter.
  314. Ema's parameter are updated with the formula:
  315. `ema_param = decay * ema_param + (1 - decay) * cur_param`.
  316. Defaults is 0.9998.
  317. ema_decay_type (str): type in ['threshold', 'normal', 'exponential'],
  318. 'threshold' as default.
  319. cycle_epoch (int): The epoch of interval to reset ema_param and
  320. step. Defaults is -1, which means not reset. Its function is to
  321. add a regular effect to ema, which is set according to experience
  322. and is effective when the total training epoch is large.
  323. """
  324. def __init__(self,
  325. model,
  326. decay=0.9998,
  327. ema_decay_type='threshold',
  328. cycle_epoch=-1):
  329. self.step = 0
  330. self.epoch = 0
  331. self.decay = decay
  332. self.state_dict = dict()
  333. for k, v in model.state_dict().items():
  334. self.state_dict[k] = paddle.zeros_like(v)
  335. self.ema_decay_type = ema_decay_type
  336. self.cycle_epoch = cycle_epoch
  337. self._model_state = {
  338. k: weakref.ref(p)
  339. for k, p in model.state_dict().items()
  340. }
  341. def reset(self):
  342. self.step = 0
  343. self.epoch = 0
  344. for k, v in self.state_dict.items():
  345. self.state_dict[k] = paddle.zeros_like(v)
  346. def resume(self, state_dict, step=0):
  347. for k, v in state_dict.items():
  348. if k in self.state_dict:
  349. self.state_dict[k] = v
  350. self.step = step
  351. def update(self, model=None):
  352. if self.ema_decay_type == 'threshold':
  353. decay = min(self.decay, (1 + self.step) / (10 + self.step))
  354. elif self.ema_decay_type == 'exponential':
  355. decay = self.decay * (1 - math.exp(-(self.step + 1) / 2000))
  356. else:
  357. decay = self.decay
  358. self._decay = decay
  359. if model is not None:
  360. model_dict = model.state_dict()
  361. else:
  362. model_dict = {k: p() for k, p in self._model_state.items()}
  363. assert all(
  364. [v is not None for _, v in model_dict.items()]), 'python gc.'
  365. for k, v in self.state_dict.items():
  366. v = decay * v + (1 - decay) * model_dict[k]
  367. v.stop_gradient = True
  368. self.state_dict[k] = v
  369. self.step += 1
  370. def apply(self):
  371. if self.step == 0:
  372. return self.state_dict
  373. state_dict = dict()
  374. for k, v in self.state_dict.items():
  375. if self.ema_decay_type != 'exponential':
  376. v = v / (1 - self._decay**self.step)
  377. v.stop_gradient = True
  378. state_dict[k] = v
  379. self.epoch += 1
  380. if self.cycle_epoch > 0 and self.epoch == self.cycle_epoch:
  381. self.reset()
  382. return state_dict