test_lamb.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. import unittest
  2. import os
  3. import torch
  4. from torch.optim import Optimizer
  5. import apex
  6. from apex.multi_tensor_apply import multi_tensor_applier
  7. from itertools import product
  8. class RefLAMB(Optimizer):
  9. r"""Implements Lamb algorithm.
  10. It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
  11. Arguments:
  12. params (iterable): iterable of parameters to optimize or dicts defining
  13. parameter groups
  14. lr (float, optional): learning rate (default: 1e-3)
  15. betas (Tuple[float, float], optional): coefficients used for computing
  16. running averages of gradient and its square (default: (0.9, 0.999))
  17. eps (float, optional): term added to the denominator to improve
  18. numerical stability (default: 1e-6)
  19. weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01)
  20. .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
  21. https://arxiv.org/abs/1904.00962
  22. """
  23. def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01):
  24. if not 0.0 <= lr:
  25. raise ValueError("Invalid learning rate: {}".format(lr))
  26. if not 0.0 <= eps:
  27. raise ValueError("Invalid epsilon value: {}".format(eps))
  28. if not 0.0 <= betas[0] < 1.0:
  29. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  30. if not 0.0 <= betas[1] < 1.0:
  31. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  32. defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
  33. super(RefLAMB, self).__init__(params, defaults)
  34. if multi_tensor_applier.available:
  35. import amp_C
  36. self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
  37. # Skip buffer
  38. self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
  39. self.multi_tensor_lamb = amp_C.multi_tensor_lamb
  40. else:
  41. raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')
  42. def step(self, closure=None):
  43. """Performs a single optimization step.
  44. Arguments:
  45. closure (callable, optional): A closure that reevaluates the model
  46. and returns the loss.
  47. """
  48. loss = None
  49. if closure is not None:
  50. loss = closure()
  51. # create separate grad lists for fp32, fp16, and bf16 params
  52. g_all_32, g_all_16, g_all_bf16 = [], [], []
  53. for group in self.param_groups:
  54. for p in group['params']:
  55. if p.grad is None:
  56. continue
  57. if p.dtype == torch.float32:
  58. g_all_32.append(p.grad.data)
  59. elif p.dtype == torch.float16:
  60. g_all_16.append(p.grad.data)
  61. elif p.dtype == torch.bfloat16:
  62. g_all_bf16.append(p.grad.data)
  63. else:
  64. raise RuntimeError('FusedLAMB only support fp16, fp32, and bf16.')
  65. device = self.param_groups[0]["params"][0].device
  66. g_norm_32, g_norm_16, g_norm_bf16 = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
  67. # compute grad norm for two lists
  68. if len(g_all_32) > 0:
  69. g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
  70. self._dummy_overflow_buf,
  71. [g_all_32], False)[0]
  72. if len(g_all_16) > 0:
  73. g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
  74. self._dummy_overflow_buf,
  75. [g_all_16], False)[0]
  76. if len(g_all_bf16) > 0:
  77. g_norm_bf16 = multi_tensor_applier(self.multi_tensor_l2norm,
  78. self._dummy_overflow_buf,
  79. [g_all_bf16], False)[0]
  80. # blend two grad norms to get global grad norm
  81. global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,
  82. self._dummy_overflow_buf,
  83. [[g_norm_32, g_norm_16, g_norm_bf16]],
  84. False)[0]
  85. max_grad_norm = 1.0
  86. clipped_ratio = max_grad_norm / max(global_grad_norm, max_grad_norm)
  87. for group in self.param_groups:
  88. for p in group['params']:
  89. if p.grad is None:
  90. continue
  91. p.grad.data *= clipped_ratio
  92. grad = p.grad.data
  93. if grad.is_sparse:
  94. raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
  95. state = self.state[p]
  96. # State initialization
  97. if len(state) == 0:
  98. state['step'] = 0
  99. # Exponential moving average of gradient values
  100. state['m'] = torch.zeros_like(p.data)
  101. # Exponential moving average of squared gradient values
  102. state['v'] = torch.zeros_like(p.data)
  103. m_t, v_t = state['m'], state['v']
  104. beta1, beta2 = group['betas']
  105. state['step'] += 1
  106. # m_t = beta1 * m + (1 - beta1) * g_t
  107. m_t.mul_(beta1).add_(grad, alpha=1-beta1)
  108. # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
  109. if len(g_all_16) > 0:
  110. v_t.mul_(beta2)
  111. v_t = v_t.to(torch.float32)
  112. grad32 = grad.to(torch.float32)
  113. v_t.addcmul_(grad32, grad32, value=1-beta2)
  114. else:
  115. v_t.mul_(beta2).addcmul_(grad, grad, value=1-beta2)
  116. # Debiasing
  117. m_t_hat = m_t / (1.0 - beta1 ** state['step'])
  118. v_t_hat = v_t / (1.0 - beta2 ** state['step'])
  119. update = m_t_hat / v_t_hat.sqrt().add(group['eps'])
  120. if group['weight_decay'] != 0:
  121. update.add_(p.data, alpha=group['weight_decay'])
  122. trust_ratio = 1.0
  123. w_norm = p.data.to(torch.float32).pow(2).sum().sqrt()
  124. g_norm = update.pow(2).sum().sqrt()
  125. if w_norm > 0 and g_norm > 0:
  126. trust_ratio = w_norm / g_norm
  127. state['w_norm'] = w_norm
  128. state['g_norm'] = g_norm
  129. state['trust_ratio'] = trust_ratio
  130. step_size = group['lr']
  131. p.data.add_(update, alpha=-step_size*trust_ratio)
  132. return loss
  133. class TestLamb(unittest.TestCase):
  134. def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):
  135. self.max_abs_diff = max_abs_diff
  136. self.max_rel_diff = max_rel_diff
  137. self.iters = iters
  138. torch.cuda.manual_seed(9876)
  139. def tearDown(self):
  140. pass
  141. def gen_param_optim(self, tensors, lamb_option):
  142. ref_param = []
  143. tst_param = []
  144. for tensor in tensors:
  145. ref_param.append(torch.nn.Parameter(tensor.clone()))
  146. tst_param.append(torch.nn.Parameter(tensor.clone()))
  147. ref_optim = self.ref_optim(ref_param, **lamb_option)
  148. tst_optim = self.tst_optim(tst_param, use_nvlamb=True, **lamb_option)
  149. return (ref_param, tst_param, ref_optim, tst_optim)
  150. def gen_grad(self, ref_param, tst_param):
  151. for p_ref, p_tst in zip(ref_param, tst_param):
  152. p_ref.grad = torch.rand_like(p_ref)
  153. p_tst.grad = p_ref.grad
  154. def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
  155. half_grads = []
  156. for p_ref, _ in zip(ref_param, tst_param):
  157. half_grads.append(torch.rand_like(p_ref).half())
  158. p_ref.grad = half_grads[-1].float() / scale
  159. return half_grads
  160. def get_max_diff(self, ref_param, tst_param):
  161. max_abs_diff = max_rel_diff = 0
  162. for p_ref, p_tst in zip(ref_param, tst_param):
  163. max_abs_diff_p = (p_ref - p_tst).abs().max().item()
  164. max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()
  165. if max_abs_diff_p > max_abs_diff: max_abs_diff = max_abs_diff_p
  166. if max_rel_diff_p > max_rel_diff: max_rel_diff = max_rel_diff_p
  167. return max_abs_diff, max_rel_diff
  168. def gen_single_type_test(self, param_type=torch.float, device="cuda"):
  169. nelem = 18011
  170. tensor = torch.rand(nelem, dtype=param_type, device=device)
  171. weight_decay = [0, 0.01]
  172. for wd in weight_decay:
  173. lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
  174. ref_param, tst_param, ref_optim, tst_optim = \
  175. self.gen_param_optim([tensor], lamb_option)
  176. if isinstance(tst_optim, apex.optimizers.FusedMixedPrecisionLamb):
  177. if param_type != torch.float:
  178. # joseli: This parameter is usually passed into the constructor,
  179. # but I do not want to change the testing interface.
  180. # As long as this parameter is set before the first call to step(),
  181. # then it should act normally.
  182. tst_optim.reduced_precision_dtype = param_type
  183. for i in range(self.iters):
  184. self.gen_grad(ref_param, tst_param)
  185. ref_optim.step()
  186. torch.cuda.synchronize()
  187. tst_optim.step()
  188. torch.cuda.synchronize()
  189. torch.testing.assert_close(tst_param, ref_param)
  190. class TestFusedLAMB(TestLamb):
  191. def __init__(self, *args, **kwargs):
  192. super(TestLamb, self).__init__(*args, **kwargs)
  193. self.ref_optim = RefLAMB
  194. self.tst_optim = apex.optimizers.FusedLAMB
  195. def test_float(self):
  196. self.gen_single_type_test(param_type=torch.float)
  197. @unittest.skip("PyTorch optimizer is not numerically correct for fp16")
  198. def test_half(self):
  199. self.gen_single_type_test(param_type=torch.float16)
  200. @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
  201. def test_multi_device(self):
  202. devices = ("cuda:0", "cuda:1")
  203. for current_dev, tensor_dev in product(devices, devices):
  204. with torch.cuda.device(current_dev):
  205. self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
  206. def test_multi_params(self):
  207. sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
  208. weight_decay = [0, 0.01]
  209. for wd in weight_decay:
  210. lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
  211. tensors = []
  212. for size in sizes:
  213. tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
  214. ref_param, tst_param, ref_optim, tst_optim = \
  215. self.gen_param_optim(tensors, lamb_option)
  216. for i in range(self.iters):
  217. self.gen_grad(ref_param, tst_param)
  218. ref_optim.step()
  219. tst_optim.step()
  220. max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
  221. self.assertLessEqual(max_abs_diff, self.max_abs_diff)
  222. self.assertLessEqual(max_rel_diff, self.max_rel_diff)
  223. def test_lamb_option(self):
  224. nelem = 1
  225. tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
  226. weight_decay = [0, 0.01]
  227. for wd in weight_decay:
  228. lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':wd}
  229. ref_param, tst_param, ref_optim, tst_optim = \
  230. self.gen_param_optim([tensor], lamb_option)
  231. for i in range(self.iters):
  232. self.gen_grad(ref_param, tst_param)
  233. ref_optim.step()
  234. tst_optim.step()
  235. max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
  236. self.assertLessEqual(max_abs_diff, self.max_abs_diff)
  237. self.assertLessEqual(max_rel_diff, self.max_rel_diff)
  238. class TestFusedMixedPrecisionLamb(TestLamb):
  239. def __init__(self, *args, **kwargs):
  240. super(TestLamb, self).__init__(*args, **kwargs)
  241. self.ref_optim = RefLAMB
  242. self.tst_optim = apex.optimizers.FusedMixedPrecisionLamb
  243. def test_float(self):
  244. self.gen_single_type_test(param_type=torch.float)
  245. def test_bfloat16(self):
  246. self.iters = 4
  247. self.gen_single_type_test(param_type=torch.bfloat16)
  248. def test_half(self):
  249. self.iters = 1
  250. self.gen_single_type_test(param_type=torch.float16)
  251. @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
  252. def test_multi_device(self):
  253. devices = ("cuda:0", "cuda:1")
  254. for current_dev, tensor_dev in product(devices, devices):
  255. with torch.cuda.device(current_dev):
  256. self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
  257. def test_multi_params(self):
  258. sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
  259. weight_decay = [0, 0.01]
  260. for wd in weight_decay:
  261. lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
  262. tensors = []
  263. for size in sizes:
  264. tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
  265. ref_param, tst_param, ref_optim, tst_optim = \
  266. self.gen_param_optim(tensors, lamb_option)
  267. for i in range(self.iters):
  268. self.gen_grad(ref_param, tst_param)
  269. ref_optim.step()
  270. tst_optim.step()
  271. max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
  272. self.assertLessEqual(max_abs_diff, self.max_abs_diff)
  273. self.assertLessEqual(max_rel_diff, self.max_rel_diff)
  274. def test_lamb_option(self):
  275. nelem = 1
  276. tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
  277. weight_decay = [0, 0.01]
  278. for wd in weight_decay:
  279. lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':wd}
  280. ref_param, tst_param, ref_optim, tst_optim = \
  281. self.gen_param_optim([tensor], lamb_option)
  282. for i in range(self.iters):
  283. self.gen_grad(ref_param, tst_param)
  284. ref_optim.step()
  285. tst_optim.step()
  286. max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
  287. self.assertLessEqual(max_abs_diff, self.max_abs_diff)
  288. self.assertLessEqual(max_rel_diff, self.max_rel_diff)
  289. if __name__ == '__main__':
  290. script_path = os.path.dirname(os.path.realpath(__file__))
  291. unittest.main()