multi_tensor_l2norm_kernel.cu 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. #include <ATen/ATen.h>
  2. #include <ATen/AccumulateType.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. #include <ATen/cuda/Exceptions.h>
  5. #include <c10/cuda/CUDAGuard.h>
  6. // Another possibility:
  7. // #include <torch/all.h>
  8. #include <assert.h>
  9. #include "type_shim.h"
  10. #include "multi_tensor_apply.cuh"
  11. #define BLOCK_SIZE 512
  12. #define ILP 4
  13. template<typename T>
  14. __device__ __forceinline__ bool is_aligned(T* p){
  15. return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
  16. }
  17. template<typename T>
  18. __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
  19. typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
  20. ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
  21. }
  22. template<typename x_t>
  23. struct L2NormFunctor
  24. {
  25. __device__ __forceinline__ void operator()(
  26. int chunk_size,
  27. volatile int* noop_gmem,
  28. TensorListMetadata<1>& tl,
  29. float* output,
  30. float* output_per_tensor,
  31. bool per_tensor,
  32. int max_chunks_per_tensor)
  33. {
  34. // I'd like this kernel to propagate infs/nans.
  35. // if(*noop_gmem == 1)
  36. // return;
  37. int tensor_loc = tl.block_to_tensor[blockIdx.x];
  38. int chunk_idx = tl.block_to_chunk[blockIdx.x];
  39. int n = tl.sizes[tensor_loc];
  40. x_t* x = (x_t*)tl.addresses[0][tensor_loc];
  41. x += chunk_idx*chunk_size;
  42. n -= chunk_idx*chunk_size;
  43. __shared__ float s_vals[512];
  44. float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
  45. x_t r_x[ILP];
  46. for(int i = 0; i < ILP; i++)
  47. {
  48. vals[i] = 0.f;
  49. r_x[i] = 0;
  50. }
  51. // to make things simple, we put aligned case in a different code path
  52. if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))
  53. {
  54. for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
  55. {
  56. // load
  57. load_store(r_x, x, 0 , i_start);
  58. #pragma unroll
  59. for(int ii = 0; ii < ILP; ii++)
  60. {
  61. float next = static_cast<float>(r_x[ii]);
  62. vals[ii] += next*next;
  63. }
  64. }
  65. }
  66. else
  67. {
  68. for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
  69. {
  70. #pragma unroll
  71. for(int ii = 0; ii < ILP; ii++)
  72. {
  73. int i = i_start + threadIdx.x + ii*blockDim.x;
  74. if(i < n && i < chunk_size)
  75. {
  76. float next = static_cast<float>(x[i]);
  77. vals[ii] += next*next;
  78. }
  79. }
  80. }
  81. }
  82. float val = 0.f;
  83. for(int i = 0; i < ILP; i++)
  84. val += vals[i];
  85. float final = reduce_block_into_lanes(s_vals, val);
  86. if(threadIdx.x == 0)
  87. {
  88. if(!isfinite(final))
  89. *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
  90. output[blockIdx.x] += final;
  91. if(per_tensor)
  92. output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;
  93. }
  94. }
  95. };
  96. template<typename x_t>
  97. struct UnscaleL2NormFunctor
  98. {
  99. __device__ __forceinline__ void operator()(
  100. int chunk_size,
  101. volatile int* noop_gmem,
  102. TensorListMetadata<1>& tl,
  103. const float* inv_scale,
  104. float* output,
  105. float* output_per_tensor,
  106. bool per_tensor,
  107. int max_chunks_per_tensor)
  108. {
  109. // I'd like this kernel to propagate infs/nans.
  110. // if(*noop_gmem == 1)
  111. // return;
  112. int tensor_loc = tl.block_to_tensor[blockIdx.x];
  113. int chunk_idx = tl.block_to_chunk[blockIdx.x];
  114. int n = tl.sizes[tensor_loc];
  115. x_t* x = (x_t*)tl.addresses[0][tensor_loc];
  116. x += chunk_idx*chunk_size;
  117. n -= chunk_idx*chunk_size;
  118. __shared__ float s_vals[512];
  119. float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
  120. x_t r_x[ILP];
  121. for(int i = 0; i < ILP; i++)
  122. {
  123. vals[i] = 0.f;
  124. r_x[i] = 0;
  125. }
  126. // to make things simple, we put aligned case in a different code path
  127. if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))
  128. {
  129. for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
  130. {
  131. // load
  132. load_store(r_x, x, 0 , i_start);
  133. #pragma unroll
  134. for(int ii = 0; ii < ILP; ii++)
  135. {
  136. float next = static_cast<float>(r_x[ii]) * (*inv_scale);
  137. vals[ii] += next*next;
  138. }
  139. }
  140. }
  141. else
  142. {
  143. for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
  144. {
  145. #pragma unroll
  146. for(int ii = 0; ii < ILP; ii++)
  147. {
  148. int i = i_start + threadIdx.x + ii*blockDim.x;
  149. if(i < n && i < chunk_size)
  150. {
  151. float next = static_cast<float>(x[i]) * (*inv_scale);
  152. vals[ii] += next*next;
  153. }
  154. }
  155. }
  156. }
  157. float val = 0.f;
  158. for(int i = 0; i < ILP; i++)
  159. val += vals[i];
  160. float final = reduce_block_into_lanes(s_vals, val);
  161. if(threadIdx.x == 0)
  162. {
  163. if(!isfinite(final))
  164. *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
  165. output[blockIdx.x] += final;
  166. if(per_tensor)
  167. output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;
  168. }
  169. }
  170. };
  171. // Probably better to template, but since we are not likely to support other norm
  172. template<typename x_t>
  173. struct MaxNormFunctor
  174. {
  175. __device__ __forceinline__ void operator()(
  176. int chunk_size,
  177. volatile int* noop_gmem,
  178. TensorListMetadata<1>& tl,
  179. float* output,
  180. float* output_per_tensor,
  181. bool per_tensor,
  182. int max_chunks_per_tensor)
  183. {
  184. // I'd like this kernel to propagate infs/nans.
  185. // if(*noop_gmem == 1)
  186. // return;
  187. int tensor_loc = tl.block_to_tensor[blockIdx.x];
  188. int chunk_idx = tl.block_to_chunk[blockIdx.x];
  189. int n = tl.sizes[tensor_loc];
  190. x_t* x = (x_t*)tl.addresses[0][tensor_loc];
  191. x += chunk_idx*chunk_size;
  192. n -= chunk_idx*chunk_size;
  193. __shared__ float s_vals[512];
  194. float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
  195. x_t r_x[ILP];
  196. for(int i = 0; i < ILP; i++)
  197. {
  198. vals[i] = 0.f;
  199. r_x[i] = 0;
  200. }
  201. // to make things simple, we put aligned case in a different code path
  202. if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))
  203. {
  204. for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
  205. {
  206. // load
  207. load_store(r_x, x, 0 , i_start);
  208. #pragma unroll
  209. for(int ii = 0; ii < ILP; ii++)
  210. {
  211. float next = static_cast<float>(r_x[ii]);
  212. vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
  213. }
  214. }
  215. }
  216. else
  217. {
  218. for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
  219. {
  220. #pragma unroll
  221. for(int ii = 0; ii < ILP; ii++)
  222. {
  223. int i = i_start + threadIdx.x + ii*blockDim.x;
  224. if(i < n && i < chunk_size)
  225. {
  226. float next = static_cast<float>(x[i]);
  227. vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
  228. }
  229. }
  230. }
  231. }
  232. float val = 0.f;
  233. for(int i = 0; i < ILP; i++)
  234. val = fmaxf(fabsf(val), fabsf(vals[i]));
  235. float final = reduce_block_into_lanes_max_op(s_vals, val);
  236. if(threadIdx.x == 0)
  237. {
  238. if(!isfinite(final))
  239. *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
  240. output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final));
  241. if(per_tensor)
  242. output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;
  243. }
  244. }
  245. };
  246. __global__ void cleanup(
  247. float* output,
  248. float* output_per_tensor,
  249. float* ret,
  250. float* ret_per_tensor,
  251. bool per_tensor,
  252. int max_chunks_per_tensor)
  253. {
  254. __shared__ float vals[512];
  255. if(blockIdx.x == 0)
  256. {
  257. float val = 0;
  258. if(threadIdx.x < 320)
  259. val = output[threadIdx.x];
  260. float final = reduce_block_into_lanes(vals, val);
  261. if(threadIdx.x == 0)
  262. *ret = sqrt(final);
  263. }
  264. if(per_tensor)
  265. {
  266. float* output_this_tensor = output_per_tensor + blockIdx.x*max_chunks_per_tensor;
  267. float val = 0;
  268. for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
  269. val += output_this_tensor[i];
  270. float final = reduce_block_into_lanes(vals, val);
  271. if(threadIdx.x == 0)
  272. ret_per_tensor[blockIdx.x] = sqrt(final);
  273. }
  274. }
  275. __global__ void cleanup_v2(
  276. float* output,
  277. float* output_per_tensor,
  278. float* ret,
  279. float* ret_per_tensor,
  280. bool per_tensor,
  281. int max_chunks_per_tensor,
  282. int norm_type,
  283. float alpha,
  284. float beta)
  285. {
  286. __shared__ float vals[512];
  287. if(blockIdx.x == 0)
  288. {
  289. float val = 0;
  290. if(threadIdx.x < 320)
  291. val = output[threadIdx.x];
  292. if (norm_type == 0) {
  293. float final = reduce_block_into_lanes_max_op(vals, val);
  294. if(threadIdx.x == 0)
  295. *ret = alpha * (*ret) + beta * final;
  296. }
  297. else {
  298. float final = reduce_block_into_lanes(vals, val);
  299. if(threadIdx.x == 0)
  300. *ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
  301. }
  302. }
  303. if(per_tensor)
  304. {
  305. float* output_this_tensor = output_per_tensor + blockIdx.x*max_chunks_per_tensor;
  306. if (norm_type == 0) {
  307. float val = 0;
  308. for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
  309. val = fmaxf(fabsf(val), fabsf(output_this_tensor[i]));
  310. float final = reduce_block_into_lanes_max_op(vals, val);
  311. if(threadIdx.x == 0)
  312. ret_per_tensor[blockIdx.x] = alpha * ret_per_tensor[blockIdx.x] + beta * final;
  313. }
  314. else {
  315. float val = 0;
  316. for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
  317. val += output_this_tensor[i];
  318. float final = reduce_block_into_lanes(vals, val);
  319. if(threadIdx.x == 0)
  320. ret_per_tensor[blockIdx.x] = sqrt(alpha * ret_per_tensor[blockIdx.x] * ret_per_tensor[blockIdx.x] + beta * final);
  321. }
  322. }
  323. }
  324. std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
  325. int chunk_size,
  326. at::Tensor noop_flag,
  327. std::vector<std::vector<at::Tensor>> tensor_lists,
  328. at::optional<bool> per_tensor_python)
  329. {
  330. bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
  331. auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
  332. auto output = at::zeros({320}, float_options);
  333. at::Tensor output_per_tensor;
  334. at::Tensor ret_per_tensor;
  335. int ntensors = tensor_lists[0].size();
  336. int max_chunks_per_tensor = -1;
  337. if(per_tensor)
  338. {
  339. for(int t = 0; t < ntensors; t++)
  340. {
  341. int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
  342. if(max_chunks_this_tensor > max_chunks_per_tensor)
  343. max_chunks_per_tensor = max_chunks_this_tensor;
  344. }
  345. output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options);
  346. ret_per_tensor = at::empty({ntensors}, float_options);
  347. }
  348. else
  349. {
  350. ret_per_tensor = at::empty({0}, float_options);
  351. }
  352. DISPATCH_FLOAT_HALF_AND_BFLOAT(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
  353. multi_tensor_apply<1>(
  354. BLOCK_SIZE,
  355. chunk_size,
  356. noop_flag,
  357. tensor_lists,
  358. L2NormFunctor<scalar_t_0>(),
  359. output.DATA_PTR<float>(),
  360. per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
  361. per_tensor,
  362. max_chunks_per_tensor);)
  363. AT_CUDA_CHECK(cudaGetLastError());
  364. // AT_CUDA_CHECK(cudaDeviceSynchronize());
  365. // This involves one more small kernel launches, but will be negligible end to end.
  366. // I could get rid of these by hacking the functor + multi tensor harness with persistence
  367. // logic, but keeping it simple for now
  368. auto ret = at::empty({1}, output.options());
  369. const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
  370. auto stream = at::cuda::getCurrentCUDAStream();
  371. cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
  372. output.DATA_PTR<float>(),
  373. per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
  374. ret.DATA_PTR<float>(),
  375. per_tensor ? ret_per_tensor.DATA_PTR<float>() : nullptr,
  376. per_tensor,
  377. max_chunks_per_tensor);
  378. return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
  379. }
  380. std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
  381. int chunk_size,
  382. at::Tensor noop_flag,
  383. std::vector<std::vector<at::Tensor>> tensor_lists,
  384. at::Tensor inv_scale,
  385. at::optional<bool> per_tensor_python)
  386. {
  387. bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
  388. auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
  389. auto output = at::zeros({320}, float_options);
  390. at::Tensor output_per_tensor;
  391. at::Tensor ret_per_tensor;
  392. int ntensors = tensor_lists[0].size();
  393. int max_chunks_per_tensor = -1;
  394. if(per_tensor)
  395. {
  396. for(int t = 0; t < ntensors; t++)
  397. {
  398. int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
  399. if(max_chunks_this_tensor > max_chunks_per_tensor)
  400. max_chunks_per_tensor = max_chunks_this_tensor;
  401. }
  402. output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options);
  403. ret_per_tensor = at::empty({ntensors}, float_options);
  404. }
  405. else
  406. {
  407. ret_per_tensor = at::empty({0}, float_options);
  408. }
  409. DISPATCH_FLOAT_HALF_AND_BFLOAT(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_unscale_l2norm_cuda",
  410. multi_tensor_apply<1>(
  411. BLOCK_SIZE,
  412. chunk_size,
  413. noop_flag,
  414. tensor_lists,
  415. UnscaleL2NormFunctor<scalar_t_0>(),
  416. inv_scale.DATA_PTR<float>(),
  417. output.DATA_PTR<float>(),
  418. per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
  419. per_tensor,
  420. max_chunks_per_tensor);)
  421. AT_CUDA_CHECK(cudaGetLastError());
  422. // AT_CUDA_CHECK(cudaDeviceSynchronize());
  423. // This involves one more small kernel launches, but will be negligible end to end.
  424. // I could get rid of these by hacking the functor + multi tensor harness with persistence
  425. // logic, but keeping it simple for now
  426. auto ret = at::empty({1}, output.options());
  427. const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
  428. auto stream = at::cuda::getCurrentCUDAStream();
  429. cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
  430. output.DATA_PTR<float>(),
  431. per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
  432. ret.DATA_PTR<float>(),
  433. per_tensor ? ret_per_tensor.DATA_PTR<float>() : nullptr,
  434. per_tensor,
  435. max_chunks_per_tensor);
  436. return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
  437. }
  438. // Compute and update grad norm
  439. // Here use a per tensor norm, and blend new norm(n) and old norm(gn) by
  440. // L-2: gn = sqrt(a * gn^2 + b * n^2)
  441. // L-inf: gn = a * gn + b * n
  442. void multi_tensor_norm_out_cuda(
  443. int chunk_size,
  444. at::Tensor noop_flag,
  445. std::vector<std::vector<at::Tensor>> tensor_lists,
  446. at::Tensor out,
  447. const float alpha,
  448. const float beta,
  449. const int norm_type)
  450. {
  451. auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
  452. TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(), "noop flag should be on the same device as tensors");
  453. // we don't need global thus uses empty here
  454. auto output = at::empty({320}, float_options);
  455. at::Tensor output_per_tensor;
  456. at::Tensor ret_per_tensor;
  457. int ntensors = tensor_lists[0].size();
  458. int max_chunks_per_tensor = -1;
  459. for(int t = 0; t < ntensors; t++)
  460. {
  461. int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
  462. if(max_chunks_this_tensor > max_chunks_per_tensor)
  463. max_chunks_per_tensor = max_chunks_this_tensor;
  464. }
  465. // Although it is single write then read, still need to be zero
  466. // Since tailing element also participate cleanup
  467. output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options);
  468. if (norm_type == 0) {
  469. DISPATCH_FLOAT_AND_HALF(
  470. tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda",
  471. multi_tensor_apply<1>(
  472. BLOCK_SIZE,
  473. chunk_size,
  474. noop_flag,
  475. tensor_lists,
  476. MaxNormFunctor<scalar_t_0>(),
  477. output.DATA_PTR<float>(),
  478. output_per_tensor.DATA_PTR<float>(),
  479. true,
  480. max_chunks_per_tensor);)
  481. }
  482. else {
  483. DISPATCH_FLOAT_HALF_AND_BFLOAT(
  484. tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
  485. multi_tensor_apply<1>(
  486. BLOCK_SIZE,
  487. chunk_size,
  488. noop_flag,
  489. tensor_lists,
  490. L2NormFunctor<scalar_t_0>(),
  491. output.DATA_PTR<float>(),
  492. output_per_tensor.DATA_PTR<float>(),
  493. true,
  494. max_chunks_per_tensor);)
  495. }
  496. AT_CUDA_CHECK(cudaGetLastError());
  497. // AT_CUDA_CHECK(cudaDeviceSynchronize());
  498. // This involves one more small kernel launches, but will be negligible end to end.
  499. // I could get rid of these by hacking the functor + multi tensor harness with persistence
  500. // logic, but keeping it simple for now
  501. auto ret = at::empty({1}, output.options());
  502. // Adding the following device guard since it happens sometimes that the
  503. // tensors are on one device and the cuda stream is on another device which
  504. // results in ILLEGAL MEM ACCESS error.
  505. const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
  506. auto stream = at::cuda::getCurrentCUDAStream();
  507. cleanup_v2<<<ntensors, 512, 0, stream>>>(
  508. output.DATA_PTR<float>(),
  509. output_per_tensor.DATA_PTR<float>(),
  510. ret.DATA_PTR<float>(),
  511. out.DATA_PTR<float>(),
  512. true,
  513. max_chunks_per_tensor,
  514. norm_type,
  515. alpha,
  516. beta);
  517. return ;
  518. }