123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195 |
- #include <cassert>
- #include <cstdio>
- #include <cstdlib>
- #include <cstring>
- #include <torch/extension.h>
- #include <ATen/ATen.h>
- #include <ATen/cuda/CUDAContext.h>
- /* Includes, cuda */
- #include <cublas_v2.h>
- #include <cuda_runtime.h>
- #include "type_shim.h"
- // BF16 Tensor core wrapper around cublas GEMMEx
- void gemmex_wrapper(
- cublasHandle_t handle,
- cublasOperation_t transa,
- cublasOperation_t transb,
- int m,
- int n,
- int k,
- const float* alpha,
- at::BFloat16* A,
- int lda,
- at::BFloat16* B,
- int ldb,
- const float* beta,
- float* C,
- int ldc) {
- TORCH_CUDABLAS_CHECK(cublasGemmEx(
- handle,
- transa,
- transb,
- m,
- n,
- k,
- alpha,
- A,
- CUDA_R_16BF,
- lda,
- B,
- CUDA_R_16BF,
- ldb,
- beta,
- C,
- CUDA_R_32F,
- ldc,
- CUDA_R_32F,
- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
- }
- // FP16 Tensor core wrapper around cublas GEMMEx
- void gemmex_wrapper(
- cublasHandle_t handle,
- cublasOperation_t transa,
- cublasOperation_t transb,
- int m,
- int n,
- int k,
- const float* alpha,
- at::Half* A,
- int lda,
- at::Half* B,
- int ldb,
- const float* beta,
- float* C,
- int ldc) {
- TORCH_CUDABLAS_CHECK(cublasGemmEx(
- handle,
- transa,
- transb,
- m,
- n,
- k,
- alpha,
- A,
- CUDA_R_16F,
- lda,
- B,
- CUDA_R_16F,
- ldb,
- beta,
- C,
- CUDA_R_32F,
- ldc,
- CUDA_R_32F,
- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
- }
- // FP32 wrapper around cublas GEMMEx
- void gemmex_wrapper(
- cublasHandle_t handle,
- cublasOperation_t transa,
- cublasOperation_t transb,
- int m,
- int n,
- int k,
- const float *alpha,
- float *A,
- int lda,
- float *B,
- int ldb,
- const float *beta,
- float *C,
- int ldc) {
- TORCH_CUDABLAS_CHECK(cublasGemmEx(
- handle,
- transa,
- transb,
- m,
- n,
- k,
- alpha,
- A,
- CUDA_R_32F,
- lda,
- B,
- CUDA_R_32F,
- ldb,
- beta,
- C,
- CUDA_R_32F,
- ldc,
- CUDA_R_32F,
- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
- }
- template <typename T>
- void wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) {
- cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
- cudaStream_t stream;
- cublasGetStream(handle, &stream);
- const float alpha = 1.0;
- const float beta = 1.0;
- gemmex_wrapper(
- handle,
- CUBLAS_OP_N,
- CUBLAS_OP_T,
- in_dim,
- out_dim,
- hidden_dim,
- &alpha,
- input,
- in_dim,
- d_output,
- out_dim,
- &beta,
- d_weight,
- in_dim);
- }
- template void wgrad_gemm_accum_fp32_cuda<at::Half>(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
- template void wgrad_gemm_accum_fp32_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
- template void wgrad_gemm_accum_fp32_cuda<float>(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
- void wgrad_gemm_accum_fp32_cuda_stub(
- at::Tensor &input,
- at::Tensor &d_output,
- at::Tensor &d_weight
- ) {
- at::Tensor input_2d, d_output_2d;
- // input tensor: collapse to the first dim
- auto in_sizes = input.sizes();
- if (input.dim() > 2) {
- input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});
- } else {
- input_2d = input;
- }
- // d_output tensor: collapse to the first dim
- auto d_out_sizes = d_output.sizes();
- if (d_output.dim() > 2) {
- d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});
- } else {
- d_output_2d = d_output;
- }
- const int hidden_dim = input_2d.size(0);
- const int in_dim = input_2d.size(1);
- const int out_dim = d_weight.size(0);
- DISPATCH_FLOAT_HALF_AND_BFLOAT(input_2d.scalar_type(), 0, "wgrad_gemm_accum_fp32",
- wgrad_gemm_accum_fp32_cuda<scalar_t_0>(
- input_2d.data_ptr<scalar_t_0>(),
- d_output_2d.data_ptr<scalar_t_0>(),
- d_weight.data_ptr<float>(),
- in_dim,
- hidden_dim,
- out_dim);
- );
- }
|