3
0

fused_rotary_positional_embedding.h 16 KB


  1. /* coding=utf-8
  2. * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #pragma once
  17. #include <ATen/ATen.h>
  18. #include <ATen/cuda/CUDAContext.h>
  19. #include <c10/macros/Macros.h>
  20. #include <cuda_runtime.h>
  21. #include <torch/extension.h>
  22. namespace {
  23. template <typename scalar_t>
  24. __device__ void fused_rope_block_forward(
  25. const scalar_t *src, const float *freqs, scalar_t *dst,
  26. const int offset_block, const int offset_block_dst, const int h,
  27. const int d, const int d2, const int stride_h, const int stride_d,
  28. const int o_stride_h, const int o_stride_d) {
  29. int s_id = blockIdx.x;
  30. #pragma unroll
  31. for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
  32. float v_cos, v_sin;
  33. sincosf(freqs[s_id * d2 + d_id], &v_sin, &v_cos);
  34. #pragma unroll
  35. for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
  36. int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
  37. int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
  38. scalar_t v_src = src[offset_src];
  39. scalar_t v_src_rotate = (d_id + d2 / 2 < d2)
  40. ? -src[offset_src + (d2 / 2) * stride_d]
  41. : src[offset_src + (d2 / 2 - d2) * stride_d];
  42. dst[offset_dst] =
  43. v_src * (scalar_t)v_cos + v_src_rotate * (scalar_t)v_sin;
  44. }
  45. }
  46. // copy the rest
  47. if (d > d2) {
  48. #pragma unroll
  49. for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
  50. int offset_head = offset_block + h_id * stride_h;
  51. int offset_head_dst = offset_block_dst + h_id * o_stride_h;
  52. #pragma unroll
  53. for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
  54. dst[offset_head_dst + d_id * o_stride_d] =
  55. src[offset_head + d_id * stride_d];
  56. }
  57. }
  58. }
  59. }
  60. template <typename scalar_t>
  61. __device__ void fused_rope_block_backward(
  62. const scalar_t *src, const float *freqs, scalar_t *dst,
  63. const int offset_block, const int offset_block_dst, const int h,
  64. const int d, const int d2, const int stride_h, const int stride_d,
  65. const int o_stride_h, const int o_stride_d) {
  66. int s_id = blockIdx.x;
  67. #pragma unroll
  68. for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
  69. scalar_t v_cos = cosf(freqs[s_id * d2 + d_id]);
  70. scalar_t v_sin = (d_id + d2 / 2 < d2)
  71. ? sinf(freqs[s_id * d2 + d_id + d2 / 2])
  72. : -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]);
  73. #pragma unroll
  74. for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
  75. int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
  76. int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
  77. scalar_t v_src = src[offset_src];
  78. scalar_t v_src_rotate = (d_id + d2 / 2 < d2)
  79. ? src[offset_src + (d2 / 2) * stride_d]
  80. : src[offset_src + (d2 / 2 - d2) * stride_d];
  81. dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
  82. }
  83. }
  84. // handle the tail
  85. if (d > d2) {
  86. #pragma unroll
  87. for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
  88. int offset_head = offset_block + h_id * stride_h;
  89. int offset_head_dst = offset_block_dst + h_id * o_stride_h;
  90. #pragma unroll
  91. for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
  92. dst[offset_head_dst + d_id * o_stride_d] =
  93. src[offset_head + d_id * stride_d];
  94. }
  95. }
  96. }
  97. }
  98. template <typename scalar_t>
  99. __global__ void fused_rope_forward(const int h, const int d, const int d2,
  100. const int stride_s, const int stride_b,
  101. const int stride_h, const int stride_d,
  102. const int o_stride_s, const int o_stride_b,
  103. const int o_stride_h, const int o_stride_d,
  104. const scalar_t* src, const float* freqs,
  105. scalar_t* dst) {
  106. int s_id = blockIdx.x, b_id = blockIdx.y;
  107. int offset_block = s_id * stride_s + b_id * stride_b;
  108. int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
  109. fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h,
  110. d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
  111. }
  112. template <typename scalar_t>
  113. __global__ void fused_rope_backward(const int h, const int d, const int d2,
  114. const int stride_s, const int stride_b,
  115. const int stride_h, const int stride_d,
  116. const int o_stride_s, const int o_stride_b,
  117. const int o_stride_h, const int o_stride_d,
  118. const scalar_t* src, const float* freqs,
  119. scalar_t* dst) {
  120. int s_id = blockIdx.x, b_id = blockIdx.y;
  121. int offset_block = s_id * stride_s + b_id * stride_b;
  122. int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
  123. fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h,
  124. d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
  125. }
  126. template <typename scalar_t_0, typename scalar_t_1>
  127. __global__ void fused_rope_cached_forward(
  128. const int h, const int d, const int d2, const int stride_s,
  129. const int stride_b, const int stride_h, const int stride_d,
  130. const int o_stride_s, const int o_stride_b, const int o_stride_h,
  131. const int o_stride_d, const scalar_t_0* src, const scalar_t_1* cos,
  132. const scalar_t_1* sin, scalar_t_0* dst) {
  133. int s_id = blockIdx.x, b_id = blockIdx.y;
  134. int offset_block = s_id * stride_s + b_id * stride_b;
  135. int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
  136. #pragma unroll
  137. for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
  138. scalar_t_0 v_cos = cos[s_id * d2 + d_id];
  139. scalar_t_0 v_sin = sin[s_id * d2 + d_id];
  140. #pragma unroll
  141. for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
  142. int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
  143. int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
  144. scalar_t_0 v_src = src[offset_src];
  145. scalar_t_0 v_src_rotate =
  146. (d_id + d2 / 2 < d2) ? -src[offset_src + (d2 / 2) * stride_d]
  147. : src[offset_src + (d2 / 2 - d2) * stride_d];
  148. dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
  149. }
  150. }
  151. // copy the rest
  152. if (d > d2) {
  153. #pragma unroll
  154. for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
  155. int offset_head = offset_block + h_id * stride_h;
  156. int offset_head_dst = offset_block_dst + h_id * o_stride_h;
  157. #pragma unroll
  158. for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
  159. dst[offset_head_dst + d_id * o_stride_d] =
  160. src[offset_head + d_id * stride_d];
  161. }
  162. }
  163. }
  164. }
  165. template <typename scalar_t_0, typename scalar_t_1>
  166. __global__ void fused_rope_cached_backward(
  167. const int h, const int d, const int d2, const int stride_s,
  168. const int stride_b, const int stride_h, const int stride_d,
  169. const int o_stride_s, const int o_stride_b, const int o_stride_h,
  170. const int o_stride_d, const scalar_t_0* src, const scalar_t_1* cos,
  171. const scalar_t_1* sin, scalar_t_0* dst) {
  172. int s_id = blockIdx.x, b_id = blockIdx.y;
  173. int offset_block = s_id * stride_s + b_id * stride_b;
  174. int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
  175. #pragma unroll
  176. for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
  177. scalar_t_0 v_cos = cos[s_id * d2 + d_id];
  178. scalar_t_0 v_sin = (d_id + d2 / 2 < d2)
  179. ? sin[s_id * d2 + d_id + d2 / 2]
  180. : -sin[s_id * d2 + d_id + d2 / 2 - d2];
  181. #pragma unroll
  182. for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
  183. int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
  184. int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
  185. scalar_t_0 v_src = src[offset_src];
  186. scalar_t_0 v_src_rotate =
  187. (d_id + d2 / 2 < d2) ? src[offset_src + (d2 / 2) * stride_d]
  188. : src[offset_src + (d2 / 2 - d2) * stride_d];
  189. dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
  190. }
  191. }
  192. // handle the tail
  193. if (d > d2) {
  194. #pragma unroll
  195. for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
  196. int offset_head = offset_block + h_id * stride_h;
  197. int offset_head_dst = offset_block_dst + h_id * o_stride_h;
  198. #pragma unroll
  199. for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
  200. dst[offset_head_dst + d_id * o_stride_d] =
  201. src[offset_head + d_id * stride_d];
  202. }
  203. }
  204. }
  205. }
  206. template <typename scalar_t>
  207. __global__ void fused_rope_thd_forward(
  208. const int h, const int d, const int d2, const int stride_t,
  209. const int stride_h, const int stride_d, const int o_stride_t,
  210. const int o_stride_h, const int o_stride_d, const scalar_t* src,
  211. const int* cu_seqlens, const float* freqs, scalar_t* dst) {
  212. int s_id = blockIdx.x, b_id = blockIdx.y;
  213. int t_id = s_id + cu_seqlens[b_id];
  214. if (t_id >= cu_seqlens[b_id + 1]) return;
  215. int offset_block = t_id * stride_t;
  216. int offset_block_dst = t_id * o_stride_t;
  217. fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h,
  218. d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
  219. }
  220. template <typename scalar_t>
  221. __global__ void fused_rope_thd_backward(
  222. const int h, const int d, const int d2, const int stride_t,
  223. const int stride_h, const int stride_d, const int o_stride_t,
  224. const int o_stride_h, const int o_stride_d, const scalar_t* src,
  225. const int* cu_seqlens, const float* freqs, scalar_t* dst) {
  226. int s_id = blockIdx.x, b_id = blockIdx.y;
  227. int t_id = s_id + cu_seqlens[b_id];
  228. if (t_id >= cu_seqlens[b_id + 1]) return;
  229. int offset_block = t_id * stride_t;
  230. int offset_block_dst = t_id * o_stride_t;
  231. fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h,
  232. d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
  233. }
  234. } // end of anonymous namespace
  235. template <typename scalar_t>
  236. void dispatch_fused_rope_forward(const int s, const int b, const int h,
  237. const int d, const int d2, const int stride_s,
  238. const int stride_b, const int stride_h,
  239. const int stride_d, const int o_stride_s,
  240. const int o_stride_b, const int o_stride_h,
  241. const int o_stride_d, const scalar_t* input,
  242. const float* freqs, scalar_t* output) {
  243. auto stream = at::cuda::getCurrentCUDAStream();
  244. int warps_per_block = h < 16 ? 4 : 8;
  245. dim3 blocks(s, b);
  246. dim3 threads(C10_WARP_SIZE, warps_per_block);
  247. fused_rope_forward<<<blocks, threads, 0, stream>>>(
  248. h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
  249. o_stride_h, o_stride_d, input, freqs, output);
  250. C10_CUDA_KERNEL_LAUNCH_CHECK();
  251. }
  252. template <typename scalar_t>
  253. void dispatch_fused_rope_backward(const int s, const int b, const int h,
  254. const int d, const int d2, const int stride_s,
  255. const int stride_b, const int stride_h,
  256. const int stride_d, const int o_stride_s,
  257. const int o_stride_b, const int o_stride_h,
  258. const int o_stride_d,
  259. const scalar_t* output_grads,
  260. const float* freqs, scalar_t* input_grads) {
  261. auto stream = at::cuda::getCurrentCUDAStream();
  262. int warps_per_block = h < 16 ? 4 : 8;
  263. dim3 blocks(s, b);
  264. dim3 threads(C10_WARP_SIZE, warps_per_block);
  265. fused_rope_backward<<<blocks, threads, 0, stream>>>(
  266. h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
  267. o_stride_h, o_stride_d, output_grads, freqs, input_grads);
  268. C10_CUDA_KERNEL_LAUNCH_CHECK();
  269. }
  270. template <typename scalar_t_0, typename scalar_t_1>
  271. void dispatch_fused_rope_cached_forward(
  272. const int s, const int b, const int h, const int d, const int d2,
  273. const int stride_s, const int stride_b, const int stride_h,
  274. const int stride_d, const int o_stride_s, const int o_stride_b,
  275. const int o_stride_h, const int o_stride_d, const scalar_t_0* input,
  276. const scalar_t_1* cos, const scalar_t_1* sin, scalar_t_0* output) {
  277. auto stream = at::cuda::getCurrentCUDAStream();
  278. int warps_per_block = h < 16 ? 4 : 8;
  279. dim3 blocks(s, b);
  280. dim3 threads(C10_WARP_SIZE, warps_per_block);
  281. fused_rope_cached_forward<<<blocks, threads, 0, stream>>>(
  282. h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
  283. o_stride_h, o_stride_d, input, cos, sin, output);
  284. C10_CUDA_KERNEL_LAUNCH_CHECK();
  285. }
  286. template <typename scalar_t_0, typename scalar_t_1>
  287. void dispatch_fused_rope_cached_backward(
  288. const int s, const int b, const int h, const int d, const int d2,
  289. const int stride_s, const int stride_b, const int stride_h,
  290. const int stride_d, const int o_stride_s, const int o_stride_b,
  291. const int o_stride_h, const int o_stride_d, const scalar_t_0* output_grads,
  292. const scalar_t_1* cos, const scalar_t_1* sin, scalar_t_0* input_grads) {
  293. auto stream = at::cuda::getCurrentCUDAStream();
  294. int warps_per_block = h < 16 ? 4 : 8;
  295. dim3 blocks(s, b);
  296. dim3 threads(C10_WARP_SIZE, warps_per_block);
  297. fused_rope_cached_backward<<<blocks, threads, 0, stream>>>(
  298. h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
  299. o_stride_h, o_stride_d, output_grads, cos, sin, input_grads);
  300. C10_CUDA_KERNEL_LAUNCH_CHECK();
  301. }
  302. template <typename scalar_t>
  303. void dispatch_fused_rope_thd_forward(const int max_s, const int b, const int h,
  304. const int d, const int d2,
  305. const int stride_t, const int stride_h,
  306. const int stride_d, const int o_stride_t,
  307. const int o_stride_h, const int o_stride_d,
  308. const scalar_t* input,
  309. const int* cu_seqlens, const float* freqs,
  310. scalar_t* output) {
  311. auto stream = at::cuda::getCurrentCUDAStream();
  312. int warps_per_block = h < 16 ? 4 : 8;
  313. dim3 blocks(max_s, b);
  314. dim3 threads(C10_WARP_SIZE, warps_per_block);
  315. fused_rope_thd_forward<<<blocks, threads, 0, stream>>>(
  316. h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h,
  317. o_stride_d, input, cu_seqlens, freqs, output);
  318. C10_CUDA_KERNEL_LAUNCH_CHECK();
  319. }
  320. template <typename scalar_t>
  321. void dispatch_fused_rope_thd_backward(
  322. const int max_s, const int b, const int h, const int d, const int d2,
  323. const int stride_t, const int stride_h, const int stride_d,
  324. const int o_stride_t, const int o_stride_h, const int o_stride_d,
  325. const scalar_t* output_grads, const int* cu_seqlens, const float* freqs,
  326. scalar_t* input_grads) {
  327. auto stream = at::cuda::getCurrentCUDAStream();
  328. int warps_per_block = h < 16 ? 4 : 8;
  329. dim3 blocks(max_s, b);
  330. dim3 threads(C10_WARP_SIZE, warps_per_block);
  331. fused_rope_thd_backward<<<blocks, threads, 0, stream>>>(
  332. h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h,
  333. o_stride_d, output_grads, cu_seqlens, freqs, input_grads);
  334. C10_CUDA_KERNEL_LAUNCH_CHECK();
  335. }