lr_scheduler.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # --------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ze Liu
  6. # --------------------------------------------------------
  7. import bisect
  8. import torch
  9. from timm.scheduler.cosine_lr import CosineLRScheduler
  10. from timm.scheduler.step_lr import StepLRScheduler
  11. from timm.scheduler.scheduler import Scheduler
  12. def build_scheduler(config, optimizer, n_iter_per_epoch):
  13. num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
  14. warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
  15. decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
  16. multi_steps = [i * n_iter_per_epoch for i in config.TRAIN.LR_SCHEDULER.MULTISTEPS]
  17. lr_scheduler = None
  18. if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
  19. lr_scheduler = CosineLRScheduler(
  20. optimizer,
  21. t_initial=(num_steps - warmup_steps) if config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX else num_steps,
  22. t_mul=1.,
  23. lr_min=config.TRAIN.MIN_LR,
  24. warmup_lr_init=config.TRAIN.WARMUP_LR,
  25. warmup_t=warmup_steps,
  26. cycle_limit=1,
  27. t_in_epochs=False,
  28. warmup_prefix=config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX,
  29. )
  30. elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
  31. lr_scheduler = LinearLRScheduler(
  32. optimizer,
  33. t_initial=num_steps,
  34. lr_min_rate=0.01,
  35. warmup_lr_init=config.TRAIN.WARMUP_LR,
  36. warmup_t=warmup_steps,
  37. t_in_epochs=False,
  38. )
  39. elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
  40. lr_scheduler = StepLRScheduler(
  41. optimizer,
  42. decay_t=decay_steps,
  43. decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
  44. warmup_lr_init=config.TRAIN.WARMUP_LR,
  45. warmup_t=warmup_steps,
  46. t_in_epochs=False,
  47. )
  48. elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep':
  49. lr_scheduler = MultiStepLRScheduler(
  50. optimizer,
  51. milestones=multi_steps,
  52. gamma=config.TRAIN.LR_SCHEDULER.GAMMA,
  53. warmup_lr_init=config.TRAIN.WARMUP_LR,
  54. warmup_t=warmup_steps,
  55. t_in_epochs=False,
  56. )
  57. return lr_scheduler
  58. class LinearLRScheduler(Scheduler):
  59. def __init__(self,
  60. optimizer: torch.optim.Optimizer,
  61. t_initial: int,
  62. lr_min_rate: float,
  63. warmup_t=0,
  64. warmup_lr_init=0.,
  65. t_in_epochs=True,
  66. noise_range_t=None,
  67. noise_pct=0.67,
  68. noise_std=1.0,
  69. noise_seed=42,
  70. initialize=True,
  71. ) -> None:
  72. super().__init__(
  73. optimizer, param_group_field="lr",
  74. noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
  75. initialize=initialize)
  76. self.t_initial = t_initial
  77. self.lr_min_rate = lr_min_rate
  78. self.warmup_t = warmup_t
  79. self.warmup_lr_init = warmup_lr_init
  80. self.t_in_epochs = t_in_epochs
  81. if self.warmup_t:
  82. self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
  83. super().update_groups(self.warmup_lr_init)
  84. else:
  85. self.warmup_steps = [1 for _ in self.base_values]
  86. def _get_lr(self, t):
  87. if t < self.warmup_t:
  88. lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
  89. else:
  90. t = t - self.warmup_t
  91. total_t = self.t_initial - self.warmup_t
  92. lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
  93. return lrs
  94. def get_epoch_values(self, epoch: int):
  95. if self.t_in_epochs:
  96. return self._get_lr(epoch)
  97. else:
  98. return None
  99. def get_update_values(self, num_updates: int):
  100. if not self.t_in_epochs:
  101. return self._get_lr(num_updates)
  102. else:
  103. return None
  104. class MultiStepLRScheduler(Scheduler):
  105. def __init__(self, optimizer: torch.optim.Optimizer, milestones, gamma=0.1, warmup_t=0, warmup_lr_init=0, t_in_epochs=True) -> None:
  106. super().__init__(optimizer, param_group_field="lr")
  107. self.milestones = milestones
  108. self.gamma = gamma
  109. self.warmup_t = warmup_t
  110. self.warmup_lr_init = warmup_lr_init
  111. self.t_in_epochs = t_in_epochs
  112. if self.warmup_t:
  113. self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
  114. super().update_groups(self.warmup_lr_init)
  115. else:
  116. self.warmup_steps = [1 for _ in self.base_values]
  117. assert self.warmup_t <= min(self.milestones)
  118. def _get_lr(self, t):
  119. if t < self.warmup_t:
  120. lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
  121. else:
  122. lrs = [v * (self.gamma ** bisect.bisect_right(self.milestones, t)) for v in self.base_values]
  123. return lrs
  124. def get_epoch_values(self, epoch: int):
  125. if self.t_in_epochs:
  126. return self._get_lr(epoch)
  127. else:
  128. return None
  129. def get_update_values(self, num_updates: int):
  130. if not self.t_in_epochs:
  131. return self._get_lr(num_updates)
  132. else:
  133. return None