fp16_optimizer.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. import torch
  2. from apex.multi_tensor_apply import multi_tensor_applier
  3. class FP16_Optimizer(object):
  4. """
  5. :class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.
  6. Designed only to wrap apex.contrib.optimizers.FusedAdam, FusedSGD.
  7. Refer to apex.fp16_utils documents for more information.
  8. Example::
  9. model = torch.nn.Linear(D_in, D_out).cuda().half()
  10. optimizer = apex.contrib.optimizers.FusedSGD(model.parameters())
  11. optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
  12. ...
  13. # loss.backward() becomes:
  14. optimizer.backward(loss)
  15. ...
  16. Example with dynamic loss scaling::
  17. ...
  18. optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
  19. # optional arg to control dynamic loss scaling behavior
  20. # dynamic_loss_args={'scale_window' : 500})
  21. # Usually, dynamic_loss_args is not necessary.
  22. """
  23. def __init__(self,
  24. init_optimizer,
  25. static_loss_scale=1.0,
  26. dynamic_loss_scale=False,
  27. dynamic_loss_args=None,
  28. verbose=True):
  29. print("\nThis fp16_optimizer is designed to only work with apex.contrib.optimizers.*")
  30. print("To update, use updated optimizers with AMP.")
  31. # The fused optimizer does all the work. We need this layer for two reason:
  32. # 1. maintain same user API from apex.fp16_utils
  33. # 2. keep common stuff here in case we need to add new fused optimizer later
  34. if not torch.cuda.is_available:
  35. raise SystemError("Cannot use fp16 without CUDA.")
  36. self.optimizer = init_optimizer
  37. self.fp16_groups = [] # model params
  38. self.fp32_groups = [] # master weights
  39. # iterate over param_groups
  40. for param_group in self.optimizer.param_groups:
  41. fp16_group = []
  42. fp32_group = []
  43. for p in param_group['params']:
  44. fp16_group.append(p)
  45. fp32_group.append(p.clone().float().detach())
  46. self.fp16_groups.append(fp16_group)
  47. self.fp32_groups.append(fp32_group)
  48. param_group['params'] = fp32_group
  49. if multi_tensor_applier.available:
  50. import amp_C
  51. self.overflow_buf = torch.cuda.IntTensor([0])
  52. self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
  53. else:
  54. raise RuntimeError('FP16_Optimizer requires cuda extensions')
  55. # we may have a way of fusing dynamic scale. Do not support for now
  56. if dynamic_loss_scale:
  57. if dynamic_loss_args is not None:
  58. raise SystemError("Do not support dynamic loss scale args for now.")
  59. self.dynamic_loss_scale = True
  60. self.cur_scale = 2**16
  61. self.cur_iter = 0
  62. self.last_overflow_iter = -1
  63. self.scale_factor = 2
  64. self.scale_window = 1000
  65. else:
  66. self.dynamic_loss_scale = False
  67. self.cur_iter = 0
  68. self.cur_scale = static_loss_scale
  69. self.verbose = verbose
  70. def zero_grad(self, set_grads_to_None=True):
  71. """
  72. Zero FP16 parameter grads.
  73. """
  74. # FP32 grad should never exist.
  75. # For speed, set model fp16 grad to None by default
  76. for group in self.fp16_groups:
  77. for p in group:
  78. if set_grads_to_None:
  79. p.grad = None
  80. else:
  81. if p.grad is not None:
  82. p.grad.detach_()
  83. p.grad.zero_()
  84. def step(self, closure=None):
  85. """
  86. Not supporting closure.
  87. """
  88. fp16_grads = []
  89. norm_groups = []
  90. skip = False
  91. for group in self.fp16_groups:
  92. fp16_grad = []
  93. for i, p in enumerate(group):
  94. fp16_grad.append(p.grad)
  95. fp16_grads.append(fp16_grad)
  96. # nan check
  97. self.overflow_buf.zero_()
  98. for fp16_grad in fp16_grads:
  99. if len(fp16_grad) > 0:
  100. norm, norm_per_tensor = multi_tensor_applier(self.multi_tensor_l2norm,
  101. self.overflow_buf,
  102. [fp16_grad], True)
  103. norm_groups.append(norm)
  104. if self.overflow_buf.item() != 0:
  105. skip = True
  106. if skip:
  107. self._update_scale(skip)
  108. return
  109. # norm is in fact norm*cur_scale
  110. self.optimizer.step(grads=fp16_grads,
  111. output_params=self.fp16_groups,
  112. scale=self.cur_scale,
  113. grad_norms=norm_groups)
  114. self._update_scale(False)
  115. return
  116. def backward(self, loss):
  117. """
  118. :attr:`backward` performs the following steps:
  119. 1. fp32_loss = loss.float()
  120. 2. scaled_loss = fp32_loss*loss_scale
  121. 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
  122. """
  123. scaled_loss = (loss.float()) * self.cur_scale
  124. scaled_loss.backward()
  125. def _update_scale(self, skip):
  126. if self.dynamic_loss_scale:
  127. if skip:
  128. if self.verbose:
  129. print("\nGrad overflow on iteration", self.cur_iter)
  130. print("Using dynamic loss scale of", self.cur_scale)
  131. self.cur_scale = max(self.cur_scale/self.scale_factor, 1)
  132. self.last_overflow_iter = self.cur_iter
  133. else:
  134. if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
  135. self.cur_scale *= self.scale_factor
  136. else:
  137. if skip:
  138. print("\nGrad overflow on iteration", self.cur_iter)
  139. print("Using static loss scale of", self.cur_scale)
  140. self.cur_iter +=1
  141. return
  142. # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
  143. def _get_state(self):
  144. return self.optimizer.state
  145. def _set_state(self, value):
  146. self.optimizer.state = value
  147. state = property(_get_state, _set_state)
  148. # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
  149. # (for example, to adjust the learning rate)
  150. def _get_param_groups(self):
  151. return self.optimizer.param_groups
  152. def _set_param_groups(self, value):
  153. self.optimizer.param_groups = value
  154. param_groups = property(_get_param_groups, _set_param_groups)
  155. def state_dict(self):
  156. """
  157. Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
  158. This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
  159. of the contained Pytorch optimizer.
  160. Example::
  161. checkpoint = {}
  162. checkpoint['model'] = model.state_dict()
  163. checkpoint['optimizer'] = optimizer.state_dict()
  164. torch.save(checkpoint, "saved.pth")
  165. """
  166. state_dict = {}
  167. state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
  168. state_dict['cur_scale'] = self.cur_scale
  169. state_dict['cur_iter'] = self.cur_iter
  170. if state_dict['dynamic_loss_scale']:
  171. state_dict['last_overflow_iter'] = self.last_overflow_iter
  172. state_dict['scale_factor'] = self.scale_factor
  173. state_dict['scale_window'] = self.scale_window
  174. state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
  175. state_dict['fp32_groups'] = self.fp32_groups
  176. return state_dict
  177. def load_state_dict(self, state_dict):
  178. """
  179. Loads a state_dict created by an earlier call to state_dict().
  180. If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
  181. whose parameters in turn came from ``model``, it is expected that the user
  182. will call ``model.load_state_dict()`` before
  183. ``fp16_optimizer_instance.load_state_dict()`` is called.
  184. Example::
  185. model = torch.nn.Linear(D_in, D_out).cuda().half()
  186. optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
  187. optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
  188. ...
  189. checkpoint = torch.load("saved.pth")
  190. model.load_state_dict(checkpoint['model'])
  191. optimizer.load_state_dict(checkpoint['optimizer'])
  192. """
  193. # I think it should actually be ok to reload the optimizer before the model.
  194. self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
  195. self.cur_scale = state_dict['cur_scale']
  196. self.cur_iter = state_dict['cur_iter']
  197. if state_dict['dynamic_loss_scale']:
  198. self.last_overflow_iter = state_dict['last_overflow_iter']
  199. self.scale_factor = state_dict['scale_factor']
  200. self.scale_window = state_dict['scale_window']
  201. self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
  202. # At this point, the optimizer's references to the model's fp32 parameters are up to date.
  203. # The optimizer's hyperparameters and internal buffers are also up to date.
  204. # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
  205. # out of date. There are two options.
  206. # 1: Refresh the master params from the model's fp16 params.
  207. # This requires less storage but incurs precision loss.
  208. # 2: Save and restore the fp32 master copies separately.
  209. # We choose option 2.
  210. #
  211. # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
  212. # of their associated parameters, because it's possible those buffers might not exist yet in
  213. # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
  214. # constructed in the same way as the one whose state_dict we are loading, the same master params
  215. # are guaranteed to exist, so we can just copy_() from the saved master params.
  216. for current, saved in zip(self.fp32_groups, state_dict['fp32_groups']):
  217. for _current, _saved in zip(current, saved):
  218. _current.data.copy_(_saved.data)