scaled_upper_triang_masked_softmax.h 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. /* coding=utf-8
  2. * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #pragma once
  17. #include <assert.h>
  18. #include <cuda_fp16.h>
  19. #include <cfloat>
  20. #include <limits>
  21. #include <stdint.h>
  22. #include <c10/macros/Macros.h>
  23. namespace {
  24. template <typename Datatype, int ELEMENTS_PER_LDG>
  25. __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
  26. template <>
  27. __device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
  28. template <>
  29. __device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
  30. template <>
  31. __device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
  32. template <>
  33. __device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
  34. template <>
  35. __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
  36. template <>
  37. __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
  38. template <typename Datatype, int ELEMENTS_PER_LDG>
  39. __device__ __inline__ void copy_zero_vector(Datatype *dst);
  40. template <>
  41. __device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16 *dst) { *dst = 0.0; }
  42. template <>
  43. __device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
  44. template <>
  45. __device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { *dst = 0.0; }
  46. template <>
  47. __device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
  48. int log2_ceil(int value) {
  49. int log2_value = 0;
  50. while ((1 << log2_value) < value) ++log2_value;
  51. return log2_value;
  52. }
  53. template<typename T>
  54. struct Add {
  55. __device__ __forceinline__ T operator()(T a, T b) const {
  56. return a + b;
  57. }
  58. };
  59. template<typename T>
  60. struct Max {
  61. __device__ __forceinline__ T operator()(T a, T b) const {
  62. return a < b ? b : a;
  63. }
  64. };
  65. template <typename T>
  66. __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
  67. {
  68. #if CUDA_VERSION >= 9000
  69. return __shfl_xor_sync(mask, value, laneMask, width);
  70. #else
  71. return __shfl_xor(value, laneMask, width);
  72. #endif
  73. }
  74. template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
  75. __device__ __forceinline__ void warp_reduce(acc_t* sum) {
  76. ReduceOp<acc_t> r;
  77. #pragma unroll
  78. for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
  79. #pragma unroll
  80. for (int i = 0; i < WARP_BATCH; ++i) {
  81. acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
  82. sum[i] = r(sum[i], b);
  83. }
  84. }
  85. }
  86. /*
  87. * Extended softmax (from native aten pytorch) with following additional features
  88. * 1) input scaling
  89. * 2) Implicit time (diagonal masking)
  90. */
  91. template <typename input_t, typename output_t, typename acc_t, int log2_elements>
  92. __global__ void scaled_upper_triang_masked_softmax_warp_forward(
  93. output_t *dst,
  94. const input_t *src,
  95. const acc_t scale,
  96. int micro_batch_size,
  97. int stride,
  98. int element_count)
  99. {
  100. // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
  101. // warp_size of method warp_softmax_forward_kernel.
  102. constexpr int next_power_of_two = 1 << log2_elements;
  103. constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  104. constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
  105. constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
  106. constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
  107. long int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
  108. int local_seq = blockIdx.x + 1;
  109. int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE;
  110. // micro_batch_size might not be a multiple of WARP_BATCH. Check how
  111. // many batches have to computed within this WARP.
  112. int local_batches = micro_batch_size - first_batch;
  113. if (local_batches > WARP_BATCH)
  114. local_batches = WARP_BATCH;
  115. // there might be multiple batches per warp. compute the index within the batch
  116. int local_idx = threadIdx.x;
  117. long int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
  118. src += thread_offset;
  119. dst += thread_offset;
  120. // load data from global memory
  121. acc_t elements[WARP_BATCH][WARP_ITERATIONS];
  122. input_t temp_data[ELEMENTS_PER_LDG_STG];
  123. #pragma unroll
  124. for (int i = 0; i < WARP_BATCH; ++i) {
  125. int batch_element_count = (i >= local_batches) ? 0 : local_seq;
  126. #pragma unroll
  127. for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
  128. int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
  129. if (element_index < batch_element_count) {
  130. copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + i*element_count*stride + it*WARP_SIZE);
  131. #pragma unroll
  132. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  133. if ((element_index + element) < batch_element_count) {
  134. elements[i][it+element] = (acc_t)temp_data[element] * scale;
  135. } else {
  136. elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
  137. }
  138. }
  139. } else {
  140. #pragma unroll
  141. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  142. elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
  143. }
  144. }
  145. }
  146. }
  147. // compute max_value
  148. acc_t max_value[WARP_BATCH];
  149. #pragma unroll
  150. for (int i = 0; i < WARP_BATCH; ++i) {
  151. max_value[i] = elements[i][0];
  152. #pragma unroll
  153. for (int it = 1; it < WARP_ITERATIONS; ++it) {
  154. max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
  155. }
  156. }
  157. warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
  158. acc_t sum[WARP_BATCH] { 0.0f };
  159. #pragma unroll
  160. for (int i = 0; i < WARP_BATCH; ++i) {
  161. #pragma unroll
  162. for (int it = 0; it < WARP_ITERATIONS; ++it) {
  163. if (it < warp_iteration_limit) {
  164. elements[i][it] = std::exp((elements[i][it] - max_value[i]));
  165. sum[i] += elements[i][it];
  166. }
  167. }
  168. }
  169. warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
  170. // store result
  171. output_t out[ELEMENTS_PER_LDG_STG];
  172. #pragma unroll
  173. for (int i = 0; i < WARP_BATCH; ++i) {
  174. if (i >= local_batches)
  175. break;
  176. #pragma unroll
  177. for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
  178. int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
  179. if (element_index < local_seq) {
  180. #pragma unroll
  181. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  182. if (element_index + element < local_seq) {
  183. out[element] = elements[i][it + element] / sum[i];
  184. } else {
  185. out[element] = 0;
  186. }
  187. }
  188. copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out);
  189. } else if (element_index < element_count) {
  190. copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE);
  191. } else {
  192. break;
  193. }
  194. }
  195. }
  196. }
  197. template <typename input_t, typename output_t, typename acc_t, int log2_elements>
  198. __global__ void scaled_upper_triang_masked_softmax_warp_backward(
  199. output_t *gradInput,
  200. input_t *grad,
  201. const input_t *output,
  202. acc_t scale,
  203. int micro_batch_size,
  204. int stride,
  205. int element_count)
  206. {
  207. // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
  208. // warp_size of method warp_softmax_backward_kernel.
  209. constexpr int next_power_of_two = 1 << log2_elements;
  210. constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  211. constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
  212. constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
  213. constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
  214. long int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
  215. int local_seq = blockIdx.x + 1;
  216. // micro_batch_size might not be a multiple of WARP_BATCH. Check how
  217. // many batches have to computed within this WARP.
  218. int local_batches = micro_batch_size - first_batch;
  219. if (local_batches > WARP_BATCH)
  220. local_batches = WARP_BATCH;
  221. // there might be multiple batches per warp. compute the index within the batch
  222. int local_idx = threadIdx.x;
  223. // the first element to process by the current thread
  224. long int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
  225. grad += thread_offset;
  226. output += thread_offset;
  227. gradInput += thread_offset;
  228. // load data from global memory
  229. acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
  230. acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
  231. input_t temp_grad[ELEMENTS_PER_LDG_STG];
  232. input_t temp_output[ELEMENTS_PER_LDG_STG];
  233. #pragma unroll
  234. for (int i = 0; i < WARP_BATCH; ++i) {
  235. int batch_element_count = (i >= local_batches) ? 0 : local_seq;
  236. #pragma unroll
  237. for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
  238. int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
  239. if (element_index < batch_element_count) {
  240. copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
  241. copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count * stride + it * WARP_SIZE);
  242. #pragma unroll
  243. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  244. if (element_index + element < batch_element_count) {
  245. output_reg[i][it + element] = (acc_t)temp_output[element];
  246. }
  247. }
  248. #pragma unroll
  249. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  250. if (element_index + element < batch_element_count) {
  251. grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
  252. }
  253. }
  254. }
  255. }
  256. }
  257. acc_t sum[WARP_BATCH];
  258. #pragma unroll
  259. for (int i = 0; i < WARP_BATCH; ++i) {
  260. sum[i] = grad_reg[i][0];
  261. #pragma unroll
  262. for (int it = 1; it < WARP_ITERATIONS; ++it) {
  263. sum[i] += grad_reg[i][it];
  264. }
  265. }
  266. warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
  267. // store result
  268. #pragma unroll
  269. for (int i = 0; i < WARP_BATCH; ++i) {
  270. if (i >= local_batches)
  271. break;
  272. #pragma unroll
  273. for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
  274. int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
  275. if (element_index < element_count) {
  276. // compute gradients
  277. output_t out[ELEMENTS_PER_LDG_STG];
  278. #pragma unroll
  279. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  280. out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
  281. }
  282. copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count * stride + it * WARP_SIZE, out);
  283. }
  284. }
  285. }
  286. }
  287. } // end of anonymous namespace
  288. template<typename input_t, typename output_t, typename acc_t>
  289. void dispatch_scaled_upper_triang_masked_softmax_forward(
  290. output_t *dst,
  291. const input_t *src,
  292. const input_t scale,
  293. int softmax_elements,
  294. int softmax_elements_stride,
  295. int attn_batches)
  296. {
  297. TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 16384 );
  298. if (softmax_elements == 0) {
  299. return;
  300. } else {
  301. int log2_elements = log2_ceil(softmax_elements);
  302. const int next_power_of_two = 1 << log2_elements;
  303. int seq_len = softmax_elements;
  304. int batch_count = attn_batches * seq_len;
  305. // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
  306. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  307. // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
  308. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
  309. // use 128 threads per block to maximize gpu utilization
  310. constexpr int threads_per_block = 128;
  311. int warps_per_block = (threads_per_block / warp_size);
  312. int batches_per_block = warps_per_block * batches_per_warp;
  313. TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
  314. int blocks_per_seq = attn_batches / batches_per_block;
  315. dim3 blocks(seq_len, blocks_per_seq, 1);
  316. dim3 threads(warp_size, warps_per_block, 1);
  317. // Launch code would be more elegant if C++ supported FOR CONSTEXPR
  318. switch (log2_elements) {
  319. case 0: // 1
  320. scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
  321. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
  322. break;
  323. case 1: // 2
  324. scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
  325. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
  326. break;
  327. case 2: // 4
  328. scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
  329. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
  330. break;
  331. case 3: // 8
  332. scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
  333. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
  334. break;
  335. case 4: // 16
  336. scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
  337. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
  338. break;
  339. case 5: // 32
  340. scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
  341. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
  342. break;
  343. case 6: // 64
  344. scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
  345. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
  346. break;
  347. case 7: // 128
  348. scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
  349. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
  350. break;
  351. case 8: // 256
  352. scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
  353. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
  354. break;
  355. case 9: // 512
  356. scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
  357. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
  358. break;
  359. case 10: // 1024
  360. scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
  361. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
  362. break;
  363. case 11: // 2048
  364. scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
  365. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
  366. break;
  367. case 12: // 4096
  368. scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
  369. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
  370. break;
  371. case 13: // 8192
  372. scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 13>
  373. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
  374. break;
  375. case 14: // 16384
  376. scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 14>
  377. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
  378. break;
  379. default:
  380. break;
  381. }
  382. }
  383. }
  384. template<typename input_t, typename output_t, typename acc_t>
  385. void dispatch_scaled_upper_triang_masked_softmax_backward(
  386. output_t *grad_input,
  387. input_t *grad,
  388. const input_t *output,
  389. const acc_t scale,
  390. int softmax_elements,
  391. int softmax_elements_stride,
  392. int attn_batches)
  393. {
  394. TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 16384 );
  395. if (softmax_elements == 0) {
  396. return;
  397. } else {
  398. int log2_elements = log2_ceil(softmax_elements);
  399. const int next_power_of_two = 1 << log2_elements;
  400. int seq_len = softmax_elements;
  401. int batch_count = attn_batches * seq_len;
  402. // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
  403. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  404. // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
  405. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
  406. // use 128 threads per block to maximize gpu utilization
  407. constexpr int threads_per_block = 128;
  408. int warps_per_block = (threads_per_block / warp_size);
  409. int batches_per_block = warps_per_block * batches_per_warp;
  410. TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
  411. int blocks_per_seq = attn_batches / batches_per_block;
  412. dim3 blocks(seq_len, blocks_per_seq, 1);
  413. dim3 threads(warp_size, warps_per_block, 1);
  414. // Launch code would be more elegant if C++ supported FOR CONSTEXPR
  415. switch (log2_elements) {
  416. case 0: // 1
  417. scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
  418. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
  419. break;
  420. case 1: // 2
  421. scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
  422. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
  423. break;
  424. case 2: // 4
  425. scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
  426. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
  427. break;
  428. case 3: // 8
  429. scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
  430. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
  431. break;
  432. case 4: // 16
  433. scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
  434. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
  435. break;
  436. case 5: // 32
  437. scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
  438. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
  439. break;
  440. case 6: // 64
  441. scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
  442. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
  443. break;
  444. case 7: // 128
  445. scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
  446. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
  447. break;
  448. case 8: // 256
  449. scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
  450. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
  451. break;
  452. case 9: // 512
  453. scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
  454. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
  455. break;
  456. case 10: // 1024
  457. scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
  458. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
  459. break;
  460. case 11: // 2048
  461. scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
  462. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
  463. break;
  464. case 12: // 4096
  465. scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
  466. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
  467. break;
  468. case 13: // 8192
  469. scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 13>
  470. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
  471. break;
  472. case 14: // 16384
  473. scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 14>
  474. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
  475. break;
  476. default:
  477. break;
  478. }
  479. }
  480. }