123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- # --------------------------------------------------------
- # Fused kernel for window process for SwinTransformer
- # Copyright (c) 2022 Nvidia
- # Licensed under The MIT License [see LICENSE for details]
- # --------------------------------------------------------
- import torch
- import swin_window_process
- class WindowProcess(torch.autograd.Function):
- @staticmethod
- def forward(ctx, input, B, H, W, C, shift_size, window_size):
- output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size)
- ctx.B = B
- ctx.H = H
- ctx.W = W
- ctx.C = C
- ctx.shift_size = shift_size
- ctx.window_size = window_size
- return output
- @staticmethod
- def backward(ctx, grad_in):
- B = ctx.B
- H = ctx.H
- W = ctx.W
- C = ctx.C
- shift_size = ctx.shift_size
- window_size = ctx.window_size
- grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size)
- return grad_out, None, None, None, None, None, None, None
- class WindowProcessReverse(torch.autograd.Function):
- @staticmethod
- def forward(ctx, input, B, H, W, C, shift_size, window_size):
- output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size)
- ctx.B = B
- ctx.H = H
- ctx.W = W
- ctx.C = C
- ctx.shift_size = shift_size
- ctx.window_size = window_size
- return output
- @staticmethod
- def backward(ctx, grad_in):
- B = ctx.B
- H = ctx.H
- W = ctx.W
- C = ctx.C
- shift_size = ctx.shift_size
- window_size = ctx.window_size
- #grad_out = ctx.saved_tensors[0]
- #grad_out = torch.zeros((B, H, W, C), dtype=dtype).cuda()
- grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size)
- return grad_out, None, None, None, None, None, None, None
|