multi_tensor_axpby_kernel.cu 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. #include <ATen/ATen.h>
  2. #include <ATen/AccumulateType.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. #include <ATen/cuda/Exceptions.h>
  5. // Another possibility:
  6. // #include <torch/all.h>
  7. #include <assert.h>
  8. #include "type_shim.h"
  9. #include "multi_tensor_apply.cuh"
  10. #define BLOCK_SIZE 512
  11. #define ILP 4
  12. template<typename T>
  13. __device__ __forceinline__ bool is_aligned(T* p){
  14. return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
  15. }
  16. template<typename T>
  17. __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
  18. typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
  19. ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
  20. }
  21. template<typename x_t, typename y_t, typename out_t>
  22. struct AxpbyFunctor
  23. {
  24. __device__ __forceinline__ void operator()(
  25. int chunk_size,
  26. volatile int* noop_gmem,
  27. TensorListMetadata<3>& tl,
  28. float a,
  29. float b,
  30. int arg_to_check)
  31. {
  32. // I'd like this kernel to propagate infs/nans.
  33. // if(*noop_gmem == 1)
  34. // return;
  35. int tensor_loc = tl.block_to_tensor[blockIdx.x];
  36. int chunk_idx = tl.block_to_chunk[blockIdx.x];
  37. int n = tl.sizes[tensor_loc];
  38. x_t* x = (x_t*)tl.addresses[0][tensor_loc];
  39. x += chunk_idx*chunk_size;
  40. y_t* y = (y_t*)tl.addresses[1][tensor_loc];
  41. y += chunk_idx*chunk_size;
  42. out_t* out = (out_t*)tl.addresses[2][tensor_loc];
  43. out += chunk_idx*chunk_size;
  44. n -= chunk_idx*chunk_size;
  45. bool finite = true;
  46. x_t r_x[ILP];
  47. y_t r_y[ILP];
  48. out_t r_out[ILP];
  49. // to make things simple, we put aligned case in a different code path
  50. if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x) && is_aligned(y) && is_aligned(out))
  51. {
  52. for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
  53. {
  54. // load
  55. load_store(r_x, x, 0 , i_start);
  56. load_store(r_y, y, 0 , i_start);
  57. #pragma unroll
  58. for(int ii = 0; ii < ILP; ii++)
  59. {
  60. r_out[ii] = a*static_cast<float>(r_x[ii]) + b*static_cast<float>(r_y[ii]);
  61. if(arg_to_check == -1)
  62. finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii]));
  63. if(arg_to_check == 0)
  64. finite = finite && isfinite(r_x[ii]);
  65. if(arg_to_check == 1)
  66. finite = finite && isfinite(r_y[ii]);
  67. }
  68. // store
  69. load_store(out, r_out, i_start , 0);
  70. }
  71. }
  72. else
  73. {
  74. // Non-divergent exit condition for __syncthreads, not necessary here
  75. for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
  76. {
  77. #pragma unroll
  78. for(int ii = 0; ii < ILP; ii++)
  79. {
  80. r_x[ii] = 0;
  81. r_y[ii] = 0;
  82. int i = i_start + threadIdx.x + ii*blockDim.x;
  83. if(i < n && i < chunk_size)
  84. {
  85. r_x[ii] = x[i];
  86. r_y[ii] = y[i];
  87. }
  88. }
  89. #pragma unroll
  90. for(int ii = 0; ii < ILP; ii++)
  91. {
  92. r_out[ii] = a*static_cast<float>(r_x[ii]) + b*static_cast<float>(r_y[ii]);
  93. if(arg_to_check == -1)
  94. finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii]));
  95. if(arg_to_check == 0)
  96. finite = finite && isfinite(r_x[ii]);
  97. if(arg_to_check == 1)
  98. finite = finite && isfinite(r_y[ii]);
  99. }
  100. // see note in multi_tensor_scale_kernel.cu
  101. #pragma unroll
  102. for(int ii = 0; ii < ILP; ii++)
  103. {
  104. int i = i_start + threadIdx.x + ii*blockDim.x;
  105. if(i < n && i < chunk_size)
  106. out[i] = r_out[ii];
  107. }
  108. }
  109. }
  110. if(!finite)
  111. *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
  112. }
  113. };
  114. void multi_tensor_axpby_cuda(
  115. int chunk_size,
  116. at::Tensor noop_flag,
  117. std::vector<std::vector<at::Tensor>> tensor_lists,
  118. float a,
  119. float b,
  120. int arg_to_check)
  121. {
  122. using namespace at;
  123. // The output (downscaled) type is always float.
  124. // If build times suffer, think about where to put this dispatch,
  125. // and what logic should be moved out of multi_tensor_apply.
  126. DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_axpby_cuda",
  127. DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_axpby_cuda",
  128. DISPATCH_FLOAT_AND_HALF(tensor_lists[2][0].scalar_type(), 2, "multi_tensor_axpby_cuda",
  129. multi_tensor_apply<3>(
  130. BLOCK_SIZE,
  131. chunk_size,
  132. noop_flag,
  133. tensor_lists,
  134. AxpbyFunctor<scalar_t_0, scalar_t_1, scalar_t_2>(),
  135. a,
  136. b,
  137. arg_to_check); )))
  138. AT_CUDA_CHECK(cudaGetLastError());
  139. // AT_CUDA_CHECK(cudaDeviceSynchronize());
  140. }