lr_scheduler.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. import math
  5. from functools import partial
  6. class LRScheduler:
  7. def __init__(self, name, lr, iters_per_epoch, total_epochs, **kwargs):
  8. """
  9. Supported lr schedulers: [cos, warmcos, multistep]
  10. Args:
  11. lr (float): learning rate.
  12. iters_per_peoch (int): number of iterations in one epoch.
  13. total_epochs (int): number of epochs in training.
  14. kwargs (dict):
  15. - cos: None
  16. - warmcos: [warmup_epochs, warmup_lr_start (default 1e-6)]
  17. - multistep: [milestones (epochs), gamma (default 0.1)]
  18. """
  19. self.lr = lr
  20. self.iters_per_epoch = iters_per_epoch
  21. self.total_epochs = total_epochs
  22. self.total_iters = iters_per_epoch * total_epochs
  23. self.__dict__.update(kwargs)
  24. self.lr_func = self._get_lr_func(name)
  25. def update_lr(self, iters):
  26. return self.lr_func(iters)
  27. def _get_lr_func(self, name):
  28. if name == "cos": # cosine lr schedule
  29. lr_func = partial(cos_lr, self.lr, self.total_iters)
  30. elif name == "warmcos":
  31. warmup_total_iters = self.iters_per_epoch * self.warmup_epochs
  32. warmup_lr_start = getattr(self, "warmup_lr_start", 1e-6)
  33. lr_func = partial(
  34. warm_cos_lr,
  35. self.lr,
  36. self.total_iters,
  37. warmup_total_iters,
  38. warmup_lr_start,
  39. )
  40. elif name == "yoloxwarmcos":
  41. warmup_total_iters = self.iters_per_epoch * self.warmup_epochs
  42. no_aug_iters = self.iters_per_epoch * self.no_aug_epochs
  43. warmup_lr_start = getattr(self, "warmup_lr_start", 0)
  44. min_lr_ratio = getattr(self, "min_lr_ratio", 0.2)
  45. lr_func = partial(
  46. yolox_warm_cos_lr,
  47. self.lr,
  48. min_lr_ratio,
  49. self.total_iters,
  50. warmup_total_iters,
  51. warmup_lr_start,
  52. no_aug_iters,
  53. )
  54. elif name == "yoloxsemiwarmcos":
  55. warmup_lr_start = getattr(self, "warmup_lr_start", 0)
  56. min_lr_ratio = getattr(self, "min_lr_ratio", 0.2)
  57. warmup_total_iters = self.iters_per_epoch * self.warmup_epochs
  58. no_aug_iters = self.iters_per_epoch * self.no_aug_epochs
  59. normal_iters = self.iters_per_epoch * self.semi_epoch
  60. semi_iters = self.iters_per_epoch_semi * (
  61. self.total_epochs - self.semi_epoch - self.no_aug_epochs
  62. )
  63. lr_func = partial(
  64. yolox_semi_warm_cos_lr,
  65. self.lr,
  66. min_lr_ratio,
  67. warmup_lr_start,
  68. self.total_iters,
  69. normal_iters,
  70. no_aug_iters,
  71. warmup_total_iters,
  72. semi_iters,
  73. self.iters_per_epoch,
  74. self.iters_per_epoch_semi,
  75. )
  76. elif name == "multistep": # stepwise lr schedule
  77. milestones = [
  78. int(self.total_iters * milestone / self.total_epochs)
  79. for milestone in self.milestones
  80. ]
  81. gamma = getattr(self, "gamma", 0.1)
  82. lr_func = partial(multistep_lr, self.lr, milestones, gamma)
  83. else:
  84. raise ValueError("Scheduler version {} not supported.".format(name))
  85. return lr_func
  86. def cos_lr(lr, total_iters, iters):
  87. """Cosine learning rate"""
  88. lr *= 0.5 * (1.0 + math.cos(math.pi * iters / total_iters))
  89. return lr
  90. def warm_cos_lr(lr, total_iters, warmup_total_iters, warmup_lr_start, iters):
  91. """Cosine learning rate with warm up."""
  92. if iters <= warmup_total_iters:
  93. lr = (lr - warmup_lr_start) * iters / float(
  94. warmup_total_iters
  95. ) + warmup_lr_start
  96. else:
  97. lr *= 0.5 * (
  98. 1.0
  99. + math.cos(
  100. math.pi
  101. * (iters - warmup_total_iters)
  102. / (total_iters - warmup_total_iters)
  103. )
  104. )
  105. return lr
  106. def yolox_warm_cos_lr(
  107. lr,
  108. min_lr_ratio,
  109. total_iters,
  110. warmup_total_iters,
  111. warmup_lr_start,
  112. no_aug_iter,
  113. iters,
  114. ):
  115. """Cosine learning rate with warm up."""
  116. min_lr = lr * min_lr_ratio
  117. if iters <= warmup_total_iters:
  118. # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
  119. lr = (lr - warmup_lr_start) * pow(
  120. iters / float(warmup_total_iters), 2
  121. ) + warmup_lr_start
  122. elif iters >= total_iters - no_aug_iter:
  123. lr = min_lr
  124. else:
  125. lr = min_lr + 0.5 * (lr - min_lr) * (
  126. 1.0
  127. + math.cos(
  128. math.pi
  129. * (iters - warmup_total_iters)
  130. / (total_iters - warmup_total_iters - no_aug_iter)
  131. )
  132. )
  133. return lr
  134. def yolox_semi_warm_cos_lr(
  135. lr,
  136. min_lr_ratio,
  137. warmup_lr_start,
  138. total_iters,
  139. normal_iters,
  140. no_aug_iters,
  141. warmup_total_iters,
  142. semi_iters,
  143. iters_per_epoch,
  144. iters_per_epoch_semi,
  145. iters,
  146. ):
  147. """Cosine learning rate with warm up."""
  148. min_lr = lr * min_lr_ratio
  149. if iters <= warmup_total_iters:
  150. # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
  151. lr = (lr - warmup_lr_start) * pow(
  152. iters / float(warmup_total_iters), 2
  153. ) + warmup_lr_start
  154. elif iters >= normal_iters + semi_iters:
  155. lr = min_lr
  156. elif iters <= normal_iters:
  157. lr = min_lr + 0.5 * (lr - min_lr) * (
  158. 1.0
  159. + math.cos(
  160. math.pi
  161. * (iters - warmup_total_iters)
  162. / (total_iters - warmup_total_iters - no_aug_iters)
  163. )
  164. )
  165. else:
  166. lr = min_lr + 0.5 * (lr - min_lr) * (
  167. 1.0
  168. + math.cos(
  169. math.pi
  170. * (
  171. normal_iters
  172. - warmup_total_iters
  173. + (iters - normal_iters)
  174. * iters_per_epoch
  175. * 1.0
  176. / iters_per_epoch_semi
  177. )
  178. / (total_iters - warmup_total_iters - no_aug_iters)
  179. )
  180. )
  181. return lr
  182. def multistep_lr(lr, milestones, gamma, iters):
  183. """MultiStep learning rate"""
  184. for milestone in milestones:
  185. lr *= gamma if iters >= milestone else 1.0
  186. return lr