advanced.rst 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. .. role:: hidden
  2. :class: hidden-section
  3. Advanced Amp Usage
  4. ===================================
  5. GANs
  6. ----
  7. GANs are an interesting synthesis of several topics below. A `comprehensive example`_
  8. is under construction.
  9. .. _`comprehensive example`:
  10. https://github.com/NVIDIA/apex/tree/master/examples/dcgan
  11. Gradient clipping
  12. -----------------
  13. Amp calls the params owned directly by the optimizer's ``param_groups`` the "master params."
  14. These master params may be fully or partially distinct from ``model.parameters()``.
  15. For example, with `opt_level="O2"`_, ``amp.initialize`` casts most model params to FP16,
  16. creates an FP32 master param outside the model for each newly-FP16 model param,
  17. and updates the optimizer's ``param_groups`` to point to these FP32 params.
  18. The master params owned by the optimizer's ``param_groups`` may also fully coincide with the
  19. model params, which is typically true for ``opt_level``\s ``O0``, ``O1``, and ``O3``.
  20. In all cases, correct practice is to clip the gradients of the params that are guaranteed to be
  21. owned **by the optimizer's** ``param_groups``, instead of those retrieved via ``model.parameters()``.
  22. Also, if Amp uses loss scaling, gradients must be clipped after they have been unscaled
  23. (which occurs during exit from the ``amp.scale_loss`` context manager).
  24. The following pattern should be correct for any ``opt_level``::
  25. with amp.scale_loss(loss, optimizer) as scaled_loss:
  26. scaled_loss.backward()
  27. # Gradients are unscaled during context manager exit.
  28. # Now it's safe to clip. Replace
  29. # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
  30. # with
  31. torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_norm)
  32. # or
  33. torch.nn.utils.clip_grad_value_(amp.master_params(optimizer), max_)
  34. Note the use of the utility function ``amp.master_params(optimizer)``,
  35. which returns a generator-expression that iterates over the
  36. params in the optimizer's ``param_groups``.
  37. Also note that ``clip_grad_norm_(amp.master_params(optimizer), max_norm)`` is invoked
  38. *instead of*, not *in addition to*, ``clip_grad_norm_(model.parameters(), max_norm)``.
  39. .. _`opt_level="O2"`:
  40. https://nvidia.github.io/apex/amp.html#o2-fast-mixed-precision
  41. Custom/user-defined autograd functions
  42. --------------------------------------
  43. The old Amp API for `registering user functions`_ is still considered correct. Functions must
  44. be registered before calling ``amp.initialize``.
  45. .. _`registering user functions`:
  46. https://github.com/NVIDIA/apex/tree/master/apex/amp#annotating-user-functions
  47. Forcing particular layers/functions to a desired type
  48. -----------------------------------------------------
  49. I'm still working on a generalizable exposure for this that won't require user-side code divergence
  50. across different ``opt-level``\ s.
  51. Multiple models/optimizers/losses
  52. ---------------------------------
  53. Initialization with multiple models/optimizers
  54. **********************************************
  55. ``amp.initialize``'s optimizer argument may be a single optimizer or a list of optimizers,
  56. as long as the output you accept has the same type.
  57. Similarly, the ``model`` argument may be a single model or a list of models, as long as the accepted
  58. output matches. The following calls are all legal::
  59. model, optim = amp.initialize(model, optim,...)
  60. model, [optim0, optim1] = amp.initialize(model, [optim0, optim1],...)
  61. [model0, model1], optim = amp.initialize([model0, model1], optim,...)
  62. [model0, model1], [optim0, optim1] = amp.initialize([model0, model1], [optim0, optim1],...)
  63. Backward passes with multiple optimizers
  64. ****************************************
  65. Whenever you invoke a backward pass, the ``amp.scale_loss`` context manager must receive
  66. **all the optimizers that own any params for which the current backward pass is creating gradients.**
  67. This is true even if each optimizer owns only some, but not all, of the params that are about to
  68. receive gradients.
  69. If, for a given backward pass, there's only one optimizer whose params are about to receive gradients,
  70. you may pass that optimizer directly to ``amp.scale_loss``. Otherwise, you must pass the
  71. list of optimizers whose params are about to receive gradients. Example with 3 losses and 2 optimizers::
  72. # loss0 accumulates gradients only into params owned by optim0:
  73. with amp.scale_loss(loss0, optim0) as scaled_loss:
  74. scaled_loss.backward()
  75. # loss1 accumulates gradients only into params owned by optim1:
  76. with amp.scale_loss(loss1, optim1) as scaled_loss:
  77. scaled_loss.backward()
  78. # loss2 accumulates gradients into some params owned by optim0
  79. # and some params owned by optim1
  80. with amp.scale_loss(loss2, [optim0, optim1]) as scaled_loss:
  81. scaled_loss.backward()
  82. Optionally have Amp use a different loss scaler per-loss
  83. ********************************************************
  84. By default, Amp maintains a single global loss scaler that will be used for all backward passes
  85. (all invocations of ``with amp.scale_loss(...)``). No additional arguments to ``amp.initialize``
  86. or ``amp.scale_loss`` are required to use the global loss scaler. The code snippets above with
  87. multiple optimizers/backward passes use the single global loss scaler under the hood,
  88. and they should "just work."
  89. However, you can optionally tell Amp to maintain a loss scaler per-loss, which gives Amp increased
  90. numerical flexibility. This is accomplished by supplying the ``num_losses`` argument to
  91. ``amp.initialize`` (which tells Amp how many backward passes you plan to invoke, and therefore
  92. how many loss scalers Amp should create), then supplying the ``loss_id`` argument to each of your
  93. backward passes (which tells Amp the loss scaler to use for this particular backward pass)::
  94. model, [optim0, optim1] = amp.initialize(model, [optim0, optim1], ..., num_losses=3)
  95. with amp.scale_loss(loss0, optim0, loss_id=0) as scaled_loss:
  96. scaled_loss.backward()
  97. with amp.scale_loss(loss1, optim1, loss_id=1) as scaled_loss:
  98. scaled_loss.backward()
  99. with amp.scale_loss(loss2, [optim0, optim1], loss_id=2) as scaled_loss:
  100. scaled_loss.backward()
  101. ``num_losses`` and ``loss_id``\ s should be specified purely based on the set of
  102. losses/backward passes. The use of multiple optimizers, or association of single or
  103. multiple optimizers with each backward pass, is unrelated.
  104. Gradient accumulation across iterations
  105. ---------------------------------------
  106. The following should "just work," and properly accommodate multiple models/optimizers/losses, as well as
  107. gradient clipping via the `instructions above`_::
  108. # If your intent is to simulate a larger batch size using gradient accumulation,
  109. # you can divide the loss by the number of accumulation iterations (so that gradients
  110. # will be averaged over that many iterations):
  111. loss = loss/iters_to_accumulate
  112. with amp.scale_loss(loss, optimizer) as scaled_loss:
  113. scaled_loss.backward()
  114. # Every iters_to_accumulate iterations, call step() and reset gradients:
  115. if iter%iters_to_accumulate == 0:
  116. # Gradient clipping if desired:
  117. # torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_norm)
  118. optimizer.step()
  119. optimizer.zero_grad()
  120. As a minor performance optimization, you can pass ``delay_unscale=True``
  121. to ``amp.scale_loss`` until you're ready to ``step()``. You should only attempt ``delay_unscale=True``
  122. if you're sure you know what you're doing, because the interaction with gradient clipping and
  123. multiple models/optimizers/losses can become tricky.::
  124. if iter%iters_to_accumulate == 0:
  125. # Every iters_to_accumulate iterations, unscale and step
  126. with amp.scale_loss(loss, optimizer) as scaled_loss:
  127. scaled_loss.backward()
  128. optimizer.step()
  129. optimizer.zero_grad()
  130. else:
  131. # Otherwise, accumulate gradients, don't unscale or step.
  132. with amp.scale_loss(loss, optimizer, delay_unscale=True) as scaled_loss:
  133. scaled_loss.backward()
  134. .. _`instructions above`:
  135. https://nvidia.github.io/apex/advanced.html#gradient-clipping
  136. Custom data batch types
  137. -----------------------
  138. The intention of Amp is that you never need to cast your input data manually, regardless of
  139. ``opt_level``. Amp accomplishes this by patching any models' ``forward`` methods to cast
  140. incoming data appropriately for the ``opt_level``. But to cast incoming data,
  141. Amp needs to know how. The patched ``forward`` will recognize and cast floating-point Tensors
  142. (non-floating-point Tensors like IntTensors are not touched) and
  143. Python containers of floating-point Tensors. However, if you wrap your Tensors in a custom class,
  144. the casting logic doesn't know how to drill
  145. through the tough custom shell to access and cast the juicy Tensor meat within. You need to tell
  146. Amp how to cast your custom batch class, by assigning it a ``to`` method that accepts a ``torch.dtype``
  147. (e.g., ``torch.float16`` or ``torch.float32``) and returns an instance of the custom batch cast to
  148. ``dtype``. The patched ``forward`` checks for the presence of your ``to`` method, and will
  149. invoke it with the correct type for the ``opt_level``.
  150. Example::
  151. class CustomData(object):
  152. def __init__(self):
  153. self.tensor = torch.cuda.FloatTensor([1,2,3])
  154. def to(self, dtype):
  155. self.tensor = self.tensor.to(dtype)
  156. return self
  157. .. warning::
  158. Amp also forwards numpy ndarrays without casting them. If you send input data as a raw, unwrapped
  159. ndarray, then later use it to create a Tensor within your ``model.forward``, this Tensor's type will
  160. not depend on the ``opt_level``, and may or may not be correct. Users are encouraged to pass
  161. castable data inputs (Tensors, collections of Tensors, or custom classes with a ``to`` method)
  162. wherever possible.
  163. .. note::
  164. Amp does not call ``.cuda()`` on any Tensors for you. Amp assumes that your original script
  165. is already set up to move Tensors from the host to the device as needed.