123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247 |
- """Test for fused RoPE functions.
- Ref: https://github.com/NVIDIA/Megatron-LM/blob/40becfc96c4144985458ac0e0fae45dbb111fbd2/megatron/fused_kernels/tests/test_fused_kernels.py
- """ # NOQA
- import itertools
- import torch
- from torch.testing._internal import common_utils
- from apex.transformer.functional import (
- fused_apply_rotary_pos_emb,
- fused_apply_rotary_pos_emb_cached,
- fused_apply_rotary_pos_emb_thd,
- )
- def _rotate_half(x: torch.Tensor) -> torch.Tensor:
- """Change sign so the last dimension becomes [-odd, +even]
- Args:
- x (Tensor): Input tensor
- Returns:
- Tensor: Tensor rotated half
- """
- x1, x2 = torch.chunk(x, 2, dim=-1)
- return torch.cat((-x2, x1), dim=-1)
- # Copied from Megatron-Core for testing.
- # https://github.com/NVIDIA/Megatron-LM/blob/5f2877d85cb26e47ce6dcdae4b80adf376abf4e8/megatron/core/models/common/embeddings/rotary_pos_embedding.py#L139
- def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
- """Apply rotary positional embedding to input tensor T.
- check https://kexue.fm/archives/8265 for detailed formulas
- Args:
- t (Tensor): Input tensor T is of shape [seq_length, ... , dim]
- freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim]
- Returns:
- Tensor: The input tensor after applying RoPE
- """
- rot_dim = freqs.shape[-1]
- # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
- t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
- # first part is cosine component
- # second part is sine component, need to change signs with _rotate_half method
- cos_ = torch.cos(freqs).to(t.dtype)
- sin_ = torch.sin(freqs).to(t.dtype)
- t = (t * cos_) + (_rotate_half(t) * sin_)
- return torch.cat((t, t_pass), dim=-1)
- def apply_rotary_pos_emb_thd(
- t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: torch.Tensor
- ) -> torch.Tensor:
- """A baseline implementation of applying RoPE for `thd` format.
- Args:
- t (Tensor): Input tensor T is of shape [t, h, d]
- cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`,
- with shape [b + 1] and dtype torch.int32.
- freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]
- Returns:
- Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
- """
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
- return torch.cat(
- [
- apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)])
- for x in torch.split(t, seqlens)
- ]
- ).squeeze(1)
- class TestFusedRoPE(common_utils.TestCase):
- def setUp(self):
- super().setUp()
- self.batch_size = 2
- self.head_num = 64
- self.seq_length = [2048, 4096]
- self.hidden_size = [128, 256]
- self.rotary_percent = [0.5, 1.0]
- self.dtype = [torch.float32, torch.bfloat16, torch.float16]
- self.transpose = [None, (0, 1), (2, 3)]
- self.transpose_output_memory = [False, True]
- self.loss_func = [self._overlapping_grad, self._non_overlapping_grad]
- self.cached = [False, True]
- self.device = torch.cuda.current_device()
- def tearDown(self) -> None:
- torch.cuda.empty_cache()
- super().tearDown()
- def _overlapping_grad(self, output) -> torch.Tensor:
- return output.sum() * 2
- def _non_overlapping_grad(self, output) -> torch.Tensor:
- t = torch.ones_like(output)
- return torch.sum(output * t)
- def test_forward_backward(self):
- for (
- dtype,
- seq_length,
- hidden_size,
- rotary_percent,
- transpose,
- transpose_output_memory,
- loss_func,
- cached,
- ) in itertools.product(
- self.dtype,
- self.seq_length,
- self.hidden_size,
- self.rotary_percent,
- self.transpose,
- self.transpose_output_memory,
- self.loss_func,
- self.cached,
- ):
- t = torch.rand(
- (seq_length, self.batch_size, self.head_num, hidden_size),
- dtype=dtype,
- device=self.device,
- )
- if transpose:
- t = t.transpose(*transpose).contiguous().transpose(*transpose)
- t.requires_grad = True
- emb = torch.rand(
- (seq_length, 1, 1, int(hidden_size * rotary_percent)),
- dtype=torch.float32,
- device=self.device,
- )
- # unfused
- output_unfused = apply_rotary_pos_emb(t, emb)
- loss_unfused = loss_func(output_unfused)
- loss_unfused.backward()
- grad_unfused = t.grad.detach().clone()
- t.grad = None
- # fused
- if cached:
- cos, sin = emb.cos(), emb.sin()
- output_fused = fused_apply_rotary_pos_emb_cached(
- t, cos, sin, transpose_output_memory=transpose_output_memory
- )
- else:
- output_fused = fused_apply_rotary_pos_emb(
- t, emb, transpose_output_memory=transpose_output_memory
- )
- loss_fused = loss_func(output_fused)
- loss_fused.backward()
- grad_fused = t.grad.detach().clone()
- t.grad = None
- self.assertEqual(
- output_unfused,
- output_fused,
- msg=f"{dtype=}, {seq_length=}, {hidden_size=}, {rotary_percent=}, "
- f"{transpose=}, {transpose_output_memory=}, loss_func={loss_func.__name__}",
- )
- self.assertEqual(
- grad_unfused,
- grad_fused,
- msg=f"{dtype=}, {seq_length=}, {hidden_size=}, {rotary_percent=}, "
- f"{transpose=}, {transpose_output_memory=}, loss_func={loss_func.__name__}",
- )
- assert (
- output_fused.transpose(0, 1).is_contiguous() is transpose_output_memory
- )
- def test_thd_forward_backward(self):
- cu_seqlens = torch.tensor(
- [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048],
- dtype=torch.int32,
- device=self.device,
- )
- for (
- dtype,
- hidden_size,
- rotary_percent,
- transpose,
- loss_func,
- ) in itertools.product(
- self.dtype,
- self.hidden_size,
- self.rotary_percent,
- [None, [1, 2]],
- self.loss_func,
- ):
- t = torch.rand(
- (cu_seqlens[-1], self.head_num, hidden_size),
- dtype=dtype,
- device=self.device,
- )
- if transpose:
- t = t.transpose(*transpose).contiguous().transpose(*transpose)
- t.requires_grad = True
- emb = torch.rand(
- (cu_seqlens[-1], 1, 1, int(hidden_size * rotary_percent)),
- dtype=torch.float32,
- device=self.device,
- )
- # unfused
- output_unfused = apply_rotary_pos_emb_thd(t, cu_seqlens, emb)
- loss_unfused = loss_func(output_unfused)
- loss_unfused.backward()
- grad_unfused = t.grad.detach().clone()
- t.grad = None
- # fused
- output_fused = fused_apply_rotary_pos_emb_thd(
- t,
- cu_seqlens,
- emb,
- )
- loss_fused = loss_func(output_fused)
- loss_fused.backward()
- grad_fused = t.grad.detach().clone()
- t.grad = None
- self.assertEqual(
- output_unfused,
- output_fused,
- msg=f"{dtype=}, {cu_seqlens=}, {hidden_size=}, {rotary_percent=}, "
- f"{transpose=}, loss_func={loss_func.__name__}",
- )
- self.assertEqual(
- grad_unfused,
- grad_fused,
- msg=f"{dtype=}, {cu_seqlens=}, {hidden_size=}, {rotary_percent=}, "
- f"{transpose=}, loss_func={loss_func.__name__}",
- )
- if __name__ == "__main__":
- common_utils.run_tests()
|