amp.rst 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. .. role:: hidden
  2. :class: hidden-section
  3. apex.amp
  4. ===================================
  5. This page documents the updated API for Amp (Automatic Mixed Precision),
  6. a tool to enable Tensor Core-accelerated training in only 3 lines of Python.
  7. A `runnable, comprehensive Imagenet example`_ demonstrating good practices can be found
  8. on the Github page.
  9. GANs are a tricky case that many people have requested. A `comprehensive DCGAN example`_
  10. is under construction.
  11. If you already implemented Amp based on the instructions below, but it isn't behaving as expected,
  12. please review `Advanced Amp Usage`_ to see if any topics match your use case. If that doesn't help,
  13. `file an issue`_.
  14. .. _`file an issue`:
  15. https://github.com/NVIDIA/apex/issues
  16. ``opt_level``\ s and Properties
  17. -------------------------------
  18. Amp allows users to easily experiment with different pure and mixed precision modes.
  19. Commonly-used default modes are chosen by
  20. selecting an "optimization level" or ``opt_level``; each ``opt_level`` establishes a set of
  21. properties that govern Amp's implementation of pure or mixed precision training.
  22. Finer-grained control of how a given ``opt_level`` behaves can be achieved by passing values for
  23. particular properties directly to ``amp.initialize``. These manually specified values
  24. override the defaults established by the ``opt_level``.
  25. Example::
  26. # Declare model and optimizer as usual, with default (FP32) precision
  27. model = torch.nn.Linear(D_in, D_out).cuda()
  28. optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
  29. # Allow Amp to perform casts as required by the opt_level
  30. model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
  31. ...
  32. # loss.backward() becomes:
  33. with amp.scale_loss(loss, optimizer) as scaled_loss:
  34. scaled_loss.backward()
  35. ...
  36. Users **should not** manually cast their model or data to ``.half()``, regardless of what ``opt_level``
  37. or properties are chosen. Amp intends that users start with an existing default (FP32) script,
  38. add the three lines corresponding to the Amp API, and begin training with mixed precision.
  39. Amp can also be disabled, in which case the original script will behave exactly as it used to.
  40. In this way, there's no risk adhering to the Amp API, and a lot of potential performance benefit.
  41. .. note::
  42. Because it's never necessary to manually cast your model (aside from the call ``amp.initialize``)
  43. or input data, a script that adheres to the new API
  44. can switch between different ``opt-level``\ s without having to make any other changes.
  45. .. _`runnable, comprehensive Imagenet example`:
  46. https://github.com/NVIDIA/apex/tree/master/examples/imagenet
  47. .. _`comprehensive DCGAN example`:
  48. https://github.com/NVIDIA/apex/tree/master/examples/dcgan
  49. .. _`Advanced Amp Usage`:
  50. https://nvidia.github.io/apex/advanced.html
  51. Properties
  52. **********
  53. Currently, the under-the-hood properties that govern pure or mixed precision training are the following:
  54. - ``cast_model_type``: Casts your model's parameters and buffers to the desired type.
  55. - ``patch_torch_functions``: Patch all Torch functions and Tensor methods to perform Tensor Core-friendly ops like GEMMs and convolutions in FP16, and any ops that benefit from FP32 precision in FP32.
  56. - ``keep_batchnorm_fp32``: To enhance precision and enable cudnn batchnorm (which improves performance), it's often beneficial to keep batchnorm weights in FP32 even if the rest of the model is FP16.
  57. - ``master_weights``: Maintain FP32 master weights to accompany any FP16 model weights. FP32 master weights are stepped by the optimizer to enhance precision and capture small gradients.
  58. - ``loss_scale``: If ``loss_scale`` is a float value, use this value as the static (fixed) loss scale. If ``loss_scale`` is the string ``"dynamic"``, adaptively adjust the loss scale over time. Dynamic loss scale adjustments are performed by Amp automatically.
  59. Again, you often don't need to specify these properties by hand. Instead, select an ``opt_level``,
  60. which will set them up for you. After selecting an ``opt_level``, you can optionally pass property
  61. kwargs as manual overrides.
  62. If you attempt to override a property that does not make sense for the selected ``opt_level``,
  63. Amp will raise an error with an explanation. For example, selecting ``opt_level="O1"`` combined with
  64. the override ``master_weights=True`` does not make sense. ``O1`` inserts casts
  65. around Torch functions rather than model weights. Data, activations, and weights are recast
  66. out-of-place on the fly as they flow through patched functions. Therefore, the model weights themselves
  67. can (and should) remain FP32, and there is no need to maintain separate FP32 master weights.
  68. ``opt_level``\ s
  69. ****************
  70. Recognized ``opt_level``\ s are ``"O0"``, ``"O1"``, ``"O2"``, and ``"O3"``.
  71. ``O0`` and ``O3`` are not true mixed precision, but they are useful for establishing accuracy and
  72. speed baselines, respectively.
  73. ``O1`` and ``O2`` are different implementations of mixed precision. Try both, and see
  74. what gives the best speedup and accuracy for your model.
  75. ``O0``: FP32 training
  76. ^^^^^^^^^^^^^^^^^^^^^^
  77. Your incoming model should be FP32 already, so this is likely a no-op.
  78. ``O0`` can be useful to establish an accuracy baseline.
  79. | Default properties set by ``O0``:
  80. | ``cast_model_type=torch.float32``
  81. | ``patch_torch_functions=False``
  82. | ``keep_batchnorm_fp32=None`` (effectively, "not applicable," everything is FP32)
  83. | ``master_weights=False``
  84. | ``loss_scale=1.0``
  85. |
  86. |
  87. ``O1``: Mixed Precision (recommended for typical use)
  88. ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  89. Patch all Torch functions and Tensor methods to cast their inputs according to a whitelist-blacklist
  90. model. Whitelist ops (for example, Tensor Core-friendly ops like GEMMs and convolutions) are performed
  91. in FP16. Blacklist ops that benefit from FP32 precision (for example, softmax)
  92. are performed in FP32. ``O1`` also uses dynamic loss scaling, unless overridden.
  93. | Default properties set by ``O1``:
  94. | ``cast_model_type=None`` (not applicable)
  95. | ``patch_torch_functions=True``
  96. | ``keep_batchnorm_fp32=None`` (again, not applicable, all model weights remain FP32)
  97. | ``master_weights=None`` (not applicable, model weights remain FP32)
  98. | ``loss_scale="dynamic"``
  99. |
  100. |
  101. ``O2``: "Almost FP16" Mixed Precision
  102. ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  103. ``O2`` casts the model weights to FP16,
  104. patches the model's ``forward`` method to cast input
  105. data to FP16, keeps batchnorms in FP32, maintains FP32 master weights,
  106. updates the optimizer's ``param_groups`` so that the ``optimizer.step()``
  107. acts directly on the FP32 weights (followed by FP32 master weight->FP16 model weight
  108. copies if necessary),
  109. and implements dynamic loss scaling (unless overridden).
  110. Unlike ``O1``, ``O2`` does not patch Torch functions or Tensor methods.
  111. | Default properties set by ``O2``:
  112. | ``cast_model_type=torch.float16``
  113. | ``patch_torch_functions=False``
  114. | ``keep_batchnorm_fp32=True``
  115. | ``master_weights=True``
  116. | ``loss_scale="dynamic"``
  117. |
  118. |
  119. ``O3``: FP16 training
  120. ^^^^^^^^^^^^^^^^^^^^^^
  121. ``O3`` may not achieve the stability of the true mixed precision options ``O1`` and ``O2``.
  122. However, it can be useful to establish a speed baseline for your model, against which
  123. the performance of ``O1`` and ``O2`` can be compared. If your model uses batch normalization,
  124. to establish "speed of light" you can try ``O3`` with the additional property override
  125. ``keep_batchnorm_fp32=True`` (which enables cudnn batchnorm, as stated earlier).
  126. | Default properties set by ``O3``:
  127. | ``cast_model_type=torch.float16``
  128. | ``patch_torch_functions=False``
  129. | ``keep_batchnorm_fp32=False``
  130. | ``master_weights=False``
  131. | ``loss_scale=1.0``
  132. |
  133. |
  134. Unified API
  135. -----------
  136. .. automodule:: apex.amp
  137. .. currentmodule:: apex.amp
  138. .. autofunction:: initialize
  139. .. autofunction:: scale_loss
  140. .. autofunction:: master_params
  141. Checkpointing
  142. -------------
  143. To properly save and load your amp training, we introduce the ``amp.state_dict()``, which contains all ``loss_scaler``\ s and their corresponding unskipped steps, as well as ``amp.load_state_dict()`` to restore these attributes.
  144. In order to get bitwise accuracy, we recommend the following workflow::
  145. # Initialization
  146. opt_level = 'O1'
  147. model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
  148. # Train your model
  149. ...
  150. # Save checkpoint
  151. checkpoint = {
  152. 'model': model.state_dict(),
  153. 'optimizer': optimizer.state_dict(),
  154. 'amp': amp.state_dict()
  155. }
  156. torch.save(checkpoint, 'amp_checkpoint.pt')
  157. ...
  158. # Restore
  159. model = ...
  160. optimizer = ...
  161. checkpoint = torch.load('amp_checkpoint.pt')
  162. model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
  163. model.load_state_dict(checkpoint['model'])
  164. optimizer.load_state_dict(checkpoint['optimizer'])
  165. amp.load_state_dict(checkpoint['amp'])
  166. # Continue training
  167. ...
  168. Note that we recommend restoring the model using the same ``opt_level``. Also note that we recommend calling the ``load_state_dict`` methods after ``amp.initialize``.
  169. Advanced use cases
  170. ------------------
  171. The unified Amp API supports gradient accumulation across iterations,
  172. multiple backward passes per iteration, multiple models/optimizers,
  173. custom/user-defined autograd functions, and custom data batch classes. Gradient clipping and GANs also
  174. require special treatment, but this treatment does not need to change
  175. for different ``opt_level``\ s. Further details can be found here:
  176. .. toctree::
  177. :maxdepth: 1
  178. advanced
  179. Transition guide for old API users
  180. ----------------------------------
  181. We strongly encourage moving to the new Amp API, because it's more versatile, easier to use, and future proof. The original :class:`FP16_Optimizer` and the old "Amp" API are deprecated, and subject to removal at at any time.
  182. For users of the old "Amp" API
  183. ******************************
  184. In the new API, ``opt-level O1`` performs the same patching of the Torch namespace as the old thing
  185. called "Amp."
  186. However, the new API allows static or dynamic loss scaling, while the old API only allowed dynamic loss scaling.
  187. In the new API, the old call to ``amp_handle = amp.init()``, and the returned ``amp_handle``, are no
  188. longer exposed or necessary. The new ``amp.initialize()`` does the duty of ``amp.init()`` (and more).
  189. Therefore, any existing calls to ``amp_handle = amp.init()`` should be deleted.
  190. The functions formerly exposed through ``amp_handle`` are now free
  191. functions accessible through the ``amp`` module.
  192. The backward context manager must be changed accordingly::
  193. # old API
  194. with amp_handle.scale_loss(loss, optimizer) as scaled_loss:
  195. scaled_loss.backward()
  196. ->
  197. # new API
  198. with amp.scale_loss(loss, optimizer) as scaled_loss:
  199. scaled_loss.backward()
  200. For now, the deprecated "Amp" API documentation can still be found on the Github README: https://github.com/NVIDIA/apex/tree/master/apex/amp. The old API calls that `annotate user functions`_ to run
  201. with a particular precision are still honored by the new API.
  202. .. _`annotate user functions`:
  203. https://github.com/NVIDIA/apex/tree/master/apex/amp#annotating-user-functions
  204. For users of the old FP16_Optimizer
  205. ***********************************
  206. ``opt-level O2`` is equivalent to :class:`FP16_Optimizer` with ``dynamic_loss_scale=True``.
  207. Once again, the backward pass must be changed to the unified version::
  208. optimizer.backward(loss)
  209. ->
  210. with amp.scale_loss(loss, optimizer) as scaled_loss:
  211. scaled_loss.backward()
  212. One annoying aspect of FP16_Optimizer was that the user had to manually convert their model to half
  213. (either by calling ``.half()`` on it, or using a function or module wrapper from
  214. ``apex.fp16_utils``), and also manually call ``.half()`` on input data. **Neither of these are
  215. necessary in the new API. No matter what --opt-level
  216. you choose, you can and should simply build your model and pass input data in the default FP32 format.**
  217. The new Amp API will perform the right conversions during
  218. ``model, optimizer = amp.initialize(model, optimizer, opt_level=....)`` based on the ``--opt-level``
  219. and any overridden flags. Floating point input data may be FP32 or FP16, but you may as well just
  220. let it be FP16, because the ``model`` returned by ``amp.initialize`` will have its ``forward``
  221. method patched to cast the input data appropriately.