test_fused_rope.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. """Test for fused RoPE functions.
  2. Ref: https://github.com/NVIDIA/Megatron-LM/blob/40becfc96c4144985458ac0e0fae45dbb111fbd2/megatron/fused_kernels/tests/test_fused_kernels.py
  3. """ # NOQA
  4. import itertools
  5. import torch
  6. from torch.testing._internal import common_utils
  7. from apex.transformer.functional import (
  8. fused_apply_rotary_pos_emb,
  9. fused_apply_rotary_pos_emb_cached,
  10. fused_apply_rotary_pos_emb_thd,
  11. )
  12. def _rotate_half(x: torch.Tensor) -> torch.Tensor:
  13. """Change sign so the last dimension becomes [-odd, +even]
  14. Args:
  15. x (Tensor): Input tensor
  16. Returns:
  17. Tensor: Tensor rotated half
  18. """
  19. x1, x2 = torch.chunk(x, 2, dim=-1)
  20. return torch.cat((-x2, x1), dim=-1)
  21. # Copied from Megatron-Core for testing.
  22. # https://github.com/NVIDIA/Megatron-LM/blob/5f2877d85cb26e47ce6dcdae4b80adf376abf4e8/megatron/core/models/common/embeddings/rotary_pos_embedding.py#L139
  23. def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
  24. """Apply rotary positional embedding to input tensor T.
  25. check https://kexue.fm/archives/8265 for detailed formulas
  26. Args:
  27. t (Tensor): Input tensor T is of shape [seq_length, ... , dim]
  28. freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim]
  29. Returns:
  30. Tensor: The input tensor after applying RoPE
  31. """
  32. rot_dim = freqs.shape[-1]
  33. # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
  34. t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
  35. # first part is cosine component
  36. # second part is sine component, need to change signs with _rotate_half method
  37. cos_ = torch.cos(freqs).to(t.dtype)
  38. sin_ = torch.sin(freqs).to(t.dtype)
  39. t = (t * cos_) + (_rotate_half(t) * sin_)
  40. return torch.cat((t, t_pass), dim=-1)
  41. def apply_rotary_pos_emb_thd(
  42. t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: torch.Tensor
  43. ) -> torch.Tensor:
  44. """A baseline implementation of applying RoPE for `thd` format.
  45. Args:
  46. t (Tensor): Input tensor T is of shape [t, h, d]
  47. cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`,
  48. with shape [b + 1] and dtype torch.int32.
  49. freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]
  50. Returns:
  51. Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
  52. """
  53. seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
  54. return torch.cat(
  55. [
  56. apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)])
  57. for x in torch.split(t, seqlens)
  58. ]
  59. ).squeeze(1)
  60. class TestFusedRoPE(common_utils.TestCase):
  61. def setUp(self):
  62. super().setUp()
  63. self.batch_size = 2
  64. self.head_num = 64
  65. self.seq_length = [2048, 4096]
  66. self.hidden_size = [128, 256]
  67. self.rotary_percent = [0.5, 1.0]
  68. self.dtype = [torch.float32, torch.bfloat16, torch.float16]
  69. self.transpose = [None, (0, 1), (2, 3)]
  70. self.transpose_output_memory = [False, True]
  71. self.loss_func = [self._overlapping_grad, self._non_overlapping_grad]
  72. self.cached = [False, True]
  73. self.device = torch.cuda.current_device()
  74. def tearDown(self) -> None:
  75. torch.cuda.empty_cache()
  76. super().tearDown()
  77. def _overlapping_grad(self, output) -> torch.Tensor:
  78. return output.sum() * 2
  79. def _non_overlapping_grad(self, output) -> torch.Tensor:
  80. t = torch.ones_like(output)
  81. return torch.sum(output * t)
  82. def test_forward_backward(self):
  83. for (
  84. dtype,
  85. seq_length,
  86. hidden_size,
  87. rotary_percent,
  88. transpose,
  89. transpose_output_memory,
  90. loss_func,
  91. cached,
  92. ) in itertools.product(
  93. self.dtype,
  94. self.seq_length,
  95. self.hidden_size,
  96. self.rotary_percent,
  97. self.transpose,
  98. self.transpose_output_memory,
  99. self.loss_func,
  100. self.cached,
  101. ):
  102. t = torch.rand(
  103. (seq_length, self.batch_size, self.head_num, hidden_size),
  104. dtype=dtype,
  105. device=self.device,
  106. )
  107. if transpose:
  108. t = t.transpose(*transpose).contiguous().transpose(*transpose)
  109. t.requires_grad = True
  110. emb = torch.rand(
  111. (seq_length, 1, 1, int(hidden_size * rotary_percent)),
  112. dtype=torch.float32,
  113. device=self.device,
  114. )
  115. # unfused
  116. output_unfused = apply_rotary_pos_emb(t, emb)
  117. loss_unfused = loss_func(output_unfused)
  118. loss_unfused.backward()
  119. grad_unfused = t.grad.detach().clone()
  120. t.grad = None
  121. # fused
  122. if cached:
  123. cos, sin = emb.cos(), emb.sin()
  124. output_fused = fused_apply_rotary_pos_emb_cached(
  125. t, cos, sin, transpose_output_memory=transpose_output_memory
  126. )
  127. else:
  128. output_fused = fused_apply_rotary_pos_emb(
  129. t, emb, transpose_output_memory=transpose_output_memory
  130. )
  131. loss_fused = loss_func(output_fused)
  132. loss_fused.backward()
  133. grad_fused = t.grad.detach().clone()
  134. t.grad = None
  135. self.assertEqual(
  136. output_unfused,
  137. output_fused,
  138. msg=f"{dtype=}, {seq_length=}, {hidden_size=}, {rotary_percent=}, "
  139. f"{transpose=}, {transpose_output_memory=}, loss_func={loss_func.__name__}",
  140. )
  141. self.assertEqual(
  142. grad_unfused,
  143. grad_fused,
  144. msg=f"{dtype=}, {seq_length=}, {hidden_size=}, {rotary_percent=}, "
  145. f"{transpose=}, {transpose_output_memory=}, loss_func={loss_func.__name__}",
  146. )
  147. assert (
  148. output_fused.transpose(0, 1).is_contiguous() is transpose_output_memory
  149. )
  150. def test_thd_forward_backward(self):
  151. cu_seqlens = torch.tensor(
  152. [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048],
  153. dtype=torch.int32,
  154. device=self.device,
  155. )
  156. for (
  157. dtype,
  158. hidden_size,
  159. rotary_percent,
  160. transpose,
  161. loss_func,
  162. ) in itertools.product(
  163. self.dtype,
  164. self.hidden_size,
  165. self.rotary_percent,
  166. [None, [1, 2]],
  167. self.loss_func,
  168. ):
  169. t = torch.rand(
  170. (cu_seqlens[-1], self.head_num, hidden_size),
  171. dtype=dtype,
  172. device=self.device,
  173. )
  174. if transpose:
  175. t = t.transpose(*transpose).contiguous().transpose(*transpose)
  176. t.requires_grad = True
  177. emb = torch.rand(
  178. (cu_seqlens[-1], 1, 1, int(hidden_size * rotary_percent)),
  179. dtype=torch.float32,
  180. device=self.device,
  181. )
  182. # unfused
  183. output_unfused = apply_rotary_pos_emb_thd(t, cu_seqlens, emb)
  184. loss_unfused = loss_func(output_unfused)
  185. loss_unfused.backward()
  186. grad_unfused = t.grad.detach().clone()
  187. t.grad = None
  188. # fused
  189. output_fused = fused_apply_rotary_pos_emb_thd(
  190. t,
  191. cu_seqlens,
  192. emb,
  193. )
  194. loss_fused = loss_func(output_fused)
  195. loss_fused.backward()
  196. grad_fused = t.grad.detach().clone()
  197. t.grad = None
  198. self.assertEqual(
  199. output_unfused,
  200. output_fused,
  201. msg=f"{dtype=}, {cu_seqlens=}, {hidden_size=}, {rotary_percent=}, "
  202. f"{transpose=}, loss_func={loss_func.__name__}",
  203. )
  204. self.assertEqual(
  205. grad_unfused,
  206. grad_fused,
  207. msg=f"{dtype=}, {cu_seqlens=}, {hidden_size=}, {rotary_percent=}, "
  208. f"{transpose=}, loss_func={loss_func.__name__}",
  209. )
  210. if __name__ == "__main__":
  211. common_utils.run_tests()