3
0

fused_dense_cuda.cu 68 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930
  1. #include <ATen/ATen.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <assert.h>
  4. #include <stdio.h>
  5. #include <stdlib.h>
  6. #include <string.h>
  7. #include <torch/torch.h>
  8. /* Includes, cuda */
  9. #include <cublas_v2.h>
  10. #include <cuda_runtime.h>
  11. #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
  12. // includes cublaslt
  13. #include <cublasLt.h>
  14. #endif
  15. // FP64 Wrapper around cublas GEMMEx
  16. cublasStatus_t gemm_bias(
  17. cublasHandle_t handle,
  18. cublasOperation_t transa,
  19. cublasOperation_t transb,
  20. int m,
  21. int n,
  22. int k,
  23. const float* alpha,
  24. double* A,
  25. int lda,
  26. double* B,
  27. int ldb,
  28. const float* beta,
  29. double* C,
  30. int ldc) {
  31. return cublasGemmEx(
  32. handle,
  33. transa,
  34. transb,
  35. m,
  36. n,
  37. k,
  38. alpha,
  39. A,
  40. CUDA_R_64F,
  41. lda,
  42. B,
  43. CUDA_R_64F,
  44. ldb,
  45. beta,
  46. C,
  47. CUDA_R_64F,
  48. ldc,
  49. CUDA_R_64F,
  50. CUBLAS_GEMM_DEFAULT);
  51. }
  52. // FP32 Wrapper around cublas GEMMEx
  53. cublasStatus_t gemm_bias(
  54. cublasHandle_t handle,
  55. cublasOperation_t transa,
  56. cublasOperation_t transb,
  57. int m,
  58. int n,
  59. int k,
  60. const float* alpha,
  61. float* A,
  62. int lda,
  63. float* B,
  64. int ldb,
  65. const float* beta,
  66. float* C,
  67. int ldc) {
  68. return cublasGemmEx(
  69. handle,
  70. transa,
  71. transb,
  72. m,
  73. n,
  74. k,
  75. alpha,
  76. A,
  77. CUDA_R_32F,
  78. lda,
  79. B,
  80. CUDA_R_32F,
  81. ldb,
  82. beta,
  83. C,
  84. CUDA_R_32F,
  85. ldc,
  86. CUDA_R_32F,
  87. CUBLAS_GEMM_DEFAULT);
  88. }
  89. // FP16 Tensor core wrapper around cublas GEMMEx
  90. cublasStatus_t gemm_bias(
  91. cublasHandle_t handle,
  92. cublasOperation_t transa,
  93. cublasOperation_t transb,
  94. int m,
  95. int n,
  96. int k,
  97. const float* alpha,
  98. at::Half* A,
  99. int lda,
  100. at::Half* B,
  101. int ldb,
  102. const float* beta,
  103. at::Half* C,
  104. int ldc) {
  105. return cublasGemmEx(
  106. handle,
  107. transa,
  108. transb,
  109. m,
  110. n,
  111. k,
  112. alpha,
  113. A,
  114. CUDA_R_16F,
  115. lda,
  116. B,
  117. CUDA_R_16F,
  118. ldb,
  119. beta,
  120. C,
  121. CUDA_R_16F,
  122. ldc,
  123. CUDA_R_32F,
  124. CUBLAS_GEMM_DEFAULT_TENSOR_OP);
  125. }
  126. // BF16 Tensor core wrapper around cublas GEMMEx
  127. cublasStatus_t gemm_bias(
  128. cublasHandle_t handle,
  129. cublasOperation_t transa,
  130. cublasOperation_t transb,
  131. int m,
  132. int n,
  133. int k,
  134. const float* alpha,
  135. at::BFloat16* A,
  136. int lda,
  137. at::BFloat16* B,
  138. int ldb,
  139. const float* beta,
  140. at::BFloat16* C,
  141. int ldc) {
  142. return cublasGemmEx(
  143. handle,
  144. transa,
  145. transb,
  146. m,
  147. n,
  148. k,
  149. alpha,
  150. A,
  151. CUDA_R_16BF,
  152. lda,
  153. B,
  154. CUDA_R_16BF,
  155. ldb,
  156. beta,
  157. C,
  158. CUDA_R_16BF,
  159. ldc,
  160. CUDA_R_32F,
  161. CUBLAS_GEMM_DEFAULT_TENSOR_OP);
  162. }
  163. #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
  164. int gemm_bias_lt(
  165. cublasLtHandle_t ltHandle,
  166. cublasOperation_t transa,
  167. cublasOperation_t transb,
  168. int m,
  169. int n,
  170. int k,
  171. const float *alpha, /* host pointer */
  172. at::Half* A,
  173. int lda,
  174. at::Half* B,
  175. int ldb,
  176. const float *beta, /* host pointer */
  177. at::Half* C,
  178. int ldc,
  179. void *workspace,
  180. size_t workspaceSize,
  181. cudaStream_t stream,
  182. bool use_bias,
  183. const void* bias) {
  184. cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
  185. cublasLtMatmulDescOpaque_t operationDesc = {};
  186. cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  187. cublasLtMatmulPreferenceOpaque_t preference = {};
  188. int returnedResults = 0;
  189. cublasLtMatmulHeuristicResult_t heuristicResult = {};
  190. cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
  191. // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  192. // for details about defaults; here we just set the transforms for
  193. // A and B.
  194. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  195. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  196. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  197. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  198. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  199. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  200. if (use_bias) {
  201. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
  202. if (status != CUBLAS_STATUS_SUCCESS) {
  203. goto CLEANUP;
  204. }
  205. epilogue = CUBLASLT_EPILOGUE_BIAS;
  206. }
  207. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  208. if (status != CUBLAS_STATUS_SUCCESS) {
  209. goto CLEANUP;
  210. }
  211. // Create matrix descriptors. Not setting any extra attributes.
  212. status = cublasLtMatrixLayoutInit(
  213. &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
  214. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  215. status = cublasLtMatrixLayoutInit(
  216. &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
  217. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  218. status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);
  219. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  220. // Create preference handle; In general, extra attributes can be
  221. // used here to disable tensor ops or to make sure algo selected
  222. // will work with badly aligned A, B, C. However, for simplicity
  223. // here we assume A,B,C are always well aligned (e.g., directly
  224. // come from cudaMalloc)
  225. status = cublasLtMatmulPreferenceInit(&preference);
  226. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  227. status = cublasLtMatmulPreferenceSetAttribute(
  228. &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  229. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  230. // We just need the best available heuristic to try and run matmul.
  231. // There is no guarantee that this will work. For example, if A is
  232. // badly aligned, you can request more (e.g. 32) algos and try to
  233. // run them one by one until something works.
  234. status = cublasLtMatmulAlgoGetHeuristic(
  235. ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  236. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  237. if (returnedResults == 0) {
  238. status = CUBLAS_STATUS_NOT_SUPPORTED;
  239. goto CLEANUP;
  240. }
  241. status = cublasLtMatmul(ltHandle,
  242. &operationDesc,
  243. alpha,
  244. A,
  245. &Adesc,
  246. B,
  247. &Bdesc,
  248. beta,
  249. C,
  250. &Cdesc,
  251. C,
  252. &Cdesc,
  253. //&heuristicResult.algo,
  254. NULL,
  255. workspace,
  256. workspaceSize,
  257. stream);
  258. CLEANUP:
  259. // Descriptors are no longer needed as all GPU work was already
  260. // enqueued.
  261. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
  262. }
  263. int gemm_bias_lt(
  264. cublasLtHandle_t ltHandle,
  265. cublasOperation_t transa,
  266. cublasOperation_t transb,
  267. int m,
  268. int n,
  269. int k,
  270. const float *alpha, /* host pointer */
  271. at::BFloat16* A,
  272. int lda,
  273. at::BFloat16* B,
  274. int ldb,
  275. const float *beta, /* host pointer */
  276. at::BFloat16* C,
  277. int ldc,
  278. void *workspace,
  279. size_t workspaceSize,
  280. cudaStream_t stream,
  281. bool use_bias,
  282. const void* bias) {
  283. cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
  284. cublasLtMatmulDescOpaque_t operationDesc = {};
  285. cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  286. cublasLtMatmulPreferenceOpaque_t preference = {};
  287. int returnedResults = 0;
  288. cublasLtMatmulHeuristicResult_t heuristicResult = {};
  289. cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
  290. // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  291. // for details about defaults; here we just set the transforms for
  292. // A and B.
  293. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  294. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  295. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  296. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  297. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  298. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  299. if (use_bias) {
  300. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
  301. if (status != CUBLAS_STATUS_SUCCESS) {
  302. goto CLEANUP;
  303. }
  304. epilogue = CUBLASLT_EPILOGUE_BIAS;
  305. }
  306. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  307. if (status != CUBLAS_STATUS_SUCCESS) {
  308. goto CLEANUP;
  309. }
  310. // Create matrix descriptors. Not setting any extra attributes.
  311. status = cublasLtMatrixLayoutInit(
  312. &Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
  313. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  314. status = cublasLtMatrixLayoutInit(
  315. &Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
  316. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  317. status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc);
  318. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  319. // Create preference handle; In general, extra attributes can be
  320. // used here to disable tensor ops or to make sure algo selected
  321. // will work with badly aligned A, B, C. However, for simplicity
  322. // here we assume A,B,C are always well aligned (e.g., directly
  323. // come from cudaMalloc)
  324. status = cublasLtMatmulPreferenceInit(&preference);
  325. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  326. status = cublasLtMatmulPreferenceSetAttribute(
  327. &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  328. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  329. // We just need the best available heuristic to try and run matmul.
  330. // There is no guarantee that this will work. For example, if A is
  331. // badly aligned, you can request more (e.g. 32) algos and try to
  332. // run them one by one until something works.
  333. status = cublasLtMatmulAlgoGetHeuristic(
  334. ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  335. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  336. if (returnedResults == 0) {
  337. status = CUBLAS_STATUS_NOT_SUPPORTED;
  338. goto CLEANUP;
  339. }
  340. status = cublasLtMatmul(ltHandle,
  341. &operationDesc,
  342. alpha,
  343. A,
  344. &Adesc,
  345. B,
  346. &Bdesc,
  347. beta,
  348. C,
  349. &Cdesc,
  350. C,
  351. &Cdesc,
  352. //&heuristicResult.algo,
  353. NULL,
  354. workspace,
  355. workspaceSize,
  356. stream);
  357. CLEANUP:
  358. // Descriptors are no longer needed as all GPU work was already
  359. // enqueued.
  360. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
  361. }
  362. int gemm_bias_lt(
  363. cublasLtHandle_t ltHandle,
  364. cublasOperation_t transa,
  365. cublasOperation_t transb,
  366. int m,
  367. int n,
  368. int k,
  369. const float *alpha, /* host pointer */
  370. double* A,
  371. int lda,
  372. double* B,
  373. int ldb,
  374. const float *beta, /* host pointer */
  375. double* C,
  376. int ldc,
  377. void *workspace,
  378. size_t workspaceSize,
  379. cudaStream_t stream,
  380. bool use_bias,
  381. const void* bias) {
  382. return 1;
  383. }
  384. int gemm_bias_lt(
  385. cublasLtHandle_t ltHandle,
  386. cublasOperation_t transa,
  387. cublasOperation_t transb,
  388. int m,
  389. int n,
  390. int k,
  391. const float *alpha, /* host pointer */
  392. float *A,
  393. int lda,
  394. float *B,
  395. int ldb,
  396. const float *beta, /* host pointer */
  397. float *C,
  398. int ldc,
  399. void *workspace,
  400. size_t workspaceSize,
  401. cudaStream_t stream,
  402. bool use_bias,
  403. const void* bias) {
  404. cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
  405. cublasLtMatmulDescOpaque_t operationDesc = {};
  406. cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  407. cublasLtMatmulPreferenceOpaque_t preference = {};
  408. int returnedResults = 0;
  409. cublasLtMatmulHeuristicResult_t heuristicResult = {};
  410. cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
  411. // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  412. // for details about defaults; here we just set the transforms for
  413. // A and B.
  414. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  415. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  416. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  417. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  418. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  419. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  420. if (use_bias) {
  421. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
  422. if (status != CUBLAS_STATUS_SUCCESS) {
  423. goto CLEANUP;
  424. }
  425. epilogue = CUBLASLT_EPILOGUE_BIAS;
  426. }
  427. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  428. if (status != CUBLAS_STATUS_SUCCESS) {
  429. goto CLEANUP;
  430. }
  431. // Create matrix descriptors. Not setting any extra attributes.
  432. status = cublasLtMatrixLayoutInit(
  433. &Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
  434. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  435. status = cublasLtMatrixLayoutInit(
  436. &Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
  437. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  438. status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc);
  439. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  440. // Create preference handle; In general, extra attributes can be
  441. // used here to disable tensor ops or to make sure algo selected
  442. // will work with badly aligned A, B, C. However, for simplicity
  443. // here we assume A,B,C are always well aligned (e.g., directly
  444. // come from cudaMalloc)
  445. status = cublasLtMatmulPreferenceInit(&preference);
  446. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  447. status = cublasLtMatmulPreferenceSetAttribute(
  448. &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  449. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  450. // We just need the best available heuristic to try and run matmul.
  451. // There is no guarantee that this will work. For example, if A is
  452. // badly aligned, you can request more (e.g. 32) algos and try to
  453. // run them one by one until something works.
  454. status = cublasLtMatmulAlgoGetHeuristic(
  455. ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  456. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  457. if (returnedResults == 0) {
  458. status = CUBLAS_STATUS_NOT_SUPPORTED;
  459. goto CLEANUP;
  460. }
  461. status = cublasLtMatmul(ltHandle,
  462. &operationDesc,
  463. alpha,
  464. A,
  465. &Adesc,
  466. B,
  467. &Bdesc,
  468. beta,
  469. C,
  470. &Cdesc,
  471. C,
  472. &Cdesc,
  473. &heuristicResult.algo,
  474. workspace,
  475. workspaceSize,
  476. stream);
  477. CLEANUP:
  478. // Descriptors are no longer needed as all GPU work was already
  479. // enqueued.
  480. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
  481. }
  482. int gemm_bias_gelu_lt(
  483. cublasLtHandle_t ltHandle,
  484. cublasOperation_t transa,
  485. cublasOperation_t transb,
  486. int m,
  487. int n,
  488. int k,
  489. const float *alpha, /* host pointer */
  490. at::Half* A,
  491. int lda,
  492. at::Half* B,
  493. int ldb,
  494. const float *beta, /* host pointer */
  495. at::Half* C,
  496. int64_t ldc,
  497. void *workspace,
  498. size_t workspaceSize,
  499. cudaStream_t stream,
  500. bool use_bias,
  501. const void* gelu_in,
  502. const void* bias) {
  503. cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
  504. cublasLtMatmulDescOpaque_t operationDesc = {};
  505. cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  506. cublasLtMatmulPreferenceOpaque_t preference = {};
  507. int returnedResults = 0;
  508. cublasLtMatmulHeuristicResult_t heuristicResult = {};
  509. cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
  510. // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  511. // for details about defaults; here we just set the transforms for
  512. // A and B.
  513. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  514. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  515. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  516. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  517. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  518. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  519. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in));
  520. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
  521. if (use_bias) {
  522. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
  523. if (status != CUBLAS_STATUS_SUCCESS) {
  524. goto CLEANUP;
  525. }
  526. epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
  527. }
  528. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  529. if (status != CUBLAS_STATUS_SUCCESS) {
  530. goto CLEANUP;
  531. }
  532. // Create matrix descriptors. Not setting any extra attributes.
  533. status = cublasLtMatrixLayoutInit(
  534. &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
  535. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  536. status = cublasLtMatrixLayoutInit(
  537. &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
  538. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  539. status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);
  540. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  541. // Create preference handle; In general, extra attributes can be
  542. // used here to disable tensor ops or to make sure algo selected
  543. // will work with badly aligned A, B, C. However, for simplicity
  544. // here we assume A,B,C are always well aligned (e.g., directly
  545. // come from cudaMalloc)
  546. status = cublasLtMatmulPreferenceInit(&preference);
  547. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  548. status = cublasLtMatmulPreferenceSetAttribute(
  549. &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  550. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  551. // We just need the best available heuristic to try and run matmul.
  552. // There is no guarantee that this will work. For example, if A is
  553. // badly aligned, you can request more (e.g. 32) algos and try to
  554. // run them one by one until something works.
  555. status = cublasLtMatmulAlgoGetHeuristic(
  556. ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  557. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  558. if (returnedResults == 0) {
  559. status = CUBLAS_STATUS_NOT_SUPPORTED;
  560. goto CLEANUP;
  561. }
  562. status = cublasLtMatmul(ltHandle,
  563. &operationDesc,
  564. alpha,
  565. A,
  566. &Adesc,
  567. B,
  568. &Bdesc,
  569. beta,
  570. C,
  571. &Cdesc,
  572. C,
  573. &Cdesc,
  574. //&heuristicResult.algo,
  575. NULL,
  576. workspace,
  577. workspaceSize,
  578. stream);
  579. CLEANUP:
  580. // Descriptors are no longer needed as all GPU work was already
  581. // enqueued.
  582. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
  583. }
  584. int gemm_bias_gelu_lt(
  585. cublasLtHandle_t ltHandle,
  586. cublasOperation_t transa,
  587. cublasOperation_t transb,
  588. int m,
  589. int n,
  590. int k,
  591. const float *alpha, /* host pointer */
  592. at::BFloat16* A,
  593. int lda,
  594. at::BFloat16* B,
  595. int ldb,
  596. const float *beta, /* host pointer */
  597. at::BFloat16* C,
  598. int64_t ldc,
  599. void *workspace,
  600. size_t workspaceSize,
  601. cudaStream_t stream,
  602. bool use_bias,
  603. const void* gelu_in,
  604. const void* bias) {
  605. cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
  606. cublasLtMatmulDescOpaque_t operationDesc = {};
  607. cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  608. cublasLtMatmulPreferenceOpaque_t preference = {};
  609. int returnedResults = 0;
  610. cublasLtMatmulHeuristicResult_t heuristicResult = {};
  611. cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
  612. // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  613. // for details about defaults; here we just set the transforms for
  614. // A and B.
  615. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  616. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  617. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  618. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  619. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  620. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  621. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in));
  622. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
  623. if (use_bias) {
  624. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
  625. if (status != CUBLAS_STATUS_SUCCESS) {
  626. goto CLEANUP;
  627. }
  628. epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
  629. }
  630. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  631. if (status != CUBLAS_STATUS_SUCCESS) {
  632. goto CLEANUP;
  633. }
  634. // Create matrix descriptors. Not setting any extra attributes.
  635. status = cublasLtMatrixLayoutInit(
  636. &Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
  637. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  638. status = cublasLtMatrixLayoutInit(
  639. &Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
  640. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  641. status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc);
  642. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  643. // Create preference handle; In general, extra attributes can be
  644. // used here to disable tensor ops or to make sure algo selected
  645. // will work with badly aligned A, B, C. However, for simplicity
  646. // here we assume A,B,C are always well aligned (e.g., directly
  647. // come from cudaMalloc)
  648. status = cublasLtMatmulPreferenceInit(&preference);
  649. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  650. status = cublasLtMatmulPreferenceSetAttribute(
  651. &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  652. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  653. // We just need the best available heuristic to try and run matmul.
  654. // There is no guarantee that this will work. For example, if A is
  655. // badly aligned, you can request more (e.g. 32) algos and try to
  656. // run them one by one until something works.
  657. status = cublasLtMatmulAlgoGetHeuristic(
  658. ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  659. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  660. if (returnedResults == 0) {
  661. status = CUBLAS_STATUS_NOT_SUPPORTED;
  662. goto CLEANUP;
  663. }
  664. status = cublasLtMatmul(ltHandle,
  665. &operationDesc,
  666. alpha,
  667. A,
  668. &Adesc,
  669. B,
  670. &Bdesc,
  671. beta,
  672. C,
  673. &Cdesc,
  674. C,
  675. &Cdesc,
  676. //&heuristicResult.algo,
  677. NULL,
  678. workspace,
  679. workspaceSize,
  680. stream);
  681. CLEANUP:
  682. // Descriptors are no longer needed as all GPU work was already
  683. // enqueued.
  684. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
  685. }
  686. int gemm_bias_gelu_lt(
  687. cublasLtHandle_t ltHandle,
  688. cublasOperation_t transa,
  689. cublasOperation_t transb,
  690. int m,
  691. int n,
  692. int k,
  693. const float *alpha, /* host pointer */
  694. double* A,
  695. int lda,
  696. double* B,
  697. int ldb,
  698. const float *beta, /* host pointer */
  699. double* C,
  700. int ldc,
  701. void *workspace,
  702. size_t workspaceSize,
  703. cudaStream_t stream,
  704. bool use_bias,
  705. const void *gelu_in,
  706. const void* bias) {
  707. return 1;
  708. }
  709. int gemm_bias_gelu_lt(
  710. cublasLtHandle_t ltHandle,
  711. cublasOperation_t transa,
  712. cublasOperation_t transb,
  713. int m,
  714. int n,
  715. int k,
  716. const float *alpha, /* host pointer */
  717. float *A,
  718. int lda,
  719. float *B,
  720. int ldb,
  721. const float *beta, /* host pointer */
  722. float *C,
  723. int64_t ldc,
  724. void *workspace,
  725. size_t workspaceSize,
  726. cudaStream_t stream,
  727. bool use_bias,
  728. const void* gelu_in,
  729. const void* bias) {
  730. cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
  731. cublasLtMatmulDescOpaque_t operationDesc = {};
  732. cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  733. cublasLtMatmulPreferenceOpaque_t preference = {};
  734. int returnedResults = 0;
  735. cublasLtMatmulHeuristicResult_t heuristicResult = {};
  736. cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
  737. // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  738. // for details about defaults; here we just set the transforms for
  739. // A and B.
  740. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  741. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  742. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  743. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  744. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  745. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  746. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in));
  747. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
  748. if (use_bias) {
  749. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
  750. if (status != CUBLAS_STATUS_SUCCESS) {
  751. goto CLEANUP;
  752. }
  753. epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
  754. }
  755. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  756. if (status != CUBLAS_STATUS_SUCCESS) {
  757. goto CLEANUP;
  758. }
  759. // Create matrix descriptors. Not setting any extra attributes.
  760. status = cublasLtMatrixLayoutInit(
  761. &Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
  762. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  763. status = cublasLtMatrixLayoutInit(
  764. &Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
  765. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  766. status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc);
  767. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  768. // Create preference handle; In general, extra attributes can be
  769. // used here to disable tensor ops or to make sure algo selected
  770. // will work with badly aligned A, B, C. However, for simplicity
  771. // here we assume A,B,C are always well aligned (e.g., directly
  772. // come from cudaMalloc)
  773. status = cublasLtMatmulPreferenceInit(&preference);
  774. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  775. status = cublasLtMatmulPreferenceSetAttribute(
  776. &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  777. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  778. // We just need the best available heuristic to try and run matmul.
  779. // There is no guarantee that this will work. For example, if A is
  780. // badly aligned, you can request more (e.g. 32) algos and try to
  781. // run them one by one until something works.
  782. status = cublasLtMatmulAlgoGetHeuristic(
  783. ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  784. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  785. if (returnedResults == 0) {
  786. status = CUBLAS_STATUS_NOT_SUPPORTED;
  787. goto CLEANUP;
  788. }
  789. status = cublasLtMatmul(ltHandle,
  790. &operationDesc,
  791. alpha,
  792. A,
  793. &Adesc,
  794. B,
  795. &Bdesc,
  796. beta,
  797. C,
  798. &Cdesc,
  799. C,
  800. &Cdesc,
  801. //&heuristicResult.algo,
  802. NULL,
  803. workspace,
  804. workspaceSize,
  805. stream);
  806. CLEANUP:
  807. // Descriptors are no longer needed as all GPU work was already
  808. // enqueued.
  809. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
  810. }
  811. int gemm_bgradb_lt(
  812. cublasLtHandle_t ltHandle,
  813. cublasOperation_t transa,
  814. cublasOperation_t transb,
  815. int m,
  816. int n,
  817. int k,
  818. const float *alpha, /* host pointer */
  819. at::Half* A,
  820. int lda,
  821. at::Half* B,
  822. int ldb,
  823. const float *beta, /* host pointer */
  824. at::Half* C,
  825. int ldc,
  826. void *workspace,
  827. size_t workspaceSize,
  828. cudaStream_t stream,
  829. bool use_bias,
  830. const void* bgrad) {
  831. cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
  832. cublasLtMatmulDescOpaque_t operationDesc = {};
  833. cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  834. cublasLtMatmulPreferenceOpaque_t preference = {};
  835. int returnedResults = 0;
  836. cublasLtMatmulHeuristicResult_t heuristicResult = {};
  837. cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
  838. // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  839. // for details about defaults; here we just set the transforms for
  840. // A and B.
  841. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  842. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  843. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  844. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  845. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  846. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  847. if (use_bias) {
  848. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
  849. if (status != CUBLAS_STATUS_SUCCESS) {
  850. goto CLEANUP;
  851. }
  852. epilogue = CUBLASLT_EPILOGUE_BGRADB;
  853. }
  854. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  855. if (status != CUBLAS_STATUS_SUCCESS) {
  856. goto CLEANUP;
  857. }
  858. // Create matrix descriptors. Not setting any extra attributes.
  859. status = cublasLtMatrixLayoutInit(
  860. &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
  861. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  862. status = cublasLtMatrixLayoutInit(
  863. &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
  864. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  865. status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);
  866. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  867. // Create preference handle; In general, extra attributes can be
  868. // used here to disable tensor ops or to make sure algo selected
  869. // will work with badly aligned A, B, C. However, for simplicity
  870. // here we assume A,B,C are always well aligned (e.g., directly
  871. // come from cudaMalloc)
  872. status = cublasLtMatmulPreferenceInit(&preference);
  873. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  874. status = cublasLtMatmulPreferenceSetAttribute(
  875. &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  876. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  877. // We just need the best available heuristic to try and run matmul.
  878. // There is no guarantee that this will work. For example, if A is
  879. // badly aligned, you can request more (e.g. 32) algos and try to
  880. // run them one by one until something works.
  881. status = cublasLtMatmulAlgoGetHeuristic(
  882. ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  883. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  884. if (returnedResults == 0) {
  885. status = CUBLAS_STATUS_NOT_SUPPORTED;
  886. goto CLEANUP;
  887. }
  888. status = cublasLtMatmul(ltHandle,
  889. &operationDesc,
  890. alpha,
  891. A,
  892. &Adesc,
  893. B,
  894. &Bdesc,
  895. beta,
  896. C,
  897. &Cdesc,
  898. C,
  899. &Cdesc,
  900. //&heuristicResult.algo,
  901. NULL,
  902. workspace,
  903. workspaceSize,
  904. stream);
  905. CLEANUP:
  906. // Descriptors are no longer needed as all GPU work was already
  907. // enqueued.
  908. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
  909. }
  910. int gemm_bgradb_lt(
  911. cublasLtHandle_t ltHandle,
  912. cublasOperation_t transa,
  913. cublasOperation_t transb,
  914. int m,
  915. int n,
  916. int k,
  917. const float *alpha, /* host pointer */
  918. at::BFloat16* A,
  919. int lda,
  920. at::BFloat16* B,
  921. int ldb,
  922. const float *beta, /* host pointer */
  923. at::BFloat16* C,
  924. int ldc,
  925. void *workspace,
  926. size_t workspaceSize,
  927. cudaStream_t stream,
  928. bool use_bias,
  929. const void* bgrad) {
  930. cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
  931. cublasLtMatmulDescOpaque_t operationDesc = {};
  932. cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  933. cublasLtMatmulPreferenceOpaque_t preference = {};
  934. int returnedResults = 0;
  935. cublasLtMatmulHeuristicResult_t heuristicResult = {};
  936. cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
  937. // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  938. // for details about defaults; here we just set the transforms for
  939. // A and B.
  940. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  941. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  942. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  943. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  944. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  945. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  946. if (use_bias) {
  947. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
  948. if (status != CUBLAS_STATUS_SUCCESS) {
  949. goto CLEANUP;
  950. }
  951. epilogue = CUBLASLT_EPILOGUE_BGRADB;
  952. }
  953. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  954. if (status != CUBLAS_STATUS_SUCCESS) {
  955. goto CLEANUP;
  956. }
  957. // Create matrix descriptors. Not setting any extra attributes.
  958. status = cublasLtMatrixLayoutInit(
  959. &Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
  960. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  961. status = cublasLtMatrixLayoutInit(
  962. &Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
  963. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  964. status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc);
  965. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  966. // Create preference handle; In general, extra attributes can be
  967. // used here to disable tensor ops or to make sure algo selected
  968. // will work with badly aligned A, B, C. However, for simplicity
  969. // here we assume A,B,C are always well aligned (e.g., directly
  970. // come from cudaMalloc)
  971. status = cublasLtMatmulPreferenceInit(&preference);
  972. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  973. status = cublasLtMatmulPreferenceSetAttribute(
  974. &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  975. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  976. // We just need the best available heuristic to try and run matmul.
  977. // There is no guarantee that this will work. For example, if A is
  978. // badly aligned, you can request more (e.g. 32) algos and try to
  979. // run them one by one until something works.
  980. status = cublasLtMatmulAlgoGetHeuristic(
  981. ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  982. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  983. if (returnedResults == 0) {
  984. status = CUBLAS_STATUS_NOT_SUPPORTED;
  985. goto CLEANUP;
  986. }
  987. status = cublasLtMatmul(ltHandle,
  988. &operationDesc,
  989. alpha,
  990. A,
  991. &Adesc,
  992. B,
  993. &Bdesc,
  994. beta,
  995. C,
  996. &Cdesc,
  997. C,
  998. &Cdesc,
  999. //&heuristicResult.algo,
  1000. NULL,
  1001. workspace,
  1002. workspaceSize,
  1003. stream);
  1004. CLEANUP:
  1005. // Descriptors are no longer needed as all GPU work was already
  1006. // enqueued.
  1007. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
  1008. }
  1009. int gemm_bgradb_lt(
  1010. cublasLtHandle_t ltHandle,
  1011. cublasOperation_t transa,
  1012. cublasOperation_t transb,
  1013. int m,
  1014. int n,
  1015. int k,
  1016. const float *alpha, /* host pointer */
  1017. double* A,
  1018. int lda,
  1019. double* B,
  1020. int ldb,
  1021. const float *beta, /* host pointer */
  1022. double* C,
  1023. int ldc,
  1024. void *workspace,
  1025. size_t workspaceSize,
  1026. cudaStream_t stream,
  1027. bool use_bias,
  1028. const void* bgrad) {
  1029. return 1;
  1030. }
  1031. int gemm_bgradb_lt(
  1032. cublasLtHandle_t ltHandle,
  1033. cublasOperation_t transa,
  1034. cublasOperation_t transb,
  1035. int m,
  1036. int n,
  1037. int k,
  1038. const float *alpha, /* host pointer */
  1039. float *A,
  1040. int lda,
  1041. float *B,
  1042. int ldb,
  1043. const float *beta, /* host pointer */
  1044. float *C,
  1045. int ldc,
  1046. void *workspace,
  1047. size_t workspaceSize,
  1048. cudaStream_t stream,
  1049. bool use_bias,
  1050. const void* bgrad) {
  1051. cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
  1052. cublasLtMatmulDescOpaque_t operationDesc = {};
  1053. cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  1054. cublasLtMatmulPreferenceOpaque_t preference = {};
  1055. int returnedResults = 0;
  1056. cublasLtMatmulHeuristicResult_t heuristicResult = {};
  1057. cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
  1058. // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  1059. // for details about defaults; here we just set the transforms for
  1060. // A and B.
  1061. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  1062. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1063. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  1064. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1065. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  1066. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1067. if (use_bias) {
  1068. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
  1069. if (status != CUBLAS_STATUS_SUCCESS) {
  1070. goto CLEANUP;
  1071. }
  1072. epilogue = CUBLASLT_EPILOGUE_BGRADB;
  1073. }
  1074. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  1075. if (status != CUBLAS_STATUS_SUCCESS) {
  1076. goto CLEANUP;
  1077. }
  1078. // Create matrix descriptors. Not setting any extra attributes.
  1079. status = cublasLtMatrixLayoutInit(
  1080. &Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
  1081. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1082. status = cublasLtMatrixLayoutInit(
  1083. &Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
  1084. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1085. status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc);
  1086. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1087. // Create preference handle; In general, extra attributes can be
  1088. // used here to disable tensor ops or to make sure algo selected
  1089. // will work with badly aligned A, B, C. However, for simplicity
  1090. // here we assume A,B,C are always well aligned (e.g., directly
  1091. // come from cudaMalloc)
  1092. status = cublasLtMatmulPreferenceInit(&preference);
  1093. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1094. status = cublasLtMatmulPreferenceSetAttribute(
  1095. &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  1096. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1097. // We just need the best available heuristic to try and run matmul.
  1098. // There is no guarantee that this will work. For example, if A is
  1099. // badly aligned, you can request more (e.g. 32) algos and try to
  1100. // run them one by one until something works.
  1101. status = cublasLtMatmulAlgoGetHeuristic(
  1102. ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  1103. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1104. if (returnedResults == 0) {
  1105. status = CUBLAS_STATUS_NOT_SUPPORTED;
  1106. goto CLEANUP;
  1107. }
  1108. status = cublasLtMatmul(ltHandle,
  1109. &operationDesc,
  1110. alpha,
  1111. A,
  1112. &Adesc,
  1113. B,
  1114. &Bdesc,
  1115. beta,
  1116. C,
  1117. &Cdesc,
  1118. C,
  1119. &Cdesc,
  1120. &heuristicResult.algo,
  1121. workspace,
  1122. workspaceSize,
  1123. stream);
  1124. CLEANUP:
  1125. // Descriptors are no longer needed as all GPU work was already
  1126. // enqueued.
  1127. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
  1128. }
  1129. int gemm_dgelu_bgradb_lt(
  1130. cublasLtHandle_t ltHandle,
  1131. cublasOperation_t transa,
  1132. cublasOperation_t transb,
  1133. int m,
  1134. int n,
  1135. int k,
  1136. const float *alpha, /* host pointer */
  1137. at::Half* A,
  1138. int lda,
  1139. at::Half* B,
  1140. int ldb,
  1141. const float *beta, /* host pointer */
  1142. at::Half* C,
  1143. int64_t ldc,
  1144. void *workspace,
  1145. size_t workspaceSize,
  1146. cudaStream_t stream,
  1147. const void *gelu_in,
  1148. const void *bgrad) {
  1149. cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
  1150. cublasLtMatmulDescOpaque_t operationDesc = {};
  1151. cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  1152. cublasLtMatmulPreferenceOpaque_t preference = {};
  1153. int returnedResults = 0;
  1154. cublasLtMatmulHeuristicResult_t heuristicResult = {};
  1155. cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
  1156. // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  1157. // for details about defaults; here we just set the transforms for
  1158. // A and B.
  1159. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  1160. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1161. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  1162. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1163. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  1164. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1165. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
  1166. if (status != CUBLAS_STATUS_SUCCESS) {
  1167. goto CLEANUP;
  1168. }
  1169. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in));
  1170. if (status != CUBLAS_STATUS_SUCCESS) {
  1171. goto CLEANUP;
  1172. }
  1173. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
  1174. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  1175. if (status != CUBLAS_STATUS_SUCCESS) {
  1176. goto CLEANUP;
  1177. }
  1178. // Create matrix descriptors. Not setting any extra attributes.
  1179. status = cublasLtMatrixLayoutInit(
  1180. &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
  1181. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1182. status = cublasLtMatrixLayoutInit(
  1183. &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
  1184. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1185. status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);
  1186. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1187. // Create preference handle; In general, extra attributes can be
  1188. // used here to disable tensor ops or to make sure algo selected
  1189. // will work with badly aligned A, B, C. However, for simplicity
  1190. // here we assume A,B,C are always well aligned (e.g., directly
  1191. // come from cudaMalloc)
  1192. status = cublasLtMatmulPreferenceInit(&preference);
  1193. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1194. status = cublasLtMatmulPreferenceSetAttribute(
  1195. &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  1196. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1197. // We just need the best available heuristic to try and run matmul.
  1198. // There is no guarantee that this will work. For example, if A is
  1199. // badly aligned, you can request more (e.g. 32) algos and try to
  1200. // run them one by one until something works.
  1201. status = cublasLtMatmulAlgoGetHeuristic(
  1202. ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  1203. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1204. if (returnedResults == 0) {
  1205. status = CUBLAS_STATUS_NOT_SUPPORTED;
  1206. goto CLEANUP;
  1207. }
  1208. status = cublasLtMatmul(ltHandle,
  1209. &operationDesc,
  1210. alpha,
  1211. A,
  1212. &Adesc,
  1213. B,
  1214. &Bdesc,
  1215. beta,
  1216. C,
  1217. &Cdesc,
  1218. C,
  1219. &Cdesc,
  1220. //&heuristicResult.algo,
  1221. NULL,
  1222. workspace,
  1223. workspaceSize,
  1224. stream);
  1225. CLEANUP:
  1226. // Descriptors are no longer needed as all GPU work was already
  1227. // enqueued.
  1228. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
  1229. }
  1230. int gemm_dgelu_bgradb_lt(
  1231. cublasLtHandle_t ltHandle,
  1232. cublasOperation_t transa,
  1233. cublasOperation_t transb,
  1234. int m,
  1235. int n,
  1236. int k,
  1237. const float *alpha, /* host pointer */
  1238. at::BFloat16* A,
  1239. int lda,
  1240. at::BFloat16* B,
  1241. int ldb,
  1242. const float *beta, /* host pointer */
  1243. at::BFloat16* C,
  1244. int64_t ldc,
  1245. void *workspace,
  1246. size_t workspaceSize,
  1247. cudaStream_t stream,
  1248. const void *gelu_in,
  1249. const void *bgrad) {
  1250. cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
  1251. cublasLtMatmulDescOpaque_t operationDesc = {};
  1252. cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  1253. cublasLtMatmulPreferenceOpaque_t preference = {};
  1254. int returnedResults = 0;
  1255. cublasLtMatmulHeuristicResult_t heuristicResult = {};
  1256. cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
  1257. // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  1258. // for details about defaults; here we just set the transforms for
  1259. // A and B.
  1260. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  1261. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1262. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  1263. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1264. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  1265. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1266. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
  1267. if (status != CUBLAS_STATUS_SUCCESS) {
  1268. goto CLEANUP;
  1269. }
  1270. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in));
  1271. if (status != CUBLAS_STATUS_SUCCESS) {
  1272. goto CLEANUP;
  1273. }
  1274. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
  1275. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  1276. if (status != CUBLAS_STATUS_SUCCESS) {
  1277. goto CLEANUP;
  1278. }
  1279. // Create matrix descriptors. Not setting any extra attributes.
  1280. status = cublasLtMatrixLayoutInit(
  1281. &Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
  1282. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1283. status = cublasLtMatrixLayoutInit(
  1284. &Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
  1285. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1286. status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc);
  1287. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1288. // Create preference handle; In general, extra attributes can be
  1289. // used here to disable tensor ops or to make sure algo selected
  1290. // will work with badly aligned A, B, C. However, for simplicity
  1291. // here we assume A,B,C are always well aligned (e.g., directly
  1292. // come from cudaMalloc)
  1293. status = cublasLtMatmulPreferenceInit(&preference);
  1294. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1295. status = cublasLtMatmulPreferenceSetAttribute(
  1296. &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  1297. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1298. // We just need the best available heuristic to try and run matmul.
  1299. // There is no guarantee that this will work. For example, if A is
  1300. // badly aligned, you can request more (e.g. 32) algos and try to
  1301. // run them one by one until something works.
  1302. status = cublasLtMatmulAlgoGetHeuristic(
  1303. ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  1304. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1305. if (returnedResults == 0) {
  1306. status = CUBLAS_STATUS_NOT_SUPPORTED;
  1307. goto CLEANUP;
  1308. }
  1309. status = cublasLtMatmul(ltHandle,
  1310. &operationDesc,
  1311. alpha,
  1312. A,
  1313. &Adesc,
  1314. B,
  1315. &Bdesc,
  1316. beta,
  1317. C,
  1318. &Cdesc,
  1319. C,
  1320. &Cdesc,
  1321. //&heuristicResult.algo,
  1322. NULL,
  1323. workspace,
  1324. workspaceSize,
  1325. stream);
  1326. CLEANUP:
  1327. // Descriptors are no longer needed as all GPU work was already
  1328. // enqueued.
  1329. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
  1330. }
  1331. int gemm_dgelu_bgradb_lt(
  1332. cublasLtHandle_t ltHandle,
  1333. cublasOperation_t transa,
  1334. cublasOperation_t transb,
  1335. int m,
  1336. int n,
  1337. int k,
  1338. const float *alpha, /* host pointer */
  1339. double *A,
  1340. int lda,
  1341. double *B,
  1342. int ldb,
  1343. const float *beta, /* host pointer */
  1344. double *C,
  1345. int ldc,
  1346. void *workspace,
  1347. size_t workspaceSize,
  1348. cudaStream_t stream,
  1349. const void *gelu_in,
  1350. const void *bgrad) {
  1351. return 1;
  1352. }
  1353. int gemm_dgelu_bgradb_lt(
  1354. cublasLtHandle_t ltHandle,
  1355. cublasOperation_t transa,
  1356. cublasOperation_t transb,
  1357. int m,
  1358. int n,
  1359. int k,
  1360. const float *alpha, /* host pointer */
  1361. float *A,
  1362. int lda,
  1363. float *B,
  1364. int ldb,
  1365. const float *beta, /* host pointer */
  1366. float *C,
  1367. int64_t ldc,
  1368. void *workspace,
  1369. size_t workspaceSize,
  1370. cudaStream_t stream,
  1371. const void *gelu_in,
  1372. const void *bgrad) {
  1373. cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
  1374. cublasLtMatmulDescOpaque_t operationDesc = {};
  1375. cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
  1376. cublasLtMatmulPreferenceOpaque_t preference = {};
  1377. int returnedResults = 0;
  1378. cublasLtMatmulHeuristicResult_t heuristicResult = {};
  1379. cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
  1380. // Create operation descriptor; see cublasLtMatmulDescAttributes_t
  1381. // for details about defaults; here we just set the transforms for
  1382. // A and B.
  1383. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
  1384. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1385. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
  1386. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1387. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
  1388. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1389. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
  1390. if (status != CUBLAS_STATUS_SUCCESS) {
  1391. goto CLEANUP;
  1392. }
  1393. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in));
  1394. if (status != CUBLAS_STATUS_SUCCESS) {
  1395. goto CLEANUP;
  1396. }
  1397. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
  1398. status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
  1399. if (status != CUBLAS_STATUS_SUCCESS) {
  1400. goto CLEANUP;
  1401. }
  1402. // Create matrix descriptors. Not setting any extra attributes.
  1403. status = cublasLtMatrixLayoutInit(
  1404. &Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
  1405. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1406. status = cublasLtMatrixLayoutInit(
  1407. &Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
  1408. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1409. status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc);
  1410. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1411. // Create preference handle; In general, extra attributes can be
  1412. // used here to disable tensor ops or to make sure algo selected
  1413. // will work with badly aligned A, B, C. However, for simplicity
  1414. // here we assume A,B,C are always well aligned (e.g., directly
  1415. // come from cudaMalloc)
  1416. status = cublasLtMatmulPreferenceInit(&preference);
  1417. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1418. status = cublasLtMatmulPreferenceSetAttribute(
  1419. &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
  1420. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1421. // We just need the best available heuristic to try and run matmul.
  1422. // There is no guarantee that this will work. For example, if A is
  1423. // badly aligned, you can request more (e.g. 32) algos and try to
  1424. // run them one by one until something works.
  1425. status = cublasLtMatmulAlgoGetHeuristic(
  1426. ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
  1427. if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
  1428. if (returnedResults == 0) {
  1429. status = CUBLAS_STATUS_NOT_SUPPORTED;
  1430. goto CLEANUP;
  1431. }
  1432. status = cublasLtMatmul(ltHandle,
  1433. &operationDesc,
  1434. alpha,
  1435. A,
  1436. &Adesc,
  1437. B,
  1438. &Bdesc,
  1439. beta,
  1440. C,
  1441. &Cdesc,
  1442. C,
  1443. &Cdesc,
  1444. //&heuristicResult.algo,
  1445. NULL,
  1446. workspace,
  1447. workspaceSize,
  1448. stream);
  1449. CLEANUP:
  1450. // Descriptors are no longer needed as all GPU work was already
  1451. // enqueued.
  1452. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
  1453. }
  1454. #endif
  1455. template <typename T>
  1456. int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace) {
  1457. cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
  1458. // Get the stream from cublas handle to reuse for biasReLU kernel.
  1459. cudaStream_t stream;
  1460. cublasGetStream(handle, &stream);
  1461. const float alpha = 1.0;
  1462. const float beta_zero = 0.0;
  1463. const float beta_one = 1.0;
  1464. int status = 1;
  1465. #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
  1466. status = gemm_bias_lt(
  1467. (cublasLtHandle_t)handle,
  1468. CUBLAS_OP_T,
  1469. CUBLAS_OP_N,
  1470. out_features,
  1471. batch_size,
  1472. in_features,
  1473. &alpha, /* host pointer */
  1474. weight,
  1475. in_features,
  1476. input.data_ptr<T>(),
  1477. in_features,
  1478. &beta_zero, /* host pointer */
  1479. output.data_ptr<T>(),
  1480. out_features,
  1481. lt_workspace,
  1482. 1 << 22,
  1483. stream,
  1484. true,
  1485. static_cast<const void*>(bias.data_ptr<T>()));
  1486. #endif
  1487. if (status != 0){
  1488. output.copy_(bias);
  1489. status = gemm_bias(
  1490. handle,
  1491. CUBLAS_OP_T,
  1492. CUBLAS_OP_N,
  1493. out_features,
  1494. batch_size,
  1495. in_features,
  1496. &alpha,
  1497. weight,
  1498. in_features,
  1499. input.data_ptr<T>(),
  1500. in_features,
  1501. &beta_one,
  1502. output.data_ptr<T>(),
  1503. out_features);
  1504. }
  1505. return status;
  1506. }
  1507. template <typename T>
  1508. int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input, void *lt_workspace) {
  1509. cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
  1510. // Get the stream from cublas handle to reuse for biasReLU kernel.
  1511. cudaStream_t stream;
  1512. cublasGetStream(handle, &stream);
  1513. const float alpha = 1.0;
  1514. const float beta_zero = 0.0;
  1515. int status = 1;
  1516. #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
  1517. status = gemm_bgradb_lt(
  1518. (cublasLtHandle_t)handle,
  1519. CUBLAS_OP_N,
  1520. CUBLAS_OP_T,
  1521. in_features,
  1522. out_features,
  1523. batch_size,
  1524. &alpha, /* host pointer */
  1525. input,
  1526. in_features,
  1527. d_output,
  1528. out_features,
  1529. &beta_zero, /* host pointer */
  1530. d_weight,
  1531. in_features,
  1532. lt_workspace,
  1533. 1 << 22,
  1534. stream,
  1535. true,
  1536. static_cast<const void*>(d_bias));
  1537. #endif
  1538. if (status != 0){
  1539. status = gemm_bias(
  1540. handle,
  1541. CUBLAS_OP_N,
  1542. CUBLAS_OP_T,
  1543. in_features,
  1544. out_features,
  1545. batch_size,
  1546. &alpha,
  1547. input,
  1548. in_features,
  1549. d_output,
  1550. out_features,
  1551. &beta_zero,
  1552. d_weight,
  1553. in_features);
  1554. }
  1555. status = gemm_bias(
  1556. handle,
  1557. CUBLAS_OP_N,
  1558. CUBLAS_OP_N,
  1559. in_features,
  1560. batch_size,
  1561. out_features,
  1562. &alpha,
  1563. weight,
  1564. in_features,
  1565. d_output,
  1566. out_features,
  1567. &beta_zero,
  1568. d_input,
  1569. in_features);
  1570. return status;
  1571. }
  1572. template <typename T>
  1573. int linear_gelu_linear_forward_cuda(T *input, T *weight1, T *bias1, T *weight2, T *bias2, int in_features, int hidden_features, int batch_size, int out_features, T *output1, T *output2, T *gelu_in, void *lt_workspace) {
  1574. cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
  1575. // Get the stream from cublas handle to reuse for biasReLU kernel.
  1576. cudaStream_t stream;
  1577. cublasGetStream(handle, &stream);
  1578. const float alpha = 1.0;
  1579. const float beta_zero = 0.0;
  1580. int status = 1;
  1581. #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
  1582. status = gemm_bias_gelu_lt(
  1583. (cublasLtHandle_t)handle,
  1584. CUBLAS_OP_T,
  1585. CUBLAS_OP_N,
  1586. hidden_features,
  1587. batch_size,
  1588. in_features,
  1589. &alpha, /* host pointer */
  1590. weight1,
  1591. in_features,
  1592. input,
  1593. in_features,
  1594. &beta_zero, /* host pointer */
  1595. output1,
  1596. hidden_features,
  1597. lt_workspace,
  1598. 1 << 22,
  1599. stream,
  1600. true,
  1601. static_cast<const void*>(gelu_in),
  1602. static_cast<const void*>(bias1));
  1603. status = gemm_bias_lt(
  1604. (cublasLtHandle_t)handle,
  1605. CUBLAS_OP_T,
  1606. CUBLAS_OP_N,
  1607. out_features,
  1608. batch_size,
  1609. hidden_features,
  1610. &alpha, /* host pointer */
  1611. weight2,
  1612. hidden_features,
  1613. output1,
  1614. hidden_features,
  1615. &beta_zero, /* host pointer */
  1616. output2,
  1617. out_features,
  1618. lt_workspace,
  1619. 1 << 22,
  1620. stream,
  1621. true,
  1622. static_cast<const void*>(bias2));
  1623. return status;
  1624. #else
  1625. return 1;
  1626. #endif
  1627. }
  1628. template <typename T>
  1629. int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, void *lt_workspace) {
  1630. cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
  1631. // Get the stream from cublas handle to reuse for biasReLU kernel.
  1632. cudaStream_t stream;
  1633. cublasGetStream(handle, &stream);
  1634. const float alpha = 1.0;
  1635. const float beta_zero = 0.0;
  1636. int status = 1;
  1637. #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
  1638. //wgrad for first gemm
  1639. status = gemm_bgradb_lt(
  1640. (cublasLtHandle_t)handle,
  1641. CUBLAS_OP_N,
  1642. CUBLAS_OP_T,
  1643. hidden_features,
  1644. out_features,
  1645. batch_size,
  1646. &alpha, /* host pointer */
  1647. output1,
  1648. hidden_features,
  1649. d_output2,
  1650. out_features,
  1651. &beta_zero, /* host pointer */
  1652. d_weight2,
  1653. hidden_features,
  1654. lt_workspace,
  1655. 1 << 22,
  1656. stream,
  1657. true,
  1658. static_cast<const void*>(d_bias2));
  1659. //dgrad for second GEMM
  1660. status = gemm_dgelu_bgradb_lt(
  1661. (cublasLtHandle_t)handle,
  1662. CUBLAS_OP_N,
  1663. CUBLAS_OP_N,
  1664. hidden_features,
  1665. batch_size,
  1666. out_features,
  1667. &alpha, /* host pointer */
  1668. weight2,
  1669. hidden_features,
  1670. d_output2,
  1671. out_features,
  1672. &beta_zero, /* host pointer */
  1673. d_output1,
  1674. hidden_features,
  1675. lt_workspace,
  1676. 1 << 22,
  1677. stream,
  1678. static_cast<const void*>(gelu_in),
  1679. static_cast<const void*>(d_bias1));
  1680. //wgrad for the first GEMM
  1681. status = gemm_bias(
  1682. handle,
  1683. CUBLAS_OP_N,
  1684. CUBLAS_OP_T,
  1685. in_features,
  1686. hidden_features,
  1687. batch_size,
  1688. &alpha,
  1689. input,
  1690. in_features,
  1691. d_output1,
  1692. hidden_features,
  1693. &beta_zero,
  1694. d_weight1,
  1695. in_features);
  1696. //dgrad for the first GEMM
  1697. status = gemm_bias(
  1698. handle,
  1699. CUBLAS_OP_N,
  1700. CUBLAS_OP_N,
  1701. in_features,
  1702. batch_size,
  1703. hidden_features,
  1704. &alpha,
  1705. weight1,
  1706. in_features,
  1707. d_output1,
  1708. hidden_features,
  1709. &beta_zero,
  1710. d_input,
  1711. in_features);
  1712. #endif
  1713. return status;
  1714. }
  1715. template int linear_bias_forward_cuda<at::Half>(at::Tensor input, at::Half *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
  1716. template int linear_bias_forward_cuda<float>(at::Tensor input, float *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
  1717. template int linear_bias_forward_cuda<double>(at::Tensor input, double *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
  1718. template int linear_bias_backward_cuda<at::Half>(at::Half *input, at::Half *weight, at::Half *d_output, int in_features, int batch_size, int out_features, at::Half *d_weight, at::Half *d_bias, at::Half *d_input, void *lt_workspace) ;
  1719. template int linear_bias_backward_cuda<float>(float *input, float *weight, float *d_output, int in_features, int batch_size, int out_features, float *d_weight, float *d_bias, float *d_input, void *lt_workspace) ;
  1720. template int linear_bias_backward_cuda<double>(double *input, double *weight, double *d_output, int in_features, int batch_size, int out_features, double *d_weight, double *d_bias, double *d_input, void *lt_workspace) ;
  1721. template int linear_gelu_linear_forward_cuda<at::Half>(at::Half *input, at::Half *weight1, at::Half *bias1, at::Half *weight2, at::Half *bias2, int in_features, int hidden_features, int batch_size, int out_features, at::Half *output1, at::Half *output2, at::Half *gelu_in, void *lt_workspace) ;
  1722. template int linear_gelu_linear_forward_cuda<float>(float *input, float *weight1, float *bias1, float *weight2, float *bias2, int in_features, int hidden_features, int batch_size, int out_features, float *output1, float *output2, float *gelu_in, void *lt_workspace);
  1723. template int linear_gelu_linear_forward_cuda<double>(double *input, double *weight1, double *bias1, double *weight2, double *bias2, int in_features, int hidden_features, int batch_size, int out_features, double *output1, double *output2, double *gelu_in, void *lt_workspace) ;
  1724. template int linear_gelu_linear_backward_cuda<at::Half>(at::Half *input, at::Half *gelu_in, at::Half *output1, at::Half *weight1, at::Half *weight2, at::Half *d_output1, at::Half *d_output2, int in_features, int batch_size, int hidden_features, int out_features, at::Half *d_weight1, at::Half *d_weight2, at::Half *d_bias1, at::Half *d_bias2, at::Half *d_input, void *lt_workspace);
  1725. template int linear_gelu_linear_backward_cuda<float>(float *input, float *gelu_in, float *output1, float *weight1, float *weight2, float *d_output1, float *d_output2, int in_features, int batch_size, int hidden_features, int out_features, float *d_weight1, float *d_weight2, float *d_bias1, float *d_bias2, float *d_input, void *lt_workspace);
  1726. template int linear_gelu_linear_backward_cuda<double>(double *input, double *gelu_in, double *output1, double *weight1, double *weight2, double *d_output1, double *d_output2, int in_features, int batch_size, int hidden_features, int out_features, double *d_weight1, double *d_weight2, double *d_bias1, double *d_bias2, double *d_input, void *lt_workspace);
  1727. template int linear_bias_forward_cuda<at::BFloat16>(at::Tensor input, at::BFloat16 *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
  1728. template int linear_bias_backward_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *weight, at::BFloat16 *d_output, int in_features, int batch_size, int out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias, at::BFloat16 *d_input, void *lt_workspace) ;
  1729. template int linear_gelu_linear_forward_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *weight1, at::BFloat16 *bias1, at::BFloat16 *weight2, at::BFloat16 *bias2, int in_features, int hidden_features, int batch_size, int out_features, at::BFloat16 *output1, at::BFloat16 *output2, at::BFloat16 *gelu_in, void *lt_workspace) ;
  1730. template int linear_gelu_linear_backward_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *gelu_in, at::BFloat16 *output1, at::BFloat16 *weight1, at::BFloat16 *weight2, at::BFloat16 *d_output1, at::BFloat16 *d_output2, int in_features, int batch_size, int hidden_features, int out_features, at::BFloat16 *d_weight1, at::BFloat16 *d_weight2, at::BFloat16 *d_bias1, at::BFloat16 *d_bias2, at::BFloat16 *d_input, void *lt_workspace);