window_process.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # --------------------------------------------------------
  2. # Fused kernel for window process for SwinTransformer
  3. # Copyright (c) 2022 Nvidia
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # --------------------------------------------------------
  6. import torch
  7. import swin_window_process
  8. class WindowProcess(torch.autograd.Function):
  9. @staticmethod
  10. def forward(ctx, input, B, H, W, C, shift_size, window_size):
  11. output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size)
  12. ctx.B = B
  13. ctx.H = H
  14. ctx.W = W
  15. ctx.C = C
  16. ctx.shift_size = shift_size
  17. ctx.window_size = window_size
  18. return output
  19. @staticmethod
  20. def backward(ctx, grad_in):
  21. B = ctx.B
  22. H = ctx.H
  23. W = ctx.W
  24. C = ctx.C
  25. shift_size = ctx.shift_size
  26. window_size = ctx.window_size
  27. grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size)
  28. return grad_out, None, None, None, None, None, None, None
  29. class WindowProcessReverse(torch.autograd.Function):
  30. @staticmethod
  31. def forward(ctx, input, B, H, W, C, shift_size, window_size):
  32. output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size)
  33. ctx.B = B
  34. ctx.H = H
  35. ctx.W = W
  36. ctx.C = C
  37. ctx.shift_size = shift_size
  38. ctx.window_size = window_size
  39. return output
  40. @staticmethod
  41. def backward(ctx, grad_in):
  42. B = ctx.B
  43. H = ctx.H
  44. W = ctx.W
  45. C = ctx.C
  46. shift_size = ctx.shift_size
  47. window_size = ctx.window_size
  48. #grad_out = ctx.saved_tensors[0]
  49. #grad_out = torch.zeros((B, H, W, C), dtype=dtype).cuda()
  50. grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size)
  51. return grad_out, None, None, None, None, None, None, None