3
0

layer_norm_cuda_kernel.cu 39 KB


  1. #include "ATen/ATen.h"
  2. #include "ATen/AccumulateType.h"
  3. #include "ATen/cuda/CUDAContext.h"
  4. #include "ATen/cuda/DeviceUtils.cuh"
  5. #include <cuda.h>
  6. #include <cuda_runtime.h>
  7. #include "type_shim.h"
  8. #include "static_switch.h"
  9. template<typename U> __device__
  10. void cuWelfordOnlineSum(
  11. const U curr,
  12. U& mu,
  13. U& sigma2,
  14. U& count)
  15. {
  16. count = count + U(1);
  17. U delta = curr - mu;
  18. U lmean = mu + delta / count;
  19. mu = lmean;
  20. U delta2 = curr - lmean;
  21. sigma2 = sigma2 + delta * delta2;
  22. }
  23. template<typename U> __device__
  24. void cuChanOnlineSum(
  25. const U muB,
  26. const U sigma2B,
  27. const U countB,
  28. U& mu,
  29. U& sigma2,
  30. U& count)
  31. {
  32. U delta = muB - mu;
  33. U nA = count;
  34. U nB = countB;
  35. count = count + countB;
  36. U nX = count;
  37. if (nX > U(0)) {
  38. nA = nA / nX;
  39. nB = nB / nX;
  40. mu = nA*mu + nB*muB;
  41. sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
  42. } else {
  43. mu = U(0);
  44. sigma2 = U(0);
  45. }
  46. }
  47. template<typename U> __device__
  48. void cuRMSOnlineSum(
  49. const U curr,
  50. U& sigma2)
  51. {
  52. sigma2 = sigma2 + curr * curr;
  53. }
  54. template<typename U> __device__
  55. void cuChanRMSOnlineSum(
  56. const U sigma2B,
  57. U& sigma2)
  58. {
  59. sigma2 = sigma2 + sigma2B;
  60. }
  61. template<typename T, typename U> __device__
  62. void cuWelfordMuSigma2(
  63. const T* __restrict__ vals,
  64. const int n1,
  65. const int n2,
  66. const int i1,
  67. U& mu,
  68. U& sigma2,
  69. U* buf,
  70. bool rms_only)
  71. {
  72. // Assumptions:
  73. // 1) blockDim.x == warpSize
  74. // 2) Tensor is contiguous
  75. // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
  76. //
  77. // compute variance and mean over n2
  78. U count = U(0);
  79. mu= U(0);
  80. sigma2 = U(0);
  81. if (i1 < n1) {
  82. // one warp normalizes one n1 index,
  83. // synchronization is implicit
  84. // initialize with standard Welford algorithm
  85. const int numx = blockDim.x * blockDim.y;
  86. const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
  87. const T* lvals = vals + i1*n2;
  88. int l = 4*thrx;
  89. for (; l+3 < n2; l+=4*numx) {
  90. for (int k = 0; k < 4; ++k) {
  91. U curr = static_cast<U>(lvals[l+k]);
  92. if (!rms_only) {
  93. cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
  94. } else {
  95. cuRMSOnlineSum<U>(curr, sigma2);
  96. }
  97. }
  98. }
  99. for (; l < n2; ++l) {
  100. U curr = static_cast<U>(lvals[l]);
  101. if (!rms_only) {
  102. cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
  103. } else {
  104. cuRMSOnlineSum<U>(curr, sigma2);
  105. }
  106. }
  107. // intra-warp reductions
  108. for (int l = 0; l <= 4; ++l) {
  109. int srcLaneB = (threadIdx.x+(1<<l))&31;
  110. U sigma2B = WARP_SHFL(sigma2, srcLaneB);
  111. if (!rms_only) {
  112. U muB = WARP_SHFL(mu, srcLaneB);
  113. U countB = WARP_SHFL(count, srcLaneB);
  114. cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
  115. } else {
  116. cuChanRMSOnlineSum<U>(sigma2B, sigma2);
  117. }
  118. }
  119. // threadIdx.x == 0 has correct values for each warp
  120. // inter-warp reductions
  121. if (blockDim.y > 1) {
  122. U* ubuf = (U*)buf;
  123. U* ibuf = (U*)(ubuf + blockDim.y);
  124. for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
  125. // upper half of warps write to shared
  126. if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
  127. const int wrt_y = threadIdx.y - offset;
  128. if (!rms_only) {
  129. ubuf[2*wrt_y] = mu;
  130. ibuf[wrt_y] = count;
  131. }
  132. ubuf[2*wrt_y+1] = sigma2;
  133. }
  134. __syncthreads();
  135. // lower half merges
  136. if (threadIdx.x == 0 && threadIdx.y < offset) {
  137. U sigma2B = ubuf[2*threadIdx.y+1];
  138. if (!rms_only) {
  139. U muB = ubuf[2*threadIdx.y];
  140. U countB = ibuf[threadIdx.y];
  141. cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
  142. } else {
  143. cuChanRMSOnlineSum<U>(sigma2B,sigma2);
  144. }
  145. }
  146. __syncthreads();
  147. }
  148. // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
  149. if (threadIdx.x == 0 && threadIdx.y == 0) {
  150. if (!rms_only) {
  151. ubuf[0] = mu;
  152. }
  153. ubuf[1] = sigma2;
  154. }
  155. __syncthreads();
  156. if (!rms_only) {
  157. mu = ubuf[0];
  158. }
  159. sigma2 = ubuf[1]/U(n2);
  160. // don't care about final value of count, we know count == n2
  161. } else {
  162. if (!rms_only) {
  163. mu = WARP_SHFL(mu, 0);
  164. }
  165. sigma2 = WARP_SHFL(sigma2/U(n2), 0);
  166. }
  167. }
  168. }
  169. template<> __device__
  170. void cuWelfordMuSigma2(
  171. const at::Half* __restrict__ vals,
  172. const int n1,
  173. const int n2,
  174. const int i1,
  175. float& mu,
  176. float& sigma2,
  177. float* buf,
  178. bool rms_only)
  179. {
  180. // Assumptions:
  181. // 1) blockDim.x == warpSize
  182. // 2) Tensor is contiguous
  183. // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
  184. //
  185. // compute variance and mean over n2
  186. float count = 0.0f;
  187. mu= float(0);
  188. sigma2 = float(0);
  189. if (i1 < n1) {
  190. // one warp normalizes one n1 index,
  191. // synchronization is implicit
  192. // initialize with standard Welford algorithm
  193. const int numx = blockDim.x * blockDim.y;
  194. const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
  195. const at::Half* lvals = vals + i1*n2;
  196. int l = 8*thrx;
  197. if ((((size_t)lvals)&3) != 0) {
  198. // 16 bit alignment
  199. // first thread consumes first point
  200. if (thrx == 0) {
  201. float curr = static_cast<float>(lvals[0]);
  202. if (!rms_only) {
  203. cuWelfordOnlineSum(curr,mu,sigma2,count);
  204. } else {
  205. cuRMSOnlineSum(curr, sigma2);
  206. }
  207. }
  208. ++l;
  209. }
  210. // at this point, lvals[l] are 32 bit aligned for all threads.
  211. for (; l+7 < n2; l+=8*numx) {
  212. for (int k = 0; k < 8; k+=2) {
  213. float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
  214. if (!rms_only) {
  215. cuWelfordOnlineSum(curr.x,mu,sigma2,count);
  216. cuWelfordOnlineSum(curr.y,mu,sigma2,count);
  217. } else {
  218. cuRMSOnlineSum(curr.x, sigma2);
  219. cuRMSOnlineSum(curr.y, sigma2);
  220. }
  221. }
  222. }
  223. for (; l < n2; ++l) {
  224. float curr = static_cast<float>(lvals[l]);
  225. if (!rms_only) {
  226. cuWelfordOnlineSum(curr,mu,sigma2,count);
  227. } else {
  228. cuRMSOnlineSum(curr, sigma2);
  229. }
  230. }
  231. // intra-warp reductions
  232. for (int l = 0; l <= 4; ++l) {
  233. int srcLaneB = (threadIdx.x+(1<<l))&31;
  234. float sigma2B = WARP_SHFL(sigma2, srcLaneB);
  235. if (!rms_only) {
  236. float muB = WARP_SHFL(mu, srcLaneB);
  237. float countB = WARP_SHFL(count, srcLaneB);
  238. cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
  239. } else {
  240. cuChanRMSOnlineSum(sigma2B, sigma2);
  241. }
  242. }
  243. // threadIdx.x == 0 has correct values for each warp
  244. // inter-warp reductions
  245. if (blockDim.y > 1) {
  246. float* ubuf = (float*)buf;
  247. float* ibuf = (float*)(ubuf + blockDim.y);
  248. for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
  249. // upper half of warps write to shared
  250. if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
  251. const int wrt_y = threadIdx.y - offset;
  252. ubuf[2*wrt_y+1] = sigma2;
  253. if (!rms_only) {
  254. ubuf[2*wrt_y] = mu;
  255. ibuf[wrt_y] = count;
  256. }
  257. }
  258. __syncthreads();
  259. // lower half merges
  260. if (threadIdx.x == 0 && threadIdx.y < offset) {
  261. float sigma2B = ubuf[2*threadIdx.y+1];
  262. if (!rms_only) {
  263. float muB = ubuf[2*threadIdx.y];
  264. float countB = ibuf[threadIdx.y];
  265. cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
  266. } else {
  267. cuChanRMSOnlineSum(sigma2B, sigma2);
  268. }
  269. }
  270. __syncthreads();
  271. }
  272. // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
  273. if (threadIdx.x == 0 && threadIdx.y == 0) {
  274. if (!rms_only) {
  275. ubuf[0] = mu;
  276. }
  277. ubuf[1] = sigma2;
  278. }
  279. __syncthreads();
  280. if (!rms_only) {
  281. mu = ubuf[0];
  282. }
  283. sigma2 = ubuf[1]/float(n2);
  284. // don't care about final value of count, we know count == n2
  285. } else {
  286. if (!rms_only) {
  287. mu = WARP_SHFL(mu, 0);
  288. }
  289. sigma2 = WARP_SHFL(sigma2/float(n2), 0);
  290. }
  291. }
  292. }
  293. template<typename U> U rsqrt(U v) {
  294. return U(1) / sqrt(v);
  295. }
  296. template<> float rsqrt(float v) {
  297. return rsqrtf(v);
  298. }
  299. template<> double rsqrt(double v) {
  300. return rsqrt(v);
  301. }
  302. namespace {
  303. // This is the un-specialized struct. Note that we prevent instantiation of this
  304. // struct by putting an undefined symbol in the function body so it won't compile.
  305. // template <typename T>
  306. // struct SharedMemory
  307. // {
  308. // // Ensure that we won't compile any un-specialized types
  309. // __device__ T *getPointer()
  310. // {
  311. // extern __device__ void error(void);
  312. // error();
  313. // return NULL;
  314. // }
  315. // };
  316. // https://github.com/NVIDIA/apex/issues/246
  317. template <typename T>
  318. struct SharedMemory;
  319. template <>
  320. struct SharedMemory <float>
  321. {
  322. __device__ float *getPointer()
  323. {
  324. extern __shared__ float s_float[];
  325. return s_float;
  326. }
  327. };
  328. template <>
  329. struct SharedMemory <double>
  330. {
  331. __device__ double *getPointer()
  332. {
  333. extern __shared__ double s_double[];
  334. return s_double;
  335. }
  336. };
  337. }
  338. template<typename T, typename U, typename V> __device__
  339. void cuApplyLayerNorm_(
  340. V* __restrict__ output_vals,
  341. U* __restrict__ mean,
  342. U* __restrict__ invvar,
  343. const T* __restrict__ vals,
  344. const int n1,
  345. const int n2,
  346. const U epsilon,
  347. const V* __restrict__ gamma,
  348. const V* __restrict__ beta,
  349. bool rms_only
  350. )
  351. {
  352. // Assumptions:
  353. // 1) blockDim.x == warpSize
  354. // 2) Tensors are contiguous
  355. //
  356. for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
  357. SharedMemory<U> shared;
  358. U* buf = shared.getPointer();
  359. U mu,sigma2;
  360. cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf,rms_only);
  361. const T* lvals = vals + i1*n2;
  362. V* ovals = output_vals + i1*n2;
  363. U c_invvar = rsqrt(sigma2 + epsilon);
  364. const int numx = blockDim.x * blockDim.y;
  365. const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
  366. if (gamma != NULL && (beta != NULL || rms_only)) {
  367. for (int i = thrx; i < n2; i+=numx) {
  368. U curr = static_cast<U>(lvals[i]);
  369. if (!rms_only) {
  370. ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
  371. } else {
  372. ovals[i] = gamma[i] * static_cast<V>(c_invvar * curr);
  373. }
  374. }
  375. } else {
  376. for (int i = thrx; i < n2; i+=numx) {
  377. U curr = static_cast<U>(lvals[i]);
  378. if (!rms_only) {
  379. ovals[i] = static_cast<V>(c_invvar * (curr - mu));
  380. } else {
  381. ovals[i] = static_cast<V>(c_invvar * curr);
  382. }
  383. }
  384. }
  385. if (threadIdx.x == 0 && threadIdx.y == 0) {
  386. if (!rms_only) {
  387. mean[i1] = mu;
  388. }
  389. invvar[i1] = c_invvar;
  390. }
  391. __syncthreads();
  392. }
  393. }
  394. template<typename T, typename U, typename V=T> __global__
  395. void cuApplyLayerNorm(
  396. V* __restrict__ output_vals,
  397. U* __restrict__ mean,
  398. U* __restrict__ invvar,
  399. const T* __restrict__ vals,
  400. const int n1,
  401. const int n2,
  402. const U epsilon,
  403. const V* __restrict__ gamma,
  404. const V* __restrict__ beta
  405. )
  406. {
  407. cuApplyLayerNorm_<T, U, V>(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, false);
  408. }
  409. template<typename T, typename U, typename V=T> __global__
  410. void cuApplyRMSNorm(
  411. V* __restrict__ output_vals,
  412. U* __restrict__ invvar,
  413. const T* __restrict__ vals,
  414. const int n1,
  415. const int n2,
  416. const U epsilon,
  417. const V* __restrict__ gamma)
  418. {
  419. cuApplyLayerNorm_<T, U, V>(output_vals, NULL, invvar, vals, n1, n2, epsilon, gamma, NULL, true);
  420. }
  421. template<typename V> __device__
  422. V clamp_by_magnitude(V curr_gamma, double eps)
  423. {
  424. const V kMinGamma = V(eps);
  425. if (curr_gamma >= 0) {
  426. if (curr_gamma < kMinGamma) {
  427. return kMinGamma;
  428. } else {
  429. return curr_gamma;
  430. }
  431. } else {
  432. if (curr_gamma > -kMinGamma) {
  433. return -kMinGamma;
  434. } else {
  435. return curr_gamma;
  436. }
  437. }
  438. }
  439. template<typename T, typename U, typename V, bool MemoryEfficient> __device__
  440. void cuLoadWriteStridedInputs(
  441. const int i1_block,
  442. const int thr_load_row_off,
  443. const int thr_load_col_off,
  444. const int i2_off,
  445. const int row_stride,
  446. U* warp_buf1,
  447. U* warp_buf2,
  448. const T* input_or_output,
  449. const V* dout,
  450. const int i1_end,
  451. const int n2,
  452. const U* __restrict__ mean,
  453. const U* __restrict__ invvar,
  454. const V* __restrict__ gamma,
  455. const V* __restrict__ beta,
  456. const double eps,
  457. bool rms_only
  458. )
  459. {
  460. int i1 = i1_block+thr_load_row_off;
  461. if (i1 < i1_end) {
  462. for (int k = 0; k < blockDim.y; ++k) {
  463. int i2 = i2_off + k;
  464. int load_idx = i1*n2+i2;
  465. int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
  466. if (i2<n2) {
  467. U c_h = static_cast<U>(input_or_output[load_idx]);
  468. U curr_dout = static_cast<U>(dout[load_idx]);
  469. if (!rms_only) {
  470. warp_buf1[write_idx] = curr_dout;
  471. if (MemoryEfficient) {
  472. U curr_beta = static_cast<U>(beta[i2]);
  473. warp_buf2[write_idx] = curr_dout * (c_h - curr_beta) / static_cast<U>(clamp_by_magnitude(gamma[i2], eps));
  474. } else {
  475. warp_buf2[write_idx] = curr_dout * (c_h - mean[i1]) * invvar[i1];
  476. }
  477. } else {
  478. if (MemoryEfficient) {
  479. warp_buf2[write_idx] = curr_dout * (c_h) / static_cast<U>(clamp_by_magnitude(gamma[i2], eps));
  480. } else {
  481. warp_buf2[write_idx] = curr_dout * (c_h) * invvar[i1];
  482. }
  483. }
  484. } else {
  485. if (!rms_only) {
  486. warp_buf1[write_idx] = U(0);
  487. }
  488. warp_buf2[write_idx] = U(0);
  489. }
  490. }
  491. } else {
  492. for (int k = 0; k < blockDim.y; ++k) {
  493. int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
  494. if (!rms_only) {
  495. warp_buf1[write_idx] = U(0);
  496. }
  497. warp_buf2[write_idx] = U(0);
  498. }
  499. }
  500. }
  501. template<typename T, typename U, typename V, bool MemoryEfficient> __device__
  502. void cuLoadAddStridedInputs(
  503. const int i1_block,
  504. const int thr_load_row_off,
  505. const int thr_load_col_off,
  506. const int i2_off,
  507. const int row_stride,
  508. U* warp_buf1,
  509. U* warp_buf2,
  510. const T* input_or_output,
  511. const V* dout,
  512. const int i1_end,
  513. const int n2,
  514. const U* __restrict__ mean,
  515. const U* __restrict__ invvar,
  516. const V* __restrict__ gamma,
  517. const V* __restrict__ beta,
  518. const double eps,
  519. bool rms_only
  520. )
  521. {
  522. int i1 = i1_block+thr_load_row_off;
  523. if (i1 < i1_end) {
  524. for (int k = 0; k < blockDim.y; ++k) {
  525. int i2 = i2_off + k;
  526. int load_idx = i1*n2+i2;
  527. int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
  528. if (i2<n2) {
  529. U c_h = static_cast<U>(input_or_output[load_idx]);
  530. U curr_dout = static_cast<U>(dout[load_idx]);
  531. if (!rms_only) {
  532. U curr_beta = static_cast<U>(beta[i2]);
  533. warp_buf1[write_idx] += curr_dout;
  534. if (MemoryEfficient) {
  535. warp_buf2[write_idx] += curr_dout * (c_h - curr_beta) / static_cast<U>(clamp_by_magnitude(gamma[i2], eps));
  536. } else {
  537. warp_buf2[write_idx] += curr_dout * (c_h - mean[i1]) * invvar[i1];
  538. }
  539. } else {
  540. if (MemoryEfficient) {
  541. warp_buf2[write_idx] += curr_dout * (c_h) / static_cast<U>(clamp_by_magnitude(gamma[i2], eps));
  542. } else {
  543. warp_buf2[write_idx] += curr_dout * (c_h) * invvar[i1];
  544. }
  545. }
  546. }
  547. }
  548. }
  549. }
  550. template<typename T, typename U, typename V, bool MemoryEfficient> __global__
  551. void cuComputePartGradGammaBeta(
  552. const V* __restrict__ dout,
  553. const T* __restrict__ input_or_output,
  554. const int n1,
  555. const int n2,
  556. const U* __restrict__ mean,
  557. const U* __restrict__ invvar,
  558. U epsilon,
  559. const V* __restrict__ gamma,
  560. const V* __restrict__ beta,
  561. U* part_grad_gamma,
  562. U* part_grad_beta,
  563. const double eps,
  564. bool rms_only)
  565. {
  566. const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
  567. const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
  568. const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
  569. const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
  570. const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
  571. const int row_stride = blockDim.x+1;
  572. const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
  573. const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
  574. const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
  575. SharedMemory<U> shared;
  576. U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
  577. U* warp_buf1 = (U*)buf;
  578. U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
  579. // compute partial sums from strided inputs
  580. // do this to increase number of loads in flight
  581. cuLoadWriteStridedInputs<T, U, V, MemoryEfficient>(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input_or_output,dout,i1_end,n2,mean,invvar,gamma,beta,eps, rms_only);
  582. for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) {
  583. cuLoadAddStridedInputs<T, U, V, MemoryEfficient>(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input_or_output,dout,i1_end,n2,mean,invvar,gamma,beta,eps, rms_only);
  584. }
  585. __syncthreads();
  586. // inter-warp reductions
  587. // sum within each warp
  588. U acc1 = U(0);
  589. U acc2 = U(0);
  590. for (int k = 0; k < blockDim.y; ++k) {
  591. int row1 = threadIdx.y + k*blockDim.y;
  592. int idx1 = row1*row_stride + threadIdx.x;
  593. if (!rms_only) {
  594. acc1 += warp_buf1[idx1];
  595. }
  596. acc2 += warp_buf2[idx1];
  597. }
  598. if (!rms_only) {
  599. warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
  600. }
  601. warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
  602. __syncthreads();
  603. // sum all warps
  604. for (int offset = blockDim.y/2; offset > 1; offset /= 2) {
  605. if (threadIdx.y < offset) {
  606. int row1 = threadIdx.y;
  607. int row2 = threadIdx.y + offset;
  608. int idx1 = row1*row_stride + threadIdx.x;
  609. int idx2 = row2*row_stride + threadIdx.x;
  610. if (!rms_only) {
  611. warp_buf1[idx1] += warp_buf1[idx2];
  612. }
  613. warp_buf2[idx1] += warp_buf2[idx2];
  614. }
  615. __syncthreads();
  616. }
  617. int i2 = blockIdx.x * blockDim.x + threadIdx.x;
  618. if (threadIdx.y == 0 && i2 < n2) {
  619. int row1 = threadIdx.y;
  620. int row2 = threadIdx.y + 1;
  621. int idx1 = row1*row_stride + threadIdx.x;
  622. int idx2 = row2*row_stride + threadIdx.x;
  623. if (!rms_only) {
  624. part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
  625. }
  626. part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];
  627. }
  628. }
  629. template<typename U, typename V> __global__
  630. void cuComputeGradGammaBeta(
  631. const U* part_grad_gamma,
  632. const U* part_grad_beta,
  633. const int part_size,
  634. const int n1,
  635. const int n2,
  636. V* grad_gamma,
  637. V* grad_beta,
  638. bool rms_only)
  639. {
  640. // sum partial gradients for gamma and beta
  641. SharedMemory<U> shared;
  642. U* buf = shared.getPointer();
  643. int i2 = blockIdx.x * blockDim.x + threadIdx.x;
  644. if (i2 < n2) {
  645. // each warp does sequential reductions until reduced part_size is num_warps
  646. int num_warp_reductions = part_size / blockDim.y;
  647. U sum_gamma = U(0);
  648. U sum_beta = U(0);
  649. const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
  650. const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
  651. for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
  652. sum_gamma += part_grad_gamma_ptr[warp_offset*n2];
  653. if (!rms_only) {
  654. sum_beta += part_grad_beta_ptr[warp_offset*n2];
  655. }
  656. }
  657. // inter-warp reductions
  658. const int nbsize3 = blockDim.x * blockDim.y / 2;
  659. for (int offset = blockDim.y/2; offset >= 1; offset /= 2) {
  660. // top half write to shared memory
  661. if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
  662. const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
  663. buf[write_idx] = sum_gamma;
  664. if (!rms_only) {
  665. buf[write_idx+nbsize3] = sum_beta;
  666. }
  667. }
  668. __syncthreads();
  669. // bottom half sums
  670. if (threadIdx.y < offset) {
  671. const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
  672. sum_gamma += buf[read_idx];
  673. if (!rms_only) {
  674. sum_beta += buf[read_idx+nbsize3];
  675. }
  676. }
  677. __syncthreads();
  678. }
  679. // write out fully summed gradients
  680. if (threadIdx.y == 0) {
  681. grad_gamma[i2] = sum_gamma;
  682. if (!rms_only) {
  683. grad_beta[i2] = sum_beta;
  684. }
  685. }
  686. }
  687. }
  688. template<typename T, typename U, typename V, bool MemoryEfficient> __global__
  689. void cuComputeGradInput(
  690. const V* __restrict__ dout,
  691. const T* __restrict__ input_or_output,
  692. const int n1,
  693. const int n2,
  694. const U* __restrict__ mean,
  695. const U* __restrict__ invvar,
  696. U epsilon,
  697. const V* gamma,
  698. const V* beta,
  699. T* grad_input,
  700. const double eps,
  701. bool rms_only)
  702. {
  703. for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
  704. U sum_loss1 = U(0);
  705. U sum_loss2 = U(0);
  706. const T* k_h = input_or_output + i1*n2;
  707. const V* k_dout = dout + i1*n2;
  708. const U c_invvar = invvar[i1];
  709. const U c_mean = !MemoryEfficient ? mean[i1] : 0.;
  710. const int numx = blockDim.x * blockDim.y;
  711. const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
  712. if (gamma != NULL) {
  713. int l = 4*thrx;
  714. for (; l+3 < n2; l+=4*numx) {
  715. for (int k = 0; k < 4; ++k) {
  716. const U c_h = static_cast<U>(k_h[l+k]);
  717. const U c_loss = static_cast<U>(k_dout[l+k]);
  718. if (!rms_only) {
  719. sum_loss1 += c_loss * gamma[l+k];
  720. if (MemoryEfficient) {
  721. sum_loss2 += c_loss * (c_h - beta[l+k]);
  722. } else {
  723. sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar;
  724. }
  725. } else {
  726. if (MemoryEfficient) {
  727. sum_loss2 += c_loss * c_h;
  728. } else {
  729. sum_loss2 += c_loss * gamma[l+k] * (c_h) * c_invvar;
  730. }
  731. }
  732. }
  733. }
  734. for (; l < n2; ++l) {
  735. const U c_h = static_cast<U>(k_h[l]);
  736. const U c_loss = static_cast<U>(k_dout[l]);
  737. if (!rms_only) {
  738. sum_loss1 += c_loss * gamma[l];
  739. if (MemoryEfficient) {
  740. sum_loss2 += c_loss * (c_h - beta[l]);
  741. } else {
  742. sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
  743. }
  744. } else {
  745. if (MemoryEfficient) {
  746. sum_loss2 += c_loss * c_h;
  747. } else {
  748. sum_loss2 += c_loss * gamma[l] * (c_h) * c_invvar;
  749. }
  750. }
  751. }
  752. } else {
  753. int l = 4*thrx;
  754. for (; l+3 < n2; l+=4*numx) {
  755. for (int k = 0; k < 4; ++k) {
  756. const U c_h = static_cast<U>(k_h[l+k]);
  757. const U c_loss = static_cast<U>(k_dout[l+k]);
  758. if (!rms_only) {
  759. sum_loss1 += c_loss;
  760. if (MemoryEfficient) {
  761. sum_loss2 += c_loss * c_h;
  762. } else {
  763. sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
  764. }
  765. } else {
  766. if (MemoryEfficient) {
  767. sum_loss2 += c_loss * c_h;
  768. } else {
  769. sum_loss2 += c_loss * (c_h) * c_invvar;
  770. }
  771. }
  772. }
  773. }
  774. for (; l < n2; ++l) {
  775. const U c_h = static_cast<U>(k_h[l]);
  776. const U c_loss = static_cast<U>(k_dout[l]);
  777. if (!rms_only) {
  778. sum_loss1 += c_loss;
  779. if (MemoryEfficient) {
  780. sum_loss2 += c_loss * c_h;
  781. } else {
  782. sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
  783. }
  784. } else {
  785. if (MemoryEfficient) {
  786. sum_loss2 += c_loss * c_h;
  787. } else {
  788. sum_loss2 += c_loss * (c_h) * c_invvar;
  789. }
  790. }
  791. }
  792. }
  793. // intra-warp reductions
  794. for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
  795. if (!rms_only) {
  796. sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
  797. }
  798. sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
  799. }
  800. // inter-warp reductions
  801. if (blockDim.y > 1) {
  802. SharedMemory<U> shared;
  803. U* buf = shared.getPointer();
  804. for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
  805. // upper half of warps write to shared
  806. if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
  807. const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
  808. if (!rms_only) {
  809. buf[2*wrt_i] = sum_loss1;
  810. }
  811. buf[2*wrt_i+1] = sum_loss2;
  812. }
  813. __syncthreads();
  814. // lower half merges
  815. if (threadIdx.y < offset) {
  816. const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
  817. if (!rms_only) {
  818. sum_loss1 += buf[2*read_i];
  819. }
  820. sum_loss2 += buf[2*read_i+1];
  821. }
  822. __syncthreads();
  823. }
  824. if (threadIdx.y == 0) {
  825. if (!rms_only) {
  826. buf[2*threadIdx.x] = sum_loss1;
  827. }
  828. buf[2*threadIdx.x+1] = sum_loss2;
  829. }
  830. __syncthreads();
  831. if (threadIdx.y !=0) {
  832. if (!rms_only) {
  833. sum_loss1 = buf[2*threadIdx.x];
  834. }
  835. sum_loss2 = buf[2*threadIdx.x+1];
  836. }
  837. }
  838. // all threads now have the two sums over l
  839. U fH = (U)n2;
  840. U term1 = (U(1) / fH) * c_invvar;
  841. T* k_grad_input = grad_input + i1*n2;
  842. if (gamma != NULL) {
  843. for (int l = thrx; l < n2; l+=numx) {
  844. const U c_h = static_cast<U>(k_h[l]);
  845. const U c_loss = static_cast<U>(k_dout[l]);
  846. const U k_gamma = static_cast<U>(clamp_by_magnitude(gamma[l], eps));
  847. U f_grad_input = fH * c_loss * k_gamma;
  848. if (!rms_only) {
  849. const U k_beta = beta[l];
  850. f_grad_input -= sum_loss1;
  851. if (MemoryEfficient) {
  852. f_grad_input -= (c_h - k_beta) / k_gamma * sum_loss2;
  853. } else {
  854. f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
  855. }
  856. } else {
  857. if (MemoryEfficient) {
  858. f_grad_input -= c_h / k_gamma * sum_loss2;
  859. } else {
  860. f_grad_input -= c_h * c_invvar * sum_loss2;
  861. }
  862. }
  863. f_grad_input *= term1;
  864. k_grad_input[l] = static_cast<T>(f_grad_input);
  865. }
  866. } else {
  867. for (int l = thrx; l < n2; l+=numx) {
  868. const U c_h = static_cast<U>(k_h[l]);
  869. const U c_loss = static_cast<U>(k_dout[l]);
  870. U f_grad_input = fH * c_loss;
  871. if (!rms_only) {
  872. f_grad_input -= sum_loss1;
  873. if (MemoryEfficient) {
  874. f_grad_input -= c_h * sum_loss2;
  875. } else {
  876. f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
  877. }
  878. } else {
  879. if (MemoryEfficient) {
  880. f_grad_input -= c_h * sum_loss2;
  881. } else {
  882. f_grad_input -= c_h * c_invvar * sum_loss2;
  883. }
  884. }
  885. f_grad_input *= term1;
  886. k_grad_input[l] = static_cast<T>(f_grad_input);
  887. }
  888. }
  889. // prevent race where buf is written again before reads are done
  890. __syncthreads();
  891. }
  892. }
  893. template<typename T, typename U, typename V=T>
  894. void HostApplyLayerNorm(
  895. V* output,
  896. U* mean,
  897. U* invvar,
  898. const T* input,
  899. int n1,
  900. int n2,
  901. double epsilon,
  902. const V* gamma,
  903. const V* beta
  904. )
  905. {
  906. auto stream = at::cuda::getCurrentCUDAStream().stream();
  907. const dim3 threads(32,4,1);
  908. const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
  909. const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
  910. int nshared =
  911. threads.y > 1 ?
  912. threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
  913. 0;
  914. cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
  915. output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta);
  916. }
  917. template<typename T, typename U, typename V=T>
  918. void HostApplyRMSNorm(
  919. V* output,
  920. U* invvar,
  921. const T* input,
  922. int n1,
  923. int n2,
  924. double epsilon,
  925. const V* gamma)
  926. {
  927. auto stream = at::cuda::getCurrentCUDAStream().stream();
  928. const dim3 threads(32,4,1);
  929. const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
  930. const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
  931. int nshared =
  932. threads.y > 1 ?
  933. threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
  934. 0;
  935. cuApplyRMSNorm<<<blocks, threads, nshared, stream>>>(
  936. output, invvar, input, n1, n2, U(epsilon), gamma);
  937. }
  938. void cuda_layer_norm(
  939. at::Tensor* output,
  940. at::Tensor* mean,
  941. at::Tensor* invvar,
  942. at::Tensor* input,
  943. int n1,
  944. int n2,
  945. #ifdef VERSION_GE_1_1
  946. at::IntArrayRef normalized_shape,
  947. #else
  948. at::IntList normalized_shape,
  949. #endif
  950. at::Tensor* gamma,
  951. at::Tensor* beta,
  952. double epsilon)
  953. {
  954. using namespace at;
  955. DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
  956. input->scalar_type(), output->scalar_type(), "layer_norm_cuda_kernel",
  957. using accscalar_t = at::acc_type<scalar_t_in, true>;
  958. HostApplyLayerNorm<scalar_t_in, accscalar_t, scalar_t_out>(
  959. output->DATA_PTR<scalar_t_out>(),
  960. mean->DATA_PTR<accscalar_t>(),
  961. invvar->DATA_PTR<accscalar_t>(),
  962. input->DATA_PTR<scalar_t_in>(),
  963. n1,n2,
  964. epsilon,
  965. gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
  966. beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);
  967. )
  968. }
  969. void cuda_rms_norm(
  970. at::Tensor* output,
  971. at::Tensor* invvar,
  972. at::Tensor* input,
  973. int n1,
  974. int n2,
  975. #ifdef VERSION_GE_1_1
  976. at::IntArrayRef normalized_shape,
  977. #else
  978. at::IntList normalized_shape,
  979. #endif
  980. at::Tensor* gamma,
  981. double epsilon)
  982. {
  983. using namespace at;
  984. DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
  985. input->scalar_type(), output->scalar_type(), "rms_norm_cuda_kernel",
  986. using accscalar_t = at::acc_type<scalar_t_in, true>;
  987. HostApplyRMSNorm<scalar_t_in, accscalar_t, scalar_t_out>(
  988. output->DATA_PTR<scalar_t_out>(),
  989. invvar->DATA_PTR<accscalar_t>(),
  990. input->DATA_PTR<scalar_t_in>(),
  991. n1,n2,
  992. epsilon,
  993. gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL);
  994. )
  995. }
  996. template<typename T, typename U=float, typename V=T>
  997. void HostLayerNormGradient(
  998. const V* dout,
  999. const U* mean,
  1000. const U* invvar,
  1001. at::Tensor* input_or_output,
  1002. int n1,
  1003. int n2,
  1004. const V* gamma,
  1005. const V* beta,
  1006. double epsilon,
  1007. T* grad_input,
  1008. V* grad_gamma,
  1009. V* grad_beta,
  1010. bool memory_efficient
  1011. )
  1012. {
  1013. auto stream = at::cuda::getCurrentCUDAStream().stream();
  1014. if (gamma != NULL && beta != NULL) {
  1015. // compute grad_gamma(j) and grad_beta(j)
  1016. const int part_size = 16;
  1017. const dim3 threads2(32,4,1);
  1018. const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
  1019. const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
  1020. const int nshared2_b = threads2.x * threads2.y * sizeof(U);
  1021. const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
  1022. // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that
  1023. // the `cuda_layer_norm_gradient` doesn't support double.
  1024. const auto part_grad_dtype =
  1025. (input_or_output->scalar_type() == at::ScalarType::Half || input_or_output->scalar_type() == at::ScalarType::BFloat16) ?
  1026. at::ScalarType::Float :
  1027. input_or_output->scalar_type();
  1028. at::Tensor part_grad_gamma = at::empty({part_size,n2}, input_or_output->options().dtype(part_grad_dtype));
  1029. at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
  1030. BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{
  1031. auto kernel = &cuComputePartGradGammaBeta<T, U, V, MemoryEfficient>;
  1032. kernel<<<blocks2, threads2, nshared2, stream>>>(
  1033. dout,
  1034. input_or_output->DATA_PTR<T>(),
  1035. n1,n2,
  1036. mean,
  1037. invvar,
  1038. U(epsilon),
  1039. gamma,
  1040. beta,
  1041. part_grad_gamma.DATA_PTR<U>(),
  1042. part_grad_beta.DATA_PTR<U>(),
  1043. epsilon,
  1044. false);
  1045. });
  1046. const dim3 threads3(32,8,1);
  1047. const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
  1048. const int nshared3 = threads3.x * threads3.y * sizeof(U);
  1049. cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
  1050. part_grad_gamma.DATA_PTR<U>(),
  1051. part_grad_beta.DATA_PTR<U>(),
  1052. part_size,
  1053. n1,n2,
  1054. grad_gamma,
  1055. grad_beta,
  1056. false);
  1057. }
  1058. // compute grad_input
  1059. const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
  1060. const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
  1061. const dim3 threads1(32,4,1);
  1062. int nshared =
  1063. threads1.y > 1 ?
  1064. threads1.y*threads1.x*sizeof(U) :
  1065. 0;
  1066. BOOL_SWITCH(memory_efficient, MemoryEfficient, [&] {
  1067. auto kernel = cuComputeGradInput<T, U, V, MemoryEfficient>;
  1068. kernel<<<blocks1, threads1, nshared, stream>>>(
  1069. dout,
  1070. input_or_output->DATA_PTR<T>(),
  1071. n1,n2,
  1072. mean,
  1073. invvar,
  1074. U(epsilon),
  1075. gamma,
  1076. beta,
  1077. grad_input,
  1078. epsilon,
  1079. false);
  1080. });
  1081. }
  1082. template<typename T, typename U=float, typename V=T>
  1083. void HostRMSNormGradient(
  1084. const V* dout,
  1085. const U* invvar,
  1086. at::Tensor* input_or_output,
  1087. int n1,
  1088. int n2,
  1089. const V* gamma,
  1090. double epsilon,
  1091. T* grad_input,
  1092. V* grad_gamma,
  1093. bool memory_efficient)
  1094. {
  1095. auto stream = at::cuda::getCurrentCUDAStream().stream();
  1096. if (gamma != NULL) {
  1097. const int part_size = 16;
  1098. const dim3 threads2(32,4,1);
  1099. const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
  1100. const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
  1101. const int nshared2_b = threads2.x * threads2.y * sizeof(U);
  1102. const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
  1103. // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that
  1104. // the `cuda_layer_norm_gradient` doesn't support double.
  1105. const auto part_grad_dtype =
  1106. (input_or_output->scalar_type() == at::ScalarType::Half || input_or_output->scalar_type() == at::ScalarType::BFloat16) ?
  1107. at::ScalarType::Float :
  1108. input_or_output->scalar_type();
  1109. at::Tensor part_grad_gamma = at::empty({part_size,n2}, input_or_output->options().dtype(part_grad_dtype));
  1110. BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{
  1111. auto kernel = &cuComputePartGradGammaBeta<T, U, V, MemoryEfficient>;
  1112. kernel<<<blocks2, threads2, nshared2, stream>>>(
  1113. dout,
  1114. input_or_output->DATA_PTR<T>(),
  1115. n1,n2,
  1116. invvar, /* unused */
  1117. invvar,
  1118. U(epsilon),
  1119. gamma,
  1120. gamma, /* unused */
  1121. part_grad_gamma.DATA_PTR<U>(),
  1122. part_grad_gamma.DATA_PTR<U>(), /* unused */
  1123. epsilon,
  1124. true);
  1125. });
  1126. const dim3 threads3(32,8,1);
  1127. const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
  1128. const int nshared3 = threads3.x * threads3.y * sizeof(U);
  1129. cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
  1130. part_grad_gamma.DATA_PTR<U>(),
  1131. part_grad_gamma.DATA_PTR<U>(), /* unused */
  1132. part_size,
  1133. n1,n2,
  1134. grad_gamma,
  1135. grad_gamma, /* unused */
  1136. true);
  1137. }
  1138. // compute grad_input
  1139. const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
  1140. const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
  1141. const dim3 threads1(32,4,1);
  1142. int nshared =
  1143. threads1.y > 1 ?
  1144. threads1.y*threads1.x*sizeof(U) :
  1145. 0;
  1146. BOOL_SWITCH(memory_efficient, MemoryEfficient, [&] {
  1147. auto kernel = cuComputeGradInput<T, U, V, MemoryEfficient>;
  1148. kernel<<<blocks1, threads1, nshared, stream>>>(
  1149. dout,
  1150. input_or_output->DATA_PTR<T>(),
  1151. n1,n2,
  1152. invvar, /* unused */
  1153. invvar,
  1154. U(epsilon),
  1155. gamma,
  1156. gamma, /* unused */
  1157. grad_input,
  1158. epsilon,
  1159. true);
  1160. });
  1161. }
  1162. void cuda_layer_norm_gradient(
  1163. at::Tensor* dout,
  1164. at::Tensor* mean,
  1165. at::Tensor* invvar,
  1166. at::Tensor* input_or_output,
  1167. int n1,
  1168. int n2,
  1169. #ifdef VERSION_GE_1_1
  1170. at::IntArrayRef normalized_shape,
  1171. #else
  1172. at::IntList normalized_shape,
  1173. #endif
  1174. at::Tensor* gamma,
  1175. at::Tensor* beta,
  1176. double epsilon,
  1177. at::Tensor* grad_input,
  1178. at::Tensor* grad_gamma,
  1179. at::Tensor* grad_beta,
  1180. bool memory_efficient)
  1181. {
  1182. using namespace at;
  1183. // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16
  1184. DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
  1185. input_or_output->scalar_type(), gamma == NULL ? input_or_output->scalar_type() : gamma->scalar_type(), "cuComputeGradInput",
  1186. using accscalar_t = at::acc_type<scalar_t_in, true>;
  1187. HostLayerNormGradient(
  1188. dout->DATA_PTR<scalar_t_out>(),
  1189. mean != NULL ? mean->DATA_PTR<accscalar_t>() : NULL,
  1190. invvar->DATA_PTR<accscalar_t>(),
  1191. input_or_output,
  1192. n1,n2,
  1193. // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
  1194. // if gamma Tensor is NULL on input.
  1195. gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
  1196. gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
  1197. epsilon,
  1198. grad_input->DATA_PTR<scalar_t_in>(),
  1199. gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
  1200. gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL,
  1201. memory_efficient);
  1202. )
  1203. }
  1204. void cuda_rms_norm_gradient(
  1205. at::Tensor* dout,
  1206. at::Tensor* invvar,
  1207. at::Tensor* input_or_output,
  1208. int n1,
  1209. int n2,
  1210. #ifdef VERSION_GE_1_1
  1211. at::IntArrayRef normalized_shape,
  1212. #else
  1213. at::IntList normalized_shape,
  1214. #endif
  1215. at::Tensor* gamma,
  1216. double epsilon,
  1217. at::Tensor* grad_input,
  1218. at::Tensor* grad_gamma,
  1219. bool memory_efficient)
  1220. {
  1221. using namespace at;
  1222. // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16
  1223. // DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
  1224. DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
  1225. input_or_output->scalar_type(), gamma == NULL ? input_or_output->scalar_type() : gamma->scalar_type(), "cuComputeGradInputRMS",
  1226. using accscalar_t = at::acc_type<scalar_t_in, true>;
  1227. HostRMSNormGradient(
  1228. dout->DATA_PTR<scalar_t_out>(),
  1229. invvar->DATA_PTR<accscalar_t>(),
  1230. input_or_output,
  1231. n1,n2,
  1232. // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
  1233. // if gamma Tensor is NULL on input.
  1234. gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
  1235. epsilon,
  1236. grad_input->DATA_PTR<scalar_t_in>(),
  1237. gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
  1238. memory_efficient);
  1239. )
  1240. }