123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323 |
- /*
- * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #include <ATen/ATen.h>
- #include <cuda.h>
- #include <cuda_runtime.h>
- #include <cuda_fp16.h>
- #include <torch/extension.h>
- #include <stdio.h>
- int best_block_dim(int feat_dim){
- int best_dim;
- if (feat_dim < 384){
- best_dim = 64;
- }
- else{
- if (feat_dim < 1024){
- best_dim = 128;
- }
- else{
- best_dim = 256;
- }
- }
- return best_dim;
- }
- template <typename T>
- __global__ void roll_and_window_partition_forward_cuda_kernel(
- T* input,
- T* output,
- const int B,
- const int H,
- const int W,
- const int C,
- const int shift_size,
- const int window_size,
- const int nH,
- const int nW){
- // start
- //bool qual = threadIdx.x < C;
- int index = threadIdx.x;
- int offset;
- for (int i = index; i < C; i += blockDim.x) {
- offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize
- int input_offset = blockIdx.z / (nH * nW) * H * W * C +
- (blockIdx.z % (nH * nW) / nW * window_size + blockIdx.y - shift_size + H) % H * W * C +
- (blockIdx.z % nW * window_size + blockIdx.x - shift_size + W) % W * C +
- i;
- output[offset] = (T)(__ldg(input + input_offset));
- }
- }
- template <typename T>
- __global__ void roll_and_window_partition_backward_cuda_kernel(
- T* grad_in,
- T* grad_out,
- const int B,
- const int H,
- const int W,
- const int C,
- const int shift_size,
- const int window_size,
- const int nH,
- const int nW){
- // start
- int index = threadIdx.x;
- int offset;
- for (int i = index; i < C; i += blockDim.x) {
- offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize
- int input_offset =
- (blockIdx.z * nH * nW + (blockIdx.y + shift_size + H) % H / window_size * nW + (blockIdx.x + shift_size + W) % W / window_size) * window_size * window_size * C +
- (blockIdx.y + shift_size + H ) % H % window_size * window_size * C +
- (blockIdx.x + shift_size + W ) % W % window_size * C +
- i;
- grad_out[offset] = (T)(__ldg(grad_in + input_offset));
- }
- }
- template <typename T>
- __global__ void window_merge_and_roll_forward_cuda_kernel(
- T* input,
- T* output,
- const int B,
- const int H,
- const int W,
- const int C,
- const int shift_size,
- const int window_size,
- const int nH,
- const int nW){
- // start
- int index = threadIdx.x;
- int offset;
- for (int i = index; i < C; i += blockDim.x) {
- offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize
- int input_offset =
- (blockIdx.z * nH * nW + (blockIdx.y - shift_size + H) % H / window_size * nH + (blockIdx.x - shift_size + W) % W / window_size) * window_size * window_size * C +
- (blockIdx.y - shift_size + H) % window_size * window_size * C +
- (blockIdx.x - shift_size + W) % window_size * C +
- i;
- output[offset] = (T)(__ldg(input + input_offset));
- }
- }
- template <typename T>
- __global__ void window_merge_and_roll_backward_cuda_kernel(
- T* grad_in,
- T* grad_out,
- const int B,
- const int H,
- const int W,
- const int C,
- const int shift_size,
- const int window_size,
- const int nH,
- const int nW){
- // start
- int index = threadIdx.x;
- int offset;
- for (int i = index; i < C; i += blockDim.x) {
- offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize
- int input_offset =
- (blockIdx.z / (nH * nW)) * H * W * C +
- (blockIdx.z % (nH * nW) / nW * window_size + blockIdx.y + shift_size + H) % H * W * C +
- (blockIdx.z % nW * window_size + blockIdx.x + shift_size + W) % W * C +
- i;
- grad_out[offset] = (T)(__ldg(grad_in + input_offset));
- }
- }
- // input: [B, H, W, C]
- // output: [B*nH*nW, window_size, window_size, C]
- at::Tensor roll_and_window_partition_forward_cuda(
- at::Tensor & input,
- //at::Tensor & output,
- const int B,
- const int H,
- const int W,
- const int C,
- const int shift_size,
- const int window_size){
-
- int nH = H / window_size;
- int nW = W / window_size;
- dim3 grid(window_size, window_size, B * nH * nW);
- //dim3 block((C + 31) / 32 * 32);
- int blocknum = best_block_dim(C);
- dim3 block(blocknum);
- at::Tensor output;
- if (input.scalar_type() == torch::kFloat16){
- output = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(true));
- }
- else{
- output = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(true));
- }
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "roll_and_window_partition_forward_cuda_kernel", ([&] {
- roll_and_window_partition_forward_cuda_kernel<scalar_t><<<grid, block, 0>>>(
- input.data<scalar_t>(),
- output.data<scalar_t>(),
- B,
- H,
- W,
- C,
- shift_size,
- window_size,
- nH,
- nW);
- }));
- return output;
- }
- // grad_in: [B*nH*nW, window_size, window_size, C]
- // grad_out: [B, H, W, C]
- at::Tensor roll_and_window_partition_backward_cuda(
- at::Tensor & grad_in,
- const int B,
- const int H,
- const int W,
- const int C,
- const int shift_size,
- const int window_size){
-
- int nH = H / window_size;
- int nW = W / window_size;
- dim3 grid(W, H, B);
- //dim3 block((C + 31) / 32 * 32);
- int blocknum = best_block_dim(C);
- dim3 block(blocknum);
- at::Tensor grad_out;
- if (grad_in.scalar_type() == torch::kFloat16){
- grad_out = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false));
- }
- else{
- grad_out = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));
- }
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_in.type(), "roll_and_window_partition_backward_cuda_kernel", ([&] {
- roll_and_window_partition_backward_cuda_kernel<scalar_t><<<grid, block, 0>>>(
- grad_in.data<scalar_t>(),
- grad_out.data<scalar_t>(),
- B,
- H,
- W,
- C,
- shift_size,
- window_size,
- nH,
- nW);
- }));
- return grad_out;
- }
- // input: [B*nH*nW, window_size, window_size, C]
- // output: [B, H, W, C]
- at::Tensor window_merge_and_roll_forward_cuda(
- at::Tensor & input,
- //at::Tensor & output,
- const int B,
- const int H,
- const int W,
- const int C,
- const int shift_size,
- const int window_size){
-
- int nH = H / window_size;
- int nW = W / window_size;
- dim3 grid(W, H, B);
- //dim3 block((C + 31) / 32 * 32);
- int blocknum = best_block_dim(C);
- dim3 block(blocknum);
- //generate output tensor inside
- at::Tensor output;
- if (input.scalar_type() == torch::kFloat16){
- output = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(true));
- }
- else{
- output = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(true));
- }
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "window_merge_and_roll_forward_cuda_kernel", ([&] {
- window_merge_and_roll_forward_cuda_kernel<scalar_t><<<grid, block, 0>>>(
- input.data<scalar_t>(),
- output.data<scalar_t>(),
- B,
- H,
- W,
- C,
- shift_size,
- window_size,
- nH,
- nW);
- }));
- return output;
- }
- at::Tensor window_merge_and_roll_backward_cuda(
- at::Tensor & grad_in,
- const int B,
- const int H,
- const int W,
- const int C,
- const int shift_size,
- const int window_size){
-
- int nH = H / window_size;
- int nW = W / window_size;
- dim3 grid(window_size, window_size, B * nH * nW);
- //dim3 block((C + 31) / 32 * 32);
- int blocknum = best_block_dim(C);
- dim3 block(blocknum);
- at::Tensor grad_out;
- if (grad_in.scalar_type() == torch::kFloat16){
- grad_out = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false));
- }
- else{
- grad_out = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));
- }
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_in.type(), "window_merge_and_roll_backward_cuda_kernel", ([&] {
- window_merge_and_roll_backward_cuda_kernel<scalar_t><<<grid, block, 0>>>(
- grad_in.data<scalar_t>(),
- grad_out.data<scalar_t>(),
- B,
- H,
- W,
- C,
- shift_size,
- window_size,
- nH,
- nW);
- }));
- return grad_out;
- }
|