fused_weight_gradient_dense_16bit_prec_cuda.cu 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. #include <cassert>
  2. #include <cstdio>
  3. #include <cstdlib>
  4. #include <cstring>
  5. #include <torch/extension.h>
  6. #include <ATen/ATen.h>
  7. #include <ATen/cuda/CUDAContext.h>
  8. /* Includes, cuda */
  9. #include <cublas_v2.h>
  10. #include <cuda_runtime.h>
  11. #include "type_shim.h"
  12. // BF16 inputs and BF16 accumulation
  13. void gemmex_wrapper_fp16(
  14. cublasHandle_t handle,
  15. cublasOperation_t transa,
  16. cublasOperation_t transb,
  17. int m,
  18. int n,
  19. int k,
  20. const float* alpha,
  21. at::BFloat16* A,
  22. int lda,
  23. at::BFloat16* B,
  24. int ldb,
  25. const float* beta,
  26. at::BFloat16* C,
  27. int ldc) {
  28. TORCH_CUDABLAS_CHECK(cublasGemmEx(
  29. handle,
  30. transa,
  31. transb,
  32. m,
  33. n,
  34. k,
  35. alpha,
  36. A,
  37. CUDA_R_16BF,
  38. lda,
  39. B,
  40. CUDA_R_16BF,
  41. ldb,
  42. beta,
  43. C,
  44. CUDA_R_16BF,
  45. ldc,
  46. CUDA_R_32F,
  47. CUBLAS_GEMM_DEFAULT_TENSOR_OP));
  48. }
  49. // FP16 inputs and FP16 accumulation
  50. void gemmex_wrapper_fp16(
  51. cublasHandle_t handle,
  52. cublasOperation_t transa,
  53. cublasOperation_t transb,
  54. int m,
  55. int n,
  56. int k,
  57. const float* alpha,
  58. at::Half* A,
  59. int lda,
  60. at::Half* B,
  61. int ldb,
  62. const float* beta,
  63. at::Half* C,
  64. int ldc) {
  65. TORCH_CUDABLAS_CHECK(cublasGemmEx(
  66. handle,
  67. transa,
  68. transb,
  69. m,
  70. n,
  71. k,
  72. alpha,
  73. A,
  74. CUDA_R_16F,
  75. lda,
  76. B,
  77. CUDA_R_16F,
  78. ldb,
  79. beta,
  80. C,
  81. CUDA_R_16F,
  82. ldc,
  83. CUDA_R_32F,
  84. CUBLAS_GEMM_DEFAULT_TENSOR_OP));
  85. }
  86. template <typename T>
  87. void wgrad_gemm_accum_fp16_cuda(T *input, T *d_output, T *d_weight, int in_dim, int hidden_dim, int out_dim) {
  88. cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
  89. cudaStream_t stream;
  90. cublasGetStream(handle, &stream);
  91. const float alpha = 1.0;
  92. const float beta = 1.0;
  93. gemmex_wrapper_fp16(
  94. handle,
  95. CUBLAS_OP_N,
  96. CUBLAS_OP_T,
  97. in_dim,
  98. out_dim,
  99. hidden_dim,
  100. &alpha,
  101. input,
  102. in_dim,
  103. d_output,
  104. out_dim,
  105. &beta,
  106. d_weight,
  107. in_dim);
  108. }
  109. template void wgrad_gemm_accum_fp16_cuda<at::Half>(at::Half *input, at::Half *d_output, at::Half *d_weight, int in_dim, int hidden_dim, int out_dim);
  110. template void wgrad_gemm_accum_fp16_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, at::BFloat16 *d_weight, int in_dim, int hidden_dim, int out_dim);
  111. void wgrad_gemm_accum_fp16_cuda_stub(
  112. at::Tensor &input,
  113. at::Tensor &d_output,
  114. at::Tensor &d_weight
  115. ) {
  116. at::Tensor input_2d, d_output_2d;
  117. // input tensor: collapse to the first dim
  118. auto in_sizes = input.sizes();
  119. if (input.dim() > 2) {
  120. input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});
  121. } else {
  122. input_2d = input;
  123. }
  124. // d_output tensor: collapse to the first dim
  125. auto d_out_sizes = d_output.sizes();
  126. if (d_output.dim() > 2) {
  127. d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});
  128. } else {
  129. d_output_2d = d_output;
  130. }
  131. const int hidden_dim = input_2d.size(0);
  132. const int in_dim = input_2d.size(1);
  133. const int out_dim = d_weight.size(0);
  134. DISPATCH_HALF_AND_BFLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp16",
  135. wgrad_gemm_accum_fp16_cuda<scalar_t>(
  136. input_2d.data_ptr<scalar_t>(),
  137. d_output_2d.data_ptr<scalar_t>(),
  138. d_weight.data_ptr<scalar_t>(),
  139. in_dim,
  140. hidden_dim,
  141. out_dim);
  142. );
  143. }