mlp_cuda.cu 51 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678
  1. #include <ATen/ATen.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <assert.h>
  4. #include <stdio.h>
  5. #include <stdlib.h>
  6. #include <string.h>
  7. #include <torch/torch.h>
  8. /* Includes, cuda */
  9. #include <cublas_v2.h>
  10. #include <cuda_runtime.h>
  11. #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
  12. // includes cublaslt
  13. #include <cublasLt.h>
  14. #endif
  15. // constants for fused bias+relu kernel
  16. #define BIAS_RELU_FW_NTHREADS 128 // forward number of thread per block
  17. #define BIAS_RELU_BW_NTHREADS_X 32 // backward number of thread in feature dim
  18. #define BIAS_RELU_BW_NTHREADS_Y 16 // backward number of thread in batch dim
  19. #define BIAS_RELU_RED_PER_THREAD 16 // backward minimal reduction length per thread
  20. // move to a header later on
  21. #define ILP 4
  22. template<typename T>
  23. __host__ __device__ __forceinline__ bool is_aligned(T* p){
  24. return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
  25. }
  26. template<typename T>
  27. __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
  28. typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
  29. ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
  30. }
  31. template<typename T>
  32. __device__ __forceinline__ void load_store(T* dst, volatile T* src, int dst_offset, int src_offset){
  33. typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
  34. ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
  35. }
  36. template<typename T>
  37. __device__ __forceinline__ void load_store(volatile T* dst, T* src, int dst_offset, int src_offset){
  38. typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
  39. ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
  40. }
  41. // Keep ReLU in float only. When using half, cast to float before calling.
  42. __device__ __inline__ float relu(float a) {
  43. float retf = max(a, 0.f);
  44. return (retf);
  45. }
  46. // Keep Sigmoid in float only. When using half, cast to float before calling.
  47. __device__ __inline__ float sigmoid(float a) {
  48. float retf = 1.f / (1.f + expf(-a));
  49. return (retf);
  50. }
  51. // FP64 Wrapper around cublas GEMMEx
  52. cublasStatus_t mlp_gemm(
  53. cublasHandle_t handle,
  54. cublasOperation_t transa,
  55. cublasOperation_t transb,
  56. int m,
  57. int n,
  58. int k,
  59. float* alpha,
  60. const double* A,
  61. int lda,
  62. const double* B,
  63. int ldb,
  64. const float* beta,
  65. double* C,
  66. int ldc) {
  67. return cublasGemmEx(
  68. handle,
  69. transa,
  70. transb,
  71. m,
  72. n,
  73. k,
  74. alpha,
  75. A,
  76. CUDA_R_64F,
  77. lda,
  78. B,
  79. CUDA_R_64F,
  80. ldb,
  81. beta,
  82. C,
  83. CUDA_R_64F,
  84. ldc,
  85. CUDA_R_64F,
  86. CUBLAS_GEMM_DEFAULT);
  87. }
  88. // FP32 Wrapper around cublas GEMMEx
  89. cublasStatus_t mlp_gemm(
  90. cublasHandle_t handle,
  91. cublasOperation_t transa,
  92. cublasOperation_t transb,
  93. int m,
  94. int n,
  95. int k,
  96. float* alpha,
  97. const float* A,
  98. int lda,
  99. const float* B,
  100. int ldb,
  101. const float* beta,
  102. float* C,
  103. int ldc) {
  104. return cublasGemmEx(
  105. handle,
  106. transa,
  107. transb,
  108. m,
  109. n,
  110. k,
  111. alpha,
  112. A,
  113. CUDA_R_32F,
  114. lda,
  115. B,
  116. CUDA_R_32F,
  117. ldb,
  118. beta,
  119. C,
  120. CUDA_R_32F,
  121. ldc,
  122. CUDA_R_32F,
  123. CUBLAS_GEMM_DEFAULT);
  124. }
  125. // FP16 Tensor core wrapper around cublas GEMMEx
  126. cublasStatus_t mlp_gemm(
  127. cublasHandle_t handle,
  128. cublasOperation_t transa,
  129. cublasOperation_t transb,
  130. int m,
  131. int n,
  132. int k,
  133. float* alpha,
  134. const at::Half* A,
  135. int lda,
  136. const at::Half* B,
  137. int ldb,
  138. float* beta,
  139. at::Half* C,
  140. int ldc) {
  141. return cublasGemmEx(
  142. handle,
  143. transa,
  144. transb,
  145. m,
  146. n,
  147. k,
  148. alpha,
  149. A,
  150. CUDA_R_16F,
  151. lda,
  152. B,
  153. CUDA_R_16F,
  154. ldb,
  155. beta,
  156. C,
  157. CUDA_R_16F,
  158. ldc,
  159. CUDA_R_32F,
  160. CUBLAS_GEMM_DEFAULT_TENSOR_OP);
  161. }
  162. #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
  163. int mlp_gemm_lt(
  164. cublasLtHandle_t ltHandle,
  165. cublasOperation_t transa,
  166. cublasOperation_t transb,
  167. int m,
  168. int n,
  169. int k,
  170. float *alpha, /* host pointer */
  171. const at::Half* A,
  172. int lda,
  173. const at::Half* B,
  174. int ldb,
  175. float *beta, /* host pointer */
  176. at::Half* C,
  177. int ldc,
  178. void *workspace,
  179. size_t workspaceSize,
  180. cudaStream_t stream,
  181. bool use_bias,
  182. bool use_relu,
  183. const void* bias) {
  184. cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
  185. cublasLtMatmulDescOpaque_t operationDesc = {};
  186. cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  187. cublasLtMatmulPreferenceOpaque_t preference = {};
  188. int returnedResults = 0;
  189. cublasLtMatmulHeuristicResult_t heuristicResult = {};
  190. cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
  191. // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  192. // for details about defaults; here we just set the transforms for
  193. // A and B.
  194. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  195. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  196. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  197. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  198. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  199. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  200. if (use_bias) {
  201. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
  202. if (status != CUBLAS_STATUS_SUCCESS) {
  203. goto CLEANUP;
  204. }
  205. if (use_relu) {
  206. epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
  207. } else {
  208. epilogue = CUBLASLT_EPILOGUE_BIAS;
  209. }
  210. } else {
  211. if (use_relu) {
  212. epilogue = CUBLASLT_EPILOGUE_RELU;
  213. }
  214. }
  215. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  216. if (status != CUBLAS_STATUS_SUCCESS) {
  217. goto CLEANUP;
  218. }
  219. // Create matrix descriptors. Not setting any extra attributes.
  220. status = cublasLtMatrixLayoutInit(
  221. &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
  222. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  223. status = cublasLtMatrixLayoutInit(
  224. &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
  225. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  226. status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);
  227. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  228. // Create preference handle; In general, extra attributes can be
  229. // used here to disable tensor ops or to make sure algo selected
  230. // will work with badly aligned A, B, C. However, for simplicity
  231. // here we assume A,B,C are always well aligned (e.g., directly
  232. // come from cudaMalloc)
  233. status = cublasLtMatmulPreferenceInit(&preference);
  234. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  235. status = cublasLtMatmulPreferenceSetAttribute(
  236. &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  237. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  238. // We just need the best available heuristic to try and run matmul.
  239. // There is no guarantee that this will work. For example, if A is
  240. // badly aligned, you can request more (e.g. 32) algos and try to
  241. // run them one by one until something works.
  242. status = cublasLtMatmulAlgoGetHeuristic(
  243. ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  244. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  245. if (returnedResults == 0) {
  246. status = CUBLAS_STATUS_NOT_SUPPORTED;
  247. goto CLEANUP;
  248. }
  249. status = cublasLtMatmul(ltHandle,
  250. &operationDesc,
  251. alpha,
  252. A,
  253. &Adesc,
  254. B,
  255. &Bdesc,
  256. beta,
  257. C,
  258. &Cdesc,
  259. C,
  260. &Cdesc,
  261. &heuristicResult.algo,
  262. workspace,
  263. workspaceSize,
  264. stream);
  265. CLEANUP:
  266. // Descriptors are no longer needed as all GPU work was already
  267. // enqueued.
  268. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
  269. }
  270. int mlp_gemm_lt(
  271. cublasLtHandle_t ltHandle,
  272. cublasOperation_t transa,
  273. cublasOperation_t transb,
  274. int m,
  275. int n,
  276. int k,
  277. float *alpha, /* host pointer */
  278. const double* A,
  279. int lda,
  280. const double* B,
  281. int ldb,
  282. float *beta, /* host pointer */
  283. double* C,
  284. int ldc,
  285. void *workspace,
  286. size_t workspaceSize,
  287. cudaStream_t stream,
  288. bool use_bias,
  289. bool use_relu,
  290. const void* bias) {
  291. return 1;
  292. }
  293. int mlp_gemm_lt(
  294. cublasLtHandle_t ltHandle,
  295. cublasOperation_t transa,
  296. cublasOperation_t transb,
  297. int m,
  298. int n,
  299. int k,
  300. float *alpha, /* host pointer */
  301. const float *A,
  302. int lda,
  303. const float *B,
  304. int ldb,
  305. float *beta, /* host pointer */
  306. float *C,
  307. int ldc,
  308. void *workspace,
  309. size_t workspaceSize,
  310. cudaStream_t stream,
  311. bool use_bias,
  312. bool use_relu,
  313. const void* bias) {
  314. cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
  315. cublasLtMatmulDescOpaque_t operationDesc = {};
  316. cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  317. cublasLtMatmulPreferenceOpaque_t preference = {};
  318. int returnedResults = 0;
  319. cublasLtMatmulHeuristicResult_t heuristicResult = {};
  320. cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
  321. // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  322. // for details about defaults; here we just set the transforms for
  323. // A and B.
  324. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  325. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  326. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  327. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  328. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  329. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  330. if (use_bias) {
  331. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
  332. if (status != CUBLAS_STATUS_SUCCESS) {
  333. goto CLEANUP;
  334. }
  335. if (use_relu) {
  336. epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
  337. } else {
  338. epilogue = CUBLASLT_EPILOGUE_BIAS;
  339. }
  340. } else {
  341. if (use_relu) {
  342. epilogue = CUBLASLT_EPILOGUE_RELU;
  343. }
  344. }
  345. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  346. if (status != CUBLAS_STATUS_SUCCESS) {
  347. goto CLEANUP;
  348. }
  349. // Create matrix descriptors. Not setting any extra attributes.
  350. status = cublasLtMatrixLayoutInit(
  351. &Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
  352. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  353. status = cublasLtMatrixLayoutInit(
  354. &Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
  355. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  356. status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc);
  357. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  358. // Create preference handle; In general, extra attributes can be
  359. // used here to disable tensor ops or to make sure algo selected
  360. // will work with badly aligned A, B, C. However, for simplicity
  361. // here we assume A,B,C are always well aligned (e.g., directly
  362. // come from cudaMalloc)
  363. status = cublasLtMatmulPreferenceInit(&preference);
  364. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  365. status = cublasLtMatmulPreferenceSetAttribute(
  366. &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  367. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  368. // We just need the best available heuristic to try and run matmul.
  369. // There is no guarantee that this will work. For example, if A is
  370. // badly aligned, you can request more (e.g. 32) algos and try to
  371. // run them one by one until something works.
  372. status = cublasLtMatmulAlgoGetHeuristic(
  373. ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  374. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  375. if (returnedResults == 0) {
  376. status = CUBLAS_STATUS_NOT_SUPPORTED;
  377. goto CLEANUP;
  378. }
  379. status = cublasLtMatmul(ltHandle,
  380. &operationDesc,
  381. alpha,
  382. A,
  383. &Adesc,
  384. B,
  385. &Bdesc,
  386. beta,
  387. C,
  388. &Cdesc,
  389. C,
  390. &Cdesc,
  391. &heuristicResult.algo,
  392. workspace,
  393. workspaceSize,
  394. stream);
  395. CLEANUP:
  396. // Descriptors are no longer needed as all GPU work was already
  397. // enqueued.
  398. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
  399. }
  400. #endif
  401. // Bias ADD. Assume input X is [features x batch size], column major.
  402. // Bias is one 'features' long vector, with implicit broadcast.
  403. template <typename T>
  404. __global__ void biasAdd_fprop(T *X, T *b, uint batch_size, uint features) {
  405. T r_x[ILP];
  406. T r_b[ILP];
  407. if(is_aligned(X) && is_aligned(b) && features % ILP ==0) {
  408. int tid = blockIdx.x * blockDim.x + threadIdx.x;
  409. for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
  410. int row = tid % (features / ILP);
  411. load_store(r_x, X, 0 , tid);
  412. load_store(r_b, b, 0 , row);
  413. #pragma unroll
  414. for(int ii = 0; ii < ILP; ii++) {
  415. float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);
  416. r_x[ii] = bias_sum;
  417. }
  418. load_store(X, r_x, tid , 0);
  419. }
  420. } else {
  421. int tid = blockIdx.x * blockDim.x + threadIdx.x;
  422. for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
  423. #pragma unroll
  424. for(int ii = 0; ii < ILP; ii++) {
  425. int idx = tid + ii * blockDim.x * gridDim.x;
  426. if(idx < features * batch_size) {
  427. int row = tid % features;
  428. r_x[ii] = X[idx];
  429. r_b[ii] = b[row];
  430. }
  431. }
  432. #pragma unroll
  433. for(int ii = 0; ii < ILP; ii++) {
  434. float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);
  435. r_x[ii] = bias_sum;
  436. }
  437. #pragma unroll
  438. for(int ii = 0; ii < ILP; ii++) {
  439. int idx = tid + ii * blockDim.x * gridDim.x;
  440. if(idx < features * batch_size) {
  441. X[idx] = r_x[ii];
  442. }
  443. }
  444. }
  445. }
  446. }
  447. // Bias ADD + ReLU. Assume input X is [features x batch size], column major.
  448. // Activation support fuesed ReLU. Safe to call in-place.
  449. template <typename T>
  450. __global__ void biasAddRelu_fprop(T *X, T *b, uint batch_size, uint features) {
  451. T r_x[ILP];
  452. T r_b[ILP];
  453. if(is_aligned(X) && is_aligned(b) && features % ILP ==0) {
  454. int tid = blockIdx.x * blockDim.x + threadIdx.x;
  455. for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
  456. int row = tid % (features / ILP);
  457. load_store(r_x, X, 0 , tid);
  458. load_store(r_b, b, 0 , row);
  459. #pragma unroll
  460. for(int ii = 0; ii < ILP; ii++) {
  461. float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);
  462. r_x[ii] = relu(bias_sum);
  463. }
  464. load_store(X, r_x, tid , 0);
  465. }
  466. } else {
  467. int tid = blockIdx.x * blockDim.x + threadIdx.x;
  468. for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
  469. #pragma unroll
  470. for(int ii = 0; ii < ILP; ii++) {
  471. int idx = tid + ii * blockDim.x * gridDim.x;
  472. if(idx < features * batch_size) {
  473. int row = tid % features;
  474. r_x[ii] = X[idx];
  475. r_b[ii] = b[row];
  476. }
  477. }
  478. #pragma unroll
  479. for(int ii = 0; ii < ILP; ii++) {
  480. float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);
  481. r_x[ii] = relu(bias_sum);
  482. }
  483. #pragma unroll
  484. for(int ii = 0; ii < ILP; ii++) {
  485. int idx = tid + ii * blockDim.x * gridDim.x;
  486. if(idx < features * batch_size) {
  487. X[idx] = r_x[ii];
  488. }
  489. }
  490. }
  491. }
  492. }
  493. // ReLU. Assume input X is [features x batch size], column major.
  494. // Safe to call in-place.
  495. template <typename T>
  496. __global__ void Relu_fprop(T *X, uint batch_size, uint features) {
  497. T r_x[ILP];
  498. if(is_aligned(X) && features % ILP ==0) {
  499. int tid = blockIdx.x * blockDim.x + threadIdx.x;
  500. for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
  501. load_store(r_x, X, 0 , tid);
  502. #pragma unroll
  503. for(int ii = 0; ii < ILP; ii++) {
  504. r_x[ii] = relu(static_cast<float>(r_x[ii]));
  505. }
  506. load_store(X, r_x, tid , 0);
  507. }
  508. } else {
  509. int tid = blockIdx.x * blockDim.x + threadIdx.x;
  510. for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
  511. #pragma unroll
  512. for(int ii = 0; ii < ILP; ii++) {
  513. int idx = tid + ii * blockDim.x * gridDim.x;
  514. if(idx < features * batch_size) {
  515. r_x[ii] = X[idx];
  516. }
  517. }
  518. #pragma unroll
  519. for(int ii = 0; ii < ILP; ii++) {
  520. r_x[ii] = relu(static_cast<float>(r_x[ii]));
  521. }
  522. #pragma unroll
  523. for(int ii = 0; ii < ILP; ii++) {
  524. int idx = tid + ii * blockDim.x * gridDim.x;
  525. if(idx < features * batch_size) {
  526. X[idx] = r_x[ii];
  527. }
  528. }
  529. }
  530. }
  531. }
  532. // Sigmoid. Assume input X is [features x batch size], column major.
  533. // Safe to call in-place.
  534. template <typename T>
  535. __global__ void Sigmoid_fprop(T *X, uint batch_size, uint features) {
  536. T r_x[ILP];
  537. if(is_aligned(X) && features % ILP ==0) {
  538. int tid = blockIdx.x * blockDim.x + threadIdx.x;
  539. for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
  540. load_store(r_x, X, 0 , tid);
  541. #pragma unroll
  542. for(int ii = 0; ii < ILP; ii++) {
  543. r_x[ii] = sigmoid(static_cast<float>(r_x[ii]));
  544. }
  545. load_store(X, r_x, tid , 0);
  546. }
  547. } else {
  548. int tid = blockIdx.x * blockDim.x + threadIdx.x;
  549. for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
  550. #pragma unroll
  551. for(int ii = 0; ii < ILP; ii++) {
  552. int idx = tid + ii * blockDim.x * gridDim.x;
  553. if(idx < features * batch_size) {
  554. r_x[ii] = X[idx];
  555. }
  556. }
  557. #pragma unroll
  558. for(int ii = 0; ii < ILP; ii++) {
  559. r_x[ii] = sigmoid(static_cast<float>(r_x[ii]));
  560. }
  561. #pragma unroll
  562. for(int ii = 0; ii < ILP; ii++) {
  563. int idx = tid + ii * blockDim.x * gridDim.x;
  564. if(idx < features * batch_size) {
  565. X[idx] = r_x[ii];
  566. }
  567. }
  568. }
  569. }
  570. }
  571. // ReLU. Assume input X is [features x batch size], column major.
  572. // Safe to call in-place.
  573. template <typename T>
  574. __global__ void Relu_bprop(T *dY, T *Y, uint batch_size, uint features, T *dX) {
  575. T r_dy[ILP];
  576. T r_y[ILP];
  577. if(is_aligned(dY) &&
  578. is_aligned(Y) &&
  579. is_aligned(dX) &&
  580. features % ILP ==0) {
  581. int tid = blockIdx.x * blockDim.x + threadIdx.x;
  582. for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
  583. load_store(r_dy, dY, 0 , tid);
  584. load_store(r_y, Y, 0 , tid);
  585. #pragma unroll
  586. for(int ii=0;ii<ILP;ii++){
  587. if ((float)r_y[ii] <= 0.f)
  588. r_dy[ii] = 0;
  589. }
  590. load_store(dX, r_dy, tid, 0);
  591. }
  592. } else {
  593. int tid = blockIdx.x * blockDim.x + threadIdx.x;
  594. for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
  595. #pragma unroll
  596. for(int ii = 0; ii < ILP; ii++) {
  597. int idx = tid + ii * blockDim.x * gridDim.x;
  598. if(idx < features * batch_size) {
  599. r_dy[ii] = dY[idx];
  600. r_y[ii] = Y[idx];
  601. }
  602. }
  603. #pragma unroll
  604. for(int ii = 0; ii < ILP; ii++) {
  605. if ((float)r_y[ii] <= 0.f)
  606. r_dy[ii] = 0;
  607. }
  608. #pragma unroll
  609. for(int ii = 0; ii < ILP; ii++) {
  610. int idx = tid + ii * blockDim.x * gridDim.x;
  611. if(idx < features * batch_size) {
  612. dX[idx] = r_dy[ii];
  613. }
  614. }
  615. }
  616. }
  617. }
  618. // Sigmoid. Assume input X is [features x batch size], column major.
  619. // Safe to call in-place.
  620. template <typename T>
  621. __global__ void Sigmoid_bprop(T *dY, T *Y, uint batch_size, uint features, T *dX) {
  622. T r_dy[ILP];
  623. T r_y[ILP];
  624. if(is_aligned(dY) &&
  625. is_aligned(Y) &&
  626. is_aligned(dX) &&
  627. features % ILP ==0) {
  628. int tid = blockIdx.x * blockDim.x + threadIdx.x;
  629. for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
  630. load_store(r_dy, dY, 0 , tid);
  631. load_store(r_y, Y, 0 , tid);
  632. #pragma unroll
  633. for(int ii=0;ii<ILP;ii++){
  634. float grad_out = r_dy[ii];
  635. float out = r_y[ii];
  636. float grad_i = out * ( 1.f - out) * grad_out;
  637. r_dy[ii] = grad_i;
  638. }
  639. load_store(dX, r_dy, tid, 0);
  640. }
  641. } else {
  642. int tid = blockIdx.x * blockDim.x + threadIdx.x;
  643. for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
  644. #pragma unroll
  645. for(int ii = 0; ii < ILP; ii++) {
  646. int idx = tid + ii * blockDim.x * gridDim.x;
  647. if(idx < features * batch_size) {
  648. r_dy[ii] = dY[idx];
  649. r_y[ii] = Y[idx];
  650. }
  651. }
  652. #pragma unroll
  653. for(int ii = 0; ii < ILP; ii++) {
  654. float grad_out = r_dy[ii];
  655. float out = r_y[ii];
  656. float grad_i = out * ( 1.f - out) * grad_out;
  657. r_dy[ii] = grad_i;
  658. }
  659. #pragma unroll
  660. for(int ii = 0; ii < ILP; ii++) {
  661. int idx = tid + ii * blockDim.x * gridDim.x;
  662. if(idx < features * batch_size) {
  663. dX[idx] = r_dy[ii];
  664. }
  665. }
  666. }
  667. }
  668. }
  669. // Compute grid size for pointwise backward kernel.
  670. // block_x/y is total elment being handled per block, not number of threads
  671. void get_biasAddRelu_bprop_grid_size(
  672. int yfeat,
  673. int batch_size,
  674. int block_x,
  675. int block_y,
  676. int* grid_x,
  677. int* grid_y) {
  678. *grid_x = (yfeat + block_x - 1) / block_x;
  679. // Get number of SMs for efficient reduction.
  680. int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
  681. // can switch to occupancy calculation. use 4 below now for sm_70
  682. int max_blocks_y = (num_SMs * 4+(*grid_x)-1) / (*grid_x);
  683. // block_y should be from minimal work per thread
  684. int nRedSplits = (batch_size + block_y - 1) / block_y;
  685. // increase number of elem per thread redcution to not launch more than enough
  686. // kernel adjust work, so here we just launch max block
  687. *grid_y = std::min(nRedSplits, max_blocks_y);
  688. return;
  689. }
  690. // Addition done deterministically via a 2-pass approach. Each CTA writes out partial
  691. // sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.
  692. template <typename T, int UNROLL_FACTOR>
  693. __global__ void biasAdd_bprop(
  694. T* dY,
  695. int features,
  696. int batch_size,
  697. volatile float* intermediate,
  698. int* semaphores,
  699. T* db) {
  700. // The feature that this thread is responsible for
  701. int f = blockIdx.x * blockDim.x + threadIdx.x;
  702. // Compute the span this thread is responsible for
  703. // For this block
  704. int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;
  705. int b_nStart = blockIdx.y * b_chunkSize;
  706. int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;
  707. // For this thread
  708. int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;
  709. int nStart = threadIdx.y * chunkSize + b_nStart;
  710. int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;
  711. volatile float* out = intermediate + blockIdx.y * features;
  712. // Flag to trigger last reduction.
  713. __shared__ bool isLastBlock;
  714. // we know block size for now
  715. __shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y];
  716. // Accumulate db in FP32 always
  717. float db_local = 0;
  718. if (f < features) {
  719. int nidx = 0;
  720. // Handle non-multiple of UNROLL_FACTOR residue
  721. for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {
  722. int64_t row, col, flat_idx;
  723. row = f;
  724. col = nStart + nidx;
  725. flat_idx = col * features + row;
  726. db_local += (float)dY[flat_idx];
  727. }
  728. // Handle meat of work
  729. for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {
  730. int64_t row, col, flat_idx;
  731. row = f;
  732. col = nStart + nidx;
  733. flat_idx = col * features + row;
  734. #pragma unroll 4
  735. for (int u = 0; u < UNROLL_FACTOR; u++) {
  736. db_local += (float)dY[flat_idx];
  737. flat_idx += features;
  738. }
  739. }
  740. // naive block reduction on y-dim
  741. int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;
  742. smem[linear_idx] = db_local;
  743. }
  744. __syncthreads();
  745. if (f < features) {
  746. if(threadIdx.y == 0) {
  747. for(int yidx = 1; yidx < blockDim.y; yidx++){
  748. db_local += smem[yidx * blockDim.x + threadIdx.x];
  749. }
  750. // block result is in db_local now for all threadIdx.y == 0
  751. // Write out partial result
  752. out[f] = db_local;
  753. }
  754. }
  755. __threadfence();
  756. __syncthreads();
  757. // Increment semaphore and check if this is the last CTA in the grid_y dimension.
  758. // Only thread (0,0) calls this
  759. if (threadIdx.x == 0 && threadIdx.y == 0 && f < features) {
  760. unsigned int sum_idx;
  761. sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);
  762. isLastBlock = (sum_idx == (gridDim.y - 1));
  763. }
  764. __syncthreads();
  765. db_local = 0;
  766. // No block reduction for now, only thread (*,0) do grid reduction
  767. if (isLastBlock && f < features) {
  768. if(threadIdx.y == 0) {
  769. for (int n = 0; n < gridDim.y; n++) {
  770. int row, col;
  771. row = f;
  772. col = n;
  773. db_local += (float)(intermediate[col * features + row]);
  774. }
  775. db[f] = (T)db_local;
  776. }
  777. }
  778. }
  779. // Addition done deterministically via a 2-pass approach. Each CTA writes out partial
  780. // sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.
  781. template <typename T, int UNROLL_FACTOR>
  782. __global__ void biasAddRelu_bprop(
  783. T* Y,
  784. T* dY,
  785. int features,
  786. int batch_size,
  787. T* dX,
  788. volatile float* intermediate,
  789. int* semaphores,
  790. T* db) {
  791. // The feature that this thread is responsible for
  792. int f = blockIdx.x * blockDim.x + threadIdx.x;
  793. // Compute the span this thread is responsible for
  794. // For this block
  795. int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;
  796. int b_nStart = blockIdx.y * b_chunkSize;
  797. int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;
  798. // For this thread
  799. int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;
  800. int nStart = threadIdx.y * chunkSize + b_nStart;
  801. int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;
  802. volatile float* out = intermediate + blockIdx.y * features;
  803. // Flag to trigger last reduction.
  804. __shared__ bool isLastBlock;
  805. // we know block size for now
  806. __shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y];
  807. // Accumulate db in FP32 always
  808. float db_local = 0;
  809. if (f < features) {
  810. int nidx = 0;
  811. // Handle non-multiple of UNROLL_FACTOR residue
  812. for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {
  813. int row, col, flat_idx;
  814. row = f;
  815. col = nStart + nidx;
  816. flat_idx = col * features + row;
  817. T y_val = Y[flat_idx];
  818. T dy_val = dY[flat_idx];
  819. T dx_val;
  820. if ((float)y_val > 0.f)
  821. dx_val = dy_val;
  822. else
  823. dx_val = 0;
  824. dX[flat_idx] = dx_val;
  825. db_local += (float)dx_val;
  826. }
  827. // Handle meat of work
  828. for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {
  829. int row, col, flat_idx;
  830. row = f;
  831. col = nStart + nidx;
  832. flat_idx = col * features + row;
  833. #pragma unroll 4
  834. for (int u = 0; u < UNROLL_FACTOR; u++) {
  835. T y_val = Y[flat_idx];
  836. T dy_val = dY[flat_idx];
  837. T dx_val;
  838. if ((float)y_val > 0.f)
  839. dx_val = dy_val;
  840. else
  841. dx_val = 0;
  842. dX[flat_idx] = dx_val;
  843. db_local += (float)dx_val;
  844. flat_idx += features;
  845. }
  846. }
  847. // naive block reduction on y-dim
  848. int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;
  849. smem[linear_idx] = db_local;
  850. }
  851. __syncthreads();
  852. if (f < features) {
  853. if(threadIdx.y == 0) {
  854. for(int yidx = 1; yidx < blockDim.y; yidx++){
  855. db_local += smem[yidx * blockDim.x + threadIdx.x];
  856. }
  857. // block result is in db_local now for all threadIdx.y == 0
  858. // Write out partial result
  859. out[f] = db_local;
  860. }
  861. }
  862. __threadfence();
  863. __syncthreads();
  864. // Increment semaphore and check if this is the last CTA in the grid_y dimension.
  865. // Only thread (0,0) calls this
  866. if (threadIdx.x == 0 && threadIdx.y == 0 && f < features) {
  867. unsigned int sum_idx;
  868. sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);
  869. isLastBlock = (sum_idx == (gridDim.y - 1));
  870. }
  871. __syncthreads();
  872. db_local = 0;
  873. // No block reduction for now, only thread (*,0) do grid reduction
  874. if (isLastBlock && f < features) {
  875. if(threadIdx.y == 0) {
  876. for (int n = 0; n < gridDim.y; n++) {
  877. int row, col;
  878. row = f;
  879. col = n;
  880. db_local += (float)(intermediate[col * features + row]);
  881. }
  882. db[f] = (T)db_local;
  883. }
  884. }
  885. }
  886. // Addition done deterministically via a 2-pass approach. Each CTA writes out partial
  887. // sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.
  888. template <typename T, int UNROLL_FACTOR>
  889. __global__ void biasAddRelu_bprop_aligned(
  890. T* Y,
  891. T* dY,
  892. int features,
  893. int batch_size,
  894. T* dX,
  895. volatile float* intermediate,
  896. int* semaphores,
  897. T* db) {
  898. // The feature that this thread is responsible for
  899. int f = blockIdx.x * blockDim.x + threadIdx.x;
  900. // Compute the span this thread is responsible for
  901. // For this block
  902. int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;
  903. int b_nStart = blockIdx.y * b_chunkSize;
  904. int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;
  905. // For this thread
  906. int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;
  907. int nStart = threadIdx.y * chunkSize + b_nStart;
  908. int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;
  909. volatile float* out = intermediate + blockIdx.y * features;
  910. // Flag to trigger last reduction.
  911. __shared__ bool isLastBlock;
  912. // Accumulate db in FP32 always
  913. float db_local[ILP];
  914. T r_y[ILP];
  915. T r_dy[ILP];
  916. #pragma unroll
  917. for(int ii=0;ii<ILP;ii++){
  918. db_local[ii] = 0.f;
  919. }
  920. // f always <= features in this case
  921. //if (f < features) {
  922. int nidx = 0;
  923. // Handle non-multiple of UNROLL_FACTOR residue
  924. for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {
  925. int row, col, flat_idx;
  926. row = f;
  927. col = nStart + nidx;
  928. flat_idx = col * features / ILP + row;
  929. load_store(r_y, Y, 0, flat_idx);
  930. load_store(r_dy, dY, 0, flat_idx);
  931. #pragma unroll
  932. for(int ii=0;ii<ILP;ii++){
  933. if ((float)r_y[ii] <= 0.f)
  934. r_dy[ii] = 0;
  935. db_local[ii] += (float)r_dy[ii];
  936. }
  937. load_store(dX, r_dy, flat_idx, 0);
  938. }
  939. // Handle meat of work
  940. for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {
  941. int row, col, flat_idx;
  942. row = f;
  943. col = nStart + nidx;
  944. flat_idx = col * features / ILP + row; // total threads in x == features/ILP
  945. #pragma unroll
  946. for (int u = 0; u < UNROLL_FACTOR; u++) {
  947. load_store(r_y, Y, 0, flat_idx);
  948. load_store(r_dy, dY, 0, flat_idx);
  949. #pragma unroll
  950. for(int ii=0;ii<ILP;ii++){
  951. if ((float)r_y[ii] <= 0.f)
  952. r_dy[ii] = 0;
  953. db_local[ii] += (float)r_dy[ii];
  954. }
  955. load_store(dX, r_dy, flat_idx, 0);
  956. flat_idx += features/ILP;
  957. }
  958. }
  959. // we know block size for now
  960. __shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y*ILP];
  961. // naive block reduction on y-dim
  962. int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;
  963. float* smem_out = smem + ILP * linear_idx;
  964. #pragma unroll
  965. for(int ii=0;ii<ILP;ii++){
  966. smem_out[ii] = db_local[ii]; // reuse local dy buffer
  967. }
  968. __syncthreads();
  969. if(threadIdx.y == 0) {
  970. for(int yidx = 1; yidx < blockDim.y; yidx++){
  971. float* smem_in = smem + ILP * (yidx * blockDim.x + threadIdx.x);
  972. #pragma unroll
  973. for(int ii=0;ii<ILP;ii++){
  974. db_local[ii] += smem_in[ii]; // reuse local dy buffer
  975. }
  976. }
  977. // block result is in db_local now for all threadIdx.y == 0
  978. if(gridDim.y == 1) {
  979. #pragma unroll
  980. for(int ii=0;ii<ILP;ii++){
  981. r_dy[ii] = db_local[ii]; // reuse local dy buffer
  982. }
  983. load_store(db, r_dy, f, 0);
  984. return;
  985. }
  986. // Write out partial result
  987. load_store(out, db_local, f, 0);
  988. }
  989. __threadfence();
  990. __syncthreads();
  991. // Increment semaphore and check if this is the last CTA in the grid_y dimension.
  992. // Only thread (0,0) calls this
  993. if (threadIdx.x == 0 && threadIdx.y == 0) {
  994. unsigned int sum_idx;
  995. sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);
  996. isLastBlock = (sum_idx == (gridDim.y - 1));
  997. }
  998. __syncthreads();
  999. #pragma unroll
  1000. for(int ii=0;ii<ILP;ii++){
  1001. db_local[ii] = 0.f;
  1002. }
  1003. float r_db[ILP];
  1004. // No block reduction for now, only thread (*,0) do grid reduction
  1005. if (isLastBlock) {
  1006. if(threadIdx.y == 0){
  1007. for (int n = 0; n < gridDim.y; n++) {
  1008. int row, col;
  1009. row = f;
  1010. col = n;
  1011. load_store(r_db, intermediate, 0, col * features / ILP + row);
  1012. #pragma unroll
  1013. for(int ii=0;ii<ILP;ii++){
  1014. db_local[ii] += r_db[ii];
  1015. }
  1016. }
  1017. #pragma unroll
  1018. for(int ii=0;ii<ILP;ii++){
  1019. r_dy[ii] = db_local[ii]; // reuse local dy buffer
  1020. }
  1021. load_store(db, r_dy, f, 0);
  1022. }
  1023. }
  1024. }
  1025. // Lists where the num_layers-1 intermediate Y buffers start in reserved space on fprop, starting
  1026. // offset 0. The last Y value is, of course, stored in the user provided output buffer.
  1027. void get_y_offsets(
  1028. int batch_size,
  1029. int num_layers,
  1030. const int* output_features,
  1031. int* y_start_offsets) {
  1032. y_start_offsets[0] = 0;
  1033. for (int i = 1; i < num_layers; i++) {
  1034. y_start_offsets[i] = y_start_offsets[i - 1] + batch_size * output_features[i - 1];
  1035. }
  1036. }
  1037. // Returns the reserved space (in elements) needed for the MLP
  1038. size_t get_mlp_reserved_space(int64_t batch_size, int num_layers, const int* output_features) {
  1039. size_t res_space = 0;
  1040. // Need to store output of every intermediate MLP - size equal to output_features[i] * batch_size
  1041. // for all 'i' in [0, num_layers-1)
  1042. for (int l = 0; l < num_layers; l++) {
  1043. res_space += output_features[l] * batch_size;
  1044. }
  1045. return res_space;
  1046. }
  1047. // Returns the size of all fprop activations combined
  1048. size_t get_all_activations_size(int64_t batch_size, int num_layers, const int* output_features) {
  1049. size_t acts_size = 0;
  1050. for (int l = 0; l < num_layers; l++) {
  1051. acts_size += output_features[l] * batch_size;
  1052. }
  1053. return acts_size;
  1054. }
  1055. #if 0
  1056. // Returns the work space (in elements) needed for the MLP bprop.
  1057. size_t get_mlp_bp_workspace (int batch_size, int num_layers, const int* output_features) {
  1058. /*
  1059. Workspace is partitioned as
  1060. DY_GEMMs : DX_GEMMs
  1061. */
  1062. size_t work_space = 0;
  1063. // Store each intermediate dY explicitly. Need 2 dYs per MLP layer (one for o/p
  1064. // of biasReLU_bp and one for o/p of dgrad GEMM).
  1065. work_space += 2*get_all_activations_size(batch_size, num_layers, output_features);
  1066. return work_space;
  1067. }
  1068. #endif
  1069. // Scratch space needed for reductions in number of elements
  1070. size_t get_reduction_scratch_space(int batch_size, int num_layers, const int* output_features) {
  1071. size_t max_scratch_space = 0;
  1072. // Loop over all layers to see which one needs the max scratch space
  1073. for (int l = 0; l < num_layers; l++) {
  1074. // need to find max(aligned, not_aligned)
  1075. int tmp, res0, res1;
  1076. int block_x = BIAS_RELU_BW_NTHREADS_X;
  1077. int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
  1078. get_biasAddRelu_bprop_grid_size(
  1079. output_features[l], batch_size, block_x, block_y, &tmp, &res0);
  1080. block_x = ILP * BIAS_RELU_BW_NTHREADS_X;
  1081. get_biasAddRelu_bprop_grid_size(
  1082. output_features[l], batch_size, block_x, block_y, &tmp, &res1);
  1083. max_scratch_space = std::max(max_scratch_space, (size_t)(output_features[l] * res0));
  1084. max_scratch_space = std::max(max_scratch_space, (size_t)(output_features[l] * res1));
  1085. }
  1086. return max_scratch_space;
  1087. }
  1088. // Buffer for semaphores
  1089. size_t get_semaphores_size(int num_layers, const int* output_features) {
  1090. // Upper bound on semaphores is one per feature for the layer
  1091. // with the most features.
  1092. int max_features = 0;
  1093. for (int l = 0; l < num_layers; l++) {
  1094. max_features = std::max(max_features, output_features[l]);
  1095. }
  1096. return (size_t)max_features;
  1097. }
  1098. // Returns the work space (in elements) needed for the MLP bprop.
  1099. template <typename T>
  1100. size_t get_mlp_bp_workspace_in_bytes(int batch_size, int num_layers, const int* output_features) {
  1101. size_t work_space = 0;
  1102. // Store each intermediate dY explicitly. Need 2 dYs per MLP layer (one for o/p
  1103. // of biasReLU_bp and one for o/p of dgrad GEMM).
  1104. work_space += 2 * get_all_activations_size(batch_size, num_layers, output_features) * sizeof(T);
  1105. work_space +=
  1106. get_reduction_scratch_space(batch_size, num_layers, output_features) * sizeof(float);
  1107. work_space += get_semaphores_size(num_layers, output_features) * sizeof(int);
  1108. return work_space;
  1109. }
  1110. // Returns pointers to each segment of the workspace
  1111. template <typename T>
  1112. void partition_mlp_bp_workspace(
  1113. int batch_size,
  1114. int num_layers,
  1115. const int* output_features,
  1116. void* work_space,
  1117. T** dy_gemms,
  1118. T** dx_gemms,
  1119. float** db_scratch,
  1120. int** semaphores) {
  1121. /*
  1122. Workspace is partitioned as
  1123. DY_GEMMs : DX_GEMMs : DB_SCRATCH : SEMAPHORES
  1124. */
  1125. // Start address where dy_gemm tensors are stored
  1126. *dy_gemms = reinterpret_cast<T*>(work_space);
  1127. // Start address where dx_gemm tensors are stored
  1128. *dx_gemms = *dy_gemms + get_all_activations_size(batch_size, num_layers, output_features);
  1129. // Start address where db intermediate tensors are stored
  1130. *db_scratch = reinterpret_cast<float*>(
  1131. *dx_gemms + get_all_activations_size(batch_size, num_layers, output_features));
  1132. // Start address of semaphores
  1133. *semaphores = reinterpret_cast<int*>(
  1134. *db_scratch + get_reduction_scratch_space(batch_size, num_layers, output_features));
  1135. return;
  1136. }
  1137. // Does a simple MLP fprop (GEMM+bias+ReLU).
  1138. // Can handle num_layers number of layers, each with its own shape. Output of layer i is assumed
  1139. // to be input of layer i+1. output_features, WPtr and BPtr are arrays of length num_layers, and
  1140. // must be in the same order i.e. WPtr[i] and BPtr[i] are respectively the weight and bias of layer
  1141. // 'i'.
  1142. template <typename T>
  1143. int mlp_fp(
  1144. T* X,
  1145. int input_features,
  1146. int batch_size,
  1147. T** WPtr,
  1148. int num_layers,
  1149. int* output_features,
  1150. T** BPtr,
  1151. T* Y,
  1152. T* reserved_space,
  1153. int use_bias,
  1154. int activation,
  1155. void* lt_workspace) {
  1156. T *weight, *input, *output, *bias;
  1157. T *reserved_space_x, *reserved_space_y;
  1158. reserved_space_x = NULL;
  1159. reserved_space_y = reserved_space;
  1160. // Get cublas handle from Pytorch
  1161. cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
  1162. // Get the stream from cublas handle to reuse for biasReLU kernel.
  1163. cudaStream_t stream;
  1164. cublasGetStream(handle, &stream);
  1165. for (int layer = 0; layer < num_layers; layer++) {
  1166. weight = WPtr[layer];
  1167. input = (layer == 0) ? X : reserved_space_x;
  1168. output = (layer == num_layers - 1) ? Y : reserved_space_y;
  1169. if (use_bias) {
  1170. bias = BPtr[layer];
  1171. }
  1172. int ifeat = (layer == 0) ? input_features : output_features[layer - 1];
  1173. int ofeat = output_features[layer];
  1174. float one = 1.f;
  1175. float zero = 0.f;
  1176. // try with cublaslt first for supported case with valid handle
  1177. int cublaslt_status = 1;
  1178. #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
  1179. if(activation < 1){
  1180. cublaslt_status = mlp_gemm_lt(
  1181. //ltHandle,
  1182. (cublasLtHandle_t)handle,
  1183. CUBLAS_OP_T,
  1184. CUBLAS_OP_N,
  1185. ofeat,
  1186. batch_size,
  1187. ifeat,
  1188. &one,
  1189. weight,
  1190. ifeat,
  1191. input,
  1192. ifeat,
  1193. &zero,
  1194. output,
  1195. ofeat,
  1196. lt_workspace,
  1197. 1 << 22,
  1198. stream,
  1199. use_bias == 1,
  1200. activation == 1,
  1201. bias);
  1202. }
  1203. #endif
  1204. // if cublaslt failed or not executed, fallback to cublas
  1205. if (cublaslt_status != 0) {
  1206. cublasStatus_t cublas_status;
  1207. // Call GEMM: fprop is Y = W'X
  1208. cublas_status = mlp_gemm(
  1209. handle,
  1210. CUBLAS_OP_T,
  1211. CUBLAS_OP_N,
  1212. ofeat,
  1213. batch_size,
  1214. ifeat,
  1215. &one,
  1216. weight,
  1217. ifeat,
  1218. input,
  1219. ifeat,
  1220. &zero,
  1221. output,
  1222. ofeat);
  1223. if (cublas_status != CUBLAS_STATUS_SUCCESS) {
  1224. printf("GEMM fprop failed with %d\n", cublas_status);
  1225. return 1;
  1226. }
  1227. const uint &input_size = ofeat;
  1228. int num_blocks = 0;
  1229. int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
  1230. // Call biasReLU
  1231. if(use_bias == 1) {
  1232. if (activation == 0) { // no activation
  1233. cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
  1234. biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
  1235. } else if (activation == 1) { // relu
  1236. cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAddRelu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
  1237. biasAddRelu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
  1238. } else if (activation == 2) { // sigmoid
  1239. cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
  1240. biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
  1241. cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
  1242. Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
  1243. }
  1244. } else {
  1245. // don't need to do anything in case of no activation and no bias
  1246. if (activation == 1) { // relu
  1247. cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
  1248. Relu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
  1249. } else if (activation == 2) { // sigmoid
  1250. cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
  1251. Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
  1252. }
  1253. }
  1254. }
  1255. // Set current output as next layer input
  1256. reserved_space_x = reserved_space_y;
  1257. // Set next layer output
  1258. reserved_space_y += ofeat * batch_size;
  1259. }
  1260. return 0;
  1261. }
  1262. // Does a simple MLP bprop (GEMM+bias+ReLU).
  1263. // Needs reserved space to come back exactly as it was populated in fprop.
  1264. // Does dgrad and wgrad sequentially.
  1265. template <typename T>
  1266. int mlp_bp(
  1267. T* X,
  1268. T* Y,
  1269. int input_features,
  1270. int batch_size,
  1271. T** WPtr,
  1272. int num_layers,
  1273. int* output_features,
  1274. T* dY,
  1275. T* reserved_space,
  1276. T* work_space,
  1277. T* dX,
  1278. T** dwPtr,
  1279. T** dbPtr,
  1280. bool requires_grad,
  1281. int use_bias,
  1282. int activation) {
  1283. T* weight;
  1284. T *dweight, *dx, *dy, *dbias;
  1285. T *x, *y;
  1286. // Where the dx of the biasReLU (== dy of gemm) is stored. Can be thrown away
  1287. // after bp call.
  1288. T* dy_gemm_base;
  1289. // Where the dx after GEMM is stored.
  1290. T* dx_gemm_base;
  1291. // Where partial reduction results are stored.
  1292. float* db_scratch;
  1293. // Semaphores for reduction.
  1294. int* semaphores;
  1295. partition_mlp_bp_workspace<T>(
  1296. batch_size,
  1297. num_layers,
  1298. output_features,
  1299. work_space,
  1300. &dy_gemm_base,
  1301. &dx_gemm_base,
  1302. &db_scratch,
  1303. &semaphores);
  1304. size_t semaphore_size = get_semaphores_size(num_layers, output_features) * sizeof(int);
  1305. // Get cublas handle from Pytorch
  1306. cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
  1307. // Get the stream from cublas handle to reuse for biasReLU kernel.
  1308. cudaStream_t stream;
  1309. cublasGetStream(handle, &stream);
  1310. int* y_offsets = (int*)malloc(num_layers * sizeof(int));
  1311. get_y_offsets(batch_size, num_layers, output_features, y_offsets);
  1312. for (int layer = num_layers - 1; layer >= 0; layer--) {
  1313. weight = WPtr[layer];
  1314. dweight = dwPtr[layer];
  1315. // x is read from reserved space
  1316. x = (layer == 0) ? X : reserved_space + y_offsets[layer - 1];
  1317. // dx is written in workspace for all but layer==0
  1318. dx = (layer == 0) ? dX : dx_gemm_base + y_offsets[layer - 1];
  1319. // y is read from reserved space
  1320. y = (layer == num_layers - 1) ? Y : reserved_space + y_offsets[layer];
  1321. // dx from layer+1
  1322. dy = (layer == num_layers - 1) ? dY : dx_gemm_base + y_offsets[layer];
  1323. // dy_gemm is written to and read immediately
  1324. T* dy_gemm = dy_gemm_base + y_offsets[layer];
  1325. dbias = dbPtr[layer];
  1326. int xfeat = (layer == 0) ? input_features : output_features[layer - 1];
  1327. int yfeat = output_features[layer];
  1328. float one = 1.f;
  1329. float zero = 0.f;
  1330. if (use_bias == 1) {
  1331. if (activation == 0) { // no acitvation
  1332. // bgrad
  1333. dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);
  1334. int grid_x, grid_y;
  1335. cudaMemsetAsync(semaphores, 0, semaphore_size, stream);
  1336. int block_x = BIAS_RELU_BW_NTHREADS_X;
  1337. int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
  1338. get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
  1339. dim3 grid(grid_x, grid_y);
  1340. biasAdd_bprop<T, 4><<<grid, block, 0, stream>>>(
  1341. dy, yfeat, batch_size, db_scratch, semaphores, dbias);
  1342. // bypass dgrad through reset pointer
  1343. dy_gemm = dy;
  1344. } else if (activation == 1) { // relu
  1345. dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);
  1346. int grid_x, grid_y;
  1347. cudaMemsetAsync(semaphores, 0, semaphore_size, stream);
  1348. if(yfeat % (ILP * BIAS_RELU_BW_NTHREADS_X) == 0 &&
  1349. is_aligned(y) &&
  1350. is_aligned(dy) &&
  1351. is_aligned(dy_gemm) &&
  1352. is_aligned(dbias)){
  1353. int block_x = ILP * BIAS_RELU_BW_NTHREADS_X;
  1354. int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
  1355. get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
  1356. dim3 grid(grid_x, grid_y);
  1357. biasAddRelu_bprop_aligned<T, 4><<<grid, block, 0, stream>>>(
  1358. y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias);
  1359. } else {
  1360. int block_x = BIAS_RELU_BW_NTHREADS_X;
  1361. int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
  1362. get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
  1363. dim3 grid(grid_x, grid_y);
  1364. biasAddRelu_bprop<T, 4><<<grid, block, 0, stream>>>(
  1365. y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias);
  1366. }
  1367. } else if (activation == 2) { // sigmoid
  1368. // activation backward
  1369. int num_blocks = 0;
  1370. int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
  1371. cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);
  1372. Sigmoid_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);
  1373. // bgrad, from dy_gemm
  1374. dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);
  1375. int grid_x, grid_y;
  1376. cudaMemsetAsync(semaphores, 0, semaphore_size, stream);
  1377. int block_x = BIAS_RELU_BW_NTHREADS_X;
  1378. int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
  1379. get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
  1380. dim3 grid(grid_x, grid_y);
  1381. biasAdd_bprop<T, 4><<<grid, block, 0, stream>>>(
  1382. dy_gemm, yfeat, batch_size, db_scratch, semaphores, dbias);
  1383. }
  1384. } else { // no bias below
  1385. if (activation == 0) {
  1386. // bypass dgrad through reset pointer
  1387. dy_gemm = dy;
  1388. } else if (activation == 1) { // relu
  1389. int num_blocks = 0;
  1390. int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
  1391. cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);
  1392. Relu_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);
  1393. } else if (activation == 2) { // sigmoid
  1394. int num_blocks = 0;
  1395. int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
  1396. cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);
  1397. Sigmoid_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);
  1398. }
  1399. }
  1400. cublasStatus_t cublas_status;
  1401. // Call GEMM dgrad
  1402. if (layer > 0 || requires_grad == 1) {
  1403. cublas_status = mlp_gemm(
  1404. handle,
  1405. CUBLAS_OP_N,
  1406. CUBLAS_OP_N,
  1407. xfeat,
  1408. batch_size,
  1409. yfeat,
  1410. &one,
  1411. weight,
  1412. xfeat,
  1413. dy_gemm,
  1414. yfeat,
  1415. &zero,
  1416. dx,
  1417. xfeat);
  1418. if (cublas_status != CUBLAS_STATUS_SUCCESS) {
  1419. printf("GEMM dgrad failed with %d\n", cublas_status);
  1420. return 1;
  1421. }
  1422. }
  1423. // Call GEMM wgrad
  1424. cublas_status = mlp_gemm(
  1425. handle,
  1426. CUBLAS_OP_N,
  1427. CUBLAS_OP_T,
  1428. xfeat,
  1429. yfeat,
  1430. batch_size,
  1431. &one,
  1432. x,
  1433. xfeat,
  1434. dy_gemm,
  1435. yfeat,
  1436. &zero,
  1437. dweight,
  1438. xfeat);
  1439. if (cublas_status != CUBLAS_STATUS_SUCCESS) {
  1440. printf("GEMM wgrad failed with %d\n", cublas_status);
  1441. return 1;
  1442. }
  1443. }
  1444. return 0;
  1445. }
  1446. // Instantiate for floating point types
  1447. template int mlp_fp<float>(
  1448. float* X,
  1449. int input_features,
  1450. int batch_size,
  1451. float** WPtr,
  1452. int num_layers,
  1453. int* output_features,
  1454. float** BPtr,
  1455. float* Y,
  1456. float* reserved_space,
  1457. int use_bias,
  1458. int activation,
  1459. void* lt_workspace);
  1460. template int mlp_bp<float>(
  1461. float* X,
  1462. float* Y,
  1463. int input_features,
  1464. int batch_size,
  1465. float** WPtr,
  1466. int num_layers,
  1467. int* output_features,
  1468. float* dY,
  1469. float* reserved_space,
  1470. float* work_space,
  1471. float* dX,
  1472. float** dwPtr,
  1473. float** dbPtr,
  1474. bool requires_grad,
  1475. int use_bias,
  1476. int activation);
  1477. template int mlp_fp<at::Half>(
  1478. at::Half* X,
  1479. int input_features,
  1480. int batch_size,
  1481. at::Half** WPtr,
  1482. int num_layers,
  1483. int* output_features,
  1484. at::Half** BPtr,
  1485. at::Half* Y,
  1486. at::Half* reserved_space,
  1487. int use_bias,
  1488. int activation,
  1489. void* lt_workspace);
  1490. template int mlp_bp<at::Half>(
  1491. at::Half* X,
  1492. at::Half* Y,
  1493. int input_features,
  1494. int batch_size,
  1495. at::Half** WPtr,
  1496. int num_layers,
  1497. int* output_features,
  1498. at::Half* dY,
  1499. at::Half* reserved_space,
  1500. at::Half* work_space,
  1501. at::Half* dX,
  1502. at::Half** dwPtr,
  1503. at::Half** dbPtr,
  1504. bool requires_grad,
  1505. int use_bias,
  1506. int activation);
  1507. template int mlp_fp<double>(
  1508. double* X,
  1509. int input_features,
  1510. int batch_size,
  1511. double** WPtr,
  1512. int num_layers,
  1513. int* output_features,
  1514. double** BPtr,
  1515. double* Y,
  1516. double* reserved_space,
  1517. int use_bias,
  1518. int activation,
  1519. void* lt_workspace);
  1520. template int mlp_bp<double>(
  1521. double* X,
  1522. double* Y,
  1523. int input_features,
  1524. int batch_size,
  1525. double** WPtr,
  1526. int num_layers,
  1527. int* output_features,
  1528. double* dY,
  1529. double* reserved_space,
  1530. double* work_space,
  1531. double* dX,
  1532. double** dwPtr,
  1533. double** dbPtr,
  1534. bool requires_grad,
  1535. int use_bias,
  1536. int activation);
  1537. template size_t get_mlp_bp_workspace_in_bytes<float>(
  1538. int batch_size,
  1539. int num_layers,
  1540. const int* output_features);
  1541. template size_t get_mlp_bp_workspace_in_bytes<at::Half>(
  1542. int batch_size,
  1543. int num_layers,
  1544. const int* output_features);
  1545. template size_t get_mlp_bp_workspace_in_bytes<double>(
  1546. int batch_size,
  1547. int num_layers,
  1548. const int* output_features);