1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678 |
- #include <ATen/ATen.h>
- #include <ATen/cuda/CUDAContext.h>
- #include <assert.h>
- #include <stdio.h>
- #include <stdlib.h>
- #include <string.h>
- #include <torch/torch.h>
- /* Includes, cuda */
- #include <cublas_v2.h>
- #include <cuda_runtime.h>
- #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
- // includes cublaslt
- #include <cublasLt.h>
- #endif
- // constants for fused bias+relu kernel
- #define BIAS_RELU_FW_NTHREADS 128 // forward number of thread per block
- #define BIAS_RELU_BW_NTHREADS_X 32 // backward number of thread in feature dim
- #define BIAS_RELU_BW_NTHREADS_Y 16 // backward number of thread in batch dim
- #define BIAS_RELU_RED_PER_THREAD 16 // backward minimal reduction length per thread
- // move to a header later on
- #define ILP 4
- template<typename T>
- __host__ __device__ __forceinline__ bool is_aligned(T* p){
- return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
- }
- template<typename T>
- __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
- typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
- ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
- }
- template<typename T>
- __device__ __forceinline__ void load_store(T* dst, volatile T* src, int dst_offset, int src_offset){
- typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
- ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
- }
- template<typename T>
- __device__ __forceinline__ void load_store(volatile T* dst, T* src, int dst_offset, int src_offset){
- typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
- ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
- }
- // Keep ReLU in float only. When using half, cast to float before calling.
- __device__ __inline__ float relu(float a) {
- float retf = max(a, 0.f);
- return (retf);
- }
- // Keep Sigmoid in float only. When using half, cast to float before calling.
- __device__ __inline__ float sigmoid(float a) {
- float retf = 1.f / (1.f + expf(-a));
- return (retf);
- }
- // FP64 Wrapper around cublas GEMMEx
- cublasStatus_t mlp_gemm(
- cublasHandle_t handle,
- cublasOperation_t transa,
- cublasOperation_t transb,
- int m,
- int n,
- int k,
- float* alpha,
- const double* A,
- int lda,
- const double* B,
- int ldb,
- const float* beta,
- double* C,
- int ldc) {
- return cublasGemmEx(
- handle,
- transa,
- transb,
- m,
- n,
- k,
- alpha,
- A,
- CUDA_R_64F,
- lda,
- B,
- CUDA_R_64F,
- ldb,
- beta,
- C,
- CUDA_R_64F,
- ldc,
- CUDA_R_64F,
- CUBLAS_GEMM_DEFAULT);
- }
- // FP32 Wrapper around cublas GEMMEx
- cublasStatus_t mlp_gemm(
- cublasHandle_t handle,
- cublasOperation_t transa,
- cublasOperation_t transb,
- int m,
- int n,
- int k,
- float* alpha,
- const float* A,
- int lda,
- const float* B,
- int ldb,
- const float* beta,
- float* C,
- int ldc) {
- return cublasGemmEx(
- handle,
- transa,
- transb,
- m,
- n,
- k,
- alpha,
- A,
- CUDA_R_32F,
- lda,
- B,
- CUDA_R_32F,
- ldb,
- beta,
- C,
- CUDA_R_32F,
- ldc,
- CUDA_R_32F,
- CUBLAS_GEMM_DEFAULT);
- }
- // FP16 Tensor core wrapper around cublas GEMMEx
- cublasStatus_t mlp_gemm(
- cublasHandle_t handle,
- cublasOperation_t transa,
- cublasOperation_t transb,
- int m,
- int n,
- int k,
- float* alpha,
- const at::Half* A,
- int lda,
- const at::Half* B,
- int ldb,
- float* beta,
- at::Half* C,
- int ldc) {
- return cublasGemmEx(
- handle,
- transa,
- transb,
- m,
- n,
- k,
- alpha,
- A,
- CUDA_R_16F,
- lda,
- B,
- CUDA_R_16F,
- ldb,
- beta,
- C,
- CUDA_R_16F,
- ldc,
- CUDA_R_32F,
- CUBLAS_GEMM_DEFAULT_TENSOR_OP);
- }
- #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
- int mlp_gemm_lt(
- cublasLtHandle_t ltHandle,
- cublasOperation_t transa,
- cublasOperation_t transb,
- int m,
- int n,
- int k,
- float *alpha, /* host pointer */
- const at::Half* A,
- int lda,
- const at::Half* B,
- int ldb,
- float *beta, /* host pointer */
- at::Half* C,
- int ldc,
- void *workspace,
- size_t workspaceSize,
- cudaStream_t stream,
- bool use_bias,
- bool use_relu,
- const void* bias) {
- cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
- cublasLtMatmulDescOpaque_t operationDesc = {};
- cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
- cublasLtMatmulPreferenceOpaque_t preference = {};
- int returnedResults = 0;
- cublasLtMatmulHeuristicResult_t heuristicResult = {};
- cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
- // Create operation descriptor; see cublasLtMatmulDescAttributes_t
- // for details about defaults; here we just set the transforms for
- // A and B.
- status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- if (use_bias) {
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
- if (status != CUBLAS_STATUS_SUCCESS) {
- goto CLEANUP;
- }
- if (use_relu) {
- epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
- } else {
- epilogue = CUBLASLT_EPILOGUE_BIAS;
- }
- } else {
- if (use_relu) {
- epilogue = CUBLASLT_EPILOGUE_RELU;
- }
- }
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
- if (status != CUBLAS_STATUS_SUCCESS) {
- goto CLEANUP;
- }
- // Create matrix descriptors. Not setting any extra attributes.
- status = cublasLtMatrixLayoutInit(
- &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatrixLayoutInit(
- &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- // Create preference handle; In general, extra attributes can be
- // used here to disable tensor ops or to make sure algo selected
- // will work with badly aligned A, B, C. However, for simplicity
- // here we assume A,B,C are always well aligned (e.g., directly
- // come from cudaMalloc)
- status = cublasLtMatmulPreferenceInit(&preference);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatmulPreferenceSetAttribute(
- &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- // We just need the best available heuristic to try and run matmul.
- // There is no guarantee that this will work. For example, if A is
- // badly aligned, you can request more (e.g. 32) algos and try to
- // run them one by one until something works.
- status = cublasLtMatmulAlgoGetHeuristic(
- ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- if (returnedResults == 0) {
- status = CUBLAS_STATUS_NOT_SUPPORTED;
- goto CLEANUP;
- }
- status = cublasLtMatmul(ltHandle,
- &operationDesc,
- alpha,
- A,
- &Adesc,
- B,
- &Bdesc,
- beta,
- C,
- &Cdesc,
- C,
- &Cdesc,
- &heuristicResult.algo,
- workspace,
- workspaceSize,
- stream);
- CLEANUP:
- // Descriptors are no longer needed as all GPU work was already
- // enqueued.
- return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
- }
- int mlp_gemm_lt(
- cublasLtHandle_t ltHandle,
- cublasOperation_t transa,
- cublasOperation_t transb,
- int m,
- int n,
- int k,
- float *alpha, /* host pointer */
- const double* A,
- int lda,
- const double* B,
- int ldb,
- float *beta, /* host pointer */
- double* C,
- int ldc,
- void *workspace,
- size_t workspaceSize,
- cudaStream_t stream,
- bool use_bias,
- bool use_relu,
- const void* bias) {
- return 1;
- }
- int mlp_gemm_lt(
- cublasLtHandle_t ltHandle,
- cublasOperation_t transa,
- cublasOperation_t transb,
- int m,
- int n,
- int k,
- float *alpha, /* host pointer */
- const float *A,
- int lda,
- const float *B,
- int ldb,
- float *beta, /* host pointer */
- float *C,
- int ldc,
- void *workspace,
- size_t workspaceSize,
- cudaStream_t stream,
- bool use_bias,
- bool use_relu,
- const void* bias) {
- cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
- cublasLtMatmulDescOpaque_t operationDesc = {};
- cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
- cublasLtMatmulPreferenceOpaque_t preference = {};
- int returnedResults = 0;
- cublasLtMatmulHeuristicResult_t heuristicResult = {};
- cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
- // Create operation descriptor; see cublasLtMatmulDescAttributes_t
- // for details about defaults; here we just set the transforms for
- // A and B.
- status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- if (use_bias) {
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
- if (status != CUBLAS_STATUS_SUCCESS) {
- goto CLEANUP;
- }
- if (use_relu) {
- epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
- } else {
- epilogue = CUBLASLT_EPILOGUE_BIAS;
- }
- } else {
- if (use_relu) {
- epilogue = CUBLASLT_EPILOGUE_RELU;
- }
- }
- status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
- if (status != CUBLAS_STATUS_SUCCESS) {
- goto CLEANUP;
- }
- // Create matrix descriptors. Not setting any extra attributes.
- status = cublasLtMatrixLayoutInit(
- &Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatrixLayoutInit(
- &Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- // Create preference handle; In general, extra attributes can be
- // used here to disable tensor ops or to make sure algo selected
- // will work with badly aligned A, B, C. However, for simplicity
- // here we assume A,B,C are always well aligned (e.g., directly
- // come from cudaMalloc)
- status = cublasLtMatmulPreferenceInit(&preference);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- status = cublasLtMatmulPreferenceSetAttribute(
- &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- // We just need the best available heuristic to try and run matmul.
- // There is no guarantee that this will work. For example, if A is
- // badly aligned, you can request more (e.g. 32) algos and try to
- // run them one by one until something works.
- status = cublasLtMatmulAlgoGetHeuristic(
- ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
- if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
- if (returnedResults == 0) {
- status = CUBLAS_STATUS_NOT_SUPPORTED;
- goto CLEANUP;
- }
- status = cublasLtMatmul(ltHandle,
- &operationDesc,
- alpha,
- A,
- &Adesc,
- B,
- &Bdesc,
- beta,
- C,
- &Cdesc,
- C,
- &Cdesc,
- &heuristicResult.algo,
- workspace,
- workspaceSize,
- stream);
- CLEANUP:
- // Descriptors are no longer needed as all GPU work was already
- // enqueued.
- return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
- }
- #endif
- // Bias ADD. Assume input X is [features x batch size], column major.
- // Bias is one 'features' long vector, with implicit broadcast.
- template <typename T>
- __global__ void biasAdd_fprop(T *X, T *b, uint batch_size, uint features) {
- T r_x[ILP];
- T r_b[ILP];
- if(is_aligned(X) && is_aligned(b) && features % ILP ==0) {
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
- for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
- int row = tid % (features / ILP);
- load_store(r_x, X, 0 , tid);
- load_store(r_b, b, 0 , row);
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);
- r_x[ii] = bias_sum;
- }
- load_store(X, r_x, tid , 0);
- }
- } else {
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
- for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- int idx = tid + ii * blockDim.x * gridDim.x;
- if(idx < features * batch_size) {
- int row = tid % features;
- r_x[ii] = X[idx];
- r_b[ii] = b[row];
- }
- }
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);
- r_x[ii] = bias_sum;
- }
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- int idx = tid + ii * blockDim.x * gridDim.x;
- if(idx < features * batch_size) {
- X[idx] = r_x[ii];
- }
- }
- }
- }
- }
- // Bias ADD + ReLU. Assume input X is [features x batch size], column major.
- // Activation support fuesed ReLU. Safe to call in-place.
- template <typename T>
- __global__ void biasAddRelu_fprop(T *X, T *b, uint batch_size, uint features) {
- T r_x[ILP];
- T r_b[ILP];
- if(is_aligned(X) && is_aligned(b) && features % ILP ==0) {
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
- for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
- int row = tid % (features / ILP);
- load_store(r_x, X, 0 , tid);
- load_store(r_b, b, 0 , row);
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);
- r_x[ii] = relu(bias_sum);
- }
- load_store(X, r_x, tid , 0);
- }
- } else {
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
- for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- int idx = tid + ii * blockDim.x * gridDim.x;
- if(idx < features * batch_size) {
- int row = tid % features;
- r_x[ii] = X[idx];
- r_b[ii] = b[row];
- }
- }
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);
- r_x[ii] = relu(bias_sum);
- }
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- int idx = tid + ii * blockDim.x * gridDim.x;
- if(idx < features * batch_size) {
- X[idx] = r_x[ii];
- }
- }
- }
- }
- }
- // ReLU. Assume input X is [features x batch size], column major.
- // Safe to call in-place.
- template <typename T>
- __global__ void Relu_fprop(T *X, uint batch_size, uint features) {
- T r_x[ILP];
- if(is_aligned(X) && features % ILP ==0) {
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
- for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
- load_store(r_x, X, 0 , tid);
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- r_x[ii] = relu(static_cast<float>(r_x[ii]));
- }
- load_store(X, r_x, tid , 0);
- }
- } else {
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
- for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- int idx = tid + ii * blockDim.x * gridDim.x;
- if(idx < features * batch_size) {
- r_x[ii] = X[idx];
- }
- }
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- r_x[ii] = relu(static_cast<float>(r_x[ii]));
- }
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- int idx = tid + ii * blockDim.x * gridDim.x;
- if(idx < features * batch_size) {
- X[idx] = r_x[ii];
- }
- }
- }
- }
- }
- // Sigmoid. Assume input X is [features x batch size], column major.
- // Safe to call in-place.
- template <typename T>
- __global__ void Sigmoid_fprop(T *X, uint batch_size, uint features) {
- T r_x[ILP];
- if(is_aligned(X) && features % ILP ==0) {
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
- for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
- load_store(r_x, X, 0 , tid);
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- r_x[ii] = sigmoid(static_cast<float>(r_x[ii]));
- }
- load_store(X, r_x, tid , 0);
- }
- } else {
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
- for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- int idx = tid + ii * blockDim.x * gridDim.x;
- if(idx < features * batch_size) {
- r_x[ii] = X[idx];
- }
- }
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- r_x[ii] = sigmoid(static_cast<float>(r_x[ii]));
- }
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- int idx = tid + ii * blockDim.x * gridDim.x;
- if(idx < features * batch_size) {
- X[idx] = r_x[ii];
- }
- }
- }
- }
- }
- // ReLU. Assume input X is [features x batch size], column major.
- // Safe to call in-place.
- template <typename T>
- __global__ void Relu_bprop(T *dY, T *Y, uint batch_size, uint features, T *dX) {
- T r_dy[ILP];
- T r_y[ILP];
- if(is_aligned(dY) &&
- is_aligned(Y) &&
- is_aligned(dX) &&
- features % ILP ==0) {
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
- for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
- load_store(r_dy, dY, 0 , tid);
- load_store(r_y, Y, 0 , tid);
- #pragma unroll
- for(int ii=0;ii<ILP;ii++){
- if ((float)r_y[ii] <= 0.f)
- r_dy[ii] = 0;
- }
- load_store(dX, r_dy, tid, 0);
- }
- } else {
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
- for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- int idx = tid + ii * blockDim.x * gridDim.x;
- if(idx < features * batch_size) {
- r_dy[ii] = dY[idx];
- r_y[ii] = Y[idx];
- }
- }
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- if ((float)r_y[ii] <= 0.f)
- r_dy[ii] = 0;
- }
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- int idx = tid + ii * blockDim.x * gridDim.x;
- if(idx < features * batch_size) {
- dX[idx] = r_dy[ii];
- }
- }
- }
- }
- }
- // Sigmoid. Assume input X is [features x batch size], column major.
- // Safe to call in-place.
- template <typename T>
- __global__ void Sigmoid_bprop(T *dY, T *Y, uint batch_size, uint features, T *dX) {
- T r_dy[ILP];
- T r_y[ILP];
- if(is_aligned(dY) &&
- is_aligned(Y) &&
- is_aligned(dX) &&
- features % ILP ==0) {
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
- for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
- load_store(r_dy, dY, 0 , tid);
- load_store(r_y, Y, 0 , tid);
- #pragma unroll
- for(int ii=0;ii<ILP;ii++){
- float grad_out = r_dy[ii];
- float out = r_y[ii];
- float grad_i = out * ( 1.f - out) * grad_out;
- r_dy[ii] = grad_i;
- }
- load_store(dX, r_dy, tid, 0);
- }
- } else {
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
- for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- int idx = tid + ii * blockDim.x * gridDim.x;
- if(idx < features * batch_size) {
- r_dy[ii] = dY[idx];
- r_y[ii] = Y[idx];
- }
- }
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- float grad_out = r_dy[ii];
- float out = r_y[ii];
- float grad_i = out * ( 1.f - out) * grad_out;
- r_dy[ii] = grad_i;
- }
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++) {
- int idx = tid + ii * blockDim.x * gridDim.x;
- if(idx < features * batch_size) {
- dX[idx] = r_dy[ii];
- }
- }
- }
- }
- }
- // Compute grid size for pointwise backward kernel.
- // block_x/y is total elment being handled per block, not number of threads
- void get_biasAddRelu_bprop_grid_size(
- int yfeat,
- int batch_size,
- int block_x,
- int block_y,
- int* grid_x,
- int* grid_y) {
- *grid_x = (yfeat + block_x - 1) / block_x;
- // Get number of SMs for efficient reduction.
- int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
- // can switch to occupancy calculation. use 4 below now for sm_70
- int max_blocks_y = (num_SMs * 4+(*grid_x)-1) / (*grid_x);
- // block_y should be from minimal work per thread
- int nRedSplits = (batch_size + block_y - 1) / block_y;
- // increase number of elem per thread redcution to not launch more than enough
- // kernel adjust work, so here we just launch max block
- *grid_y = std::min(nRedSplits, max_blocks_y);
- return;
- }
- // Addition done deterministically via a 2-pass approach. Each CTA writes out partial
- // sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.
- template <typename T, int UNROLL_FACTOR>
- __global__ void biasAdd_bprop(
- T* dY,
- int features,
- int batch_size,
- volatile float* intermediate,
- int* semaphores,
- T* db) {
- // The feature that this thread is responsible for
- int f = blockIdx.x * blockDim.x + threadIdx.x;
- // Compute the span this thread is responsible for
- // For this block
- int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;
- int b_nStart = blockIdx.y * b_chunkSize;
- int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;
- // For this thread
- int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;
- int nStart = threadIdx.y * chunkSize + b_nStart;
- int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;
- volatile float* out = intermediate + blockIdx.y * features;
- // Flag to trigger last reduction.
- __shared__ bool isLastBlock;
- // we know block size for now
- __shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y];
- // Accumulate db in FP32 always
- float db_local = 0;
- if (f < features) {
- int nidx = 0;
- // Handle non-multiple of UNROLL_FACTOR residue
- for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {
- int64_t row, col, flat_idx;
- row = f;
- col = nStart + nidx;
- flat_idx = col * features + row;
- db_local += (float)dY[flat_idx];
- }
- // Handle meat of work
- for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {
- int64_t row, col, flat_idx;
- row = f;
- col = nStart + nidx;
- flat_idx = col * features + row;
- #pragma unroll 4
- for (int u = 0; u < UNROLL_FACTOR; u++) {
- db_local += (float)dY[flat_idx];
- flat_idx += features;
- }
- }
- // naive block reduction on y-dim
- int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;
- smem[linear_idx] = db_local;
- }
- __syncthreads();
- if (f < features) {
- if(threadIdx.y == 0) {
- for(int yidx = 1; yidx < blockDim.y; yidx++){
- db_local += smem[yidx * blockDim.x + threadIdx.x];
- }
- // block result is in db_local now for all threadIdx.y == 0
- // Write out partial result
- out[f] = db_local;
- }
- }
- __threadfence();
- __syncthreads();
- // Increment semaphore and check if this is the last CTA in the grid_y dimension.
- // Only thread (0,0) calls this
- if (threadIdx.x == 0 && threadIdx.y == 0 && f < features) {
- unsigned int sum_idx;
- sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);
- isLastBlock = (sum_idx == (gridDim.y - 1));
- }
- __syncthreads();
- db_local = 0;
- // No block reduction for now, only thread (*,0) do grid reduction
- if (isLastBlock && f < features) {
- if(threadIdx.y == 0) {
- for (int n = 0; n < gridDim.y; n++) {
- int row, col;
- row = f;
- col = n;
- db_local += (float)(intermediate[col * features + row]);
- }
- db[f] = (T)db_local;
- }
- }
- }
- // Addition done deterministically via a 2-pass approach. Each CTA writes out partial
- // sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.
- template <typename T, int UNROLL_FACTOR>
- __global__ void biasAddRelu_bprop(
- T* Y,
- T* dY,
- int features,
- int batch_size,
- T* dX,
- volatile float* intermediate,
- int* semaphores,
- T* db) {
- // The feature that this thread is responsible for
- int f = blockIdx.x * blockDim.x + threadIdx.x;
- // Compute the span this thread is responsible for
- // For this block
- int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;
- int b_nStart = blockIdx.y * b_chunkSize;
- int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;
- // For this thread
- int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;
- int nStart = threadIdx.y * chunkSize + b_nStart;
- int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;
- volatile float* out = intermediate + blockIdx.y * features;
- // Flag to trigger last reduction.
- __shared__ bool isLastBlock;
- // we know block size for now
- __shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y];
- // Accumulate db in FP32 always
- float db_local = 0;
- if (f < features) {
- int nidx = 0;
- // Handle non-multiple of UNROLL_FACTOR residue
- for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {
- int row, col, flat_idx;
- row = f;
- col = nStart + nidx;
- flat_idx = col * features + row;
- T y_val = Y[flat_idx];
- T dy_val = dY[flat_idx];
- T dx_val;
- if ((float)y_val > 0.f)
- dx_val = dy_val;
- else
- dx_val = 0;
- dX[flat_idx] = dx_val;
- db_local += (float)dx_val;
- }
- // Handle meat of work
- for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {
- int row, col, flat_idx;
- row = f;
- col = nStart + nidx;
- flat_idx = col * features + row;
- #pragma unroll 4
- for (int u = 0; u < UNROLL_FACTOR; u++) {
- T y_val = Y[flat_idx];
- T dy_val = dY[flat_idx];
- T dx_val;
- if ((float)y_val > 0.f)
- dx_val = dy_val;
- else
- dx_val = 0;
- dX[flat_idx] = dx_val;
- db_local += (float)dx_val;
- flat_idx += features;
- }
- }
- // naive block reduction on y-dim
- int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;
- smem[linear_idx] = db_local;
- }
- __syncthreads();
- if (f < features) {
- if(threadIdx.y == 0) {
- for(int yidx = 1; yidx < blockDim.y; yidx++){
- db_local += smem[yidx * blockDim.x + threadIdx.x];
- }
- // block result is in db_local now for all threadIdx.y == 0
- // Write out partial result
- out[f] = db_local;
- }
- }
- __threadfence();
- __syncthreads();
- // Increment semaphore and check if this is the last CTA in the grid_y dimension.
- // Only thread (0,0) calls this
- if (threadIdx.x == 0 && threadIdx.y == 0 && f < features) {
- unsigned int sum_idx;
- sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);
- isLastBlock = (sum_idx == (gridDim.y - 1));
- }
- __syncthreads();
- db_local = 0;
- // No block reduction for now, only thread (*,0) do grid reduction
- if (isLastBlock && f < features) {
- if(threadIdx.y == 0) {
- for (int n = 0; n < gridDim.y; n++) {
- int row, col;
- row = f;
- col = n;
- db_local += (float)(intermediate[col * features + row]);
- }
- db[f] = (T)db_local;
- }
- }
- }
- // Addition done deterministically via a 2-pass approach. Each CTA writes out partial
- // sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.
- template <typename T, int UNROLL_FACTOR>
- __global__ void biasAddRelu_bprop_aligned(
- T* Y,
- T* dY,
- int features,
- int batch_size,
- T* dX,
- volatile float* intermediate,
- int* semaphores,
- T* db) {
- // The feature that this thread is responsible for
- int f = blockIdx.x * blockDim.x + threadIdx.x;
- // Compute the span this thread is responsible for
- // For this block
- int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;
- int b_nStart = blockIdx.y * b_chunkSize;
- int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;
- // For this thread
- int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;
- int nStart = threadIdx.y * chunkSize + b_nStart;
- int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;
- volatile float* out = intermediate + blockIdx.y * features;
- // Flag to trigger last reduction.
- __shared__ bool isLastBlock;
- // Accumulate db in FP32 always
- float db_local[ILP];
- T r_y[ILP];
- T r_dy[ILP];
- #pragma unroll
- for(int ii=0;ii<ILP;ii++){
- db_local[ii] = 0.f;
- }
- // f always <= features in this case
- //if (f < features) {
- int nidx = 0;
- // Handle non-multiple of UNROLL_FACTOR residue
- for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {
- int row, col, flat_idx;
- row = f;
- col = nStart + nidx;
- flat_idx = col * features / ILP + row;
- load_store(r_y, Y, 0, flat_idx);
- load_store(r_dy, dY, 0, flat_idx);
- #pragma unroll
- for(int ii=0;ii<ILP;ii++){
- if ((float)r_y[ii] <= 0.f)
- r_dy[ii] = 0;
- db_local[ii] += (float)r_dy[ii];
- }
- load_store(dX, r_dy, flat_idx, 0);
- }
- // Handle meat of work
- for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {
- int row, col, flat_idx;
- row = f;
- col = nStart + nidx;
- flat_idx = col * features / ILP + row; // total threads in x == features/ILP
- #pragma unroll
- for (int u = 0; u < UNROLL_FACTOR; u++) {
- load_store(r_y, Y, 0, flat_idx);
- load_store(r_dy, dY, 0, flat_idx);
- #pragma unroll
- for(int ii=0;ii<ILP;ii++){
- if ((float)r_y[ii] <= 0.f)
- r_dy[ii] = 0;
- db_local[ii] += (float)r_dy[ii];
- }
- load_store(dX, r_dy, flat_idx, 0);
- flat_idx += features/ILP;
- }
- }
- // we know block size for now
- __shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y*ILP];
- // naive block reduction on y-dim
- int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;
- float* smem_out = smem + ILP * linear_idx;
- #pragma unroll
- for(int ii=0;ii<ILP;ii++){
- smem_out[ii] = db_local[ii]; // reuse local dy buffer
- }
- __syncthreads();
- if(threadIdx.y == 0) {
- for(int yidx = 1; yidx < blockDim.y; yidx++){
- float* smem_in = smem + ILP * (yidx * blockDim.x + threadIdx.x);
- #pragma unroll
- for(int ii=0;ii<ILP;ii++){
- db_local[ii] += smem_in[ii]; // reuse local dy buffer
- }
- }
- // block result is in db_local now for all threadIdx.y == 0
- if(gridDim.y == 1) {
- #pragma unroll
- for(int ii=0;ii<ILP;ii++){
- r_dy[ii] = db_local[ii]; // reuse local dy buffer
- }
- load_store(db, r_dy, f, 0);
- return;
- }
- // Write out partial result
- load_store(out, db_local, f, 0);
- }
- __threadfence();
- __syncthreads();
- // Increment semaphore and check if this is the last CTA in the grid_y dimension.
- // Only thread (0,0) calls this
- if (threadIdx.x == 0 && threadIdx.y == 0) {
- unsigned int sum_idx;
- sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);
- isLastBlock = (sum_idx == (gridDim.y - 1));
- }
- __syncthreads();
- #pragma unroll
- for(int ii=0;ii<ILP;ii++){
- db_local[ii] = 0.f;
- }
- float r_db[ILP];
- // No block reduction for now, only thread (*,0) do grid reduction
- if (isLastBlock) {
- if(threadIdx.y == 0){
- for (int n = 0; n < gridDim.y; n++) {
- int row, col;
- row = f;
- col = n;
- load_store(r_db, intermediate, 0, col * features / ILP + row);
- #pragma unroll
- for(int ii=0;ii<ILP;ii++){
- db_local[ii] += r_db[ii];
- }
- }
- #pragma unroll
- for(int ii=0;ii<ILP;ii++){
- r_dy[ii] = db_local[ii]; // reuse local dy buffer
- }
- load_store(db, r_dy, f, 0);
- }
- }
- }
- // Lists where the num_layers-1 intermediate Y buffers start in reserved space on fprop, starting
- // offset 0. The last Y value is, of course, stored in the user provided output buffer.
- void get_y_offsets(
- int batch_size,
- int num_layers,
- const int* output_features,
- int* y_start_offsets) {
- y_start_offsets[0] = 0;
- for (int i = 1; i < num_layers; i++) {
- y_start_offsets[i] = y_start_offsets[i - 1] + batch_size * output_features[i - 1];
- }
- }
- // Returns the reserved space (in elements) needed for the MLP
- size_t get_mlp_reserved_space(int64_t batch_size, int num_layers, const int* output_features) {
- size_t res_space = 0;
- // Need to store output of every intermediate MLP - size equal to output_features[i] * batch_size
- // for all 'i' in [0, num_layers-1)
- for (int l = 0; l < num_layers; l++) {
- res_space += output_features[l] * batch_size;
- }
- return res_space;
- }
- // Returns the size of all fprop activations combined
- size_t get_all_activations_size(int64_t batch_size, int num_layers, const int* output_features) {
- size_t acts_size = 0;
- for (int l = 0; l < num_layers; l++) {
- acts_size += output_features[l] * batch_size;
- }
- return acts_size;
- }
- #if 0
- // Returns the work space (in elements) needed for the MLP bprop.
- size_t get_mlp_bp_workspace (int batch_size, int num_layers, const int* output_features) {
- /*
- Workspace is partitioned as
- DY_GEMMs : DX_GEMMs
- */
- size_t work_space = 0;
- // Store each intermediate dY explicitly. Need 2 dYs per MLP layer (one for o/p
- // of biasReLU_bp and one for o/p of dgrad GEMM).
- work_space += 2*get_all_activations_size(batch_size, num_layers, output_features);
- return work_space;
- }
- #endif
- // Scratch space needed for reductions in number of elements
- size_t get_reduction_scratch_space(int batch_size, int num_layers, const int* output_features) {
- size_t max_scratch_space = 0;
- // Loop over all layers to see which one needs the max scratch space
- for (int l = 0; l < num_layers; l++) {
- // need to find max(aligned, not_aligned)
- int tmp, res0, res1;
- int block_x = BIAS_RELU_BW_NTHREADS_X;
- int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
- get_biasAddRelu_bprop_grid_size(
- output_features[l], batch_size, block_x, block_y, &tmp, &res0);
- block_x = ILP * BIAS_RELU_BW_NTHREADS_X;
- get_biasAddRelu_bprop_grid_size(
- output_features[l], batch_size, block_x, block_y, &tmp, &res1);
- max_scratch_space = std::max(max_scratch_space, (size_t)(output_features[l] * res0));
- max_scratch_space = std::max(max_scratch_space, (size_t)(output_features[l] * res1));
- }
- return max_scratch_space;
- }
- // Buffer for semaphores
- size_t get_semaphores_size(int num_layers, const int* output_features) {
- // Upper bound on semaphores is one per feature for the layer
- // with the most features.
- int max_features = 0;
- for (int l = 0; l < num_layers; l++) {
- max_features = std::max(max_features, output_features[l]);
- }
- return (size_t)max_features;
- }
- // Returns the work space (in elements) needed for the MLP bprop.
- template <typename T>
- size_t get_mlp_bp_workspace_in_bytes(int batch_size, int num_layers, const int* output_features) {
- size_t work_space = 0;
- // Store each intermediate dY explicitly. Need 2 dYs per MLP layer (one for o/p
- // of biasReLU_bp and one for o/p of dgrad GEMM).
- work_space += 2 * get_all_activations_size(batch_size, num_layers, output_features) * sizeof(T);
- work_space +=
- get_reduction_scratch_space(batch_size, num_layers, output_features) * sizeof(float);
- work_space += get_semaphores_size(num_layers, output_features) * sizeof(int);
- return work_space;
- }
- // Returns pointers to each segment of the workspace
- template <typename T>
- void partition_mlp_bp_workspace(
- int batch_size,
- int num_layers,
- const int* output_features,
- void* work_space,
- T** dy_gemms,
- T** dx_gemms,
- float** db_scratch,
- int** semaphores) {
- /*
- Workspace is partitioned as
- DY_GEMMs : DX_GEMMs : DB_SCRATCH : SEMAPHORES
- */
- // Start address where dy_gemm tensors are stored
- *dy_gemms = reinterpret_cast<T*>(work_space);
- // Start address where dx_gemm tensors are stored
- *dx_gemms = *dy_gemms + get_all_activations_size(batch_size, num_layers, output_features);
- // Start address where db intermediate tensors are stored
- *db_scratch = reinterpret_cast<float*>(
- *dx_gemms + get_all_activations_size(batch_size, num_layers, output_features));
- // Start address of semaphores
- *semaphores = reinterpret_cast<int*>(
- *db_scratch + get_reduction_scratch_space(batch_size, num_layers, output_features));
- return;
- }
- // Does a simple MLP fprop (GEMM+bias+ReLU).
- // Can handle num_layers number of layers, each with its own shape. Output of layer i is assumed
- // to be input of layer i+1. output_features, WPtr and BPtr are arrays of length num_layers, and
- // must be in the same order i.e. WPtr[i] and BPtr[i] are respectively the weight and bias of layer
- // 'i'.
- template <typename T>
- int mlp_fp(
- T* X,
- int input_features,
- int batch_size,
- T** WPtr,
- int num_layers,
- int* output_features,
- T** BPtr,
- T* Y,
- T* reserved_space,
- int use_bias,
- int activation,
- void* lt_workspace) {
- T *weight, *input, *output, *bias;
- T *reserved_space_x, *reserved_space_y;
- reserved_space_x = NULL;
- reserved_space_y = reserved_space;
- // Get cublas handle from Pytorch
- cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
- // Get the stream from cublas handle to reuse for biasReLU kernel.
- cudaStream_t stream;
- cublasGetStream(handle, &stream);
- for (int layer = 0; layer < num_layers; layer++) {
- weight = WPtr[layer];
- input = (layer == 0) ? X : reserved_space_x;
- output = (layer == num_layers - 1) ? Y : reserved_space_y;
- if (use_bias) {
- bias = BPtr[layer];
- }
- int ifeat = (layer == 0) ? input_features : output_features[layer - 1];
- int ofeat = output_features[layer];
- float one = 1.f;
- float zero = 0.f;
- // try with cublaslt first for supported case with valid handle
- int cublaslt_status = 1;
- #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
- if(activation < 1){
- cublaslt_status = mlp_gemm_lt(
- //ltHandle,
- (cublasLtHandle_t)handle,
- CUBLAS_OP_T,
- CUBLAS_OP_N,
- ofeat,
- batch_size,
- ifeat,
- &one,
- weight,
- ifeat,
- input,
- ifeat,
- &zero,
- output,
- ofeat,
- lt_workspace,
- 1 << 22,
- stream,
- use_bias == 1,
- activation == 1,
- bias);
- }
- #endif
- // if cublaslt failed or not executed, fallback to cublas
- if (cublaslt_status != 0) {
- cublasStatus_t cublas_status;
- // Call GEMM: fprop is Y = W'X
- cublas_status = mlp_gemm(
- handle,
- CUBLAS_OP_T,
- CUBLAS_OP_N,
- ofeat,
- batch_size,
- ifeat,
- &one,
- weight,
- ifeat,
- input,
- ifeat,
- &zero,
- output,
- ofeat);
- if (cublas_status != CUBLAS_STATUS_SUCCESS) {
- printf("GEMM fprop failed with %d\n", cublas_status);
- return 1;
- }
- const uint &input_size = ofeat;
- int num_blocks = 0;
- int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
- // Call biasReLU
- if(use_bias == 1) {
- if (activation == 0) { // no activation
- cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
- biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
- } else if (activation == 1) { // relu
- cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAddRelu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
- biasAddRelu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
- } else if (activation == 2) { // sigmoid
- cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
- biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
- cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
- Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
- }
- } else {
- // don't need to do anything in case of no activation and no bias
- if (activation == 1) { // relu
- cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
- Relu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
- } else if (activation == 2) { // sigmoid
- cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
- Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
- }
- }
- }
- // Set current output as next layer input
- reserved_space_x = reserved_space_y;
- // Set next layer output
- reserved_space_y += ofeat * batch_size;
- }
- return 0;
- }
- // Does a simple MLP bprop (GEMM+bias+ReLU).
- // Needs reserved space to come back exactly as it was populated in fprop.
- // Does dgrad and wgrad sequentially.
- template <typename T>
- int mlp_bp(
- T* X,
- T* Y,
- int input_features,
- int batch_size,
- T** WPtr,
- int num_layers,
- int* output_features,
- T* dY,
- T* reserved_space,
- T* work_space,
- T* dX,
- T** dwPtr,
- T** dbPtr,
- bool requires_grad,
- int use_bias,
- int activation) {
- T* weight;
- T *dweight, *dx, *dy, *dbias;
- T *x, *y;
- // Where the dx of the biasReLU (== dy of gemm) is stored. Can be thrown away
- // after bp call.
- T* dy_gemm_base;
- // Where the dx after GEMM is stored.
- T* dx_gemm_base;
- // Where partial reduction results are stored.
- float* db_scratch;
- // Semaphores for reduction.
- int* semaphores;
- partition_mlp_bp_workspace<T>(
- batch_size,
- num_layers,
- output_features,
- work_space,
- &dy_gemm_base,
- &dx_gemm_base,
- &db_scratch,
- &semaphores);
- size_t semaphore_size = get_semaphores_size(num_layers, output_features) * sizeof(int);
- // Get cublas handle from Pytorch
- cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
- // Get the stream from cublas handle to reuse for biasReLU kernel.
- cudaStream_t stream;
- cublasGetStream(handle, &stream);
- int* y_offsets = (int*)malloc(num_layers * sizeof(int));
- get_y_offsets(batch_size, num_layers, output_features, y_offsets);
- for (int layer = num_layers - 1; layer >= 0; layer--) {
- weight = WPtr[layer];
- dweight = dwPtr[layer];
- // x is read from reserved space
- x = (layer == 0) ? X : reserved_space + y_offsets[layer - 1];
- // dx is written in workspace for all but layer==0
- dx = (layer == 0) ? dX : dx_gemm_base + y_offsets[layer - 1];
- // y is read from reserved space
- y = (layer == num_layers - 1) ? Y : reserved_space + y_offsets[layer];
- // dx from layer+1
- dy = (layer == num_layers - 1) ? dY : dx_gemm_base + y_offsets[layer];
- // dy_gemm is written to and read immediately
- T* dy_gemm = dy_gemm_base + y_offsets[layer];
- dbias = dbPtr[layer];
- int xfeat = (layer == 0) ? input_features : output_features[layer - 1];
- int yfeat = output_features[layer];
- float one = 1.f;
- float zero = 0.f;
- if (use_bias == 1) {
- if (activation == 0) { // no acitvation
- // bgrad
- dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);
- int grid_x, grid_y;
- cudaMemsetAsync(semaphores, 0, semaphore_size, stream);
- int block_x = BIAS_RELU_BW_NTHREADS_X;
- int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
- get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
- dim3 grid(grid_x, grid_y);
- biasAdd_bprop<T, 4><<<grid, block, 0, stream>>>(
- dy, yfeat, batch_size, db_scratch, semaphores, dbias);
- // bypass dgrad through reset pointer
- dy_gemm = dy;
- } else if (activation == 1) { // relu
- dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);
- int grid_x, grid_y;
- cudaMemsetAsync(semaphores, 0, semaphore_size, stream);
- if(yfeat % (ILP * BIAS_RELU_BW_NTHREADS_X) == 0 &&
- is_aligned(y) &&
- is_aligned(dy) &&
- is_aligned(dy_gemm) &&
- is_aligned(dbias)){
- int block_x = ILP * BIAS_RELU_BW_NTHREADS_X;
- int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
- get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
- dim3 grid(grid_x, grid_y);
- biasAddRelu_bprop_aligned<T, 4><<<grid, block, 0, stream>>>(
- y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias);
- } else {
- int block_x = BIAS_RELU_BW_NTHREADS_X;
- int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
- get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
- dim3 grid(grid_x, grid_y);
- biasAddRelu_bprop<T, 4><<<grid, block, 0, stream>>>(
- y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias);
- }
- } else if (activation == 2) { // sigmoid
- // activation backward
- int num_blocks = 0;
- int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
- cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);
- Sigmoid_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);
- // bgrad, from dy_gemm
- dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);
- int grid_x, grid_y;
- cudaMemsetAsync(semaphores, 0, semaphore_size, stream);
- int block_x = BIAS_RELU_BW_NTHREADS_X;
- int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
- get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
- dim3 grid(grid_x, grid_y);
- biasAdd_bprop<T, 4><<<grid, block, 0, stream>>>(
- dy_gemm, yfeat, batch_size, db_scratch, semaphores, dbias);
- }
- } else { // no bias below
- if (activation == 0) {
- // bypass dgrad through reset pointer
- dy_gemm = dy;
- } else if (activation == 1) { // relu
- int num_blocks = 0;
- int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
- cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);
- Relu_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);
- } else if (activation == 2) { // sigmoid
- int num_blocks = 0;
- int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
- cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);
- Sigmoid_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);
- }
- }
- cublasStatus_t cublas_status;
- // Call GEMM dgrad
- if (layer > 0 || requires_grad == 1) {
- cublas_status = mlp_gemm(
- handle,
- CUBLAS_OP_N,
- CUBLAS_OP_N,
- xfeat,
- batch_size,
- yfeat,
- &one,
- weight,
- xfeat,
- dy_gemm,
- yfeat,
- &zero,
- dx,
- xfeat);
- if (cublas_status != CUBLAS_STATUS_SUCCESS) {
- printf("GEMM dgrad failed with %d\n", cublas_status);
- return 1;
- }
- }
- // Call GEMM wgrad
- cublas_status = mlp_gemm(
- handle,
- CUBLAS_OP_N,
- CUBLAS_OP_T,
- xfeat,
- yfeat,
- batch_size,
- &one,
- x,
- xfeat,
- dy_gemm,
- yfeat,
- &zero,
- dweight,
- xfeat);
- if (cublas_status != CUBLAS_STATUS_SUCCESS) {
- printf("GEMM wgrad failed with %d\n", cublas_status);
- return 1;
- }
- }
- return 0;
- }
- // Instantiate for floating point types
- template int mlp_fp<float>(
- float* X,
- int input_features,
- int batch_size,
- float** WPtr,
- int num_layers,
- int* output_features,
- float** BPtr,
- float* Y,
- float* reserved_space,
- int use_bias,
- int activation,
- void* lt_workspace);
- template int mlp_bp<float>(
- float* X,
- float* Y,
- int input_features,
- int batch_size,
- float** WPtr,
- int num_layers,
- int* output_features,
- float* dY,
- float* reserved_space,
- float* work_space,
- float* dX,
- float** dwPtr,
- float** dbPtr,
- bool requires_grad,
- int use_bias,
- int activation);
- template int mlp_fp<at::Half>(
- at::Half* X,
- int input_features,
- int batch_size,
- at::Half** WPtr,
- int num_layers,
- int* output_features,
- at::Half** BPtr,
- at::Half* Y,
- at::Half* reserved_space,
- int use_bias,
- int activation,
- void* lt_workspace);
- template int mlp_bp<at::Half>(
- at::Half* X,
- at::Half* Y,
- int input_features,
- int batch_size,
- at::Half** WPtr,
- int num_layers,
- int* output_features,
- at::Half* dY,
- at::Half* reserved_space,
- at::Half* work_space,
- at::Half* dX,
- at::Half** dwPtr,
- at::Half** dbPtr,
- bool requires_grad,
- int use_bias,
- int activation);
- template int mlp_fp<double>(
- double* X,
- int input_features,
- int batch_size,
- double** WPtr,
- int num_layers,
- int* output_features,
- double** BPtr,
- double* Y,
- double* reserved_space,
- int use_bias,
- int activation,
- void* lt_workspace);
- template int mlp_bp<double>(
- double* X,
- double* Y,
- int input_features,
- int batch_size,
- double** WPtr,
- int num_layers,
- int* output_features,
- double* dY,
- double* reserved_space,
- double* work_space,
- double* dX,
- double** dwPtr,
- double** dbPtr,
- bool requires_grad,
- int use_bias,
- int activation);
- template size_t get_mlp_bp_workspace_in_bytes<float>(
- int batch_size,
- int num_layers,
- const int* output_features);
- template size_t get_mlp_bp_workspace_in_bytes<at::Half>(
- int batch_size,
- int num_layers,
- const int* output_features);
- template size_t get_mlp_bp_workspace_in_bytes<double>(
- int batch_size,
- int num_layers,
- const int* output_features);
|