123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374 |
- """Test for fused softmax functions.
- Ref: https://github.com/NVIDIA/Megatron-LM/blob/40becfc96c4144985458ac0e0fae45dbb111fbd2/megatron/fused_kernels/tests/test_fused_kernels.py
- """ # NOQA
- import itertools
- import torch
- from torch.testing._internal import common_utils
- from apex.transformer import AttnMaskType
- from apex.transformer.functional import FusedScaleMaskSoftmax
- def attention_mask_func(attention_scores, attention_mask):
- return attention_scores.masked_fill(attention_mask, -10000.0)
- def forward_torch_softmax(input, mask, scale):
- input = input * scale
- mask_output = attention_mask_func(input, mask) if mask is not None else input
- probs = torch.nn.Softmax(dim=-1)(mask_output)
- all_k_masked = mask.all(axis=-1)
- zero_attention_mask = (1.0 - all_k_masked.float())[:, :, :, None]
- probs = probs * zero_attention_mask
- return probs
- autocast_dtypes = (
- (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
- )
- class TestFusedScaleMaskSoftmax(common_utils.TestCase):
- def _setup_fused_softmax(
- self,
- input_in_fp16,
- input_in_bf16,
- scale=None,
- softmax_in_fp32=False,
- attn_mask_type=AttnMaskType.padding,
- ):
- fused_fn = FusedScaleMaskSoftmax(
- input_in_fp16=input_in_fp16,
- input_in_bf16=input_in_bf16,
- mask_func=attention_mask_func,
- scale=scale,
- softmax_in_fp32=softmax_in_fp32,
- attn_mask_type=attn_mask_type,
- scaled_masked_softmax_fusion=True,
- )
- torch_fn = FusedScaleMaskSoftmax(
- input_in_fp16=input_in_fp16,
- input_in_bf16=input_in_bf16,
- mask_func=attention_mask_func,
- scale=scale,
- softmax_in_fp32=softmax_in_fp32,
- attn_mask_type=attn_mask_type,
- scaled_masked_softmax_fusion=False,
- )
- return fused_fn, torch_fn
- def tearDown(self) -> None:
- torch.cuda.empty_cache()
- super().tearDown()
- def test_fused_scale_mask_softmax(self):
- """
- attention_scores.shape = [4, 12, 24, 24]
- mask.shape = [4, 1, 24, 24]
- """
- for (dtype, scale, softmax_in_fp32, shape) in itertools.product(
- (torch.half, torch.bfloat16), (None, 2.0), (False, True), ((4, 12, 24, 24), (32, 12, 4, 214))
- ):
- msg = f"{dtype}-{scale}-{softmax_in_fp32}"
- input_in_fp16 = dtype == torch.half
- input_in_bf16 = dtype == torch.bfloat16
- if not (scale is None or softmax_in_fp32):
- with self.assertRaises(RuntimeError, msg=msg):
- self._setup_fused_softmax(
- input_in_fp16,
- input_in_bf16,
- scale,
- softmax_in_fp32,
- AttnMaskType.padding,
- )
- return
- fused_fn, torch_fn = self._setup_fused_softmax(
- input_in_fp16,
- input_in_bf16,
- scale,
- softmax_in_fp32,
- AttnMaskType.padding,
- )
- attention_scores_0 = (
- torch.randn(shape)
- .to(device="cuda", dtype=dtype)
- .requires_grad_(True)
- )
- with torch.no_grad():
- attention_scores_1 = attention_scores_0.clone().requires_grad_(True)
- mask_shape = (shape[0],) + (1,) + shape[2:]
- mask = torch.randint(0, 2, mask_shape, device="cuda").bool()
- expected = fused_fn(attention_scores_0, mask)
- actual = torch_fn(attention_scores_1, mask)
- self.assertEqual(actual, expected, msg=msg)
- g0 = torch.rand_like(actual)
- with torch.no_grad():
- g1 = g0.clone()
- expected.backward(g0)
- actual.backward(g1)
- def test_autocast_fused_scale_mask_softmax(self):
- for dtype in autocast_dtypes:
- msg = f"dtype: {dtype}"
- input_in_fp16 = dtype == torch.half
- input_in_bf16 = dtype == torch.bfloat16
- fused_fn, torch_fn = self._setup_fused_softmax(
- input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.padding
- )
- attention_scores_0 = (
- torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
- )
- with torch.no_grad():
- attention_scores_1 = (
- attention_scores_0.clone().to(dtype).requires_grad_(True)
- )
- mask = torch.randint(0, 2, (4, 1, 24, 24)).bool().cuda()
- expected = torch_fn(attention_scores_1, mask)
- with torch.cuda.amp.autocast(dtype=dtype):
- actual = fused_fn(attention_scores_0, mask)
- self.assertEqual(actual.dtype, dtype, msg=msg)
- self.assertEqual(actual, expected, msg=msg)
- g0 = torch.rand_like(actual)
- with torch.no_grad():
- g1 = g0.clone()
- expected.backward(g0)
- actual.backward(g1)
- def test_fused_scale_softmax(self):
- """
- attention_scores.shape = [4, 12, 24, 24]
- mask = None
- """
- for (dtype, scale, softmax_in_fp32, shape) in itertools.product(
- (torch.half, torch.bfloat16), (None, 2.0), (False, True), ((4, 12, 24, 24), (32, 12, 4, 214))
- ):
- msg = f"{dtype}-{scale}-{softmax_in_fp32}"
- input_in_fp16 = dtype == torch.half
- input_in_bf16 = dtype == torch.bfloat16
- if not (scale is None or softmax_in_fp32):
- with self.assertRaises(RuntimeError, msg=msg):
- self._setup_fused_softmax(
- input_in_fp16,
- input_in_bf16,
- scale,
- softmax_in_fp32,
- AttnMaskType.padding,
- )
- return
- fused_fn, torch_fn = self._setup_fused_softmax(
- input_in_fp16,
- input_in_bf16,
- scale,
- softmax_in_fp32,
- AttnMaskType.padding,
- )
- attention_scores_0 = (
- torch.randn(shape)
- .to(device="cuda", dtype=dtype)
- .requires_grad_(True)
- )
- with torch.no_grad():
- attention_scores_1 = attention_scores_0.clone().requires_grad_(True)
- mask = None
- expected = fused_fn(attention_scores_0, mask)
- actual = torch_fn(attention_scores_1, mask)
- self.assertEqual(actual, expected, msg=msg)
- g0 = torch.rand_like(actual)
- with torch.no_grad():
- g1 = g0.clone()
- expected.backward(g0)
- actual.backward(g1)
- def test_autocast_fused_scale_softmax(self):
- for dtype in autocast_dtypes:
- msg = f"dtype: {dtype}"
- input_in_fp16 = dtype == torch.half
- input_in_bf16 = dtype == torch.bfloat16
- fused_fn, torch_fn = self._setup_fused_softmax(
- input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.padding
- )
- attention_scores_0 = (
- torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
- )
- with torch.no_grad():
- attention_scores_1 = (
- attention_scores_0.clone().to(dtype).requires_grad_(True)
- )
- mask = None
- expected = torch_fn(attention_scores_1, mask)
- with torch.cuda.amp.autocast(dtype=dtype):
- actual = fused_fn(attention_scores_0, mask)
- self.assertEqual(actual.dtype, dtype, msg=msg)
- self.assertEqual(actual, expected, msg=msg)
- g0 = torch.rand_like(actual)
- with torch.no_grad():
- g1 = g0.clone()
- expected.backward(g0)
- actual.backward(g1)
- def test_fused_upper_triangle_mask_softmax(self):
- """
- attn_weights.shape: [4, 12, 24, 24]
- total_mask.shape: [4, 1, 24, 24]
- total_mask[0, 0], a 24x24 matrix is like a lower triangular matrix, but
- upper elements are True and lower elements and diagonal are False.
- """
- for (dtype, scale, softmax_in_fp32) in itertools.product(
- (torch.half, torch.bfloat16), (None, 2.0), (False, True),
- ):
- msg = f"{dtype}-{scale}-{softmax_in_fp32}"
- input_in_fp16 = dtype == torch.half
- input_in_bf16 = dtype == torch.bfloat16
- if not (scale is None or softmax_in_fp32):
- with self.assertRaises(RuntimeError, msg=msg):
- self._setup_fused_softmax(
- input_in_fp16,
- input_in_bf16,
- scale,
- softmax_in_fp32,
- AttnMaskType.causal,
- )
- return
- fused_fn, torch_fn = self._setup_fused_softmax(
- input_in_fp16,
- input_in_bf16,
- scale,
- softmax_in_fp32,
- AttnMaskType.causal,
- )
- attn_weights_0 = (
- torch.randn((4, 12, 24, 24))
- .to(device="cuda", dtype=dtype)
- .requires_grad_(True)
- )
- with torch.no_grad():
- attn_weights_1 = attn_weights_0.clone().requires_grad_(True)
- total_mask = (
- ~(torch.tril(torch.randn((24, 24), device="cuda")).bool())
- .unsqueeze(0)
- .unsqueeze(0)
- )
- total_mask = total_mask.repeat((4, 1, 1, 1))
- expected = fused_fn(attn_weights_0, total_mask)
- actual = torch_fn(attn_weights_1, total_mask)
- self.assertEqual(actual, expected, msg=msg)
- g0 = torch.randn_like(actual)
- with torch.no_grad():
- g1 = g0.clone()
- actual.backward(g0)
- expected.backward(g1)
- def test_autocast_fused_upper_triangle_mask_softmax(self):
- for dtype in autocast_dtypes:
- msg = f"dtype: {dtype}"
- input_in_fp16 = dtype == torch.half
- input_in_bf16 = dtype == torch.bfloat16
- fused_fn, torch_fn = self._setup_fused_softmax(
- input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.causal
- )
- attn_weights_0 = (
- torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
- )
- with torch.no_grad():
- attn_weights_1 = (
- attn_weights_0.clone().to(dtype).requires_grad_(True)
- )
- total_mask = (
- ~(torch.tril(torch.randn((24, 24), device="cuda")).bool())
- .unsqueeze(0)
- .unsqueeze(0)
- )
- with torch.cuda.amp.autocast(dtype=dtype):
- actual = fused_fn(attn_weights_0, total_mask)
- self.assertEqual(actual.dtype, dtype, msg=msg)
- expected = torch_fn(attn_weights_1, total_mask)
- self.assertEqual(actual, expected, msg=msg)
- g0 = torch.randn_like(actual)
- with torch.no_grad():
- g1 = g0.clone()
- actual.backward(g0)
- expected.backward(g1)
- class TestGenericFusedSoftmaxKernel(common_utils.TestCase):
- def setUp(self):
- super().setUp()
- self.batch = 2
- self.attn = 16
- self.scale_t = 1.0
- self.dtype = torch.float16
- self.device = torch.cuda.current_device()
- self.thresh = {"atol": 1e-3, "rtol": 1e-3}
- qlen = [1, 2]
- klen = [1, 2, 3, 4, 5, 8, 10, 11, 13, 128, 256, 1200, 1234]
- available_cuda_mem = torch.cuda.memory.mem_get_info(self.device)[0] / (1024 ** 3)
- if available_cuda_mem > 40:
- qlen.extend([1234, 2322, 2348])
- klen.extend([2048, 3123, 4096, 4128, 7234, 8192])
- self.q_k_lens = itertools.product(qlen, klen)
- def tearDown(self) -> None:
- torch.cuda.empty_cache()
- super().tearDown()
- def test_forward(self, allmasked: bool=False):
- import generic_scaled_masked_softmax_cuda
- for qlen, klen in self.q_k_lens:
- inputs = torch.normal(0, 2, (self.batch, self.attn, qlen, klen), dtype=self.dtype, device=self.device)
- masks = (
- torch.randint(0, 2, (self.batch, 1, qlen, klen), dtype=torch.bool, device=self.device)
- if not allmasked else torch.ones((self.batch, 1, qlen, klen), dtype=torch.bool, device=self.device)
- )
- softmax_results = generic_scaled_masked_softmax_cuda.forward(inputs, masks, self.scale_t)
- softmax_results_torch = forward_torch_softmax(inputs, masks, self.scale_t)
- self.assertEqual(
- softmax_results_torch.to(self.dtype), softmax_results, **self.thresh, msg=f"(q, k) = ({qlen, klen})")
- def test_backward(self, allmasked: bool=False):
- import generic_scaled_masked_softmax_cuda
- prev_thresh = self.thresh
- self.thresh = {"atol": 1.5e-1, "rtol": 5e-3}
- for qlen, klen in self.q_k_lens:
- inputs = torch.normal(0, 2, (self.batch, self.attn, qlen, klen), dtype=self.dtype, device=self.device)
- backward = torch.rand_like(inputs, dtype=torch.float16, device=self.device)
- masks = (
- torch.randint(0, 2, (self.batch, 1, qlen, klen), dtype=torch.bool, device=self.device)
- if not allmasked else torch.ones((self.batch, 1, qlen, klen), dtype=torch.bool, device=self.device)
- )
- softmax_results = generic_scaled_masked_softmax_cuda.forward(inputs, masks, self.scale_t)
- back_grad = generic_scaled_masked_softmax_cuda.backward(backward, softmax_results, self.scale_t)
- inputs.requires_grad = True
- softmax_results_torch = forward_torch_softmax(inputs, masks, self.scale_t)
- softmax_results_torch.backward(backward)
- self.assertEqual(back_grad, inputs.grad, **self.thresh, msg=f"(q, k) = ({qlen, klen})")
- self.thresh = prev_thresh
- def test_allmasked(self):
- self.test_forward(True)
- def test_allmask_backward(self):
- self.test_backward(True)
- if __name__ == "__main__":
- common_utils.run_tests()
|