123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335 |
- /* coding=utf-8
- * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #include <ATen/ATen.h>
- #include "fused_rotary_positional_embedding.h"
- #include "type_shim.h"
- namespace fused_rope {
- torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &freqs,
- const bool transpose_output) {
- // input sizes: (s, b, h, d)
- // s: sequence length
- // b: batch size
- // h: head num
- // d: dim of each head
- const int s = input.size(0);
- const int b = input.size(1);
- const int h = input.size(2);
- const int d = input.size(3);
- // input strides
- const int stride_s = input.stride(0);
- const int stride_b = input.stride(1);
- const int stride_h = input.stride(2);
- const int stride_d = input.stride(3);
- // freqs' shape is always (s, 1, 1, d2), so the strides are same under
- // different memory formats
- const int d2 = freqs.size(3);
- // output
- auto act_options = input.options().requires_grad(false);
- torch::Tensor output;
- if (transpose_output) {
- output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
- } else {
- output = torch::empty({s, b, h, d}, act_options);
- }
- // output strides
- const int o_stride_s = output.stride(0);
- const int o_stride_b = output.stride(1);
- const int o_stride_h = output.stride(2);
- const int o_stride_d = output.stride(3);
- DISPATCH_FLOAT_HALF_AND_BFLOAT(
- input.scalar_type(), 0, "dispatch_fused_rope_forward",
- dispatch_fused_rope_forward(
- s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,
- o_stride_b, o_stride_h, o_stride_d, input.data_ptr<scalar_t_0>(),
- freqs.data_ptr<float>(), output.data_ptr<scalar_t_0>()););
- return output;
- }
- torch::Tensor bwd_cuda(const torch::Tensor &output_grads,
- const torch::Tensor &freqs,
- const bool transpose_output) {
- // output_grads sizes: (s, b, h, d)
- // s: sequence length
- // b: batch size
- // h: head num
- // d: dim of each head
- const int s = output_grads.size(0);
- const int b = output_grads.size(1);
- const int h = output_grads.size(2);
- const int d = output_grads.size(3);
- // output_grads strides
- const int stride_s = output_grads.stride(0);
- const int stride_b = output_grads.stride(1);
- const int stride_h = output_grads.stride(2);
- const int stride_d = output_grads.stride(3);
- // freqs' shape is always (s, 1, 1, d2), so the strides are same under
- // different memory formats
- const int d2 = freqs.size(3);
- auto act_options = output_grads.options().requires_grad(false);
- torch::Tensor input_grads;
- if (transpose_output) {
- input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
- } else {
- input_grads = torch::empty({s, b, h, d}, act_options);
- }
- const int o_stride_s = input_grads.stride(0);
- const int o_stride_b = input_grads.stride(1);
- const int o_stride_h = input_grads.stride(2);
- const int o_stride_d = input_grads.stride(3);
- DISPATCH_FLOAT_HALF_AND_BFLOAT(
- output_grads.scalar_type(), 0, "dispatch_fused_rope_backward",
- dispatch_fused_rope_backward(
- s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,
- o_stride_b, o_stride_h, o_stride_d,
- output_grads.data_ptr<scalar_t_0>(), freqs.data_ptr<float>(),
- input_grads.data_ptr<scalar_t_0>()););
- return input_grads;
- }
- #define DISPATCH_FUSED_ROPE_TYPES(TYPE1, TYPE2, NAME, ...) \
- switch (TYPE1) { \
- case at::ScalarType::Float: { \
- using scalar_t_0 = float; \
- switch (TYPE2) { \
- case at::ScalarType::Float: { \
- using scalar_t_1 = float; \
- __VA_ARGS__; \
- break; \
- } \
- default: \
- TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), \
- "' with '", toString(TYPE2), "'"); \
- } \
- break; \
- } \
- case at::ScalarType::Half: { \
- using scalar_t_0 = at::Half; \
- switch (TYPE2) { \
- case at::ScalarType::Float: { \
- using scalar_t_1 = float; \
- __VA_ARGS__; \
- break; \
- } \
- case at::ScalarType::Half: { \
- using scalar_t_1 = at::Half; \
- __VA_ARGS__; \
- break; \
- } \
- default: \
- TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), \
- "' with '", toString(TYPE2), "'"); \
- } \
- break; \
- } \
- case at::ScalarType::BFloat16: { \
- using scalar_t_0 = at::BFloat16; \
- switch (TYPE2) { \
- case at::ScalarType::Float: { \
- using scalar_t_1 = float; \
- __VA_ARGS__; \
- break; \
- } \
- case at::ScalarType::BFloat16: { \
- using scalar_t_1 = at::BFloat16; \
- __VA_ARGS__; \
- break; \
- } \
- default: \
- TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), \
- "' with '", toString(TYPE2), "'"); \
- } \
- break; \
- } \
- default: \
- TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), \
- "' with '", toString(TYPE2), "'"); \
- }
- torch::Tensor fwd_cached_cuda(const torch::Tensor &input,
- const torch::Tensor &cos,
- const torch::Tensor &sin,
- const bool transpose_output) {
- // input sizes: (s, b, h, d)
- // s: sequence length
- // b: batch size
- // h: head num
- // d: dim of each head
- const int s = input.size(0);
- const int b = input.size(1);
- const int h = input.size(2);
- const int d = input.size(3);
- // input strides
- const int stride_s = input.stride(0);
- const int stride_b = input.stride(1);
- const int stride_h = input.stride(2);
- const int stride_d = input.stride(3);
- // cos/sin's shape is always (s, 1, 1, d2), so the strides are same under
- // different memory formats
- const int d2 = cos.size(3);
- // output
- auto act_options = input.options().requires_grad(false);
- torch::Tensor output;
- if (transpose_output) {
- output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
- } else {
- output = torch::empty({s, b, h, d}, act_options);
- }
- // output strides
- const int o_stride_s = output.stride(0);
- const int o_stride_b = output.stride(1);
- const int o_stride_h = output.stride(2);
- const int o_stride_d = output.stride(3);
- DISPATCH_FUSED_ROPE_TYPES(
- input.scalar_type(), cos.scalar_type(),
- "dispatch_fused_rope_cached_forward",
- dispatch_fused_rope_cached_forward(
- s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,
- o_stride_b, o_stride_h, o_stride_d, input.data_ptr<scalar_t_0>(),
- cos.data_ptr<scalar_t_1>(), sin.data_ptr<scalar_t_1>(),
- output.data_ptr<scalar_t_0>()););
- return output;
- }
- torch::Tensor bwd_cached_cuda(const torch::Tensor &output_grads,
- const torch::Tensor &cos,
- const torch::Tensor &sin,
- const bool transpose_output) {
- // output_grads sizes: (s, b, h, d)
- // s: sequence length
- // b: batch size
- // h: head num
- // d: dim of each head
- const int s = output_grads.size(0);
- const int b = output_grads.size(1);
- const int h = output_grads.size(2);
- const int d = output_grads.size(3);
- // output_grads strides
- const int stride_s = output_grads.stride(0);
- const int stride_b = output_grads.stride(1);
- const int stride_h = output_grads.stride(2);
- const int stride_d = output_grads.stride(3);
- // cos/sin's shape is always (s, 1, 1, d2), so the strides are same under
- // different memory formats
- const int d2 = cos.size(3);
- auto act_options = output_grads.options().requires_grad(false);
- torch::Tensor input_grads;
- if (transpose_output) {
- input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
- } else {
- input_grads = torch::empty({s, b, h, d}, act_options);
- }
- const int o_stride_s = input_grads.stride(0);
- const int o_stride_b = input_grads.stride(1);
- const int o_stride_h = input_grads.stride(2);
- const int o_stride_d = input_grads.stride(3);
- DISPATCH_FUSED_ROPE_TYPES(
- output_grads.scalar_type(), cos.scalar_type(),
- "dispatch_fused_rope_cached_backward",
- dispatch_fused_rope_cached_backward(
- s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,
- o_stride_b, o_stride_h, o_stride_d,
- output_grads.data_ptr<scalar_t_0>(), cos.data_ptr<scalar_t_1>(),
- sin.data_ptr<scalar_t_1>(), input_grads.data_ptr<scalar_t_0>()););
- return input_grads;
- }
- torch::Tensor fwd_thd_cuda(const torch::Tensor &input,
- const torch::Tensor &cu_seqlens,
- const torch::Tensor &freqs) {
- // input sizes: (t, h, d)
- // t: cumulative sum of sequence lengths
- // h: head num
- // d: dim of each head
- const int t = input.size(0);
- const int h = input.size(1);
- const int d = input.size(2);
- // input strides
- const int stride_t = input.stride(0);
- const int stride_h = input.stride(1);
- const int stride_d = input.stride(2);
- // batch size
- const int b = cu_seqlens.size(0) - 1;
- // freqs' shape is (max_s, 1, 1, d2)
- const int max_s = freqs.size(0);
- const int d2 = freqs.size(3);
- // output
- auto act_options = input.options().requires_grad(false);
- auto output = torch::empty({t, h, d}, act_options);
- // output strides
- const int o_stride_t = output.stride(0);
- const int o_stride_h = output.stride(1);
- const int o_stride_d = output.stride(2);
- DISPATCH_FLOAT_HALF_AND_BFLOAT(
- input.scalar_type(), 0, "dispatch_fused_rope_thd_forward",
- dispatch_fused_rope_thd_forward(
- max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t,
- o_stride_h, o_stride_d, input.data_ptr<scalar_t_0>(),
- cu_seqlens.data_ptr<int>(), freqs.data_ptr<float>(),
- output.data_ptr<scalar_t_0>()););
- return output;
- }
- torch::Tensor bwd_thd_cuda(const torch::Tensor &output_grads,
- const torch::Tensor &cu_seqlens,
- const torch::Tensor &freqs) {
- // output_grads sizes: (t, h, d)
- // t: cumulative sum of sequence lengths
- // h: head num
- // d: dim of each head
- const int t = output_grads.size(0);
- const int h = output_grads.size(1);
- const int d = output_grads.size(2);
- // output_grads strides
- const int stride_t = output_grads.stride(0);
- const int stride_h = output_grads.stride(1);
- const int stride_d = output_grads.stride(2);
- // batch size
- const int b = cu_seqlens.size(0) - 1;
- // freqs' shape is (max_s, 1, 1, d2)
- const int max_s = freqs.size(0);
- const int d2 = freqs.size(3);
- auto act_options = output_grads.options().requires_grad(false);
- auto input_grads = torch::empty({t, h, d}, act_options);
- const int o_stride_t = input_grads.stride(0);
- const int o_stride_h = input_grads.stride(1);
- const int o_stride_d = input_grads.stride(2);
- DISPATCH_FLOAT_HALF_AND_BFLOAT(
- output_grads.scalar_type(), 0, "dispatch_fused_rope_thd_backward",
- dispatch_fused_rope_thd_backward(
- max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t,
- o_stride_h, o_stride_d, output_grads.data_ptr<scalar_t_0>(),
- cu_seqlens.data_ptr<int>(), freqs.data_ptr<float>(),
- input_grads.data_ptr<scalar_t_0>()););
- return input_grads;
- }
- } // end namespace fused_rope
|