3
0

scaled_masked_softmax.h 32 KB


  1. /* coding=utf-8
  2. * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #pragma once
  17. #include <assert.h>
  18. #include <cuda_fp16.h>
  19. #include <cfloat>
  20. #include <limits>
  21. #include <stdint.h>
  22. #include <cuda_fp16.h>
  23. #include <c10/macros/Macros.h>
  24. namespace {
  25. template <typename Datatype, int ELEMENTS_PER_LDG>
  26. __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
  27. template <>
  28. __device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
  29. template <>
  30. __device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
  31. template <>
  32. __device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
  33. template <>
  34. __device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
  35. template <>
  36. __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
  37. template <>
  38. __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
  39. int log2_ceil(int value) {
  40. int log2_value = 0;
  41. while ((1 << log2_value) < value) ++log2_value;
  42. return log2_value;
  43. }
  44. template<typename T>
  45. struct Add {
  46. __device__ __forceinline__ T operator()(T a, T b) const {
  47. return a + b;
  48. }
  49. };
  50. template<typename T>
  51. struct Max {
  52. __device__ __forceinline__ T operator()(T a, T b) const {
  53. return a < b ? b : a;
  54. }
  55. };
  56. template <typename T>
  57. __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
  58. {
  59. #if CUDA_VERSION >= 9000
  60. return __shfl_xor_sync(mask, value, laneMask, width);
  61. #else
  62. return __shfl_xor(value, laneMask, width);
  63. #endif
  64. }
  65. template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
  66. __device__ __forceinline__ void warp_reduce(acc_t* sum) {
  67. ReduceOp<acc_t> r;
  68. #pragma unroll
  69. for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
  70. #pragma unroll
  71. for (int i = 0; i < WARP_BATCH; ++i) {
  72. acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
  73. sum[i] = r(sum[i], b);
  74. }
  75. }
  76. }
  77. /*
  78. * Extended softmax (from native aten pytorch) with following additional features
  79. * 1) input scaling
  80. */
  81. template <typename input_t, typename output_t, typename acc_t, int log2_elements>
  82. __global__ void scaled_softmax_warp_forward(
  83. output_t *dst,
  84. const input_t *src,
  85. const acc_t scale,
  86. int micro_batch_size,
  87. int element_count)
  88. {
  89. // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
  90. // warp_size of method warp_softmax_forward_kernel.
  91. constexpr int next_power_of_two = 1 << log2_elements;
  92. constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  93. constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
  94. constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
  95. constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
  96. // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
  97. // gridDim/blockIdx = (seq_len, attn_heads, batches)
  98. long int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
  99. // micro_batch_size might not be a multiple of WARP_BATCH. Check how
  100. // many batches have to computed within this WARP.
  101. int local_batches = micro_batch_size - first_batch;
  102. if (local_batches > WARP_BATCH)
  103. local_batches = WARP_BATCH;
  104. // there might be multiple batches per warp. compute the index within the batch
  105. int local_idx = threadIdx.x;
  106. long int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
  107. src += thread_offset;
  108. dst += thread_offset;
  109. // load data from global memory
  110. acc_t elements[WARP_BATCH][WARP_ITERATIONS];
  111. input_t temp_data[ELEMENTS_PER_LDG_STG];
  112. #pragma unroll
  113. for (int i = 0; i < WARP_BATCH; ++i) {
  114. int batch_element_count = (i >= local_batches) ? 0 : element_count;
  115. #pragma unroll
  116. for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
  117. int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
  118. if (element_index < batch_element_count) {
  119. int itr_idx = i*element_count+it*WARP_SIZE;
  120. copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
  121. #pragma unroll
  122. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  123. elements[i][it + element] = (acc_t)temp_data[element] * scale;
  124. }
  125. } else {
  126. #pragma unroll
  127. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  128. elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
  129. }
  130. }
  131. }
  132. }
  133. // compute max_value
  134. acc_t max_value[WARP_BATCH];
  135. #pragma unroll
  136. for (int i = 0; i < WARP_BATCH; ++i) {
  137. max_value[i] = elements[i][0];
  138. #pragma unroll
  139. for (int it = 1; it < WARP_ITERATIONS; ++it) {
  140. max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
  141. }
  142. }
  143. warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
  144. acc_t sum[WARP_BATCH] { 0.0f };
  145. #pragma unroll
  146. for (int i = 0; i < WARP_BATCH; ++i) {
  147. #pragma unroll
  148. for (int it = 0; it < WARP_ITERATIONS; ++it) {
  149. elements[i][it] = std::exp((elements[i][it] - max_value[i]));
  150. sum[i] += elements[i][it];
  151. }
  152. }
  153. warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
  154. // store result
  155. output_t out[ELEMENTS_PER_LDG_STG];
  156. #pragma unroll
  157. for (int i = 0; i < WARP_BATCH; ++i) {
  158. if (i >= local_batches)
  159. break;
  160. #pragma unroll
  161. for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
  162. int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
  163. if (element_index < element_count) {
  164. #pragma unroll
  165. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  166. out[element] = elements[i][it + element] / sum[i];
  167. }
  168. copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
  169. } else {
  170. break;
  171. }
  172. }
  173. }
  174. }
  175. /*
  176. * Extended softmax (from native aten pytorch) with following additional features
  177. * 1) input scaling
  178. * 2) Explicit masking
  179. */
  180. template <typename input_t, typename output_t, typename acc_t, int log2_elements>
  181. __global__ void scaled_masked_softmax_warp_forward(
  182. output_t *dst,
  183. const input_t *src,
  184. const uint8_t *mask,
  185. const acc_t scale,
  186. int micro_batch_size,
  187. int element_count,
  188. int pad_batches)
  189. {
  190. // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
  191. // warp_size of method warp_softmax_forward_kernel.
  192. constexpr int next_power_of_two = 1 << log2_elements;
  193. constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  194. constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
  195. constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
  196. constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
  197. // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
  198. // gridDim/blockIdx = (seq_len, attn_heads, batches)
  199. long int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
  200. long int pad_first_batch = 0;
  201. if (pad_batches != 1) { // bert style
  202. pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;
  203. } else { // gpt2 style
  204. pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
  205. }
  206. // micro_batch_size might not be a multiple of WARP_BATCH. Check how
  207. // many batches have to computed within this WARP.
  208. int local_batches = micro_batch_size - first_batch;
  209. if (local_batches > WARP_BATCH)
  210. local_batches = WARP_BATCH;
  211. // there might be multiple batches per warp. compute the index within the batch
  212. int local_idx = threadIdx.x;
  213. long int thread_offset_src_dst = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
  214. long int thread_offset_mask = pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
  215. src += thread_offset_src_dst;
  216. dst += thread_offset_src_dst;
  217. mask += thread_offset_mask;
  218. // load data from global memory
  219. acc_t elements[WARP_BATCH][WARP_ITERATIONS];
  220. input_t temp_data[ELEMENTS_PER_LDG_STG];
  221. uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
  222. #pragma unroll
  223. for (int i = 0; i < WARP_BATCH; ++i) {
  224. int batch_element_count = (i >= local_batches) ? 0 : element_count;
  225. #pragma unroll
  226. for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
  227. int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
  228. if (element_index < batch_element_count) {
  229. int itr_idx = i*element_count+it*WARP_SIZE;
  230. copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
  231. copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
  232. #pragma unroll
  233. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  234. if (temp_mask[element] != 1) {
  235. elements[i][it + element] = (acc_t)temp_data[element] * scale;
  236. } else {
  237. elements[i][it + element] = -10000.0;
  238. }
  239. }
  240. } else {
  241. #pragma unroll
  242. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  243. elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
  244. }
  245. }
  246. }
  247. }
  248. // compute max_value
  249. acc_t max_value[WARP_BATCH];
  250. #pragma unroll
  251. for (int i = 0; i < WARP_BATCH; ++i) {
  252. max_value[i] = elements[i][0];
  253. #pragma unroll
  254. for (int it = 1; it < WARP_ITERATIONS; ++it) {
  255. max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
  256. }
  257. }
  258. warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
  259. // compute scale value to account for full mask
  260. acc_t scale_value[WARP_BATCH];
  261. #pragma unroll
  262. for (int i = 0; i < WARP_BATCH; ++i) {
  263. scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0;
  264. }
  265. acc_t sum[WARP_BATCH] { 0.0f };
  266. #pragma unroll
  267. for (int i = 0; i < WARP_BATCH; ++i) {
  268. #pragma unroll
  269. for (int it = 0; it < WARP_ITERATIONS; ++it) {
  270. elements[i][it] = std::exp((elements[i][it] - max_value[i]));
  271. sum[i] += elements[i][it];
  272. }
  273. }
  274. warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
  275. // store result
  276. output_t out[ELEMENTS_PER_LDG_STG];
  277. #pragma unroll
  278. for (int i = 0; i < WARP_BATCH; ++i) {
  279. if (i >= local_batches)
  280. break;
  281. #pragma unroll
  282. for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
  283. int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
  284. if (element_index < element_count) {
  285. #pragma unroll
  286. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  287. out[element] = elements[i][it + element] * scale_value[i]/ sum[i];
  288. }
  289. copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
  290. } else {
  291. break;
  292. }
  293. }
  294. }
  295. }
  296. template <typename input_t, typename output_t, typename acc_t, int log2_elements>
  297. __global__ void scaled_masked_softmax_warp_backward(
  298. output_t *gradInput,
  299. input_t *grad,
  300. const input_t *output,
  301. acc_t scale,
  302. int micro_batch_size,
  303. int element_count)
  304. {
  305. // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
  306. // warp_size of method warp_softmax_backward_kernel.
  307. constexpr int next_power_of_two = 1 << log2_elements;
  308. constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  309. constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
  310. constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
  311. constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
  312. // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
  313. // gridDim/blockIdx = (seq_len, attn_heads, batches)
  314. long int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
  315. // micro_batch_size might not be a multiple of WARP_BATCH. Check how
  316. // many batches have to computed within this WARP.
  317. int local_batches = micro_batch_size - first_batch;
  318. if (local_batches > WARP_BATCH)
  319. local_batches = WARP_BATCH;
  320. // there might be multiple batches per warp. compute the index within the batch
  321. int local_idx = threadIdx.x;
  322. // the first element to process by the current thread
  323. long int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
  324. grad += thread_offset;
  325. output += thread_offset;
  326. gradInput += thread_offset;
  327. // load data from global memory
  328. acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
  329. acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
  330. input_t temp_grad[ELEMENTS_PER_LDG_STG];
  331. input_t temp_output[ELEMENTS_PER_LDG_STG];
  332. #pragma unroll
  333. for (int i = 0; i < WARP_BATCH; ++i) {
  334. int batch_element_count = (i >= local_batches) ? 0 : element_count;
  335. #pragma unroll
  336. for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
  337. int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
  338. if (element_index < batch_element_count) {
  339. copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count + it * WARP_SIZE);
  340. copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count + it * WARP_SIZE);
  341. #pragma unroll
  342. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  343. output_reg[i][it + element] = (acc_t)temp_output[element];
  344. }
  345. #pragma unroll
  346. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  347. grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
  348. }
  349. }
  350. }
  351. }
  352. acc_t sum[WARP_BATCH];
  353. #pragma unroll
  354. for (int i = 0; i < WARP_BATCH; ++i) {
  355. sum[i] = grad_reg[i][0];
  356. #pragma unroll
  357. for (int it = 1; it < WARP_ITERATIONS; ++it) {
  358. sum[i] += grad_reg[i][it];
  359. }
  360. }
  361. warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
  362. // store result
  363. #pragma unroll
  364. for (int i = 0; i < WARP_BATCH; ++i) {
  365. if (i >= local_batches)
  366. break;
  367. #pragma unroll
  368. for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
  369. int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
  370. if (element_index < element_count) {
  371. // compute gradients
  372. output_t out[ELEMENTS_PER_LDG_STG];
  373. #pragma unroll
  374. for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
  375. out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
  376. }
  377. copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);
  378. }
  379. }
  380. }
  381. }
  382. } // end of anonymous namespace
  383. int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){
  384. int log2_elements = log2_ceil(key_seq_len);
  385. const int next_power_of_two = 1 << log2_elements;
  386. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  387. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
  388. constexpr int threads_per_block = 128;
  389. int warps_per_block = (threads_per_block / warp_size);
  390. int batches_per_block = warps_per_block * batches_per_warp;
  391. return batches_per_block;
  392. }
  393. template<typename input_t, typename output_t, typename acc_t>
  394. void dispatch_scaled_softmax_forward(
  395. output_t *dst,
  396. const input_t *src,
  397. const input_t scale,
  398. int query_seq_len,
  399. int key_seq_len,
  400. int batches,
  401. int attn_heads)
  402. {
  403. TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 16384 );
  404. if (key_seq_len == 0) {
  405. return;
  406. } else {
  407. int log2_elements = log2_ceil(key_seq_len);
  408. const int next_power_of_two = 1 << log2_elements;
  409. int batch_count = batches * attn_heads * query_seq_len;
  410. // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
  411. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  412. // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
  413. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
  414. // use 128 threads per block to maximize gpu utilization
  415. constexpr int threads_per_block = 128;
  416. int warps_per_block = (threads_per_block / warp_size);
  417. int batches_per_block = warps_per_block * batches_per_warp;
  418. TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
  419. dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
  420. dim3 threads(warp_size, warps_per_block, 1);
  421. // Launch code would be more elegant if C++ supported FOR CONSTEXPR
  422. switch (log2_elements) {
  423. case 0: // 1
  424. scaled_softmax_warp_forward<input_t, output_t, acc_t, 0>
  425. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
  426. break;
  427. case 1: // 2
  428. scaled_softmax_warp_forward<input_t, output_t, acc_t, 1>
  429. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
  430. break;
  431. case 2: // 4
  432. scaled_softmax_warp_forward<input_t, output_t, acc_t, 2>
  433. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
  434. break;
  435. case 3: // 8
  436. scaled_softmax_warp_forward<input_t, output_t, acc_t, 3>
  437. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
  438. break;
  439. case 4: // 16
  440. scaled_softmax_warp_forward<input_t, output_t, acc_t, 4>
  441. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
  442. break;
  443. case 5: // 32
  444. scaled_softmax_warp_forward<input_t, output_t, acc_t, 5>
  445. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
  446. break;
  447. case 6: // 64
  448. scaled_softmax_warp_forward<input_t, output_t, acc_t, 6>
  449. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
  450. break;
  451. case 7: // 128
  452. scaled_softmax_warp_forward<input_t, output_t, acc_t, 7>
  453. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
  454. break;
  455. case 8: // 256
  456. scaled_softmax_warp_forward<input_t, output_t, acc_t, 8>
  457. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
  458. break;
  459. case 9: // 512
  460. scaled_softmax_warp_forward<input_t, output_t, acc_t, 9>
  461. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
  462. break;
  463. case 10: // 1024
  464. scaled_softmax_warp_forward<input_t, output_t, acc_t, 10>
  465. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
  466. break;
  467. case 11: // 2048
  468. scaled_softmax_warp_forward<input_t, output_t, acc_t, 11>
  469. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
  470. break;
  471. case 12: // 4096
  472. scaled_softmax_warp_forward<input_t, output_t, acc_t, 12>
  473. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
  474. break;
  475. case 13: // 8192
  476. scaled_softmax_warp_forward<input_t, output_t, acc_t, 13>
  477. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
  478. break;
  479. case 14: // 16384
  480. scaled_softmax_warp_forward<input_t, output_t, acc_t, 14>
  481. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
  482. break;
  483. default:
  484. break;
  485. }
  486. }
  487. }
  488. template<typename input_t, typename output_t, typename acc_t>
  489. void dispatch_scaled_masked_softmax_forward(
  490. output_t *dst,
  491. const input_t *src,
  492. const uint8_t *mask,
  493. const input_t scale,
  494. int query_seq_len,
  495. int key_seq_len,
  496. int batches,
  497. int attn_heads,
  498. int pad_batches)
  499. {
  500. TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
  501. if (key_seq_len == 0) {
  502. return;
  503. } else {
  504. int log2_elements = log2_ceil(key_seq_len);
  505. const int next_power_of_two = 1 << log2_elements;
  506. int batch_count = batches * attn_heads * query_seq_len;
  507. // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
  508. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  509. // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
  510. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
  511. // use 128 threads per block to maximize gpu utilization
  512. constexpr int threads_per_block = 128;
  513. int warps_per_block = (threads_per_block / warp_size);
  514. int batches_per_block = warps_per_block * batches_per_warp;
  515. TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
  516. dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
  517. dim3 threads(warp_size, warps_per_block, 1);
  518. // Launch code would be more elegant if C++ supported FOR CONSTEXPR
  519. switch (log2_elements) {
  520. case 0: // 1
  521. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
  522. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  523. break;
  524. case 1: // 2
  525. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
  526. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  527. break;
  528. case 2: // 4
  529. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
  530. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  531. break;
  532. case 3: // 8
  533. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
  534. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  535. break;
  536. case 4: // 16
  537. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
  538. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  539. break;
  540. case 5: // 32
  541. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
  542. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  543. break;
  544. case 6: // 64
  545. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
  546. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  547. break;
  548. case 7: // 128
  549. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
  550. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  551. break;
  552. case 8: // 256
  553. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
  554. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  555. break;
  556. case 9: // 512
  557. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
  558. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  559. break;
  560. case 10: // 1024
  561. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
  562. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  563. break;
  564. case 11: // 2048
  565. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
  566. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  567. break;
  568. case 12: // 4096
  569. scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
  570. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
  571. break;
  572. default:
  573. break;
  574. }
  575. }
  576. }
  577. template<typename input_t, typename output_t, typename acc_t>
  578. void dispatch_scaled_masked_softmax_backward(
  579. output_t *grad_input,
  580. input_t *grad,
  581. const input_t *output,
  582. const acc_t scale,
  583. int query_seq_len,
  584. int key_seq_len,
  585. int batches,
  586. int attn_heads)
  587. {
  588. TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 4096 );
  589. if (key_seq_len == 0) {
  590. return;
  591. } else {
  592. int log2_elements = log2_ceil(key_seq_len);
  593. const int next_power_of_two = 1 << log2_elements;
  594. int batch_count = batches * attn_heads * query_seq_len;
  595. // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
  596. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  597. // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
  598. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
  599. // use 128 threads per block to maximize gpu utilization
  600. constexpr int threads_per_block = 128;
  601. int warps_per_block = (threads_per_block / warp_size);
  602. int batches_per_block = warps_per_block * batches_per_warp;
  603. int blocks = batch_count/batches_per_block;
  604. dim3 threads(warp_size, warps_per_block, 1);
  605. // Launch code would be more elegant if C++ supported FOR CONSTEXPR
  606. switch (log2_elements) {
  607. case 0: // 1
  608. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
  609. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  610. break;
  611. case 1: // 2
  612. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
  613. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  614. break;
  615. case 2: // 4
  616. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
  617. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  618. break;
  619. case 3: // 8
  620. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
  621. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  622. break;
  623. case 4: // 16
  624. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
  625. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  626. break;
  627. case 5: // 32
  628. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
  629. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  630. break;
  631. case 6: // 64
  632. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
  633. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  634. break;
  635. case 7: // 128
  636. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
  637. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  638. break;
  639. case 8: // 256
  640. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
  641. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  642. break;
  643. case 9: // 512
  644. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
  645. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  646. break;
  647. case 10: // 1024
  648. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
  649. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  650. break;
  651. case 11: // 2048
  652. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
  653. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  654. break;
  655. case 12: // 4096
  656. scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
  657. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
  658. break;
  659. default:
  660. break;
  661. }
  662. }
  663. }