multi_tensor_scale_kernel.cu 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. #include <ATen/ATen.h>
  2. #include <ATen/AccumulateType.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. #include <ATen/cuda/Exceptions.h>
  5. // Another possibility:
  6. // #include <torch/all.h>
  7. #include <assert.h>
  8. // Stringstream is a big hammer, but I want to rely on operator<< for dtype.
  9. #include <sstream>
  10. #include "type_shim.h"
  11. #include "multi_tensor_apply.cuh"
  12. #define BLOCK_SIZE 512
  13. #define ILP 4
  14. template<typename T>
  15. __device__ __forceinline__ bool is_aligned(T* p){
  16. return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
  17. }
  18. template<typename T>
  19. __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
  20. typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
  21. ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
  22. }
  23. template<typename in_t, typename out_t>
  24. struct ScaleFunctor
  25. {
  26. __device__ __forceinline__ void operator()(
  27. int chunk_size,
  28. volatile int* noop_gmem,
  29. TensorListMetadata<2>& tl,
  30. float scale)
  31. {
  32. // I'd like this kernel to propagate infs/nans.
  33. // if(*noop_gmem == 1)
  34. // return;
  35. int tensor_loc = tl.block_to_tensor[blockIdx.x];
  36. int chunk_idx = tl.block_to_chunk[blockIdx.x];
  37. int n = tl.sizes[tensor_loc];
  38. in_t* in = (in_t*)tl.addresses[0][tensor_loc];
  39. in += chunk_idx*chunk_size;
  40. out_t* out = (out_t*)tl.addresses[1][tensor_loc];
  41. out += chunk_idx*chunk_size;
  42. n -= chunk_idx*chunk_size;
  43. bool finite = true;
  44. in_t r_in[ILP];
  45. out_t r_out[ILP];
  46. // to make things simple, we put aligned case in a different code path
  47. if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out))
  48. {
  49. for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
  50. {
  51. // load
  52. load_store(r_in, in, 0 , i_start);
  53. #pragma unroll
  54. for(int ii = 0; ii < ILP; ii++)
  55. {
  56. r_out[ii] = static_cast<float>(r_in[ii]) * scale;
  57. finite = finite && isfinite(r_in[ii]);
  58. }
  59. // store
  60. load_store(out, r_out, i_start, 0);
  61. }
  62. }
  63. else
  64. {
  65. // Non-divergent exit condition for __syncthreads, not necessary here
  66. for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
  67. {
  68. #pragma unroll
  69. for(int ii = 0; ii < ILP; ii++)
  70. {
  71. r_in[ii] = 0;
  72. int i = i_start + threadIdx.x + ii*blockDim.x;
  73. if(i < n && i < chunk_size)
  74. r_in[ii] = in[i];
  75. }
  76. // note for clarification to future michael:
  77. // From a pure memory dependency perspective, there's likely no point unrolling
  78. // the write loop, since writes just fire off once their LDGs arrive.
  79. // Put another way, the STGs are dependent on the LDGs, but not on each other.
  80. // There is still compute ILP benefit from unrolling the loop though.
  81. #pragma unroll
  82. for(int ii = 0; ii < ILP; ii++)
  83. {
  84. r_out[ii] = static_cast<float>(r_in[ii]) * scale;
  85. finite = finite && isfinite(r_in[ii]);
  86. }
  87. #pragma unroll
  88. for(int ii = 0; ii < ILP; ii++)
  89. {
  90. int i = i_start + threadIdx.x + ii*blockDim.x;
  91. if(i < n && i < chunk_size)
  92. out[i] = r_out[ii];
  93. }
  94. }
  95. }
  96. if(!finite)
  97. *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
  98. }
  99. };
  100. void multi_tensor_scale_cuda(
  101. int chunk_size,
  102. at::Tensor noop_flag,
  103. std::vector<std::vector<at::Tensor>> tensor_lists,
  104. float scale)
  105. {
  106. using namespace at;
  107. // The output (downscaled) type is always float.
  108. // If build times suffer, think about where to put this dispatch,
  109. // and what logic should be moved out of multi_tensor_apply.
  110. DISPATCH_FLOAT_HALF_AND_BFLOAT(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
  111. DISPATCH_FLOAT_HALF_AND_BFLOAT(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
  112. multi_tensor_apply<2>(
  113. BLOCK_SIZE,
  114. chunk_size,
  115. noop_flag,
  116. tensor_lists,
  117. ScaleFunctor<scalar_t_0, scalar_t_1>(),
  118. scale); ))
  119. AT_CUDA_CHECK(cudaGetLastError());
  120. // AT_CUDA_CHECK(cudaDeviceSynchronize());
  121. }