swin_window_process_kernel.cu 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. /*
  2. * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <ATen/ATen.h>
  17. #include <cuda.h>
  18. #include <cuda_runtime.h>
  19. #include <cuda_fp16.h>
  20. #include <torch/extension.h>
  21. #include <stdio.h>
  22. int best_block_dim(int feat_dim){
  23. int best_dim;
  24. if (feat_dim < 384){
  25. best_dim = 64;
  26. }
  27. else{
  28. if (feat_dim < 1024){
  29. best_dim = 128;
  30. }
  31. else{
  32. best_dim = 256;
  33. }
  34. }
  35. return best_dim;
  36. }
  37. template <typename T>
  38. __global__ void roll_and_window_partition_forward_cuda_kernel(
  39. T* input,
  40. T* output,
  41. const int B,
  42. const int H,
  43. const int W,
  44. const int C,
  45. const int shift_size,
  46. const int window_size,
  47. const int nH,
  48. const int nW){
  49. // start
  50. //bool qual = threadIdx.x < C;
  51. int index = threadIdx.x;
  52. int offset;
  53. for (int i = index; i < C; i += blockDim.x) {
  54. offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize
  55. int input_offset = blockIdx.z / (nH * nW) * H * W * C +
  56. (blockIdx.z % (nH * nW) / nW * window_size + blockIdx.y - shift_size + H) % H * W * C +
  57. (blockIdx.z % nW * window_size + blockIdx.x - shift_size + W) % W * C +
  58. i;
  59. output[offset] = (T)(__ldg(input + input_offset));
  60. }
  61. }
  62. template <typename T>
  63. __global__ void roll_and_window_partition_backward_cuda_kernel(
  64. T* grad_in,
  65. T* grad_out,
  66. const int B,
  67. const int H,
  68. const int W,
  69. const int C,
  70. const int shift_size,
  71. const int window_size,
  72. const int nH,
  73. const int nW){
  74. // start
  75. int index = threadIdx.x;
  76. int offset;
  77. for (int i = index; i < C; i += blockDim.x) {
  78. offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize
  79. int input_offset =
  80. (blockIdx.z * nH * nW + (blockIdx.y + shift_size + H) % H / window_size * nW + (blockIdx.x + shift_size + W) % W / window_size) * window_size * window_size * C +
  81. (blockIdx.y + shift_size + H ) % H % window_size * window_size * C +
  82. (blockIdx.x + shift_size + W ) % W % window_size * C +
  83. i;
  84. grad_out[offset] = (T)(__ldg(grad_in + input_offset));
  85. }
  86. }
  87. template <typename T>
  88. __global__ void window_merge_and_roll_forward_cuda_kernel(
  89. T* input,
  90. T* output,
  91. const int B,
  92. const int H,
  93. const int W,
  94. const int C,
  95. const int shift_size,
  96. const int window_size,
  97. const int nH,
  98. const int nW){
  99. // start
  100. int index = threadIdx.x;
  101. int offset;
  102. for (int i = index; i < C; i += blockDim.x) {
  103. offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize
  104. int input_offset =
  105. (blockIdx.z * nH * nW + (blockIdx.y - shift_size + H) % H / window_size * nH + (blockIdx.x - shift_size + W) % W / window_size) * window_size * window_size * C +
  106. (blockIdx.y - shift_size + H) % window_size * window_size * C +
  107. (blockIdx.x - shift_size + W) % window_size * C +
  108. i;
  109. output[offset] = (T)(__ldg(input + input_offset));
  110. }
  111. }
  112. template <typename T>
  113. __global__ void window_merge_and_roll_backward_cuda_kernel(
  114. T* grad_in,
  115. T* grad_out,
  116. const int B,
  117. const int H,
  118. const int W,
  119. const int C,
  120. const int shift_size,
  121. const int window_size,
  122. const int nH,
  123. const int nW){
  124. // start
  125. int index = threadIdx.x;
  126. int offset;
  127. for (int i = index; i < C; i += blockDim.x) {
  128. offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize
  129. int input_offset =
  130. (blockIdx.z / (nH * nW)) * H * W * C +
  131. (blockIdx.z % (nH * nW) / nW * window_size + blockIdx.y + shift_size + H) % H * W * C +
  132. (blockIdx.z % nW * window_size + blockIdx.x + shift_size + W) % W * C +
  133. i;
  134. grad_out[offset] = (T)(__ldg(grad_in + input_offset));
  135. }
  136. }
  137. // input: [B, H, W, C]
  138. // output: [B*nH*nW, window_size, window_size, C]
  139. at::Tensor roll_and_window_partition_forward_cuda(
  140. at::Tensor & input,
  141. //at::Tensor & output,
  142. const int B,
  143. const int H,
  144. const int W,
  145. const int C,
  146. const int shift_size,
  147. const int window_size){
  148. int nH = H / window_size;
  149. int nW = W / window_size;
  150. dim3 grid(window_size, window_size, B * nH * nW);
  151. //dim3 block((C + 31) / 32 * 32);
  152. int blocknum = best_block_dim(C);
  153. dim3 block(blocknum);
  154. at::Tensor output;
  155. if (input.scalar_type() == torch::kFloat16){
  156. output = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(true));
  157. }
  158. else{
  159. output = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(true));
  160. }
  161. AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "roll_and_window_partition_forward_cuda_kernel", ([&] {
  162. roll_and_window_partition_forward_cuda_kernel<scalar_t><<<grid, block, 0>>>(
  163. input.data<scalar_t>(),
  164. output.data<scalar_t>(),
  165. B,
  166. H,
  167. W,
  168. C,
  169. shift_size,
  170. window_size,
  171. nH,
  172. nW);
  173. }));
  174. return output;
  175. }
  176. // grad_in: [B*nH*nW, window_size, window_size, C]
  177. // grad_out: [B, H, W, C]
  178. at::Tensor roll_and_window_partition_backward_cuda(
  179. at::Tensor & grad_in,
  180. const int B,
  181. const int H,
  182. const int W,
  183. const int C,
  184. const int shift_size,
  185. const int window_size){
  186. int nH = H / window_size;
  187. int nW = W / window_size;
  188. dim3 grid(W, H, B);
  189. //dim3 block((C + 31) / 32 * 32);
  190. int blocknum = best_block_dim(C);
  191. dim3 block(blocknum);
  192. at::Tensor grad_out;
  193. if (grad_in.scalar_type() == torch::kFloat16){
  194. grad_out = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false));
  195. }
  196. else{
  197. grad_out = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));
  198. }
  199. AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_in.type(), "roll_and_window_partition_backward_cuda_kernel", ([&] {
  200. roll_and_window_partition_backward_cuda_kernel<scalar_t><<<grid, block, 0>>>(
  201. grad_in.data<scalar_t>(),
  202. grad_out.data<scalar_t>(),
  203. B,
  204. H,
  205. W,
  206. C,
  207. shift_size,
  208. window_size,
  209. nH,
  210. nW);
  211. }));
  212. return grad_out;
  213. }
  214. // input: [B*nH*nW, window_size, window_size, C]
  215. // output: [B, H, W, C]
  216. at::Tensor window_merge_and_roll_forward_cuda(
  217. at::Tensor & input,
  218. //at::Tensor & output,
  219. const int B,
  220. const int H,
  221. const int W,
  222. const int C,
  223. const int shift_size,
  224. const int window_size){
  225. int nH = H / window_size;
  226. int nW = W / window_size;
  227. dim3 grid(W, H, B);
  228. //dim3 block((C + 31) / 32 * 32);
  229. int blocknum = best_block_dim(C);
  230. dim3 block(blocknum);
  231. //generate output tensor inside
  232. at::Tensor output;
  233. if (input.scalar_type() == torch::kFloat16){
  234. output = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(true));
  235. }
  236. else{
  237. output = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(true));
  238. }
  239. AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "window_merge_and_roll_forward_cuda_kernel", ([&] {
  240. window_merge_and_roll_forward_cuda_kernel<scalar_t><<<grid, block, 0>>>(
  241. input.data<scalar_t>(),
  242. output.data<scalar_t>(),
  243. B,
  244. H,
  245. W,
  246. C,
  247. shift_size,
  248. window_size,
  249. nH,
  250. nW);
  251. }));
  252. return output;
  253. }
  254. at::Tensor window_merge_and_roll_backward_cuda(
  255. at::Tensor & grad_in,
  256. const int B,
  257. const int H,
  258. const int W,
  259. const int C,
  260. const int shift_size,
  261. const int window_size){
  262. int nH = H / window_size;
  263. int nW = W / window_size;
  264. dim3 grid(window_size, window_size, B * nH * nW);
  265. //dim3 block((C + 31) / 32 * 32);
  266. int blocknum = best_block_dim(C);
  267. dim3 block(blocknum);
  268. at::Tensor grad_out;
  269. if (grad_in.scalar_type() == torch::kFloat16){
  270. grad_out = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false));
  271. }
  272. else{
  273. grad_out = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));
  274. }
  275. AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_in.type(), "window_merge_and_roll_backward_cuda_kernel", ([&] {
  276. window_merge_and_roll_backward_cuda_kernel<scalar_t><<<grid, block, 0>>>(
  277. grad_in.data<scalar_t>(),
  278. grad_out.data<scalar_t>(),
  279. B,
  280. H,
  281. W,
  282. C,
  283. shift_size,
  284. window_size,
  285. nH,
  286. nW);
  287. }));
  288. return grad_out;
  289. }