#include #include #include #include #include #include #include /* Includes, cuda */ #include #include #include "type_shim.h" // BF16 Tensor core wrapper around cublas GEMMEx void gemmex_wrapper( cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float* alpha, at::BFloat16* A, int lda, at::BFloat16* B, int ldb, const float* beta, float* C, int ldc) { TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, transa, transb, m, n, k, alpha, A, CUDA_R_16BF, lda, B, CUDA_R_16BF, ldb, beta, C, CUDA_R_32F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } // FP16 Tensor core wrapper around cublas GEMMEx void gemmex_wrapper( cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float* alpha, at::Half* A, int lda, at::Half* B, int ldb, const float* beta, float* C, int ldc) { TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, transa, transb, m, n, k, alpha, A, CUDA_R_16F, lda, B, CUDA_R_16F, ldb, beta, C, CUDA_R_32F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } // FP32 wrapper around cublas GEMMEx void gemmex_wrapper( cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float *alpha, float *A, int lda, float *B, int ldb, const float *beta, float *C, int ldc) { TORCH_CUDABLAS_CHECK(cublasGemmEx( handle, transa, transb, m, n, k, alpha, A, CUDA_R_32F, lda, B, CUDA_R_32F, ldb, beta, C, CUDA_R_32F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } template void wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) { cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream; cublasGetStream(handle, &stream); const float alpha = 1.0; const float beta = 1.0; gemmex_wrapper( handle, CUBLAS_OP_N, CUBLAS_OP_T, in_dim, out_dim, hidden_dim, &alpha, input, in_dim, d_output, out_dim, &beta, d_weight, in_dim); } template void wgrad_gemm_accum_fp32_cuda(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); template void wgrad_gemm_accum_fp32_cuda(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); template void wgrad_gemm_accum_fp32_cuda(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); void wgrad_gemm_accum_fp32_cuda_stub( at::Tensor &input, at::Tensor &d_output, at::Tensor &d_weight ) { at::Tensor input_2d, d_output_2d; // input tensor: collapse to the first dim auto in_sizes = input.sizes(); if (input.dim() > 2) { input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]}); } else { input_2d = input; } // d_output tensor: collapse to the first dim auto d_out_sizes = d_output.sizes(); if (d_output.dim() > 2) { d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]}); } else { d_output_2d = d_output; } const int hidden_dim = input_2d.size(0); const int in_dim = input_2d.size(1); const int out_dim = d_weight.size(0); DISPATCH_FLOAT_HALF_AND_BFLOAT(input_2d.scalar_type(), 0, "wgrad_gemm_accum_fp32", wgrad_gemm_accum_fp32_cuda( input_2d.data_ptr(), d_output_2d.data_ptr(), d_weight.data_ptr(), in_dim, hidden_dim, out_dim); ); }