test_fused_novograd.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import torch
  2. from torch.optim import Optimizer
  3. import math
  4. import apex
  5. import unittest
  6. from test_fused_optimizer import TestFusedOptimizer
  7. from itertools import product
  8. class Novograd(Optimizer):
  9. """
  10. Implements Novograd algorithm.
  11. Args:
  12. params (iterable): iterable of parameters to optimize or dicts defining
  13. parameter groups
  14. lr (float, optional): learning rate (default: 1e-3)
  15. betas (Tuple[float, float], optional): coefficients used for computing
  16. running averages of gradient and its square (default: (0.95, 0))
  17. eps (float, optional): term added to the denominator to improve
  18. numerical stability (default: 1e-8)
  19. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  20. grad_averaging: gradient averaging
  21. amsgrad (boolean, optional): whether to use the AMSGrad variant of this
  22. algorithm from the paper `On the Convergence of Adam and Beyond`_
  23. (default: False)
  24. """
  25. def __init__(self, params, lr=1e-3, betas=(0.95, 0), eps=1e-8,
  26. weight_decay=0, grad_averaging=False, amsgrad=False):
  27. if not 0.0 <= lr:
  28. raise ValueError("Invalid learning rate: {}".format(lr))
  29. if not 0.0 <= eps:
  30. raise ValueError("Invalid epsilon value: {}".format(eps))
  31. if not 0.0 <= betas[0] < 1.0:
  32. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  33. if not 0.0 <= betas[1] < 1.0:
  34. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  35. defaults = dict(lr=lr, betas=betas, eps=eps,
  36. weight_decay=weight_decay,
  37. grad_averaging=grad_averaging,
  38. amsgrad=amsgrad)
  39. super(Novograd, self).__init__(params, defaults)
  40. def __setstate__(self, state):
  41. super(Novograd, self).__setstate__(state)
  42. for group in self.param_groups:
  43. group.setdefault('amsgrad', False)
  44. def step(self, closure=None):
  45. """Performs a single optimization step.
  46. Arguments:
  47. closure (callable, optional): A closure that reevaluates the model
  48. and returns the loss.
  49. """
  50. loss = None
  51. if closure is not None:
  52. loss = closure()
  53. for group in self.param_groups:
  54. for p in group['params']:
  55. if p.grad is None:
  56. continue
  57. grad = p.grad.data
  58. if grad.is_sparse:
  59. raise RuntimeError('Sparse gradients are not supported.')
  60. amsgrad = group['amsgrad']
  61. state = self.state[p]
  62. # State initialization
  63. if len(state) == 0:
  64. state['step'] = 0
  65. # Exponential moving average of gradient values
  66. state['exp_avg'] = torch.zeros_like(p.data)
  67. # Exponential moving average of squared gradient values
  68. state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
  69. if amsgrad:
  70. # Maintains max of all exp. moving avg. of sq. grad. values
  71. state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
  72. exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
  73. if amsgrad:
  74. max_exp_avg_sq = state['max_exp_avg_sq']
  75. beta1, beta2 = group['betas']
  76. state['step'] += 1
  77. norm = torch.sum(torch.pow(grad, 2))
  78. if exp_avg_sq == 0:
  79. exp_avg_sq.copy_(norm)
  80. else:
  81. exp_avg_sq.mul_(beta2).add_(norm, alpha=1 - beta2)
  82. if amsgrad:
  83. # Maintains the maximum of all 2nd moment running avg. till now
  84. torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
  85. # Use the max. for normalizing running avg. of gradient
  86. denom = max_exp_avg_sq.sqrt().add_(group['eps'])
  87. else:
  88. denom = exp_avg_sq.sqrt().add_(group['eps'])
  89. grad.div_(denom)
  90. if group['weight_decay'] != 0:
  91. grad.add_(p.data, alpha=group['weight_decay'])
  92. if group['grad_averaging']:
  93. grad.mul_(1 - beta1)
  94. exp_avg.mul_(beta1).add_(grad)
  95. p.data.add_(exp_avg, alpha=-group['lr'])
  96. return loss
  97. class TestFusedNovoGrad(TestFusedOptimizer):
  98. def __init__(self, *args, **kwargs):
  99. super(TestFusedNovoGrad, self).__init__(*args, **kwargs)
  100. # The options for NovoGrad and FusedNovoGrad are very specific if they
  101. # are expected to behave the same.
  102. self.options = {'lr':1e-3, 'betas':(0.95, 0), 'eps':1e-8,
  103. 'weight_decay':0, 'grad_averaging':False, 'amsgrad':False}
  104. self.tst_options = {'lr':1e-3, 'betas':(0.95, 0), 'eps':1e-8,
  105. 'weight_decay':0, 'grad_averaging':False, 'amsgrad':False,
  106. 'bias_correction':False, 'reg_inside_moment':True,
  107. 'norm_type':2, 'init_zero':False, 'set_grad_none':True}
  108. self.ref_optim = Novograd
  109. self.fused_optim = apex.optimizers.FusedNovoGrad
  110. def test_float(self):
  111. self.gen_single_type_test(param_type=torch.float)
  112. def test_half(self):
  113. self.gen_single_type_test(param_type=torch.float16)
  114. @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
  115. def test_multi_device(self):
  116. devices = ("cuda:1", "cuda:0")
  117. for current_dev, tensor_dev in product(devices, devices):
  118. with torch.cuda.device(current_dev):
  119. torch.cuda.synchronize()
  120. self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
  121. def test_multi_params(self):
  122. sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
  123. tensors = []
  124. for size in sizes:
  125. tensors.append(torch.rand(size, dtype=torch.float, device="cuda"))
  126. ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
  127. tensors, self.options, self.tst_options
  128. )
  129. for _ in range(self.iters):
  130. self.gen_grad(ref_param, tst_param)
  131. ref_optim.step()
  132. tst_optim.step()
  133. max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
  134. self.assertLessEqual(max_abs_diff, self.max_abs_diff)
  135. self.assertLessEqual(max_rel_diff, self.max_rel_diff)
  136. if __name__ == '__main__':
  137. unittest.main()