fused_rotary_positional_embedding.cpp 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. /* coding=utf-8
  2. * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <torch/extension.h>
  17. namespace fused_rope {
  18. torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &freqs,
  19. const bool transpose_output);
  20. torch::Tensor bwd_cuda(const torch::Tensor &output_grads,
  21. const torch::Tensor &freqs, const bool transpose_output);
  22. torch::Tensor fwd_cached_cuda(const torch::Tensor &input,
  23. const torch::Tensor &cos,
  24. const torch::Tensor &sin,
  25. const bool transpose_output);
  26. torch::Tensor bwd_cached_cuda(const torch::Tensor &output_grads,
  27. const torch::Tensor &cos,
  28. const torch::Tensor &sin,
  29. const bool transpose_output);
  30. torch::Tensor fwd_thd_cuda(const torch::Tensor &input,
  31. const torch::Tensor &cu_seqlens,
  32. const torch::Tensor &freqs);
  33. torch::Tensor bwd_thd_cuda(const torch::Tensor &output_grads,
  34. const torch::Tensor &cu_seqlens,
  35. const torch::Tensor &freqs);
  36. torch::Tensor fwd(const at::Tensor &input, const at::Tensor &freqs,
  37. const bool transpose_output) {
  38. TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
  39. TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
  40. TORCH_CHECK(input.size(0) == freqs.size(0),
  41. "expected input and freqs tensor have the same sequence length");
  42. TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
  43. "expected the second and third dims of the freqs tensor equal 1");
  44. TORCH_CHECK(input.size(3) >= freqs.size(3),
  45. "expected the last dim of the input tensor equals or is "
  46. "greater than the freqs tensor");
  47. TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
  48. "Dtype of the freqs tensor must be float");
  49. return fwd_cuda(input, freqs, transpose_output);
  50. }
  51. torch::Tensor bwd(const torch::Tensor &output_grads, const at::Tensor &freqs,
  52. const bool transpose_output) {
  53. TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
  54. TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
  55. TORCH_CHECK(
  56. output_grads.size(0) == freqs.size(0),
  57. "expected output_grads and freqs tensor have the same sequence length");
  58. TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
  59. "expected the second and third dims of the freqs tensor equal 1");
  60. TORCH_CHECK(output_grads.size(3) >= freqs.size(3),
  61. "expected the last dim of the output_grads tensor equals or is "
  62. "greater than the freqs tensor");
  63. TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
  64. "Dtype of the freqs tensor must be float");
  65. return bwd_cuda(output_grads, freqs, transpose_output);
  66. }
  67. torch::Tensor fwd_cached(const at::Tensor &input, const at::Tensor &cos,
  68. const at::Tensor &sin, const bool transpose_output) {
  69. TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
  70. TORCH_CHECK(cos.dim() == 4, "expected 4D tensor");
  71. TORCH_CHECK(sin.dim() == 4, "expected 4D tensor");
  72. TORCH_CHECK(input.size(0) == cos.size(0),
  73. "expected input and cos tensor have the same sequence length");
  74. TORCH_CHECK(input.size(0) == sin.size(0),
  75. "expected input and sin tensor have the same sequence length");
  76. TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1,
  77. "expected the second and third dims of the cos tensor equal 1");
  78. TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1,
  79. "expected the second and third dims of the sin tensor equal 1");
  80. TORCH_CHECK(cos.size(3) == sin.size(3),
  81. "expected cos and sin tensor have the same last dim");
  82. TORCH_CHECK(input.size(3) >= cos.size(3),
  83. "expected the last dim of the input tensor equals or is "
  84. "greater than the cos tensor");
  85. TORCH_CHECK(cos.scalar_type() == sin.scalar_type(),
  86. "expected cos and sin tensor have the same dtype");
  87. return fwd_cached_cuda(input, cos, sin, transpose_output);
  88. }
  89. torch::Tensor bwd_cached(const torch::Tensor &output_grads,
  90. const at::Tensor &cos, const at::Tensor &sin,
  91. const bool transpose_output) {
  92. TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
  93. TORCH_CHECK(cos.dim() == 4, "expected 4D tensor");
  94. TORCH_CHECK(sin.dim() == 4, "expected 4D tensor");
  95. TORCH_CHECK(
  96. output_grads.size(0) == cos.size(0),
  97. "expected output_grads and cos tensor have the same sequence length");
  98. TORCH_CHECK(
  99. output_grads.size(0) == sin.size(0),
  100. "expected output_grads and sin tensor have the same sequence length");
  101. TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1,
  102. "expected the second and third dims of the cos tensor equal 1");
  103. TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1,
  104. "expected the second and third dims of the sin tensor equal 1");
  105. TORCH_CHECK(cos.size(3) == sin.size(3),
  106. "expected cos and sin tensor have the same last dim");
  107. TORCH_CHECK(output_grads.size(3) >= cos.size(3),
  108. "expected the last dim of the output_grads tensor equals or is "
  109. "greater than the cos tensor");
  110. TORCH_CHECK(cos.scalar_type() == sin.scalar_type(),
  111. "expected cos and sin tensor have the same dtype");
  112. return bwd_cached_cuda(output_grads, cos, sin, transpose_output);
  113. }
  114. torch::Tensor fwd_thd(const torch::Tensor &input,
  115. const torch::Tensor &cu_seqlens,
  116. const torch::Tensor &freqs) {
  117. TORCH_CHECK(input.dim() == 3, "expected 3D tensor");
  118. TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor");
  119. TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
  120. TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
  121. "expected the second and third dims of the freqs tensor equal 1");
  122. TORCH_CHECK(input.size(2) >= freqs.size(3),
  123. "expected the last dim of the input tensor equals or is "
  124. "greater than the freqs tensor");
  125. TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
  126. "Dtype of the freqs tensor must be float");
  127. return fwd_thd_cuda(input, cu_seqlens, freqs);
  128. }
  129. torch::Tensor bwd_thd(const torch::Tensor &output_grads,
  130. const torch::Tensor &cu_seqlens,
  131. const torch::Tensor &freqs) {
  132. TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor");
  133. TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor");
  134. TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
  135. TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
  136. "expected the second and third dims of the freqs tensor equal 1");
  137. TORCH_CHECK(output_grads.size(2) >= freqs.size(3),
  138. "expected the last dim of the output_grads tensor equals or is "
  139. "greater than the freqs tensor");
  140. TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
  141. "Dtype of the freqs tensor must be float");
  142. return bwd_thd_cuda(output_grads, cu_seqlens, freqs);
  143. }
  144. } // end namespace fused_rope
  145. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  146. m.def("forward", &fused_rope::fwd,
  147. "Fused Rotary Positional Embedding -- Forward.");
  148. m.def("backward", &fused_rope::bwd,
  149. "Fused Rotary Positional Embedding -- Backward.");
  150. // cache sin/cos
  151. m.def("forward_cached", &fused_rope::fwd_cached,
  152. "Fused Rotary Positional Embedding Cached -- Forward.");
  153. m.def("backward_cached", &fused_rope::bwd_cached,
  154. "Fused Rotary Positional Embedding Cached -- Backward.");
  155. // thd
  156. m.def("forward_thd", &fused_rope::fwd_thd,
  157. "Fused Rotary Positional Embedding for thd layout -- Forward.");
  158. m.def("backward_thd", &fused_rope::bwd_thd,
  159. "Fused Rotary Positional Embedding for thd layout -- Backward.");
  160. }