unit_test.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. # --------------------------------------------------------
  2. # Fused kernel for window process for SwinTransformer
  3. # Copyright (c) 2022 Nvidia
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # --------------------------------------------------------
  6. import torch
  7. import swin_window_process
  8. import random
  9. import time
  10. import unittest
  11. class WindowProcess(torch.autograd.Function):
  12. @staticmethod
  13. def forward(ctx, input, B, H, W, C, shift_size, window_size):
  14. output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size)
  15. ctx.B = B
  16. ctx.H = H
  17. ctx.W = W
  18. ctx.C = C
  19. ctx.shift_size = shift_size
  20. ctx.window_size = window_size
  21. return output
  22. @staticmethod
  23. def backward(ctx, grad_in):
  24. B = ctx.B
  25. H = ctx.H
  26. W = ctx.W
  27. C = ctx.C
  28. shift_size = ctx.shift_size
  29. window_size = ctx.window_size
  30. grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size)
  31. return grad_out, None, None, None, None, None, None, None
  32. class WindowProcessReverse(torch.autograd.Function):
  33. @staticmethod
  34. def forward(ctx, input, B, H, W, C, shift_size, window_size):
  35. output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size)
  36. ctx.B = B
  37. ctx.H = H
  38. ctx.W = W
  39. ctx.C = C
  40. ctx.shift_size = shift_size
  41. ctx.window_size = window_size
  42. return output
  43. @staticmethod
  44. def backward(ctx, grad_in):
  45. B = ctx.B
  46. H = ctx.H
  47. W = ctx.W
  48. C = ctx.C
  49. shift_size = ctx.shift_size
  50. window_size = ctx.window_size
  51. grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size)
  52. return grad_out, None, None, None, None, None, None, None
  53. def window_partition(x, window_size):
  54. """
  55. Args:
  56. x: (B, H, W, C)
  57. window_size (int): window size
  58. Returns:
  59. windows: (num_windows*B, window_size, window_size, C)
  60. """
  61. B, H, W, C = x.shape
  62. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
  63. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  64. return windows
  65. def window_reverse(windows, window_size, H, W):
  66. """
  67. Args:
  68. windows: (num_windows*B, window_size, window_size, C)
  69. window_size (int): Window size
  70. H (int): Height of image
  71. W (int): Width of image
  72. Returns:
  73. x: (B, H, W, C)
  74. """
  75. B = int(windows.shape[0] / (H * W / window_size / window_size))
  76. x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
  77. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
  78. return x
  79. def pyt_forward(x, shift_size, window_size):
  80. # x in shape(B, H, W, C)
  81. # cyclic shift
  82. if shift_size > 0:
  83. shifted_x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2))
  84. else:
  85. shifted_x = x
  86. # partition windows
  87. x_windows = window_partition(shifted_x, window_size)
  88. return x_windows
  89. def reverse_pyt_forward(attn_windows, shift_size, window_size, H, W):
  90. # x in shape(B*nH*nW, window_size, window_size, C)
  91. shifted_x = window_reverse(attn_windows, window_size, H, W)
  92. if shift_size > 0:
  93. x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2))
  94. else:
  95. x = shifted_x
  96. return x
  97. def copy_one_tensor(input, requires_grad=True):
  98. input1 = input.clone().detach().requires_grad_(requires_grad).cuda()
  99. return input1
  100. class Test_WindowProcess(unittest.TestCase):
  101. def setUp(self):
  102. self.B = 192
  103. self.H = 56
  104. self.W = 56
  105. self.C = 96
  106. self.shift_size = 2
  107. self.window_size = 7
  108. self.nH = self.H // self.window_size
  109. self.nW = self.W // self.window_size
  110. def test_roll_and_window_partition_forward(self, dtype=torch.float32):
  111. input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
  112. input1 = copy_one_tensor(input, True)
  113. input2 = copy_one_tensor(input, True)
  114. with torch.no_grad():
  115. # ori
  116. expected = pyt_forward(input1, self.shift_size, self.window_size)
  117. # fused kernel
  118. fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size)
  119. self.assertTrue(torch.equal(expected, fused_output))
  120. #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
  121. def test_roll_and_window_partition_backward(self, dtype=torch.float32):
  122. input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
  123. d_loss_tensor = torch.randn((self.B*self.nW*self.nH, self.window_size, self.window_size, self.C), dtype=dtype).cuda()
  124. input1 = copy_one_tensor(input, True)
  125. input2 = copy_one_tensor(input, True)
  126. # ori
  127. expected = pyt_forward(input1, self.shift_size, self.window_size)
  128. expected.backward(d_loss_tensor)
  129. # fused kernel
  130. fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size)
  131. fused_output.backward(d_loss_tensor)
  132. self.assertTrue(torch.equal(expected, fused_output))
  133. #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
  134. def test_window_merge_and_roll_forward(self, dtype=torch.float32):
  135. input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()
  136. input1 = copy_one_tensor(input, True)
  137. input2 = copy_one_tensor(input, True)
  138. with torch.no_grad():
  139. # ori
  140. expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)
  141. # fused kernel
  142. fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)
  143. self.assertTrue(torch.equal(expected, fused_output))
  144. #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
  145. def test_window_merge_and_roll_backward(self, dtype=torch.float32):
  146. input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()
  147. d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
  148. input1 = copy_one_tensor(input, True)
  149. input2 = copy_one_tensor(input, True)
  150. # ori
  151. expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)
  152. expected.backward(d_loss_tensor)
  153. # fused kernel
  154. fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)
  155. fused_output.backward(d_loss_tensor)
  156. self.assertTrue(torch.equal(expected, fused_output))
  157. #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
  158. def test_forward_backward_speed(self, dtype=torch.float32, times=1000):
  159. input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()
  160. d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
  161. input1 = copy_one_tensor(input, True)
  162. input2 = copy_one_tensor(input, True)
  163. # SwinTransformer official
  164. def run_pyt(t=1000):
  165. for _ in range(t):
  166. expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)
  167. expected.backward(d_loss_tensor)
  168. # my op
  169. def run_fusedop(t=1000):
  170. for _ in range(t):
  171. fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)
  172. fused_output.backward(d_loss_tensor)
  173. torch.cuda.synchronize()
  174. t1 = time.time()
  175. run_pyt(t=times)
  176. torch.cuda.synchronize()
  177. t2 = time.time()
  178. run_fusedop(t=times)
  179. torch.cuda.synchronize()
  180. t3 = time.time()
  181. self.assertTrue((t3 - t2) < (t2 - t1))
  182. print('Run {} times'.format(times))
  183. print('Original time cost: {}'.format(t2 - t1))
  184. print('Fused op time cost: {}'.format(t3 - t2))
  185. def test_roll_and_window_partition_forward_fp16(self, dtype=torch.float16):
  186. self.test_roll_and_window_partition_forward(dtype=dtype)
  187. def test_roll_and_window_partition_backward_fp16(self, dtype=torch.float16):
  188. self.test_roll_and_window_partition_backward(dtype=dtype)
  189. def test_window_merge_and_roll_forward_fp16(self, dtype=torch.float16):
  190. self.test_window_merge_and_roll_forward(dtype=dtype)
  191. def test_window_merge_and_roll_backward_fp16(self, dtype=torch.float16):
  192. self.test_window_merge_and_roll_backward(dtype=dtype)
  193. def test_forward_backward_speed_fp16(self, dtype=torch.float16, times=1000):
  194. self.test_forward_backward_speed(dtype=dtype, times=times)
  195. if __name__ == '__main__':
  196. print('Pass only two tensors are exactly the same (using torch.equal).\n')
  197. torch.manual_seed(0)
  198. unittest.main(verbosity=2)