3
0

layer_norm_cuda.cpp 14 KB


  1. #include <torch/extension.h>
  2. #include <vector>
  3. #include <cassert>
  4. #include "compat.h"
  5. namespace {
  6. void compute_n1_n2(
  7. at::Tensor input,
  8. #ifdef VERSION_GE_1_1
  9. at::IntArrayRef normalized_shape,
  10. #else
  11. at::IntList normalized_shape,
  12. #endif
  13. int& n1,
  14. int& n2)
  15. {
  16. int idiff = input.ndimension() - normalized_shape.size();
  17. n2 = 1;
  18. for (int i = 0; i < (int)normalized_shape.size(); ++i) {
  19. assert( input.sizes()[i+idiff] == normalized_shape[i] );
  20. n2 *= normalized_shape[i];
  21. }
  22. n1 = 1;
  23. for (int i = 0; i < idiff; ++i) {
  24. n1 *= input.sizes()[i];
  25. }
  26. }
  27. void check_args(
  28. #ifdef VERSION_GE_1_1
  29. at::IntArrayRef normalized_shape,
  30. #else
  31. at::IntList normalized_shape,
  32. #endif
  33. at::Tensor gamma,
  34. at::Tensor beta
  35. )
  36. {
  37. TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
  38. TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
  39. }
  40. void check_args(
  41. #ifdef VERSION_GE_1_1
  42. at::IntArrayRef normalized_shape,
  43. #else
  44. at::IntList normalized_shape,
  45. #endif
  46. at::Tensor gamma
  47. )
  48. {
  49. TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
  50. }
  51. void check_args(
  52. at::Tensor input,
  53. #ifdef VERSION_GE_1_1
  54. at::IntArrayRef normalized_shape,
  55. #else
  56. at::IntList normalized_shape,
  57. #endif
  58. int& n1,
  59. int& n2
  60. )
  61. {
  62. int64_t normalized_ndim = normalized_shape.size();
  63. if (normalized_ndim < 1) {
  64. std::stringstream ss;
  65. ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
  66. << "containing at least one element, but got normalized_shape="
  67. << normalized_shape;
  68. throw std::runtime_error(ss.str());
  69. }
  70. auto input_shape = input.sizes();
  71. auto input_ndim = input.dim();
  72. if (input_ndim < normalized_ndim ||
  73. !input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
  74. std::stringstream ss;
  75. ss << "Given normalized_shape=" << normalized_shape
  76. << ", expected input with shape [*";
  77. for (auto size : normalized_shape) {
  78. ss << ", " << size;
  79. }
  80. ss << "], but got input of size" << input_shape;
  81. throw std::runtime_error(ss.str());
  82. }
  83. compute_n1_n2(input,normalized_shape,n1,n2);
  84. }
  85. void check_args(
  86. at::Tensor input,
  87. #ifdef VERSION_GE_1_1
  88. at::IntArrayRef normalized_shape,
  89. #else
  90. at::IntList normalized_shape,
  91. #endif
  92. at::Tensor gamma,
  93. at::Tensor beta,
  94. int& n1,
  95. int& n2
  96. )
  97. {
  98. check_args(input,normalized_shape,n1,n2);
  99. check_args(normalized_shape,gamma,beta);
  100. }
  101. void check_args(
  102. at::Tensor input,
  103. #ifdef VERSION_GE_1_1
  104. at::IntArrayRef normalized_shape,
  105. #else
  106. at::IntList normalized_shape,
  107. #endif
  108. at::Tensor gamma,
  109. int& n1,
  110. int& n2
  111. )
  112. {
  113. check_args(input,normalized_shape,n1,n2);
  114. check_args(normalized_shape,gamma);
  115. }
  116. }
  117. void cuda_layer_norm(
  118. at::Tensor* output,
  119. at::Tensor* mean,
  120. at::Tensor* invvar,
  121. at::Tensor* input,
  122. int n1,
  123. int n2,
  124. #ifdef VERSION_GE_1_1
  125. at::IntArrayRef normalized_shape,
  126. #else
  127. at::IntList normalized_shape,
  128. #endif
  129. at::Tensor* gamma,
  130. at::Tensor* beta,
  131. double epsilon);
  132. #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
  133. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  134. #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
  135. std::vector<at::Tensor> layer_norm(
  136. at::Tensor input,
  137. #ifdef VERSION_GE_1_1
  138. at::IntArrayRef normalized_shape,
  139. #else
  140. at::IntList normalized_shape,
  141. #endif
  142. double epsilon) {
  143. CHECK_INPUT(input);
  144. int n1,n2;
  145. check_args(input,normalized_shape,n1,n2);
  146. at::Tensor output = at::empty_like(input);
  147. at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half || input.scalar_type()==at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type()));
  148. at::Tensor invvar = at::empty_like(mean);
  149. cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
  150. normalized_shape,NULL,NULL,epsilon);
  151. return {output, mean, invvar};
  152. }
  153. std::vector<at::Tensor> layer_norm_affine(
  154. at::Tensor input,
  155. #ifdef VERSION_GE_1_1
  156. at::IntArrayRef normalized_shape,
  157. #else
  158. at::IntList normalized_shape,
  159. #endif
  160. at::Tensor gamma,
  161. at::Tensor beta,
  162. double epsilon) {
  163. CHECK_INPUT(input);
  164. CHECK_INPUT(gamma);
  165. CHECK_INPUT(beta);
  166. int n1,n2;
  167. check_args(input,normalized_shape,gamma,beta,n1,n2);
  168. at::Tensor output = at::empty_like(input);
  169. const auto stats_dtype = (input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type();
  170. at::Tensor mean = at::empty({n1}, input.options().dtype(stats_dtype));
  171. at::Tensor invvar = at::empty_like(mean);
  172. cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
  173. normalized_shape,&gamma,&beta,epsilon);
  174. return {output, mean, invvar};
  175. }
  176. std::vector<at::Tensor> layer_norm_affine_mixed_dtypes(
  177. at::Tensor input,
  178. #ifdef VERSION_GE_1_1
  179. at::IntArrayRef normalized_shape,
  180. #else
  181. at::IntList normalized_shape,
  182. #endif
  183. at::Tensor gamma,
  184. at::Tensor beta,
  185. double epsilon) {
  186. CHECK_INPUT(input);
  187. int n1, n2;
  188. check_args(input, normalized_shape, n1, n2);
  189. at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));
  190. at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type()));
  191. at::Tensor invvar = at::empty_like(mean);
  192. cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2,
  193. normalized_shape, &gamma, &beta, epsilon);
  194. return {output, mean, invvar};
  195. }
  196. void cuda_layer_norm_gradient(
  197. at::Tensor* dout,
  198. at::Tensor* mean,
  199. at::Tensor* invvar,
  200. at::Tensor* input_or_output,
  201. int n1,
  202. int n2,
  203. #ifdef VERSION_GE_1_1
  204. at::IntArrayRef normalized_shape,
  205. #else
  206. at::IntList normalized_shape,
  207. #endif
  208. at::Tensor* gamma,
  209. at::Tensor* beta,
  210. double epsilon,
  211. at::Tensor* grad_input,
  212. at::Tensor* grad_gamma,
  213. at::Tensor* grad_beta,
  214. bool memory_efficient
  215. );
  216. at::Tensor layer_norm_gradient(
  217. at::Tensor dout,
  218. c10::optional<at::Tensor> mean_,
  219. at::Tensor invvar,
  220. at::Tensor input_or_output,
  221. #ifdef VERSION_GE_1_1
  222. at::IntArrayRef normalized_shape,
  223. #else
  224. at::IntList normalized_shape,
  225. #endif
  226. double epsilon,
  227. bool memory_efficient) {
  228. CHECK_INPUT(dout);
  229. CHECK_INPUT(invvar);
  230. CHECK_INPUT(input_or_output);
  231. int n1,n2;
  232. check_args(input_or_output,normalized_shape,n1,n2);
  233. at::Tensor grad_input = at::empty_like(input_or_output);
  234. if (mean_.has_value()) {
  235. cuda_layer_norm_gradient(&dout,&mean_.value(),&invvar,&input_or_output,n1,n2,
  236. normalized_shape,NULL,NULL,epsilon,
  237. &grad_input,NULL,NULL,memory_efficient);
  238. } else {
  239. cuda_layer_norm_gradient(&dout,NULL,&invvar,&input_or_output,n1,n2,
  240. normalized_shape,NULL,NULL,epsilon,
  241. &grad_input,NULL,NULL,memory_efficient);
  242. }
  243. return grad_input;
  244. }
  245. std::vector<at::Tensor> layer_norm_gradient_affine(
  246. at::Tensor dout,
  247. c10::optional<at::Tensor> mean_,
  248. at::Tensor invvar,
  249. at::Tensor input_or_output,
  250. #ifdef VERSION_GE_1_1
  251. at::IntArrayRef normalized_shape,
  252. #else
  253. at::IntList normalized_shape,
  254. #endif
  255. at::Tensor gamma,
  256. at::Tensor beta,
  257. double epsilon,
  258. bool memory_efficient) {
  259. CHECK_INPUT(dout);
  260. CHECK_INPUT(invvar);
  261. CHECK_INPUT(input_or_output);
  262. CHECK_INPUT(gamma);
  263. CHECK_INPUT(beta);
  264. int n1,n2;
  265. check_args(input_or_output,normalized_shape,gamma,beta,n1,n2);
  266. at::Tensor grad_input = at::empty_like(input_or_output);
  267. at::Tensor grad_gamma = at::empty_like(gamma);
  268. at::Tensor grad_beta = at::empty_like(beta);
  269. // at::Tensor *mean = mean_.has_value() ? &mean_.value() : NULL;
  270. if (mean_.has_value()) {
  271. cuda_layer_norm_gradient(&dout,&mean_.value(),&invvar,&input_or_output,n1,n2,
  272. normalized_shape,&gamma,&beta,epsilon,
  273. &grad_input,&grad_gamma,&grad_beta,memory_efficient);
  274. } else {
  275. cuda_layer_norm_gradient(&dout,NULL,&invvar,&input_or_output,n1,n2,
  276. normalized_shape,&gamma,&beta,epsilon,
  277. &grad_input,&grad_gamma,&grad_beta,memory_efficient);
  278. }
  279. return {grad_input, grad_gamma, grad_beta};
  280. }
  281. void cuda_rms_norm(
  282. at::Tensor* output,
  283. at::Tensor* invvar,
  284. at::Tensor* input,
  285. int n1,
  286. int n2,
  287. #ifdef VERSION_GE_1_1
  288. at::IntArrayRef normalized_shape,
  289. #else
  290. at::IntList normalized_shape,
  291. #endif
  292. at::Tensor* gamma,
  293. double epsilon);
  294. #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
  295. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  296. #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
  297. std::vector<at::Tensor> rms_norm(
  298. at::Tensor input,
  299. #ifdef VERSION_GE_1_1
  300. at::IntArrayRef normalized_shape,
  301. #else
  302. at::IntList normalized_shape,
  303. #endif
  304. double epsilon) {
  305. CHECK_INPUT(input);
  306. int n1,n2;
  307. check_args(input,normalized_shape,n1,n2);
  308. at::Tensor output = at::empty_like(input);
  309. at::Tensor invvar = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half || input.scalar_type()==at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type()));
  310. cuda_rms_norm(&output,&invvar,&input,n1,n2,
  311. normalized_shape,NULL,epsilon);
  312. return {output, invvar};
  313. }
  314. std::vector<at::Tensor> rms_norm_affine(
  315. at::Tensor input,
  316. #ifdef VERSION_GE_1_1
  317. at::IntArrayRef normalized_shape,
  318. #else
  319. at::IntList normalized_shape,
  320. #endif
  321. at::Tensor gamma,
  322. double epsilon) {
  323. CHECK_INPUT(input);
  324. CHECK_INPUT(gamma);
  325. int n1,n2;
  326. check_args(input,normalized_shape,gamma,n1,n2);
  327. at::Tensor output = at::empty_like(input);
  328. const auto stats_dtype = (input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type();
  329. at::Tensor invvar = at::empty({n1}, input.options().dtype(stats_dtype));
  330. cuda_rms_norm(&output,&invvar,&input,n1,n2,
  331. normalized_shape,&gamma,epsilon);
  332. return {output, invvar};
  333. }
  334. std::vector<at::Tensor> rms_norm_affine_mixed_dtypes(
  335. at::Tensor input,
  336. #ifdef VERSION_GE_1_1
  337. at::IntArrayRef normalized_shape,
  338. #else
  339. at::IntList normalized_shape,
  340. #endif
  341. at::Tensor gamma,
  342. double epsilon) {
  343. CHECK_INPUT(input);
  344. int n1, n2;
  345. check_args(input, normalized_shape, n1, n2);
  346. at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));
  347. at::Tensor invvar = at::empty({n1}, input.options().dtype(input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type()));
  348. cuda_rms_norm(&output,&invvar, &input, n1, n2,
  349. normalized_shape, &gamma,epsilon);
  350. return {output,invvar};
  351. }
  352. void cuda_rms_norm_gradient(
  353. at::Tensor* dout,
  354. at::Tensor* invvar,
  355. at::Tensor* input_or_output,
  356. int n1,
  357. int n2,
  358. #ifdef VERSION_GE_1_1
  359. at::IntArrayRef normalized_shape,
  360. #else
  361. at::IntList normalized_shape,
  362. #endif
  363. at::Tensor* gamma,
  364. double epsilon,
  365. at::Tensor* grad_input,
  366. at::Tensor* grad_gamma,
  367. bool memory_efficient);
  368. at::Tensor rms_norm_gradient(
  369. at::Tensor dout,
  370. at::Tensor invvar,
  371. at::Tensor input_or_output,
  372. #ifdef VERSION_GE_1_1
  373. at::IntArrayRef normalized_shape,
  374. #else
  375. at::IntList normalized_shape,
  376. #endif
  377. double epsilon,
  378. bool memory_efficient) {
  379. CHECK_INPUT(dout);
  380. CHECK_INPUT(invvar);
  381. CHECK_INPUT(input_or_output);
  382. int n1,n2;
  383. check_args(input_or_output,normalized_shape,n1,n2);
  384. at::Tensor grad_input = at::empty_like(input_or_output);
  385. cuda_rms_norm_gradient(&dout,&invvar,&input_or_output,n1,n2,
  386. normalized_shape,NULL,epsilon,
  387. &grad_input,NULL,memory_efficient);
  388. return grad_input;
  389. }
  390. std::vector<at::Tensor> rms_norm_gradient_affine(
  391. at::Tensor dout,
  392. at::Tensor invvar,
  393. at::Tensor input_or_output,
  394. #ifdef VERSION_GE_1_1
  395. at::IntArrayRef normalized_shape,
  396. #else
  397. at::IntList normalized_shape,
  398. #endif
  399. at::Tensor gamma,
  400. double epsilon,
  401. bool memory_efficient) {
  402. CHECK_INPUT(dout);
  403. CHECK_INPUT(invvar);
  404. CHECK_INPUT(input_or_output);
  405. CHECK_INPUT(gamma);
  406. int n1,n2;
  407. check_args(input_or_output,normalized_shape,gamma,n1,n2);
  408. at::Tensor grad_input = at::empty_like(input_or_output);
  409. at::Tensor grad_gamma = at::empty_like(gamma);
  410. cuda_rms_norm_gradient(&dout,&invvar,&input_or_output,n1,n2,
  411. normalized_shape,&gamma,epsilon,
  412. &grad_input,&grad_gamma,memory_efficient);
  413. return {grad_input, grad_gamma};
  414. }
  415. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  416. m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");
  417. m.def("forward", &layer_norm, "LayerNorm forward (CUDA)");
  418. m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)");
  419. m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)");
  420. m.def("forward_affine_mixed_dtypes", &layer_norm_affine_mixed_dtypes, "LayerNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation");
  421. m.def("rms_forward_affine", &rms_norm_affine, "RMSNorm forward (CUDA)");
  422. m.def("rms_forward", &rms_norm, "RMSNorm forward (CUDA)");
  423. m.def("rms_backward_affine", &rms_norm_gradient_affine, "RMSNorm backward (CUDA)");
  424. m.def("rms_backward", &rms_norm_gradient, "RMSNorm backward (CUDA)");
  425. m.def("rms_forward_affine_mixed_dtypes", &rms_norm_affine_mixed_dtypes, "RMSNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation");
  426. }