fused_rotary_positional_embedding_cuda.cu 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. /* coding=utf-8
  2. * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <ATen/ATen.h>
  17. #include "fused_rotary_positional_embedding.h"
  18. #include "type_shim.h"
  19. namespace fused_rope {
  20. torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &freqs,
  21. const bool transpose_output) {
  22. // input sizes: (s, b, h, d)
  23. // s: sequence length
  24. // b: batch size
  25. // h: head num
  26. // d: dim of each head
  27. const int s = input.size(0);
  28. const int b = input.size(1);
  29. const int h = input.size(2);
  30. const int d = input.size(3);
  31. // input strides
  32. const int stride_s = input.stride(0);
  33. const int stride_b = input.stride(1);
  34. const int stride_h = input.stride(2);
  35. const int stride_d = input.stride(3);
  36. // freqs' shape is always (s, 1, 1, d2), so the strides are same under
  37. // different memory formats
  38. const int d2 = freqs.size(3);
  39. // output
  40. auto act_options = input.options().requires_grad(false);
  41. torch::Tensor output;
  42. if (transpose_output) {
  43. output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
  44. } else {
  45. output = torch::empty({s, b, h, d}, act_options);
  46. }
  47. // output strides
  48. const int o_stride_s = output.stride(0);
  49. const int o_stride_b = output.stride(1);
  50. const int o_stride_h = output.stride(2);
  51. const int o_stride_d = output.stride(3);
  52. DISPATCH_FLOAT_HALF_AND_BFLOAT(
  53. input.scalar_type(), 0, "dispatch_fused_rope_forward",
  54. dispatch_fused_rope_forward(
  55. s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,
  56. o_stride_b, o_stride_h, o_stride_d, input.data_ptr<scalar_t_0>(),
  57. freqs.data_ptr<float>(), output.data_ptr<scalar_t_0>()););
  58. return output;
  59. }
  60. torch::Tensor bwd_cuda(const torch::Tensor &output_grads,
  61. const torch::Tensor &freqs,
  62. const bool transpose_output) {
  63. // output_grads sizes: (s, b, h, d)
  64. // s: sequence length
  65. // b: batch size
  66. // h: head num
  67. // d: dim of each head
  68. const int s = output_grads.size(0);
  69. const int b = output_grads.size(1);
  70. const int h = output_grads.size(2);
  71. const int d = output_grads.size(3);
  72. // output_grads strides
  73. const int stride_s = output_grads.stride(0);
  74. const int stride_b = output_grads.stride(1);
  75. const int stride_h = output_grads.stride(2);
  76. const int stride_d = output_grads.stride(3);
  77. // freqs' shape is always (s, 1, 1, d2), so the strides are same under
  78. // different memory formats
  79. const int d2 = freqs.size(3);
  80. auto act_options = output_grads.options().requires_grad(false);
  81. torch::Tensor input_grads;
  82. if (transpose_output) {
  83. input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
  84. } else {
  85. input_grads = torch::empty({s, b, h, d}, act_options);
  86. }
  87. const int o_stride_s = input_grads.stride(0);
  88. const int o_stride_b = input_grads.stride(1);
  89. const int o_stride_h = input_grads.stride(2);
  90. const int o_stride_d = input_grads.stride(3);
  91. DISPATCH_FLOAT_HALF_AND_BFLOAT(
  92. output_grads.scalar_type(), 0, "dispatch_fused_rope_backward",
  93. dispatch_fused_rope_backward(
  94. s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,
  95. o_stride_b, o_stride_h, o_stride_d,
  96. output_grads.data_ptr<scalar_t_0>(), freqs.data_ptr<float>(),
  97. input_grads.data_ptr<scalar_t_0>()););
  98. return input_grads;
  99. }
  100. #define DISPATCH_FUSED_ROPE_TYPES(TYPE1, TYPE2, NAME, ...) \
  101. switch (TYPE1) { \
  102. case at::ScalarType::Float: { \
  103. using scalar_t_0 = float; \
  104. switch (TYPE2) { \
  105. case at::ScalarType::Float: { \
  106. using scalar_t_1 = float; \
  107. __VA_ARGS__; \
  108. break; \
  109. } \
  110. default: \
  111. TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), \
  112. "' with '", toString(TYPE2), "'"); \
  113. } \
  114. break; \
  115. } \
  116. case at::ScalarType::Half: { \
  117. using scalar_t_0 = at::Half; \
  118. switch (TYPE2) { \
  119. case at::ScalarType::Float: { \
  120. using scalar_t_1 = float; \
  121. __VA_ARGS__; \
  122. break; \
  123. } \
  124. case at::ScalarType::Half: { \
  125. using scalar_t_1 = at::Half; \
  126. __VA_ARGS__; \
  127. break; \
  128. } \
  129. default: \
  130. TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), \
  131. "' with '", toString(TYPE2), "'"); \
  132. } \
  133. break; \
  134. } \
  135. case at::ScalarType::BFloat16: { \
  136. using scalar_t_0 = at::BFloat16; \
  137. switch (TYPE2) { \
  138. case at::ScalarType::Float: { \
  139. using scalar_t_1 = float; \
  140. __VA_ARGS__; \
  141. break; \
  142. } \
  143. case at::ScalarType::BFloat16: { \
  144. using scalar_t_1 = at::BFloat16; \
  145. __VA_ARGS__; \
  146. break; \
  147. } \
  148. default: \
  149. TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), \
  150. "' with '", toString(TYPE2), "'"); \
  151. } \
  152. break; \
  153. } \
  154. default: \
  155. TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), \
  156. "' with '", toString(TYPE2), "'"); \
  157. }
  158. torch::Tensor fwd_cached_cuda(const torch::Tensor &input,
  159. const torch::Tensor &cos,
  160. const torch::Tensor &sin,
  161. const bool transpose_output) {
  162. // input sizes: (s, b, h, d)
  163. // s: sequence length
  164. // b: batch size
  165. // h: head num
  166. // d: dim of each head
  167. const int s = input.size(0);
  168. const int b = input.size(1);
  169. const int h = input.size(2);
  170. const int d = input.size(3);
  171. // input strides
  172. const int stride_s = input.stride(0);
  173. const int stride_b = input.stride(1);
  174. const int stride_h = input.stride(2);
  175. const int stride_d = input.stride(3);
  176. // cos/sin's shape is always (s, 1, 1, d2), so the strides are same under
  177. // different memory formats
  178. const int d2 = cos.size(3);
  179. // output
  180. auto act_options = input.options().requires_grad(false);
  181. torch::Tensor output;
  182. if (transpose_output) {
  183. output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
  184. } else {
  185. output = torch::empty({s, b, h, d}, act_options);
  186. }
  187. // output strides
  188. const int o_stride_s = output.stride(0);
  189. const int o_stride_b = output.stride(1);
  190. const int o_stride_h = output.stride(2);
  191. const int o_stride_d = output.stride(3);
  192. DISPATCH_FUSED_ROPE_TYPES(
  193. input.scalar_type(), cos.scalar_type(),
  194. "dispatch_fused_rope_cached_forward",
  195. dispatch_fused_rope_cached_forward(
  196. s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,
  197. o_stride_b, o_stride_h, o_stride_d, input.data_ptr<scalar_t_0>(),
  198. cos.data_ptr<scalar_t_1>(), sin.data_ptr<scalar_t_1>(),
  199. output.data_ptr<scalar_t_0>()););
  200. return output;
  201. }
  202. torch::Tensor bwd_cached_cuda(const torch::Tensor &output_grads,
  203. const torch::Tensor &cos,
  204. const torch::Tensor &sin,
  205. const bool transpose_output) {
  206. // output_grads sizes: (s, b, h, d)
  207. // s: sequence length
  208. // b: batch size
  209. // h: head num
  210. // d: dim of each head
  211. const int s = output_grads.size(0);
  212. const int b = output_grads.size(1);
  213. const int h = output_grads.size(2);
  214. const int d = output_grads.size(3);
  215. // output_grads strides
  216. const int stride_s = output_grads.stride(0);
  217. const int stride_b = output_grads.stride(1);
  218. const int stride_h = output_grads.stride(2);
  219. const int stride_d = output_grads.stride(3);
  220. // cos/sin's shape is always (s, 1, 1, d2), so the strides are same under
  221. // different memory formats
  222. const int d2 = cos.size(3);
  223. auto act_options = output_grads.options().requires_grad(false);
  224. torch::Tensor input_grads;
  225. if (transpose_output) {
  226. input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
  227. } else {
  228. input_grads = torch::empty({s, b, h, d}, act_options);
  229. }
  230. const int o_stride_s = input_grads.stride(0);
  231. const int o_stride_b = input_grads.stride(1);
  232. const int o_stride_h = input_grads.stride(2);
  233. const int o_stride_d = input_grads.stride(3);
  234. DISPATCH_FUSED_ROPE_TYPES(
  235. output_grads.scalar_type(), cos.scalar_type(),
  236. "dispatch_fused_rope_cached_backward",
  237. dispatch_fused_rope_cached_backward(
  238. s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,
  239. o_stride_b, o_stride_h, o_stride_d,
  240. output_grads.data_ptr<scalar_t_0>(), cos.data_ptr<scalar_t_1>(),
  241. sin.data_ptr<scalar_t_1>(), input_grads.data_ptr<scalar_t_0>()););
  242. return input_grads;
  243. }
  244. torch::Tensor fwd_thd_cuda(const torch::Tensor &input,
  245. const torch::Tensor &cu_seqlens,
  246. const torch::Tensor &freqs) {
  247. // input sizes: (t, h, d)
  248. // t: cumulative sum of sequence lengths
  249. // h: head num
  250. // d: dim of each head
  251. const int t = input.size(0);
  252. const int h = input.size(1);
  253. const int d = input.size(2);
  254. // input strides
  255. const int stride_t = input.stride(0);
  256. const int stride_h = input.stride(1);
  257. const int stride_d = input.stride(2);
  258. // batch size
  259. const int b = cu_seqlens.size(0) - 1;
  260. // freqs' shape is (max_s, 1, 1, d2)
  261. const int max_s = freqs.size(0);
  262. const int d2 = freqs.size(3);
  263. // output
  264. auto act_options = input.options().requires_grad(false);
  265. auto output = torch::empty({t, h, d}, act_options);
  266. // output strides
  267. const int o_stride_t = output.stride(0);
  268. const int o_stride_h = output.stride(1);
  269. const int o_stride_d = output.stride(2);
  270. DISPATCH_FLOAT_HALF_AND_BFLOAT(
  271. input.scalar_type(), 0, "dispatch_fused_rope_thd_forward",
  272. dispatch_fused_rope_thd_forward(
  273. max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t,
  274. o_stride_h, o_stride_d, input.data_ptr<scalar_t_0>(),
  275. cu_seqlens.data_ptr<int>(), freqs.data_ptr<float>(),
  276. output.data_ptr<scalar_t_0>()););
  277. return output;
  278. }
  279. torch::Tensor bwd_thd_cuda(const torch::Tensor &output_grads,
  280. const torch::Tensor &cu_seqlens,
  281. const torch::Tensor &freqs) {
  282. // output_grads sizes: (t, h, d)
  283. // t: cumulative sum of sequence lengths
  284. // h: head num
  285. // d: dim of each head
  286. const int t = output_grads.size(0);
  287. const int h = output_grads.size(1);
  288. const int d = output_grads.size(2);
  289. // output_grads strides
  290. const int stride_t = output_grads.stride(0);
  291. const int stride_h = output_grads.stride(1);
  292. const int stride_d = output_grads.stride(2);
  293. // batch size
  294. const int b = cu_seqlens.size(0) - 1;
  295. // freqs' shape is (max_s, 1, 1, d2)
  296. const int max_s = freqs.size(0);
  297. const int d2 = freqs.size(3);
  298. auto act_options = output_grads.options().requires_grad(false);
  299. auto input_grads = torch::empty({t, h, d}, act_options);
  300. const int o_stride_t = input_grads.stride(0);
  301. const int o_stride_h = input_grads.stride(1);
  302. const int o_stride_d = input_grads.stride(2);
  303. DISPATCH_FLOAT_HALF_AND_BFLOAT(
  304. output_grads.scalar_type(), 0, "dispatch_fused_rope_thd_backward",
  305. dispatch_fused_rope_thd_backward(
  306. max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t,
  307. o_stride_h, o_stride_d, output_grads.data_ptr<scalar_t_0>(),
  308. cu_seqlens.data_ptr<int>(), freqs.data_ptr<float>(),
  309. input_grads.data_ptr<scalar_t_0>()););
  310. return input_grads;
  311. }
  312. } // end namespace fused_rope