3
0

amp_C_frontend.cpp 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. #include <torch/extension.h>
  2. void multi_tensor_scale_cuda(
  3. int chunk_size,
  4. at::Tensor noop_flag,
  5. std::vector<std::vector<at::Tensor>> tensor_lists,
  6. float scale);
  7. void multi_tensor_sgd_cuda(
  8. int chunk_size,
  9. at::Tensor noop_flag,
  10. std::vector<std::vector<at::Tensor>> tensor_lists,
  11. float wd,
  12. float momentum,
  13. float dampening,
  14. float lr,
  15. bool nesterov,
  16. bool first_run,
  17. bool wd_after_momentum,
  18. float scale);
  19. void multi_tensor_axpby_cuda(
  20. int chunk_size,
  21. at::Tensor noop_flag,
  22. std::vector<std::vector<at::Tensor>> tensor_lists,
  23. float a,
  24. float b,
  25. int arg_to_check);
  26. std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
  27. int chunk_size,
  28. at::Tensor noop_flag,
  29. std::vector<std::vector<at::Tensor>> tensor_lists,
  30. at::optional<bool> per_tensor_python);
  31. std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_mp_cuda(
  32. int chunk_size,
  33. at::Tensor noop_flag,
  34. std::vector<std::vector<at::Tensor>> tensor_lists,
  35. at::optional<bool> per_tensor_python);
  36. std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_scale_cuda(
  37. int chunk_size,
  38. at::Tensor noop_flag,
  39. std::vector<std::vector<at::Tensor>> tensor_lists,
  40. float scale,
  41. at::optional<bool> per_tensor_python);
  42. std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
  43. int chunk_size,
  44. at::Tensor noop_flag,
  45. std::vector<std::vector<at::Tensor>> tensor_lists,
  46. at::Tensor inv_scale,
  47. at::optional<bool> per_tensor_python);
  48. void multi_tensor_lamb_stage1_cuda(
  49. int chunk_size,
  50. at::Tensor noop_flag,
  51. std::vector<std::vector<at::Tensor>> tensor_lists,
  52. at::Tensor per_tensor_decay,
  53. const int step,
  54. const float beta1,
  55. const float beta2,
  56. const float epsilon,
  57. at::Tensor global_grad_norm,
  58. const float max_global_grad_norm);
  59. void multi_tensor_lamb_stage2_cuda(
  60. int chunk_size,
  61. at::Tensor noop_flag,
  62. std::vector<std::vector<at::Tensor>> tensor_lists,
  63. at::Tensor per_tensor_param_norm,
  64. at::Tensor per_tensor_update_norm,
  65. const float lr,
  66. const float weight_decay,
  67. at::optional<bool> use_nvlamb_python);
  68. void multi_tensor_adam_cuda(
  69. int chunk_size,
  70. at::Tensor noop_flag,
  71. std::vector<std::vector<at::Tensor>> tensor_lists,
  72. const float lr,
  73. const float beta1,
  74. const float beta2,
  75. const float epsilon,
  76. const int step,
  77. const int mode,
  78. const int bias_correction,
  79. const float weight_decay);
  80. void multi_tensor_adam_capturable_cuda(
  81. int chunk_size,
  82. at::Tensor noop_flag,
  83. std::vector<std::vector<at::Tensor>> tensor_lists,
  84. at::Tensor lr,
  85. const float beta1,
  86. const float beta2,
  87. const float epsilon,
  88. at::Tensor step,
  89. const int mode,
  90. const int bias_correction,
  91. const float weight_decay,
  92. at::Tensor inv_scale);
  93. void multi_tensor_adam_capturable_master_cuda(
  94. int chunk_size,
  95. at::Tensor noop_flag,
  96. std::vector<std::vector<at::Tensor>> tensor_lists,
  97. at::Tensor lr,
  98. const float beta1,
  99. const float beta2,
  100. const float epsilon,
  101. at::Tensor step,
  102. const int mode,
  103. const int bias_correction,
  104. const float weight_decay,
  105. at::Tensor inv_scale);
  106. void multi_tensor_adagrad_cuda(
  107. int chunk_size,
  108. at::Tensor noop_flag,
  109. std::vector<std::vector<at::Tensor>> tensor_lists,
  110. const float lr,
  111. const float epsilon,
  112. const int mode,
  113. const float weight_decay);
  114. void multi_tensor_novograd_cuda(
  115. int chunk_size,
  116. at::Tensor noop_flag,
  117. std::vector<std::vector<at::Tensor>> tensor_lists,
  118. at::Tensor grad_norms,
  119. const float lr,
  120. const float beta1,
  121. const float beta2,
  122. const float epsilon,
  123. const int step,
  124. const int bias_correction,
  125. const float weight_decay,
  126. const int grad_averaging,
  127. const int mode,
  128. const int norm_type);
  129. void multi_tensor_lamb_cuda(
  130. int chunk_size,
  131. at::Tensor noop_flag,
  132. std::vector<std::vector<at::Tensor>> tensor_lists,
  133. const float lr,
  134. const float beta1,
  135. const float beta2,
  136. const float epsilon,
  137. const int step,
  138. const int bias_correction,
  139. const float weight_decay,
  140. const int grad_averaging,
  141. const int mode,
  142. at::Tensor global_grad_norm,
  143. const float max_grad_norm,
  144. at::optional<bool> use_nvlamb_python);
  145. void multi_tensor_lamb_mp_cuda(
  146. int chunk_size,
  147. at::Tensor noop_flag,
  148. std::vector<std::vector<at::Tensor>> tensor_lists,
  149. at::Tensor lr,
  150. const float beta1,
  151. const float beta2,
  152. const float epsilon,
  153. at::Tensor step,
  154. const int bias_correction,
  155. const float weight_decay,
  156. const int grad_averaging,
  157. const int mode,
  158. at::Tensor global_grad_norm,
  159. at::Tensor max_grad_norm,
  160. at::optional<bool> use_nvlamb_python,
  161. at::Tensor found_inf,
  162. at::Tensor inv_scale);
  163. at::Tensor update_scale_hysteresis_cuda(
  164. at::Tensor current_scale,
  165. at::Tensor growth_tracker,
  166. at::Tensor hysteresis_tracker,
  167. at::Tensor found_inf,
  168. const double growth_factor,
  169. const double backoff_factor,
  170. const int64_t growth_interval,
  171. const int hysteresis);
  172. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  173. m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
  174. "Fused overflow check + scale for a list of contiguous tensors");
  175. m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
  176. "Fused SGD optimizer for list of contiguous tensors");
  177. m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda,
  178. "out = a*x + b*y for a list of contiguous tensors");
  179. m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
  180. "Computes L2 norm for a list of contiguous tensors");
  181. m.def("multi_tensor_l2norm_mp", &multi_tensor_l2norm_mp_cuda,
  182. "Computes L2 norm for a list of contiguous tensors");
  183. m.def("multi_tensor_l2norm_scale", &multi_tensor_l2norm_scale_cuda,
  184. "Computes L2 norm for a list of contiguous tensors and does scaling");
  185. m.def("multi_tensor_unscale_l2norm", &multi_tensor_unscale_l2norm_cuda,
  186. "Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only performed for L2 norm computation, and tensors are not updated)");
  187. m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda,
  188. "Computes update part of LAMB optimizer");
  189. m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda,
  190. "Completes application of gradient to parameters for LAMB optimizer");
  191. m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
  192. "Compute and apply gradient update to parameters for Adam optimizer");
  193. m.def("multi_tensor_adam_capturable", &multi_tensor_adam_capturable_cuda,
  194. "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph support and LR scheduling");
  195. m.def("multi_tensor_adam_capturable_master", &multi_tensor_adam_capturable_master_cuda,
  196. "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph support, LR scheduling and FP32 master weights");
  197. m.def("multi_tensor_adagrad", &multi_tensor_adagrad_cuda,
  198. "Compute and apply gradient update to parameters for Adam optimizer");
  199. m.def("multi_tensor_novograd", &multi_tensor_novograd_cuda,
  200. "Compute and apply gradient update to parameters for Adam optimizer");
  201. m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda,
  202. "Computes and apply update for LAMB optimizer");
  203. m.def("multi_tensor_lamb_mp", &multi_tensor_lamb_mp_cuda,
  204. "Computes and apply update for LAMB optimizer");
  205. m.def("update_scale_hysteresis", &update_scale_hysteresis_cuda,
  206. "Updates scale while accounting for hysteresis");
  207. }