multi_tensor_apply.cuh 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. #include <ATen/ATen.h>
  2. #include <ATen/AccumulateType.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. #include <ATen/cuda/Exceptions.h>
  5. #include <c10/cuda/CUDAGuard.h>
  6. #include "compat.h"
  7. #include <assert.h>
  8. // #include <iostream>
  9. // This header is the one-stop shop for all your multi-tensor apply needs.
  10. // TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
  11. constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24};
  12. constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320};
  13. template<int n> struct TensorListMetadata
  14. {
  15. void* addresses[n][depth_to_max_tensors[n-1]];
  16. int sizes[depth_to_max_tensors[n-1]];
  17. unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
  18. int block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int.
  19. int start_tensor_this_launch;
  20. };
  21. template<typename T, typename U, typename... ArgTypes>
  22. __global__ void multi_tensor_apply_kernel(
  23. int64_t chunk_size,
  24. volatile int* noop_flag,
  25. T tl,
  26. U callable,
  27. ArgTypes... args)
  28. {
  29. // Hand the chunk information to the user-supplied functor to process however it likes.
  30. callable(chunk_size, noop_flag, tl, args...);
  31. }
  32. template<int depth, typename T, typename... ArgTypes>
  33. void multi_tensor_apply(
  34. int64_t block_size,
  35. int64_t chunk_size,
  36. const at::Tensor& noop_flag,
  37. const std::vector<std::vector<at::Tensor>>& tensor_lists,
  38. T callable,
  39. ArgTypes... args)
  40. {
  41. TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
  42. int len0 = tensor_lists[0].size();
  43. TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
  44. auto ref_device = tensor_lists[0][0].device();
  45. TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
  46. for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
  47. {
  48. TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
  49. for(int t = 0; t < tensor_lists[l].size(); t++)
  50. {
  51. // TODO: Print which tensor fails.
  52. bool contiguous_memory = tensor_lists[l][t].is_contiguous();
  53. #ifdef VERSION_GE_1_5
  54. contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d));
  55. #endif
  56. TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
  57. TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor");
  58. TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
  59. }
  60. }
  61. int ntensors = tensor_lists[0].size();
  62. TensorListMetadata<depth> tl;
  63. const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
  64. auto stream = at::cuda::getCurrentCUDAStream();
  65. tl.start_tensor_this_launch = 0;
  66. int loc_block_info = 0;
  67. int loc_tensor_info = 0;
  68. for(int t = 0; t < ntensors; t++)
  69. {
  70. tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
  71. for(int d = 0; d < depth; d++)
  72. tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
  73. loc_tensor_info++;
  74. auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
  75. for(auto chunk = 0; chunk < chunks_this_tensor; chunk++)
  76. {
  77. // std::cout << chunks_this_tensor << std::endl;
  78. tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
  79. tl.block_to_chunk[loc_block_info] = chunk;
  80. loc_block_info++;
  81. bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth-1] &&
  82. chunk == chunks_this_tensor - 1);
  83. bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]);
  84. bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
  85. if(tensors_full || blocks_full || last_chunk)
  86. {
  87. // using accscalar_t = acc_type<scalar_t, true>;
  88. multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
  89. chunk_size,
  90. noop_flag.DATA_PTR<int>(),
  91. tl,
  92. callable,
  93. args...);
  94. AT_CUDA_CHECK(cudaGetLastError());
  95. // Reset. The control flow possibilities here make my brain hurt.
  96. loc_block_info = 0;
  97. if(chunk == chunks_this_tensor - 1)
  98. {
  99. // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
  100. loc_tensor_info = 0;
  101. tl.start_tensor_this_launch = t + 1;
  102. }
  103. else
  104. {
  105. // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
  106. tl.sizes[0] = tl.sizes[loc_tensor_info-1];
  107. for(int d = 0; d < depth; d++)
  108. tl.addresses[d][0] = tl.addresses[d][loc_tensor_info-1];
  109. loc_tensor_info = 1;
  110. tl.start_tensor_this_launch = t;
  111. }
  112. }
  113. }
  114. }
  115. }