fused_weight_gradient_dense_cuda.cu 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. #include <cassert>
  2. #include <cstdio>
  3. #include <cstdlib>
  4. #include <cstring>
  5. #include <torch/extension.h>
  6. #include <ATen/ATen.h>
  7. #include <ATen/cuda/CUDAContext.h>
  8. /* Includes, cuda */
  9. #include <cublas_v2.h>
  10. #include <cuda_runtime.h>
  11. #include "type_shim.h"
  12. // BF16 Tensor core wrapper around cublas GEMMEx
  13. void gemmex_wrapper(
  14. cublasHandle_t handle,
  15. cublasOperation_t transa,
  16. cublasOperation_t transb,
  17. int m,
  18. int n,
  19. int k,
  20. const float* alpha,
  21. at::BFloat16* A,
  22. int lda,
  23. at::BFloat16* B,
  24. int ldb,
  25. const float* beta,
  26. float* C,
  27. int ldc) {
  28. TORCH_CUDABLAS_CHECK(cublasGemmEx(
  29. handle,
  30. transa,
  31. transb,
  32. m,
  33. n,
  34. k,
  35. alpha,
  36. A,
  37. CUDA_R_16BF,
  38. lda,
  39. B,
  40. CUDA_R_16BF,
  41. ldb,
  42. beta,
  43. C,
  44. CUDA_R_32F,
  45. ldc,
  46. CUDA_R_32F,
  47. CUBLAS_GEMM_DEFAULT_TENSOR_OP));
  48. }
  49. // FP16 Tensor core wrapper around cublas GEMMEx
  50. void gemmex_wrapper(
  51. cublasHandle_t handle,
  52. cublasOperation_t transa,
  53. cublasOperation_t transb,
  54. int m,
  55. int n,
  56. int k,
  57. const float* alpha,
  58. at::Half* A,
  59. int lda,
  60. at::Half* B,
  61. int ldb,
  62. const float* beta,
  63. float* C,
  64. int ldc) {
  65. TORCH_CUDABLAS_CHECK(cublasGemmEx(
  66. handle,
  67. transa,
  68. transb,
  69. m,
  70. n,
  71. k,
  72. alpha,
  73. A,
  74. CUDA_R_16F,
  75. lda,
  76. B,
  77. CUDA_R_16F,
  78. ldb,
  79. beta,
  80. C,
  81. CUDA_R_32F,
  82. ldc,
  83. CUDA_R_32F,
  84. CUBLAS_GEMM_DEFAULT_TENSOR_OP));
  85. }
  86. // FP32 wrapper around cublas GEMMEx
  87. void gemmex_wrapper(
  88. cublasHandle_t handle,
  89. cublasOperation_t transa,
  90. cublasOperation_t transb,
  91. int m,
  92. int n,
  93. int k,
  94. const float *alpha,
  95. float *A,
  96. int lda,
  97. float *B,
  98. int ldb,
  99. const float *beta,
  100. float *C,
  101. int ldc) {
  102. TORCH_CUDABLAS_CHECK(cublasGemmEx(
  103. handle,
  104. transa,
  105. transb,
  106. m,
  107. n,
  108. k,
  109. alpha,
  110. A,
  111. CUDA_R_32F,
  112. lda,
  113. B,
  114. CUDA_R_32F,
  115. ldb,
  116. beta,
  117. C,
  118. CUDA_R_32F,
  119. ldc,
  120. CUDA_R_32F,
  121. CUBLAS_GEMM_DEFAULT_TENSOR_OP));
  122. }
  123. template <typename T>
  124. void wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) {
  125. cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
  126. cudaStream_t stream;
  127. cublasGetStream(handle, &stream);
  128. const float alpha = 1.0;
  129. const float beta = 1.0;
  130. gemmex_wrapper(
  131. handle,
  132. CUBLAS_OP_N,
  133. CUBLAS_OP_T,
  134. in_dim,
  135. out_dim,
  136. hidden_dim,
  137. &alpha,
  138. input,
  139. in_dim,
  140. d_output,
  141. out_dim,
  142. &beta,
  143. d_weight,
  144. in_dim);
  145. }
  146. template void wgrad_gemm_accum_fp32_cuda<at::Half>(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
  147. template void wgrad_gemm_accum_fp32_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
  148. template void wgrad_gemm_accum_fp32_cuda<float>(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
  149. void wgrad_gemm_accum_fp32_cuda_stub(
  150. at::Tensor &input,
  151. at::Tensor &d_output,
  152. at::Tensor &d_weight
  153. ) {
  154. at::Tensor input_2d, d_output_2d;
  155. // input tensor: collapse to the first dim
  156. auto in_sizes = input.sizes();
  157. if (input.dim() > 2) {
  158. input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});
  159. } else {
  160. input_2d = input;
  161. }
  162. // d_output tensor: collapse to the first dim
  163. auto d_out_sizes = d_output.sizes();
  164. if (d_output.dim() > 2) {
  165. d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});
  166. } else {
  167. d_output_2d = d_output;
  168. }
  169. const int hidden_dim = input_2d.size(0);
  170. const int in_dim = input_2d.size(1);
  171. const int out_dim = d_weight.size(0);
  172. DISPATCH_FLOAT_HALF_AND_BFLOAT(input_2d.scalar_type(), 0, "wgrad_gemm_accum_fp32",
  173. wgrad_gemm_accum_fp32_cuda<scalar_t_0>(
  174. input_2d.data_ptr<scalar_t_0>(),
  175. d_output_2d.data_ptr<scalar_t_0>(),
  176. d_weight.data_ptr<float>(),
  177. in_dim,
  178. hidden_dim,
  179. out_dim);
  180. );
  181. }