123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369 |
- /* coding=utf-8
- * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #pragma once
- #include <ATen/ATen.h>
- #include <ATen/cuda/CUDAContext.h>
- #include <c10/macros/Macros.h>
- #include <cuda_runtime.h>
- #include <torch/extension.h>
- namespace {
- template <typename scalar_t>
- __device__ void fused_rope_block_forward(
- const scalar_t *src, const float *freqs, scalar_t *dst,
- const int offset_block, const int offset_block_dst, const int h,
- const int d, const int d2, const int stride_h, const int stride_d,
- const int o_stride_h, const int o_stride_d) {
- int s_id = blockIdx.x;
- #pragma unroll
- for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
- float v_cos, v_sin;
- sincosf(freqs[s_id * d2 + d_id], &v_sin, &v_cos);
- #pragma unroll
- for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
- int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
- int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
- scalar_t v_src = src[offset_src];
- scalar_t v_src_rotate = (d_id + d2 / 2 < d2)
- ? -src[offset_src + (d2 / 2) * stride_d]
- : src[offset_src + (d2 / 2 - d2) * stride_d];
- dst[offset_dst] =
- v_src * (scalar_t)v_cos + v_src_rotate * (scalar_t)v_sin;
- }
- }
- // copy the rest
- if (d > d2) {
- #pragma unroll
- for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
- int offset_head = offset_block + h_id * stride_h;
- int offset_head_dst = offset_block_dst + h_id * o_stride_h;
- #pragma unroll
- for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
- dst[offset_head_dst + d_id * o_stride_d] =
- src[offset_head + d_id * stride_d];
- }
- }
- }
- }
- template <typename scalar_t>
- __device__ void fused_rope_block_backward(
- const scalar_t *src, const float *freqs, scalar_t *dst,
- const int offset_block, const int offset_block_dst, const int h,
- const int d, const int d2, const int stride_h, const int stride_d,
- const int o_stride_h, const int o_stride_d) {
- int s_id = blockIdx.x;
- #pragma unroll
- for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
- scalar_t v_cos = cosf(freqs[s_id * d2 + d_id]);
- scalar_t v_sin = (d_id + d2 / 2 < d2)
- ? sinf(freqs[s_id * d2 + d_id + d2 / 2])
- : -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]);
- #pragma unroll
- for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
- int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
- int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
- scalar_t v_src = src[offset_src];
- scalar_t v_src_rotate = (d_id + d2 / 2 < d2)
- ? src[offset_src + (d2 / 2) * stride_d]
- : src[offset_src + (d2 / 2 - d2) * stride_d];
- dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
- }
- }
- // handle the tail
- if (d > d2) {
- #pragma unroll
- for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
- int offset_head = offset_block + h_id * stride_h;
- int offset_head_dst = offset_block_dst + h_id * o_stride_h;
- #pragma unroll
- for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
- dst[offset_head_dst + d_id * o_stride_d] =
- src[offset_head + d_id * stride_d];
- }
- }
- }
- }
- template <typename scalar_t>
- __global__ void fused_rope_forward(const int h, const int d, const int d2,
- const int stride_s, const int stride_b,
- const int stride_h, const int stride_d,
- const int o_stride_s, const int o_stride_b,
- const int o_stride_h, const int o_stride_d,
- const scalar_t* src, const float* freqs,
- scalar_t* dst) {
- int s_id = blockIdx.x, b_id = blockIdx.y;
- int offset_block = s_id * stride_s + b_id * stride_b;
- int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
- fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h,
- d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
- }
- template <typename scalar_t>
- __global__ void fused_rope_backward(const int h, const int d, const int d2,
- const int stride_s, const int stride_b,
- const int stride_h, const int stride_d,
- const int o_stride_s, const int o_stride_b,
- const int o_stride_h, const int o_stride_d,
- const scalar_t* src, const float* freqs,
- scalar_t* dst) {
- int s_id = blockIdx.x, b_id = blockIdx.y;
- int offset_block = s_id * stride_s + b_id * stride_b;
- int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
- fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h,
- d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
- }
- template <typename scalar_t_0, typename scalar_t_1>
- __global__ void fused_rope_cached_forward(
- const int h, const int d, const int d2, const int stride_s,
- const int stride_b, const int stride_h, const int stride_d,
- const int o_stride_s, const int o_stride_b, const int o_stride_h,
- const int o_stride_d, const scalar_t_0* src, const scalar_t_1* cos,
- const scalar_t_1* sin, scalar_t_0* dst) {
- int s_id = blockIdx.x, b_id = blockIdx.y;
- int offset_block = s_id * stride_s + b_id * stride_b;
- int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
- #pragma unroll
- for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
- scalar_t_0 v_cos = cos[s_id * d2 + d_id];
- scalar_t_0 v_sin = sin[s_id * d2 + d_id];
- #pragma unroll
- for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
- int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
- int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
- scalar_t_0 v_src = src[offset_src];
- scalar_t_0 v_src_rotate =
- (d_id + d2 / 2 < d2) ? -src[offset_src + (d2 / 2) * stride_d]
- : src[offset_src + (d2 / 2 - d2) * stride_d];
- dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
- }
- }
- // copy the rest
- if (d > d2) {
- #pragma unroll
- for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
- int offset_head = offset_block + h_id * stride_h;
- int offset_head_dst = offset_block_dst + h_id * o_stride_h;
- #pragma unroll
- for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
- dst[offset_head_dst + d_id * o_stride_d] =
- src[offset_head + d_id * stride_d];
- }
- }
- }
- }
- template <typename scalar_t_0, typename scalar_t_1>
- __global__ void fused_rope_cached_backward(
- const int h, const int d, const int d2, const int stride_s,
- const int stride_b, const int stride_h, const int stride_d,
- const int o_stride_s, const int o_stride_b, const int o_stride_h,
- const int o_stride_d, const scalar_t_0* src, const scalar_t_1* cos,
- const scalar_t_1* sin, scalar_t_0* dst) {
- int s_id = blockIdx.x, b_id = blockIdx.y;
- int offset_block = s_id * stride_s + b_id * stride_b;
- int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
- #pragma unroll
- for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
- scalar_t_0 v_cos = cos[s_id * d2 + d_id];
- scalar_t_0 v_sin = (d_id + d2 / 2 < d2)
- ? sin[s_id * d2 + d_id + d2 / 2]
- : -sin[s_id * d2 + d_id + d2 / 2 - d2];
- #pragma unroll
- for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
- int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
- int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
- scalar_t_0 v_src = src[offset_src];
- scalar_t_0 v_src_rotate =
- (d_id + d2 / 2 < d2) ? src[offset_src + (d2 / 2) * stride_d]
- : src[offset_src + (d2 / 2 - d2) * stride_d];
- dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
- }
- }
- // handle the tail
- if (d > d2) {
- #pragma unroll
- for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
- int offset_head = offset_block + h_id * stride_h;
- int offset_head_dst = offset_block_dst + h_id * o_stride_h;
- #pragma unroll
- for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
- dst[offset_head_dst + d_id * o_stride_d] =
- src[offset_head + d_id * stride_d];
- }
- }
- }
- }
- template <typename scalar_t>
- __global__ void fused_rope_thd_forward(
- const int h, const int d, const int d2, const int stride_t,
- const int stride_h, const int stride_d, const int o_stride_t,
- const int o_stride_h, const int o_stride_d, const scalar_t* src,
- const int* cu_seqlens, const float* freqs, scalar_t* dst) {
- int s_id = blockIdx.x, b_id = blockIdx.y;
- int t_id = s_id + cu_seqlens[b_id];
- if (t_id >= cu_seqlens[b_id + 1]) return;
- int offset_block = t_id * stride_t;
- int offset_block_dst = t_id * o_stride_t;
- fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h,
- d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
- }
- template <typename scalar_t>
- __global__ void fused_rope_thd_backward(
- const int h, const int d, const int d2, const int stride_t,
- const int stride_h, const int stride_d, const int o_stride_t,
- const int o_stride_h, const int o_stride_d, const scalar_t* src,
- const int* cu_seqlens, const float* freqs, scalar_t* dst) {
- int s_id = blockIdx.x, b_id = blockIdx.y;
- int t_id = s_id + cu_seqlens[b_id];
- if (t_id >= cu_seqlens[b_id + 1]) return;
- int offset_block = t_id * stride_t;
- int offset_block_dst = t_id * o_stride_t;
- fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h,
- d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
- }
- } // end of anonymous namespace
- template <typename scalar_t>
- void dispatch_fused_rope_forward(const int s, const int b, const int h,
- const int d, const int d2, const int stride_s,
- const int stride_b, const int stride_h,
- const int stride_d, const int o_stride_s,
- const int o_stride_b, const int o_stride_h,
- const int o_stride_d, const scalar_t* input,
- const float* freqs, scalar_t* output) {
- auto stream = at::cuda::getCurrentCUDAStream();
- int warps_per_block = h < 16 ? 4 : 8;
- dim3 blocks(s, b);
- dim3 threads(C10_WARP_SIZE, warps_per_block);
- fused_rope_forward<<<blocks, threads, 0, stream>>>(
- h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
- o_stride_h, o_stride_d, input, freqs, output);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- }
- template <typename scalar_t>
- void dispatch_fused_rope_backward(const int s, const int b, const int h,
- const int d, const int d2, const int stride_s,
- const int stride_b, const int stride_h,
- const int stride_d, const int o_stride_s,
- const int o_stride_b, const int o_stride_h,
- const int o_stride_d,
- const scalar_t* output_grads,
- const float* freqs, scalar_t* input_grads) {
- auto stream = at::cuda::getCurrentCUDAStream();
- int warps_per_block = h < 16 ? 4 : 8;
- dim3 blocks(s, b);
- dim3 threads(C10_WARP_SIZE, warps_per_block);
- fused_rope_backward<<<blocks, threads, 0, stream>>>(
- h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
- o_stride_h, o_stride_d, output_grads, freqs, input_grads);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- }
- template <typename scalar_t_0, typename scalar_t_1>
- void dispatch_fused_rope_cached_forward(
- const int s, const int b, const int h, const int d, const int d2,
- const int stride_s, const int stride_b, const int stride_h,
- const int stride_d, const int o_stride_s, const int o_stride_b,
- const int o_stride_h, const int o_stride_d, const scalar_t_0* input,
- const scalar_t_1* cos, const scalar_t_1* sin, scalar_t_0* output) {
- auto stream = at::cuda::getCurrentCUDAStream();
- int warps_per_block = h < 16 ? 4 : 8;
- dim3 blocks(s, b);
- dim3 threads(C10_WARP_SIZE, warps_per_block);
- fused_rope_cached_forward<<<blocks, threads, 0, stream>>>(
- h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
- o_stride_h, o_stride_d, input, cos, sin, output);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- }
- template <typename scalar_t_0, typename scalar_t_1>
- void dispatch_fused_rope_cached_backward(
- const int s, const int b, const int h, const int d, const int d2,
- const int stride_s, const int stride_b, const int stride_h,
- const int stride_d, const int o_stride_s, const int o_stride_b,
- const int o_stride_h, const int o_stride_d, const scalar_t_0* output_grads,
- const scalar_t_1* cos, const scalar_t_1* sin, scalar_t_0* input_grads) {
- auto stream = at::cuda::getCurrentCUDAStream();
- int warps_per_block = h < 16 ? 4 : 8;
- dim3 blocks(s, b);
- dim3 threads(C10_WARP_SIZE, warps_per_block);
- fused_rope_cached_backward<<<blocks, threads, 0, stream>>>(
- h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
- o_stride_h, o_stride_d, output_grads, cos, sin, input_grads);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- }
- template <typename scalar_t>
- void dispatch_fused_rope_thd_forward(const int max_s, const int b, const int h,
- const int d, const int d2,
- const int stride_t, const int stride_h,
- const int stride_d, const int o_stride_t,
- const int o_stride_h, const int o_stride_d,
- const scalar_t* input,
- const int* cu_seqlens, const float* freqs,
- scalar_t* output) {
- auto stream = at::cuda::getCurrentCUDAStream();
- int warps_per_block = h < 16 ? 4 : 8;
- dim3 blocks(max_s, b);
- dim3 threads(C10_WARP_SIZE, warps_per_block);
- fused_rope_thd_forward<<<blocks, threads, 0, stream>>>(
- h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h,
- o_stride_d, input, cu_seqlens, freqs, output);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- }
- template <typename scalar_t>
- void dispatch_fused_rope_thd_backward(
- const int max_s, const int b, const int h, const int d, const int d2,
- const int stride_t, const int stride_h, const int stride_d,
- const int o_stride_t, const int o_stride_h, const int o_stride_d,
- const scalar_t* output_grads, const int* cu_seqlens, const float* freqs,
- scalar_t* input_grads) {
- auto stream = at::cuda::getCurrentCUDAStream();
- int warps_per_block = h < 16 ? 4 : 8;
- dim3 blocks(max_s, b);
- dim3 threads(C10_WARP_SIZE, warps_per_block);
- fused_rope_thd_backward<<<blocks, threads, 0, stream>>>(
- h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h,
- o_stride_d, output_grads, cu_seqlens, freqs, input_grads);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- }
|