3
0

welford.cu 53 KB


  1. #include <iostream>
  2. #include <ATen/ATen.h>
  3. #include <ATen/AccumulateType.h>
  4. #include <ATen/cuda/CUDAContext.h>
  5. #include <cuda.h>
  6. #include <cuda_runtime.h>
  7. #include <vector>
  8. #include "type_shim.h"
  9. #include "compat.h"
  10. __device__ __forceinline__ int lastpow2(int n)
  11. {
  12. int out = 1 << (31 - __clz(n));
  13. if(n == out)
  14. out >>= 1;
  15. return out;
  16. }
  17. __host__ __forceinline__ int h_next_pow2(unsigned int n) {
  18. n--;
  19. n |= (n >> 1);
  20. n |= (n >> 2);
  21. n |= (n >> 4);
  22. n |= (n >> 8);
  23. n |= (n >> 16);
  24. return ++n;
  25. }
  26. __host__ __forceinline__ int h_last_pow2(unsigned int n) {
  27. n |= (n >> 1);
  28. n |= (n >> 2);
  29. n |= (n >> 4);
  30. n |= (n >> 8);
  31. n |= (n >> 16);
  32. return n - (n >> 1);
  33. }
  34. #define WARP_SIZE 32
  35. template<typename T>
  36. __device__ __forceinline__ T warp_reduce_sum(T val)
  37. {
  38. #pragma unroll
  39. for(int i = WARP_SIZE/2; i > 0; i >>= 1)
  40. val = val + __shfl_down_sync(0xffffffff, val, i);
  41. return val;
  42. }
  43. template<typename T>
  44. __device__ __forceinline__ T reduce_block(T *x, T val)
  45. {
  46. int tid = threadIdx.y*blockDim.x + threadIdx.x;
  47. int blockSize = blockDim.x * blockDim.y;
  48. if (blockSize > 32) {
  49. val = warp_reduce_sum(val);
  50. if (tid % WARP_SIZE == 0)
  51. x[tid/WARP_SIZE] = val;
  52. __syncthreads();
  53. val = (tid < blockSize / WARP_SIZE? x[tid%WARP_SIZE] : T(0));
  54. }
  55. if(tid/WARP_SIZE==0) val = warp_reduce_sum(val);
  56. return val;
  57. }
  58. #define ELEMENTS_PER_ITER 4 // enables concurrency within each thread to hide latency
  59. #define ELEMENTS_PER_THREAD 16
  60. #define OPTIMAL_TILE_W 32
  61. #define MAX_H_BLOCK 128
  62. #define MAX_BLOCK_SIZE 512
  63. __host__ int div_ru(int x, int y) {
  64. return h_last_pow2(1 + (x-1)/y);
  65. }
  66. __host__ void flexible_launch_configs(
  67. const int reduction,
  68. const int stride,
  69. dim3 &block,
  70. dim3 &grid,
  71. const bool coop_flag = false) {
  72. int block_x = std::min(h_last_pow2(stride), OPTIMAL_TILE_W);
  73. int block_y = std::min(h_last_pow2(div_ru(reduction , ELEMENTS_PER_THREAD)),
  74. MAX_BLOCK_SIZE / block_x);
  75. if (block_x * block_y != MAX_BLOCK_SIZE) {
  76. block_x = std::min(h_last_pow2(stride), MAX_BLOCK_SIZE / block_y);
  77. }
  78. int grid_x = div_ru(stride, block_x);
  79. int grid_y = std::min(div_ru(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK);
  80. if (coop_flag) {
  81. // it's not worth having a grid reduction if the reduction dimension is not big enough
  82. grid_y = grid_y < 8 ? 1 : grid_y;
  83. }
  84. block.x = block_x;
  85. block.y = block_y;
  86. block.z = 1;
  87. grid.x = grid_x;
  88. grid.y = grid_y;
  89. grid.z = 1;
  90. }
  91. template<typename T, typename C>
  92. __device__ __forceinline__ void welford_merge_element(C& count,
  93. T& mean,
  94. T& m2n,
  95. const C& num_new,
  96. const T& mean_new,
  97. const T& m2n_new) {
  98. T factor = T(1.0) / max(1, (count + num_new));
  99. T delta0 = mean - mean_new;
  100. mean = (mean_new * num_new + mean * count) * factor;
  101. m2n += m2n_new + delta0 * delta0 * num_new * count * factor;
  102. count += num_new;
  103. }
  104. template<typename T>
  105. __device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num)
  106. {
  107. #pragma unroll
  108. for(int i = WARP_SIZE/2; i > 0; i >>= 1) {
  109. auto num_new = __shfl_down_sync(0xffffffff, num, i);
  110. auto mean_new = __shfl_down_sync(0xffffffff, mean, i);
  111. auto m2n_new = __shfl_down_sync(0xffffffff, m2n, i);
  112. welford_merge_element(num, mean, m2n, num_new, mean_new, m2n_new);
  113. }
  114. }
  115. template <typename T>
  116. __device__ void welford_reduce_mean_m2n(
  117. T* __restrict__ x,
  118. int* __restrict__ count,
  119. T &mean,
  120. T &m2n,
  121. int &num,
  122. int block_size,
  123. int thread_id)
  124. {
  125. int lane = thread_id % WARP_SIZE;
  126. int wid = thread_id / WARP_SIZE;
  127. if (block_size > 32) {
  128. warp_reduce_mean_m2n(mean, m2n, num);
  129. if (lane == 0) {
  130. x[wid*2] = mean;
  131. x[wid*2+1] = m2n;
  132. count[wid] = num;
  133. }
  134. __syncthreads();
  135. if (wid == 0) {
  136. mean = (thread_id < block_size / WARP_SIZE)? x[lane*2] : T(0);
  137. m2n = (thread_id < block_size / WARP_SIZE)? x[lane*2+1] : T(0);
  138. num = (thread_id < block_size / WARP_SIZE)? count[lane] : int(0);
  139. }
  140. }
  141. if (wid==0) warp_reduce_mean_m2n(mean, m2n, num);
  142. return;
  143. }
  144. // return spatial size for NC+ Tensors
  145. __host__ int get_tensor_spatial_size(const at::Tensor& input)
  146. {
  147. auto space_size = input.size(2);
  148. for (int i = 3; i < input.ndimension(); i++) {
  149. space_size *= input.size(i);
  150. }
  151. return space_size;
  152. }
  153. // promote accumulation scalar type. promote half to float.
  154. __host__ at::ScalarType promote_scalartype(const at::Tensor& input)
  155. {
  156. return input.scalar_type() == at::ScalarType::Half ?
  157. at::ScalarType::Float : input.scalar_type();
  158. }
  159. // return single element size, optional accumulation type promotion.
  160. __host__ size_t get_element_data_size(const at::Tensor& input, bool accumulation = false)
  161. {
  162. auto scalar_type = accumulation ? promote_scalartype(input) : input.scalar_type();
  163. return at::elementSize(scalar_type);
  164. }
  165. template<typename T, typename C>
  166. __device__ __forceinline__ void welford_merge_block_vertical(C& count,
  167. T& mean,
  168. T& m2n,
  169. C* shmem_count,
  170. T* shmem_mean,
  171. T* shmem_m2n) {
  172. // write to shared memory
  173. auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
  174. shmem_mean[address_base] = mean;
  175. shmem_m2n[address_base] = m2n;
  176. shmem_count[address_base] = count;
  177. #pragma unroll
  178. for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
  179. __syncthreads();
  180. if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
  181. auto address = address_base + offset * blockDim.x;
  182. // read shared memory back to register for reduction
  183. auto num_new = shmem_count[address];
  184. auto mean_new = shmem_mean[address];
  185. auto m2n_new = shmem_m2n[address];
  186. welford_merge_element(count, mean, m2n, num_new, mean_new, m2n_new);
  187. // last write is not necessary
  188. shmem_mean[address_base] = mean;
  189. shmem_m2n[address_base] = m2n;
  190. shmem_count[address_base] = count;
  191. }
  192. }
  193. }
  194. template<typename T>
  195. __device__ __forceinline__ void merge_block_vertical(T& sum_dy,
  196. T& sum_dy_xmu,
  197. T* shmem_sum_dy,
  198. T* shmem_sum_dy_xmu) {
  199. // write to shared memory
  200. auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
  201. shmem_sum_dy[address_base] = sum_dy;
  202. shmem_sum_dy_xmu[address_base] = sum_dy_xmu;
  203. #pragma unroll
  204. for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
  205. __syncthreads();
  206. if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
  207. auto address = address_base + offset * blockDim.x;
  208. sum_dy += shmem_sum_dy[address];
  209. sum_dy_xmu += shmem_sum_dy_xmu[address];
  210. // last write is not necessary
  211. shmem_sum_dy[address_base] = sum_dy;
  212. shmem_sum_dy_xmu[address_base] = sum_dy_xmu;
  213. }
  214. }
  215. }
  216. // welford kernel calculating mean/biased_variance/unbiased_variance
  217. template <typename scalar_t, typename accscalar_t, typename outscalar_t>
  218. __global__ void welford_kernel(
  219. const scalar_t* __restrict__ input,
  220. outscalar_t* __restrict__ out_mean,
  221. outscalar_t* __restrict__ out_var_biased,
  222. const int bs,
  223. const int fs,
  224. const int ss) {
  225. int block_size = blockDim.x * blockDim.y;
  226. int count = 0;
  227. accscalar_t x_mean = accscalar_t(0);
  228. accscalar_t m_2_n = accscalar_t(0);
  229. int thread_id = threadIdx.y*blockDim.x + threadIdx.x;
  230. for (int batch_id = threadIdx.y; batch_id < bs; batch_id += blockDim.y) {
  231. int input_base = blockIdx.x*ss + batch_id*ss*fs;
  232. // sequential welford
  233. for (int offset = threadIdx.x; offset < ss ; offset += blockDim.x) {
  234. count++;
  235. auto x_n = static_cast<accscalar_t>(input[offset+input_base]);
  236. auto d = x_n - x_mean;
  237. x_mean += d / count;
  238. m_2_n += d * (x_n - x_mean);
  239. }
  240. }
  241. static __shared__ int s_mem[160];
  242. accscalar_t* s_mem_ac = (accscalar_t*) &s_mem[32];
  243. welford_reduce_mean_m2n<accscalar_t>(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id);
  244. if (thread_id == 0) {
  245. out_mean[blockIdx.x] = static_cast<outscalar_t>(x_mean);
  246. out_var_biased[blockIdx.x] = static_cast<outscalar_t>(m_2_n/count);
  247. }
  248. }
  249. // elementwise BN kernel
  250. template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
  251. __global__ void batchnorm_forward_kernel(
  252. const scalar_t* __restrict__ input,
  253. const accscalar_t* __restrict__ mean,
  254. const accscalar_t* __restrict__ inv_std,
  255. const layerscalar_t* __restrict__ weight,
  256. const layerscalar_t* __restrict__ shift,
  257. scalar_t* __restrict__ out,
  258. const int ss,
  259. const int bs) {
  260. auto m_c = mean[blockIdx.x];
  261. auto inv_std_c = inv_std[blockIdx.x];
  262. auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x]);
  263. auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[blockIdx.x]);
  264. for (int batch_offset = blockIdx.y*blockDim.y + threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {
  265. int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss;
  266. for (int offset = threadIdx.x + blockIdx.z*blockDim.x; offset < ss ; offset+= gridDim.z*blockDim.x) {
  267. out[address_base+offset] = static_cast<scalar_t>(w_c * (static_cast<accscalar_t>(input[address_base+offset]) - m_c ) * inv_std_c + s_c);
  268. }
  269. }
  270. }
  271. // Backward BN kernel, calculates grad_bias, grad_weight as well as intermediate
  272. // results to calculating grad_input.
  273. // Breaking the grad_input to two step to support sync BN, which requires all
  274. // reduce of the intermediate results across processes.
  275. template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
  276. __global__ void reduce_bn_kernel(
  277. const scalar_t* __restrict__ input,
  278. const scalar_t* __restrict__ grad_output,
  279. const accscalar_t* __restrict__ mean,
  280. const accscalar_t* __restrict__ inv_std,
  281. accscalar_t* __restrict__ sum_dy_o,
  282. accscalar_t* __restrict__ sum_dy_xmu_o,
  283. layerscalar_t* __restrict__ grad_weight,
  284. layerscalar_t* __restrict__ grad_bias,
  285. const int bs,
  286. const int fs,
  287. const int ss) {
  288. static __shared__ int s_mem[64];
  289. //int total_item_num = bs * ss;
  290. int thread_id = threadIdx.y*blockDim.x + threadIdx.x;
  291. auto r_mean = mean[blockIdx.x];
  292. auto factor = inv_std[blockIdx.x];
  293. // Kahan sum
  294. accscalar_t sum_dy = 0.0;
  295. accscalar_t sum_dy_xmu = 0.0;
  296. accscalar_t sum_dy_c = 0.0;
  297. accscalar_t sum_dy_xmu_c = 0.0;
  298. for (int batch_id = threadIdx.y; batch_id < bs; batch_id += blockDim.y) {
  299. int input_base = blockIdx.x*ss + batch_id*ss*fs;
  300. for (int offset = threadIdx.x; offset < ss ; offset += blockDim.x) {
  301. auto e_grad = static_cast<accscalar_t>(grad_output[offset+input_base]);
  302. auto e_input = static_cast<accscalar_t>(input[offset+input_base]);
  303. // calculating sum_dy
  304. auto sum_dy_y = e_grad - sum_dy_c;
  305. auto sum_dy_t = sum_dy + sum_dy_y;
  306. sum_dy_c = (sum_dy_t - sum_dy) - sum_dy_y;
  307. sum_dy = sum_dy_t;
  308. // calculating sum_dy_xmu
  309. auto sum_dy_xmu_y = e_grad * (e_input - r_mean) - sum_dy_xmu_c;
  310. auto sum_dy_xmu_t = sum_dy_xmu + sum_dy_xmu_y;
  311. sum_dy_xmu_c = (sum_dy_xmu_t - sum_dy_xmu) - sum_dy_xmu_y;
  312. sum_dy_xmu = sum_dy_xmu_t;
  313. }
  314. }
  315. sum_dy = reduce_block((accscalar_t*)s_mem, sum_dy);
  316. __syncthreads();
  317. sum_dy_xmu = reduce_block((accscalar_t*)s_mem, sum_dy_xmu);
  318. if (thread_id == 0) {
  319. if (grad_bias != NULL) {
  320. grad_bias[blockIdx.x] = static_cast<layerscalar_t>(sum_dy);
  321. }
  322. if (grad_weight != NULL) {
  323. grad_weight[blockIdx.x] = static_cast<layerscalar_t>(sum_dy_xmu * factor);
  324. }
  325. //mean_dy[blockIdx.x] = sum_dy / total_item_num;
  326. //mean_dy_xmu[blockIdx.x] = sum_dy_xmu / total_item_num;
  327. sum_dy_o[blockIdx.x] = sum_dy;
  328. sum_dy_xmu_o[blockIdx.x] = sum_dy_xmu;
  329. }
  330. }
  331. // elementwise backward BN kernel
  332. template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
  333. __global__ void batchnorm_backward_kernel(
  334. const scalar_t* __restrict__ grad_output,
  335. const scalar_t* __restrict__ input,
  336. const accscalar_t* __restrict__ mean,
  337. const accscalar_t* __restrict__ inv_std,
  338. const layerscalar_t* __restrict__ weight,
  339. const accscalar_t* __restrict__ sum_dy,
  340. const accscalar_t* __restrict__ sum_dy_xmu,
  341. const int* __restrict__ numel,
  342. scalar_t* __restrict__ grad_input,
  343. const int64_t world_size,
  344. const int ss,
  345. const int bs) {
  346. int64_t div = 0;
  347. for (int i = 0; i < world_size; i++) {
  348. div += numel[i];
  349. }
  350. auto m_c = static_cast<accscalar_t>(mean[blockIdx.x]);
  351. //auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]);
  352. auto m_dy_c = static_cast<accscalar_t>(sum_dy[blockIdx.x]) / div;
  353. auto factor_1_c = inv_std[blockIdx.x];
  354. auto factor_2_c = (weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x])) * factor_1_c;
  355. //factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[blockIdx.x];
  356. factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[blockIdx.x] / div;
  357. for (int batch_offset = blockIdx.y*blockDim.y+threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {
  358. int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss;
  359. for (int offset = threadIdx.x + blockIdx.z*blockDim.x; offset < ss ; offset+= gridDim.z*blockDim.x) {
  360. grad_input[address_base+offset] = (static_cast<accscalar_t>(grad_output[address_base+offset]) - m_dy_c - (static_cast<accscalar_t>(input[address_base+offset]) - m_c) * factor_1_c) * factor_2_c;
  361. }
  362. }
  363. }
  364. // welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance
  365. template
  366. <typename scalar_t,
  367. typename accscalar_t,
  368. typename outscalar_t,
  369. int PARALLEL_LOADS>
  370. __global__ void
  371. welford_kernel_c_last(
  372. const scalar_t* __restrict__ input,
  373. outscalar_t* __restrict__ out_mean,
  374. outscalar_t* __restrict__ out_var_biased,
  375. volatile accscalar_t* staging_data,
  376. int* semaphores,
  377. const int reduction_size,
  378. const int stride) {
  379. // hide latency with concurrency
  380. accscalar_t x_mean[PARALLEL_LOADS];
  381. accscalar_t m_2_n[PARALLEL_LOADS];
  382. int count[PARALLEL_LOADS];
  383. #pragma unroll
  384. for (int i = 0; i < PARALLEL_LOADS; i++) {
  385. x_mean[i] = accscalar_t(0);
  386. m_2_n[i] = accscalar_t(0);
  387. count[i] = accscalar_t(0);
  388. }
  389. // tensor dimension (m,c)
  390. // loop along m dimension
  391. int inner_loop_stride = blockDim.y * gridDim.y;
  392. // offset along m dimension
  393. int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
  394. int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
  395. int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
  396. int address_base = m_offset * stride + c_offset;
  397. int address_increment = inner_loop_stride * stride;
  398. for (int i = 0; i < loop_count; i++) {
  399. accscalar_t x_math[PARALLEL_LOADS];
  400. accscalar_t x_count_inv[PARALLEL_LOADS];
  401. accscalar_t is_valid[PARALLEL_LOADS];
  402. // load multiple data in
  403. #pragma unroll
  404. for (int j = 0; j < PARALLEL_LOADS; j++) {
  405. if (c_offset < stride && m_offset < reduction_size) {
  406. x_math[j] = input[address_base];
  407. count[j]++;
  408. x_count_inv[j] = accscalar_t(1) / count[j];
  409. is_valid[j] = accscalar_t(1);
  410. } else {
  411. x_math[j] = accscalar_t(0);
  412. x_count_inv[j] = accscalar_t(0);
  413. is_valid[j] = accscalar_t(0);
  414. }
  415. m_offset += inner_loop_stride;
  416. address_base += address_increment;
  417. }
  418. // calculate mean/m2n with welford
  419. #pragma unroll
  420. for (int j = 0; j < PARALLEL_LOADS; j++) {
  421. accscalar_t delta0 = x_math[j] - x_mean[j];
  422. x_mean[j] += delta0 * x_count_inv[j];
  423. accscalar_t delta1 = x_math[j] - x_mean[j];
  424. m_2_n[j] += delta0 * delta1 * is_valid[j];
  425. }
  426. }
  427. // thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS
  428. #pragma unroll
  429. for (int j = 1; j < PARALLEL_LOADS; j++) {
  430. welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]);
  431. }
  432. // release x_mean / m_2_n
  433. auto mean_th = x_mean[0];
  434. auto m2_th = m_2_n[0];
  435. auto count_th = count[0];
  436. // block-wise reduction with shared memory (since reduction cannot be done within a warp)
  437. static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE];
  438. static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE];
  439. static __shared__ int shmem_count[MAX_BLOCK_SIZE];
  440. welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
  441. // grid reduction if needed (coop launch used at the first place)
  442. if (gridDim.y > 1) {
  443. volatile accscalar_t* staging_mean = staging_data;
  444. volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y];
  445. volatile int* staging_count = reinterpret_cast<volatile int*>(&staging_m2n[stride*gridDim.y]);
  446. address_base = c_offset + blockIdx.y * stride;
  447. // write data to staging_data;
  448. if (threadIdx.y == 0 && c_offset < stride) {
  449. staging_mean[address_base] = mean_th;
  450. staging_m2n[address_base] = m2_th;
  451. staging_count[address_base] = count_th;
  452. }
  453. __threadfence();
  454. __syncthreads(); // ensuring writes to staging_ is visible to all blocks
  455. __shared__ bool is_last_block_done;
  456. // mark block done
  457. if (threadIdx.x == 0 && threadIdx.y == 0) {
  458. int old = atomicAdd(&semaphores[blockIdx.x], 1);
  459. is_last_block_done = (old == (gridDim.y-1));
  460. }
  461. __syncthreads();
  462. // check that all data is now available in global memory
  463. if (is_last_block_done) {
  464. count_th = 0;
  465. mean_th = accscalar_t(0.0);
  466. m2_th = accscalar_t(0.0);
  467. for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
  468. address_base = c_offset + y * stride;
  469. int num_new = c_offset < stride ? staging_count[address_base] : 0;
  470. accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0);
  471. accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0);
  472. welford_merge_element(count_th, mean_th, m2_th, num_new, mean_new, m2n_new);
  473. }
  474. welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
  475. if (threadIdx.y == 0 && c_offset < stride) {
  476. out_mean[c_offset] = static_cast<outscalar_t>(mean_th);
  477. out_var_biased[c_offset] = static_cast<outscalar_t>(m2_th / count_th);
  478. }
  479. }
  480. } else {
  481. if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
  482. out_mean[c_offset] = static_cast<outscalar_t>(mean_th);
  483. out_var_biased[c_offset] = static_cast<outscalar_t>(m2_th / count_th);
  484. }
  485. }
  486. }
  487. // parallel welford kernel to further reduce mean / biased_var
  488. // into mean / unbiased_var / inv_std across multiple processes.
  489. template <typename scalar_t>
  490. __global__ void welford_kernel_parallel(
  491. const scalar_t* __restrict__ mean,
  492. const scalar_t* __restrict__ var_biased,
  493. const int* __restrict__ numel,
  494. scalar_t* __restrict__ out_mean,
  495. scalar_t* __restrict__ out_var,
  496. scalar_t* __restrict__ inv_std,
  497. const int world_size,
  498. const int feature_size,
  499. const float eps) {
  500. for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < feature_size; i += gridDim.x * blockDim.x) {
  501. // load data;
  502. int address = i;
  503. scalar_t x_mean = 0;
  504. scalar_t m_2_n = 0;
  505. int count = 0;
  506. for (int j = 0; j < world_size; j++) {
  507. welford_merge_element(count, x_mean, m_2_n, numel[j], mean[address], var_biased[address]*numel[j]);
  508. address += feature_size;
  509. }
  510. out_mean[i] = x_mean;
  511. out_var[i] = m_2_n/ (count - 1);
  512. inv_std[i] = scalar_t(1) / sqrt(m_2_n/count + eps);
  513. }
  514. }
  515. // elementwise BN kernel
  516. template <
  517. typename scalar_t,
  518. typename accscalar_t,
  519. typename layerscalar_t,
  520. int PARALLEL_LOADS>
  521. __global__ void batchnorm_forward_c_last_kernel(
  522. const scalar_t* __restrict__ input,
  523. const scalar_t* __restrict__ z,
  524. const accscalar_t* __restrict__ mean,
  525. const accscalar_t* __restrict__ inv_std,
  526. const layerscalar_t* __restrict__ weight,
  527. const layerscalar_t* __restrict__ shift,
  528. scalar_t* __restrict__ out,
  529. const int reduction_size,
  530. const int stride,
  531. const bool fuse_relu) {
  532. // tensor dimension (m,c)
  533. // loop along m dimension
  534. int inner_loop_stride = blockDim.y * gridDim.y;
  535. // offset along m dimension
  536. int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
  537. int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
  538. auto m_c = mean[c_offset];
  539. auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
  540. auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
  541. auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);
  542. int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
  543. int address_base = m_offset * stride + c_offset;
  544. int address_increment = inner_loop_stride * stride;
  545. for (int i = 0; i < loop_count; i++) {
  546. #pragma unroll
  547. for (int j = 0; j < PARALLEL_LOADS; j++) {
  548. if (c_offset < stride && m_offset < reduction_size) {
  549. auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
  550. if (z != NULL) {
  551. tmp += z[address_base];
  552. }
  553. out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast<scalar_t>(tmp));
  554. }
  555. m_offset += inner_loop_stride;
  556. address_base += address_increment;
  557. }
  558. }
  559. }
  560. // elementwise BN kernel
  561. template <
  562. typename scalar_t,
  563. typename accscalar_t,
  564. typename layerscalar_t,
  565. int PARALLEL_LOADS>
  566. __global__ void relu_backward_c_last_kernel(
  567. const scalar_t* __restrict__ grad_output,
  568. const scalar_t* __restrict__ input,
  569. const scalar_t* __restrict__ z,
  570. const accscalar_t* __restrict__ mean,
  571. const accscalar_t* __restrict__ inv_std,
  572. const layerscalar_t* __restrict__ weight,
  573. const layerscalar_t* __restrict__ shift,
  574. scalar_t* __restrict__ out,
  575. const int reduction_size,
  576. const int stride) {
  577. // tensor dimension (m,c)
  578. // loop along m dimension
  579. int inner_loop_stride = blockDim.y * gridDim.y;
  580. // offset along m dimension
  581. int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
  582. int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
  583. auto m_c = mean[c_offset];
  584. auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
  585. auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
  586. auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);
  587. int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
  588. int address_base = m_offset * stride + c_offset;
  589. int address_increment = inner_loop_stride * stride;
  590. for (int i = 0; i < loop_count; i++) {
  591. #pragma unroll
  592. for (int j = 0; j < PARALLEL_LOADS; j++) {
  593. if (c_offset < stride && m_offset < reduction_size) {
  594. auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
  595. if (z != NULL) {
  596. tmp += z[address_base];
  597. }
  598. out[address_base] = (tmp <= accscalar_t(0.0) ? scalar_t(0.0) : grad_output[address_base]);
  599. }
  600. m_offset += inner_loop_stride;
  601. address_base += address_increment;
  602. }
  603. }
  604. }
  605. // batchnorm backward kernel for c last tensor
  606. template
  607. <typename scalar_t,
  608. typename accscalar_t,
  609. typename layerscalar_t,
  610. int PARALLEL_LOADS>
  611. __global__ void reduce_bn_c_last_kernel(
  612. const scalar_t* __restrict__ input,
  613. const scalar_t* __restrict__ grad_output,
  614. const accscalar_t* __restrict__ mean,
  615. const accscalar_t* __restrict__ inv_std,
  616. accscalar_t* __restrict__ sum_dy_o,
  617. accscalar_t* __restrict__ sum_dy_xmu_o,
  618. layerscalar_t* __restrict__ grad_weight,
  619. layerscalar_t* __restrict__ grad_bias,
  620. volatile accscalar_t* staging_data,
  621. int* semaphores,
  622. const int reduction_size,
  623. const int stride) {
  624. // hide latency with concurrency
  625. accscalar_t sum_dy[PARALLEL_LOADS];
  626. accscalar_t sum_dy_xmu[PARALLEL_LOADS];
  627. #pragma unroll
  628. for (int i = 0; i < PARALLEL_LOADS; i++) {
  629. sum_dy[i] = accscalar_t(0);
  630. sum_dy_xmu[i] = accscalar_t(0);
  631. }
  632. // tensor dimension (m,c)
  633. // loop along m dimension
  634. int inner_loop_stride = blockDim.y * gridDim.y;
  635. // offset along m dimension
  636. int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
  637. int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
  638. int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
  639. int address_base = m_offset * stride + c_offset;
  640. int address_increment = inner_loop_stride * stride;
  641. auto r_mean = mean[c_offset];
  642. auto factor = inv_std[c_offset];
  643. for (int i = 0; i < loop_count; i++) {
  644. accscalar_t x_input[PARALLEL_LOADS];
  645. accscalar_t x_grad_output[PARALLEL_LOADS];
  646. // load multiple data in
  647. #pragma unroll
  648. for (int j = 0; j < PARALLEL_LOADS; j++) {
  649. if (c_offset < stride && m_offset < reduction_size) {
  650. x_input[j] = input[address_base];
  651. x_grad_output[j] = grad_output[address_base];
  652. } else {
  653. x_input[j] = accscalar_t(0);
  654. x_grad_output[j] = accscalar_t(0);
  655. }
  656. m_offset += inner_loop_stride;
  657. address_base += address_increment;
  658. }
  659. // calculate sum_dy / sum_dy_xmu
  660. #pragma unroll
  661. for (int j = 0; j < PARALLEL_LOADS; j++) {
  662. sum_dy[j] += x_grad_output[j];
  663. sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean);
  664. }
  665. }
  666. // thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS
  667. #pragma unroll
  668. for (int j = 1; j < PARALLEL_LOADS; j++) {
  669. sum_dy[0] += sum_dy[j];
  670. sum_dy_xmu[0] += sum_dy_xmu[j];
  671. }
  672. // release array of registers
  673. auto sum_dy_th = sum_dy[0];
  674. auto sum_dy_xmu_th = sum_dy_xmu[0];
  675. // block-wise reduction with shared memory (since reduction cannot be done within a warp)
  676. static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE];
  677. static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE];
  678. merge_block_vertical(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
  679. // grid reduction if needed (coop launch used at the first place)
  680. if (gridDim.y > 1) {
  681. volatile accscalar_t* staging_sum_dy = staging_data;
  682. volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y];
  683. address_base = c_offset + blockIdx.y * stride;
  684. // write data to staging_data;
  685. if (threadIdx.y == 0 && c_offset < stride) {
  686. staging_sum_dy[address_base] = sum_dy_th;
  687. staging_sum_dy_xmu[address_base] = sum_dy_xmu_th;
  688. }
  689. __threadfence();
  690. __syncthreads(); // ensuring writes to staging_ is visible to all blocks
  691. __shared__ bool is_last_block_done;
  692. // mark block done
  693. if (threadIdx.x == 0 && threadIdx.y == 0) {
  694. int old = atomicAdd(&semaphores[blockIdx.x], 1);
  695. is_last_block_done = (old == (gridDim.y-1));
  696. }
  697. __syncthreads();
  698. // check that all data is now available in global memory
  699. if (is_last_block_done) {
  700. sum_dy_th = accscalar_t(0.0);
  701. sum_dy_xmu_th = accscalar_t(0.0);
  702. for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
  703. address_base = c_offset + y * stride;
  704. sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0));
  705. sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0));
  706. }
  707. merge_block_vertical(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
  708. if (threadIdx.y == 0 && c_offset < stride) {
  709. if (grad_bias != NULL) {
  710. grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
  711. }
  712. if (grad_weight != NULL) {
  713. grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
  714. }
  715. //mean_dy[c_offset] = sum_dy_th / reduction_size;
  716. //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
  717. sum_dy_o[c_offset] = sum_dy_th;
  718. sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
  719. }
  720. }
  721. } else {
  722. if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
  723. if (grad_bias != NULL) {
  724. grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
  725. }
  726. if (grad_weight != NULL) {
  727. grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
  728. }
  729. //mean_dy[c_offset] = sum_dy_th / reduction_size;
  730. //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
  731. sum_dy_o[c_offset] = sum_dy_th;
  732. sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
  733. }
  734. }
  735. }
  736. // elementwise BN kernel
  737. template <
  738. typename scalar_t,
  739. typename accscalar_t,
  740. typename layerscalar_t,
  741. int PARALLEL_LOADS>
  742. __global__ void batchnorm_backward_c_last_kernel(
  743. const scalar_t* __restrict__ grad_output,
  744. const scalar_t* __restrict__ input,
  745. const accscalar_t* __restrict__ mean,
  746. const accscalar_t* __restrict__ inv_std,
  747. const layerscalar_t* __restrict__ weight,
  748. const accscalar_t* __restrict__ sum_dy,
  749. const accscalar_t* __restrict__ sum_dy_xmu,
  750. const int* __restrict__ numel,
  751. scalar_t* __restrict__ grad_input,
  752. const int64_t world_size,
  753. const int reduction_size,
  754. const int stride) {
  755. int64_t div = 0;
  756. for (int i = 0; i < world_size; i++) {
  757. div += numel[i];
  758. }
  759. // tensor dimension (m,c)
  760. // loop along m dimension
  761. int inner_loop_stride = blockDim.y * gridDim.y;
  762. // offset along m dimension
  763. int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
  764. int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
  765. auto m_c = mean[c_offset];
  766. auto m_dy_c = sum_dy[c_offset] / div;
  767. auto factor_1_c = inv_std[c_offset];
  768. auto factor_2_c = (weight == NULL? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset])) * factor_1_c;
  769. factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] / div;
  770. int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
  771. int address_base = m_offset * stride + c_offset;
  772. int address_increment = inner_loop_stride * stride;
  773. for (int i = 0; i < loop_count; i++) {
  774. #pragma unroll
  775. for (int j = 0; j < PARALLEL_LOADS; j++) {
  776. if (c_offset < stride && m_offset < reduction_size) {
  777. grad_input[address_base] = static_cast<scalar_t>(
  778. (static_cast<accscalar_t>(grad_output[address_base]) - m_dy_c -
  779. (static_cast<accscalar_t>(input[address_base]) - m_c) * factor_1_c)
  780. * factor_2_c);
  781. }
  782. m_offset += inner_loop_stride;
  783. address_base += address_increment;
  784. }
  785. }
  786. }
  787. std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
  788. const auto batch_size = input.size(0);
  789. const auto feature_size = input.size(1);
  790. auto space_size = get_tensor_spatial_size(input);
  791. auto scalar_type = promote_scalartype(input);
  792. at::Tensor out_var_biased = at::empty({feature_size}, input.options().dtype(scalar_type));
  793. at::Tensor out_mean = at::empty({feature_size}, input.options().dtype(scalar_type));
  794. int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / 32));
  795. int block_x = max(1, min(MAX_BLOCK_SIZE / block_y, h_last_pow2(space_size)));
  796. const dim3 block(block_x, block_y);
  797. const dim3 grid(feature_size);
  798. auto stream = at::cuda::getCurrentCUDAStream();
  799. {
  800. using namespace at;
  801. DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "welford_mean_var_kernel",
  802. using accscalar_t = at::acc_type<scalar_t_0, true>;
  803. welford_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
  804. input.DATA_PTR<scalar_t_0>(),
  805. out_mean.DATA_PTR<accscalar_t>(),
  806. out_var_biased.DATA_PTR<accscalar_t>(),
  807. batch_size,
  808. feature_size,
  809. space_size);
  810. );
  811. }
  812. return {out_mean, out_var_biased};
  813. }
  814. at::Tensor batchnorm_forward_CUDA(
  815. const at::Tensor input,
  816. const at::Tensor mean,
  817. const at::Tensor inv_std,
  818. const at::optional<at::Tensor> weight,
  819. const at::optional<at::Tensor> shift) {
  820. const auto batch_size = input.size(0);
  821. const auto feature_size = input.size(1);
  822. at::Tensor out = at::empty_like(input);
  823. auto space_size = get_tensor_spatial_size(input);
  824. int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4));
  825. int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4));
  826. const dim3 block(block_x, block_y);
  827. int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x));
  828. int batch_group_size = max(1, min(65535, h_last_pow2(batch_size)/block_y));
  829. const dim3 grid(feature_size, batch_group_size, grid_z);
  830. auto stream = at::cuda::getCurrentCUDAStream();
  831. if (input.scalar_type() == at::ScalarType::Half
  832. && weight.has_value() &&
  833. weight.value().scalar_type() == at::ScalarType::Float) {
  834. using namespace at;
  835. DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
  836. using accscalar_t = at::acc_type<scalar_t_0, true>;
  837. batchnorm_forward_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
  838. input.DATA_PTR<scalar_t_0>(),
  839. mean.DATA_PTR<accscalar_t>(),
  840. inv_std.DATA_PTR<accscalar_t>(),
  841. weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
  842. shift.has_value() ? shift.value().DATA_PTR<accscalar_t>() : NULL,
  843. out.DATA_PTR<scalar_t_0>(),
  844. space_size,
  845. batch_size);
  846. );
  847. } else {
  848. if (weight.has_value()) {
  849. TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
  850. "input.scalar_type() is not supported with weight.scalar_type()");
  851. }
  852. using namespace at;
  853. DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
  854. using accscalar_t = at::acc_type<scalar_t_0, true>;
  855. batchnorm_forward_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(
  856. input.DATA_PTR<scalar_t_0>(),
  857. mean.DATA_PTR<accscalar_t>(),
  858. inv_std.DATA_PTR<accscalar_t>(),
  859. weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
  860. shift.has_value() ? shift.value().DATA_PTR<scalar_t_0>() : NULL,
  861. out.DATA_PTR<scalar_t_0>(),
  862. space_size,
  863. batch_size);
  864. );
  865. }
  866. return out;
  867. }
  868. std::vector<at::Tensor> reduce_bn_CUDA(
  869. const at::Tensor grad_output,
  870. const at::Tensor input,
  871. const at::Tensor mean,
  872. const at::Tensor inv_std,
  873. const at::optional<at::Tensor> weight)
  874. {
  875. const auto batch_size = input.size(0);
  876. const auto feature_size = input.size(1);
  877. auto scalar_type = promote_scalartype(input);
  878. at::Tensor sum_dy = at::empty({feature_size}, mean.options());
  879. at::Tensor sum_dy_xmu = at::empty({feature_size}, mean.options());
  880. at::Tensor grad_weight;
  881. at::Tensor grad_bias;
  882. if (weight.has_value()) {
  883. grad_weight = at::empty({feature_size}, weight.value().options());
  884. grad_bias = at::empty({feature_size}, weight.value().options());
  885. } else {
  886. grad_weight = at::empty({0}, mean.options());
  887. grad_bias = at::empty({0}, mean.options());
  888. }
  889. auto space_size = get_tensor_spatial_size(input);
  890. int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE/ 32));
  891. int block_x = max(1, min(MAX_BLOCK_SIZE/ block_y, h_last_pow2(space_size)));
  892. const dim3 block(block_x, block_y);
  893. const dim3 grid(feature_size);
  894. auto stream = at::cuda::getCurrentCUDAStream();
  895. if (input.scalar_type() == at::ScalarType::Half
  896. && weight.has_value() &&
  897. weight.value().scalar_type() == at::ScalarType::Float) {
  898. using namespace at;
  899. DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
  900. using accscalar_t = at::acc_type<scalar_t_0, true>;
  901. reduce_bn_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
  902. input.DATA_PTR<scalar_t_0>(),
  903. grad_output.DATA_PTR<scalar_t_0>(),
  904. mean.DATA_PTR<accscalar_t>(),
  905. inv_std.DATA_PTR<accscalar_t>(),
  906. sum_dy.DATA_PTR<accscalar_t>(),
  907. sum_dy_xmu.DATA_PTR<accscalar_t>(),
  908. weight.has_value() ? grad_weight.DATA_PTR<accscalar_t>() : NULL,
  909. weight.has_value() ? grad_bias.DATA_PTR<accscalar_t>() : NULL,
  910. batch_size,
  911. feature_size,
  912. space_size);
  913. );
  914. } else {
  915. if (weight.has_value()) {
  916. TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
  917. "input.scalar_type() is not supported with weight.scalar_type()");
  918. }
  919. using namespace at;
  920. DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
  921. using accscalar_t = at::acc_type<scalar_t_0, true>;
  922. reduce_bn_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(
  923. input.DATA_PTR<scalar_t_0>(),
  924. grad_output.DATA_PTR<scalar_t_0>(),
  925. mean.DATA_PTR<accscalar_t>(),
  926. inv_std.DATA_PTR<accscalar_t>(),
  927. sum_dy.DATA_PTR<accscalar_t>(),
  928. sum_dy_xmu.DATA_PTR<accscalar_t>(),
  929. weight.has_value() ? grad_weight.DATA_PTR<scalar_t_0>() : NULL,
  930. weight.has_value() ? grad_bias.DATA_PTR<scalar_t_0>() : NULL,
  931. batch_size,
  932. feature_size,
  933. space_size);
  934. );
  935. }
  936. return {sum_dy, sum_dy_xmu, grad_weight, grad_bias};
  937. }
  938. at::Tensor batchnorm_backward_CUDA(
  939. const at::Tensor grad_output,
  940. const at::Tensor input,
  941. const at::Tensor mean,
  942. const at::Tensor inv_std,
  943. const at::optional<at::Tensor> weight,
  944. const at::Tensor sum_dy,
  945. const at::Tensor sum_dy_xmu,
  946. const at::Tensor count) {
  947. const auto batch_size = input.size(0);
  948. const auto feature_size = input.size(1);
  949. at::Tensor grad_input = at::empty_like(input);
  950. auto space_size = get_tensor_spatial_size(input);
  951. int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4));
  952. int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4));
  953. const dim3 block(block_x, block_y);
  954. int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x));
  955. int batch_group_size = max(1, min(65535, h_last_pow2(batch_size)/block_y));
  956. const dim3 grid(feature_size, batch_group_size, grid_z);
  957. auto stream = at::cuda::getCurrentCUDAStream();
  958. if (input.scalar_type() == at::ScalarType::Half
  959. && weight.has_value() &&
  960. weight.value().scalar_type() == at::ScalarType::Float) {
  961. using namespace at;
  962. DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward",
  963. using accscalar_t = at::acc_type<scalar_t_0, true>;
  964. batchnorm_backward_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
  965. grad_output.DATA_PTR<scalar_t_0>(),
  966. input.DATA_PTR<scalar_t_0>(),
  967. mean.DATA_PTR<accscalar_t>(),
  968. inv_std.DATA_PTR<accscalar_t>(),
  969. weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
  970. sum_dy.DATA_PTR<accscalar_t>(),
  971. sum_dy_xmu.DATA_PTR<accscalar_t>(),
  972. count.DATA_PTR<int>(),
  973. grad_input.DATA_PTR<scalar_t_0>(),
  974. count.numel(),
  975. space_size,
  976. batch_size);
  977. );
  978. } else {
  979. if (weight.has_value()) {
  980. TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
  981. "input.scalar_type() is not supported with weight.scalar_type()");
  982. }
  983. using namespace at;
  984. DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward",
  985. using accscalar_t = at::acc_type<scalar_t_0, true>;
  986. batchnorm_backward_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(
  987. grad_output.DATA_PTR<scalar_t_0>(),
  988. input.DATA_PTR<scalar_t_0>(),
  989. mean.DATA_PTR<accscalar_t>(),
  990. inv_std.DATA_PTR<accscalar_t>(),
  991. weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
  992. sum_dy.DATA_PTR<accscalar_t>(),
  993. sum_dy_xmu.DATA_PTR<accscalar_t>(),
  994. count.DATA_PTR<int>(),
  995. grad_input.DATA_PTR<scalar_t_0>(),
  996. count.numel(),
  997. space_size,
  998. batch_size);
  999. );
  1000. }
  1001. return grad_input;
  1002. }
  1003. std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,
  1004. const at::Tensor var_biased,
  1005. const at::Tensor numel,
  1006. const float eps) {
  1007. const auto world_size = mean_feature_nodes.size(0);
  1008. const auto feature_size = mean_feature_nodes.size(1);
  1009. at::Tensor out_var = at::empty({feature_size}, var_biased.options());
  1010. at::Tensor inv_std = at::empty_like(out_var);
  1011. at::Tensor out_mean = at::empty_like(out_var);
  1012. at::Tensor mean_feature_nodes_ = mean_feature_nodes.contiguous();
  1013. at::Tensor var_biased_ = var_biased.contiguous();
  1014. at::Tensor numel_ = numel.contiguous();
  1015. // TODO(jie): tile this for memory coalescing!
  1016. const int block = std::min(h_last_pow2(feature_size), MAX_BLOCK_SIZE);
  1017. const int grid = std::max<int>(1, feature_size / block);
  1018. auto stream = at::cuda::getCurrentCUDAStream();
  1019. {
  1020. using namespace at;
  1021. DISPATCH_FLOAT_AND_HALF(mean_feature_nodes.scalar_type(), 0, "welford_parallel_kernel",
  1022. welford_kernel_parallel<scalar_t_0><<<grid, block, 0, stream>>>(
  1023. mean_feature_nodes_.DATA_PTR<scalar_t_0>(),
  1024. var_biased_.DATA_PTR<scalar_t_0>(),
  1025. numel_.DATA_PTR<int>(),
  1026. out_mean.DATA_PTR<scalar_t_0>(),
  1027. out_var.DATA_PTR<scalar_t_0>(),
  1028. inv_std.DATA_PTR<scalar_t_0>(),
  1029. world_size,
  1030. feature_size,
  1031. eps);
  1032. );
  1033. }
  1034. return {out_mean, out_var, inv_std};
  1035. }
  1036. std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {
  1037. const auto stride = input.size(input.ndimension()-1);
  1038. const auto reduction_size = input.numel() / stride;
  1039. auto scalar_type = promote_scalartype(input);
  1040. auto option = input.options().dtype(scalar_type);
  1041. at::Tensor out_var_biased = at::empty({stride}, option);
  1042. at::Tensor out_mean = at::empty({stride}, option);
  1043. dim3 block;
  1044. dim3 grid;
  1045. flexible_launch_configs(reduction_size, stride, block, grid, true);
  1046. at::Tensor staging_data;
  1047. at::Tensor semaphores;
  1048. if (grid.y > 1) {
  1049. staging_data = at::empty({4*stride*grid.y}, option);
  1050. semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
  1051. }
  1052. auto stream = at::cuda::getCurrentCUDAStream();
  1053. {
  1054. using namespace at;
  1055. DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "welford_mean_var_c_last",
  1056. using accscalar_t = at::acc_type<scalar_t_0, true>;
  1057. accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.DATA_PTR<accscalar_t>() : nullptr;
  1058. int* semaphores_ptr = grid.y > 1 ? semaphores.DATA_PTR<int>() : nullptr;
  1059. welford_kernel_c_last<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
  1060. <<<grid, block, 0, stream>>>(
  1061. input.DATA_PTR<scalar_t_0>(),
  1062. out_mean.DATA_PTR<accscalar_t>(),
  1063. out_var_biased.DATA_PTR<accscalar_t>(),
  1064. staging_data_ptr,
  1065. semaphores_ptr,
  1066. reduction_size,
  1067. stride);
  1068. );
  1069. }
  1070. return {out_mean, out_var_biased};
  1071. }
  1072. at::Tensor batchnorm_forward_c_last_CUDA(
  1073. const at::Tensor input,
  1074. const at::optional<at::Tensor> z,
  1075. const at::Tensor mean,
  1076. const at::Tensor inv_std,
  1077. const at::optional<at::Tensor> weight,
  1078. const at::optional<at::Tensor> shift,
  1079. const bool fuse_relu) {
  1080. const auto stride = input.size(input.ndimension()-1);
  1081. const auto reduction_size = input.numel() / stride;
  1082. at::Tensor out = at::empty_like(input);
  1083. dim3 block;
  1084. dim3 grid;
  1085. flexible_launch_configs(reduction_size, stride, block, grid);
  1086. auto stream = at::cuda::getCurrentCUDAStream();
  1087. if (input.scalar_type() == at::ScalarType::Half
  1088. && weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {
  1089. using namespace at;
  1090. DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
  1091. using accscalar_t = at::acc_type<scalar_t_0, true>;
  1092. batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
  1093. <<<grid, block, 0, stream>>>(
  1094. input.DATA_PTR<scalar_t_0>(),
  1095. z.has_value() ? z.value().DATA_PTR<scalar_t_0>() : NULL,
  1096. mean.DATA_PTR<accscalar_t>(),
  1097. inv_std.DATA_PTR<accscalar_t>(),
  1098. weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
  1099. shift.has_value() ? shift.value().DATA_PTR<accscalar_t>(): NULL,
  1100. out.DATA_PTR<scalar_t_0>(),
  1101. reduction_size,
  1102. stride,
  1103. fuse_relu);
  1104. );
  1105. } else {
  1106. if (weight.has_value()) {
  1107. TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
  1108. "input.scalar_type() is not supported with weight.scalar_type()");
  1109. }
  1110. using namespace at;
  1111. DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
  1112. using accscalar_t = at::acc_type<scalar_t_0, true>;
  1113. batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
  1114. <<<grid, block, 0, stream>>>(
  1115. input.DATA_PTR<scalar_t_0>(),
  1116. z.has_value() ? z.value().DATA_PTR<scalar_t_0>() : NULL,
  1117. mean.DATA_PTR<accscalar_t>(),
  1118. inv_std.DATA_PTR<accscalar_t>(),
  1119. weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
  1120. shift.has_value() ? shift.value().DATA_PTR<scalar_t_0>(): NULL,
  1121. out.DATA_PTR<scalar_t_0>(),
  1122. reduction_size,
  1123. stride,
  1124. fuse_relu);
  1125. );
  1126. }
  1127. return out;
  1128. }
  1129. std::vector<at::Tensor> reduce_bn_c_last_CUDA(
  1130. const at::Tensor grad_output,
  1131. const at::Tensor input,
  1132. const at::Tensor mean,
  1133. const at::Tensor inv_std,
  1134. const at::optional<at::Tensor> weight) {
  1135. const auto stride = input.size(input.ndimension()-1);
  1136. const auto reduction_size = input.numel() / stride;
  1137. at::Tensor sumn_dy = at::empty({stride}, mean.options());
  1138. at::Tensor sum_dy_xmu = at::empty({stride}, mean.options());
  1139. at::Tensor grad_weight;
  1140. at::Tensor grad_bias;
  1141. if (weight.has_value()) {
  1142. grad_weight = at::empty({stride}, weight.value().options());
  1143. grad_bias = at::empty({stride}, weight.value().options());
  1144. } else {
  1145. // because I cannot return an uninitialized at::Tensor
  1146. grad_weight = at::empty({0}, mean.options());
  1147. grad_bias = at::empty({0}, mean.options());
  1148. }
  1149. dim3 block;
  1150. dim3 grid;
  1151. flexible_launch_configs(reduction_size, stride, block, grid, true);
  1152. at::Tensor staging_data;
  1153. at::Tensor semaphores;
  1154. if (grid.y > 1) {
  1155. staging_data = at::empty({2*stride*grid.y}, mean.options());
  1156. semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
  1157. }
  1158. auto stream = at::cuda::getCurrentCUDAStream();
  1159. if (input.scalar_type() == at::ScalarType::Half
  1160. && weight.has_value()
  1161. && weight.value().scalar_type() == at::ScalarType::Float) {
  1162. using namespace at;
  1163. DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
  1164. using accscalar_t = at::acc_type<scalar_t_0, true>;
  1165. accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.DATA_PTR<accscalar_t>() : nullptr;
  1166. int* semaphores_ptr = grid.y > 1 ? semaphores.DATA_PTR<int>() : nullptr;
  1167. reduce_bn_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
  1168. <<<grid, block, 0, stream>>>(
  1169. input.DATA_PTR<scalar_t_0>(),
  1170. grad_output.DATA_PTR<scalar_t_0>(),
  1171. mean.DATA_PTR<accscalar_t>(),
  1172. inv_std.DATA_PTR<accscalar_t>(),
  1173. sumn_dy.DATA_PTR<accscalar_t>(),
  1174. sum_dy_xmu.DATA_PTR<accscalar_t>(),
  1175. weight.has_value() ? grad_weight.DATA_PTR<accscalar_t>() : NULL,
  1176. weight.has_value() ?grad_bias.DATA_PTR<accscalar_t>() : NULL,
  1177. staging_data_ptr,
  1178. semaphores_ptr,
  1179. reduction_size,
  1180. stride);
  1181. );
  1182. } else {
  1183. if (weight.has_value()) {
  1184. TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
  1185. "input.scalar_type() is not supported with weight.scalar_type()");
  1186. }
  1187. using namespace at;
  1188. DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
  1189. using accscalar_t = at::acc_type<scalar_t_0, true>;
  1190. accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.DATA_PTR<accscalar_t>() : nullptr;
  1191. int* semaphores_ptr = grid.y > 1 ? semaphores.DATA_PTR<int>() : nullptr;
  1192. reduce_bn_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
  1193. <<<grid, block, 0, stream>>>(
  1194. input.DATA_PTR<scalar_t_0>(),
  1195. grad_output.DATA_PTR<scalar_t_0>(),
  1196. mean.DATA_PTR<accscalar_t>(),
  1197. inv_std.DATA_PTR<accscalar_t>(),
  1198. sumn_dy.DATA_PTR<accscalar_t>(),
  1199. sum_dy_xmu.DATA_PTR<accscalar_t>(),
  1200. weight.has_value() ? grad_weight.DATA_PTR<scalar_t_0>() : NULL,
  1201. weight.has_value() ?grad_bias.DATA_PTR<scalar_t_0>() : NULL,
  1202. staging_data_ptr,
  1203. semaphores_ptr,
  1204. reduction_size,
  1205. stride);
  1206. );
  1207. }
  1208. return {sumn_dy, sum_dy_xmu, grad_weight, grad_bias};
  1209. }
  1210. at::Tensor batchnorm_backward_c_last_CUDA(
  1211. const at::Tensor grad_output,
  1212. const at::Tensor input,
  1213. const at::Tensor mean,
  1214. const at::Tensor inv_std,
  1215. const at::optional<at::Tensor> weight,
  1216. const at::Tensor sum_dy,
  1217. const at::Tensor sum_dy_xmu,
  1218. const at::Tensor count) {
  1219. const auto stride = input.size(input.ndimension()-1);
  1220. const auto reduction_size = input.numel() / stride;
  1221. at::Tensor grad_input = at::empty_like(input);
  1222. dim3 block;
  1223. dim3 grid;
  1224. flexible_launch_configs(reduction_size, stride, block, grid);
  1225. auto stream = at::cuda::getCurrentCUDAStream();
  1226. if (input.scalar_type() == at::ScalarType::Half
  1227. && weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {
  1228. using namespace at;
  1229. DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
  1230. using accscalar_t = at::acc_type<scalar_t_0, true>;
  1231. batchnorm_backward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
  1232. <<<grid, block, 0, stream>>>(
  1233. grad_output.DATA_PTR<scalar_t_0>(),
  1234. input.DATA_PTR<scalar_t_0>(),
  1235. mean.DATA_PTR<accscalar_t>(),
  1236. inv_std.DATA_PTR<accscalar_t>(),
  1237. weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
  1238. sum_dy.DATA_PTR<accscalar_t>(),
  1239. sum_dy_xmu.DATA_PTR<accscalar_t>(),
  1240. count.DATA_PTR<int>(),
  1241. grad_input.DATA_PTR<scalar_t_0>(),
  1242. count.numel(),
  1243. reduction_size,
  1244. stride);
  1245. );
  1246. } else {
  1247. if (weight.has_value()) {
  1248. TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
  1249. "input.scalar_type() is not supported with weight.scalar_type()");
  1250. }
  1251. using namespace at;
  1252. DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
  1253. using accscalar_t = at::acc_type<scalar_t_0, true>;
  1254. batchnorm_backward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
  1255. <<<grid, block, 0, stream>>>(
  1256. grad_output.DATA_PTR<scalar_t_0>(),
  1257. input.DATA_PTR<scalar_t_0>(),
  1258. mean.DATA_PTR<accscalar_t>(),
  1259. inv_std.DATA_PTR<accscalar_t>(),
  1260. weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
  1261. sum_dy.DATA_PTR<accscalar_t>(),
  1262. sum_dy_xmu.DATA_PTR<accscalar_t>(),
  1263. count.DATA_PTR<int>(),
  1264. grad_input.DATA_PTR<scalar_t_0>(),
  1265. count.numel(),
  1266. reduction_size,
  1267. stride);
  1268. );
  1269. }
  1270. return grad_input;
  1271. }
  1272. at::Tensor relu_backward_c_last_CUDA(
  1273. const at::Tensor grad_output,
  1274. const at::Tensor input,
  1275. const at::optional<at::Tensor> z,
  1276. const at::Tensor mean,
  1277. const at::Tensor inv_std,
  1278. const at::optional<at::Tensor> weight,
  1279. const at::optional<at::Tensor> shift) {
  1280. const auto stride = input.size(input.ndimension()-1);
  1281. const auto reduction_size = input.numel() / stride;
  1282. at::Tensor out = at::empty_like(input);
  1283. dim3 block;
  1284. dim3 grid;
  1285. flexible_launch_configs(reduction_size, stride, block, grid);
  1286. auto stream = at::cuda::getCurrentCUDAStream();
  1287. if (input.scalar_type() == at::ScalarType::Half
  1288. && weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {
  1289. using namespace at;
  1290. DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
  1291. using accscalar_t = at::acc_type<scalar_t_0, true>;
  1292. relu_backward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
  1293. <<<grid, block, 0, stream>>>(
  1294. grad_output.DATA_PTR<scalar_t_0>(),
  1295. input.DATA_PTR<scalar_t_0>(),
  1296. z.has_value() ? z.value().DATA_PTR<scalar_t_0>() : NULL,
  1297. mean.DATA_PTR<accscalar_t>(),
  1298. inv_std.DATA_PTR<accscalar_t>(),
  1299. weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
  1300. shift.has_value() ? shift.value().DATA_PTR<accscalar_t>(): NULL,
  1301. out.DATA_PTR<scalar_t_0>(),
  1302. reduction_size,
  1303. stride);
  1304. );
  1305. } else {
  1306. if (weight.has_value()) {
  1307. TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
  1308. "input.scalar_type() is not supported with weight.scalar_type()");
  1309. }
  1310. using namespace at;
  1311. DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
  1312. using accscalar_t = at::acc_type<scalar_t_0, true>;
  1313. relu_backward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
  1314. <<<grid, block, 0, stream>>>(
  1315. grad_output.DATA_PTR<scalar_t_0>(),
  1316. input.DATA_PTR<scalar_t_0>(),
  1317. z.has_value() ? z.value().DATA_PTR<scalar_t_0>() : NULL,
  1318. mean.DATA_PTR<accscalar_t>(),
  1319. inv_std.DATA_PTR<accscalar_t>(),
  1320. weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
  1321. shift.has_value() ? shift.value().DATA_PTR<scalar_t_0>(): NULL,
  1322. out.DATA_PTR<scalar_t_0>(),
  1323. reduction_size,
  1324. stride);
  1325. );
  1326. }
  1327. return out;
  1328. }