test_fused_softmax.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. """Test for fused softmax functions.
  2. Ref: https://github.com/NVIDIA/Megatron-LM/blob/40becfc96c4144985458ac0e0fae45dbb111fbd2/megatron/fused_kernels/tests/test_fused_kernels.py
  3. """ # NOQA
  4. import itertools
  5. import torch
  6. from torch.testing._internal import common_utils
  7. from apex.transformer import AttnMaskType
  8. from apex.transformer.functional import FusedScaleMaskSoftmax
  9. def attention_mask_func(attention_scores, attention_mask):
  10. return attention_scores.masked_fill(attention_mask, -10000.0)
  11. def forward_torch_softmax(input, mask, scale):
  12. input = input * scale
  13. mask_output = attention_mask_func(input, mask) if mask is not None else input
  14. probs = torch.nn.Softmax(dim=-1)(mask_output)
  15. all_k_masked = mask.all(axis=-1)
  16. zero_attention_mask = (1.0 - all_k_masked.float())[:, :, :, None]
  17. probs = probs * zero_attention_mask
  18. return probs
  19. autocast_dtypes = (
  20. (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
  21. )
  22. class TestFusedScaleMaskSoftmax(common_utils.TestCase):
  23. def _setup_fused_softmax(
  24. self,
  25. input_in_fp16,
  26. input_in_bf16,
  27. scale=None,
  28. softmax_in_fp32=False,
  29. attn_mask_type=AttnMaskType.padding,
  30. ):
  31. fused_fn = FusedScaleMaskSoftmax(
  32. input_in_fp16=input_in_fp16,
  33. input_in_bf16=input_in_bf16,
  34. mask_func=attention_mask_func,
  35. scale=scale,
  36. softmax_in_fp32=softmax_in_fp32,
  37. attn_mask_type=attn_mask_type,
  38. scaled_masked_softmax_fusion=True,
  39. )
  40. torch_fn = FusedScaleMaskSoftmax(
  41. input_in_fp16=input_in_fp16,
  42. input_in_bf16=input_in_bf16,
  43. mask_func=attention_mask_func,
  44. scale=scale,
  45. softmax_in_fp32=softmax_in_fp32,
  46. attn_mask_type=attn_mask_type,
  47. scaled_masked_softmax_fusion=False,
  48. )
  49. return fused_fn, torch_fn
  50. def tearDown(self) -> None:
  51. torch.cuda.empty_cache()
  52. super().tearDown()
  53. def test_fused_scale_mask_softmax(self):
  54. """
  55. attention_scores.shape = [4, 12, 24, 24]
  56. mask.shape = [4, 1, 24, 24]
  57. """
  58. for (dtype, scale, softmax_in_fp32, shape) in itertools.product(
  59. (torch.half, torch.bfloat16), (None, 2.0), (False, True), ((4, 12, 24, 24), (32, 12, 4, 214))
  60. ):
  61. msg = f"{dtype}-{scale}-{softmax_in_fp32}"
  62. input_in_fp16 = dtype == torch.half
  63. input_in_bf16 = dtype == torch.bfloat16
  64. if not (scale is None or softmax_in_fp32):
  65. with self.assertRaises(RuntimeError, msg=msg):
  66. self._setup_fused_softmax(
  67. input_in_fp16,
  68. input_in_bf16,
  69. scale,
  70. softmax_in_fp32,
  71. AttnMaskType.padding,
  72. )
  73. return
  74. fused_fn, torch_fn = self._setup_fused_softmax(
  75. input_in_fp16,
  76. input_in_bf16,
  77. scale,
  78. softmax_in_fp32,
  79. AttnMaskType.padding,
  80. )
  81. attention_scores_0 = (
  82. torch.randn(shape)
  83. .to(device="cuda", dtype=dtype)
  84. .requires_grad_(True)
  85. )
  86. with torch.no_grad():
  87. attention_scores_1 = attention_scores_0.clone().requires_grad_(True)
  88. mask_shape = (shape[0],) + (1,) + shape[2:]
  89. mask = torch.randint(0, 2, mask_shape, device="cuda").bool()
  90. expected = fused_fn(attention_scores_0, mask)
  91. actual = torch_fn(attention_scores_1, mask)
  92. self.assertEqual(actual, expected, msg=msg)
  93. g0 = torch.rand_like(actual)
  94. with torch.no_grad():
  95. g1 = g0.clone()
  96. expected.backward(g0)
  97. actual.backward(g1)
  98. def test_autocast_fused_scale_mask_softmax(self):
  99. for dtype in autocast_dtypes:
  100. msg = f"dtype: {dtype}"
  101. input_in_fp16 = dtype == torch.half
  102. input_in_bf16 = dtype == torch.bfloat16
  103. fused_fn, torch_fn = self._setup_fused_softmax(
  104. input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.padding
  105. )
  106. attention_scores_0 = (
  107. torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
  108. )
  109. with torch.no_grad():
  110. attention_scores_1 = (
  111. attention_scores_0.clone().to(dtype).requires_grad_(True)
  112. )
  113. mask = torch.randint(0, 2, (4, 1, 24, 24)).bool().cuda()
  114. expected = torch_fn(attention_scores_1, mask)
  115. with torch.cuda.amp.autocast(dtype=dtype):
  116. actual = fused_fn(attention_scores_0, mask)
  117. self.assertEqual(actual.dtype, dtype, msg=msg)
  118. self.assertEqual(actual, expected, msg=msg)
  119. g0 = torch.rand_like(actual)
  120. with torch.no_grad():
  121. g1 = g0.clone()
  122. expected.backward(g0)
  123. actual.backward(g1)
  124. def test_fused_scale_softmax(self):
  125. """
  126. attention_scores.shape = [4, 12, 24, 24]
  127. mask = None
  128. """
  129. for (dtype, scale, softmax_in_fp32, shape) in itertools.product(
  130. (torch.half, torch.bfloat16), (None, 2.0), (False, True), ((4, 12, 24, 24), (32, 12, 4, 214))
  131. ):
  132. msg = f"{dtype}-{scale}-{softmax_in_fp32}"
  133. input_in_fp16 = dtype == torch.half
  134. input_in_bf16 = dtype == torch.bfloat16
  135. if not (scale is None or softmax_in_fp32):
  136. with self.assertRaises(RuntimeError, msg=msg):
  137. self._setup_fused_softmax(
  138. input_in_fp16,
  139. input_in_bf16,
  140. scale,
  141. softmax_in_fp32,
  142. AttnMaskType.padding,
  143. )
  144. return
  145. fused_fn, torch_fn = self._setup_fused_softmax(
  146. input_in_fp16,
  147. input_in_bf16,
  148. scale,
  149. softmax_in_fp32,
  150. AttnMaskType.padding,
  151. )
  152. attention_scores_0 = (
  153. torch.randn(shape)
  154. .to(device="cuda", dtype=dtype)
  155. .requires_grad_(True)
  156. )
  157. with torch.no_grad():
  158. attention_scores_1 = attention_scores_0.clone().requires_grad_(True)
  159. mask = None
  160. expected = fused_fn(attention_scores_0, mask)
  161. actual = torch_fn(attention_scores_1, mask)
  162. self.assertEqual(actual, expected, msg=msg)
  163. g0 = torch.rand_like(actual)
  164. with torch.no_grad():
  165. g1 = g0.clone()
  166. expected.backward(g0)
  167. actual.backward(g1)
  168. def test_autocast_fused_scale_softmax(self):
  169. for dtype in autocast_dtypes:
  170. msg = f"dtype: {dtype}"
  171. input_in_fp16 = dtype == torch.half
  172. input_in_bf16 = dtype == torch.bfloat16
  173. fused_fn, torch_fn = self._setup_fused_softmax(
  174. input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.padding
  175. )
  176. attention_scores_0 = (
  177. torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
  178. )
  179. with torch.no_grad():
  180. attention_scores_1 = (
  181. attention_scores_0.clone().to(dtype).requires_grad_(True)
  182. )
  183. mask = None
  184. expected = torch_fn(attention_scores_1, mask)
  185. with torch.cuda.amp.autocast(dtype=dtype):
  186. actual = fused_fn(attention_scores_0, mask)
  187. self.assertEqual(actual.dtype, dtype, msg=msg)
  188. self.assertEqual(actual, expected, msg=msg)
  189. g0 = torch.rand_like(actual)
  190. with torch.no_grad():
  191. g1 = g0.clone()
  192. expected.backward(g0)
  193. actual.backward(g1)
  194. def test_fused_upper_triangle_mask_softmax(self):
  195. """
  196. attn_weights.shape: [4, 12, 24, 24]
  197. total_mask.shape: [4, 1, 24, 24]
  198. total_mask[0, 0], a 24x24 matrix is like a lower triangular matrix, but
  199. upper elements are True and lower elements and diagonal are False.
  200. """
  201. for (dtype, scale, softmax_in_fp32) in itertools.product(
  202. (torch.half, torch.bfloat16), (None, 2.0), (False, True),
  203. ):
  204. msg = f"{dtype}-{scale}-{softmax_in_fp32}"
  205. input_in_fp16 = dtype == torch.half
  206. input_in_bf16 = dtype == torch.bfloat16
  207. if not (scale is None or softmax_in_fp32):
  208. with self.assertRaises(RuntimeError, msg=msg):
  209. self._setup_fused_softmax(
  210. input_in_fp16,
  211. input_in_bf16,
  212. scale,
  213. softmax_in_fp32,
  214. AttnMaskType.causal,
  215. )
  216. return
  217. fused_fn, torch_fn = self._setup_fused_softmax(
  218. input_in_fp16,
  219. input_in_bf16,
  220. scale,
  221. softmax_in_fp32,
  222. AttnMaskType.causal,
  223. )
  224. attn_weights_0 = (
  225. torch.randn((4, 12, 24, 24))
  226. .to(device="cuda", dtype=dtype)
  227. .requires_grad_(True)
  228. )
  229. with torch.no_grad():
  230. attn_weights_1 = attn_weights_0.clone().requires_grad_(True)
  231. total_mask = (
  232. ~(torch.tril(torch.randn((24, 24), device="cuda")).bool())
  233. .unsqueeze(0)
  234. .unsqueeze(0)
  235. )
  236. total_mask = total_mask.repeat((4, 1, 1, 1))
  237. expected = fused_fn(attn_weights_0, total_mask)
  238. actual = torch_fn(attn_weights_1, total_mask)
  239. self.assertEqual(actual, expected, msg=msg)
  240. g0 = torch.randn_like(actual)
  241. with torch.no_grad():
  242. g1 = g0.clone()
  243. actual.backward(g0)
  244. expected.backward(g1)
  245. def test_autocast_fused_upper_triangle_mask_softmax(self):
  246. for dtype in autocast_dtypes:
  247. msg = f"dtype: {dtype}"
  248. input_in_fp16 = dtype == torch.half
  249. input_in_bf16 = dtype == torch.bfloat16
  250. fused_fn, torch_fn = self._setup_fused_softmax(
  251. input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.causal
  252. )
  253. attn_weights_0 = (
  254. torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
  255. )
  256. with torch.no_grad():
  257. attn_weights_1 = (
  258. attn_weights_0.clone().to(dtype).requires_grad_(True)
  259. )
  260. total_mask = (
  261. ~(torch.tril(torch.randn((24, 24), device="cuda")).bool())
  262. .unsqueeze(0)
  263. .unsqueeze(0)
  264. )
  265. with torch.cuda.amp.autocast(dtype=dtype):
  266. actual = fused_fn(attn_weights_0, total_mask)
  267. self.assertEqual(actual.dtype, dtype, msg=msg)
  268. expected = torch_fn(attn_weights_1, total_mask)
  269. self.assertEqual(actual, expected, msg=msg)
  270. g0 = torch.randn_like(actual)
  271. with torch.no_grad():
  272. g1 = g0.clone()
  273. actual.backward(g0)
  274. expected.backward(g1)
  275. class TestGenericFusedSoftmaxKernel(common_utils.TestCase):
  276. def setUp(self):
  277. super().setUp()
  278. self.batch = 2
  279. self.attn = 16
  280. self.scale_t = 1.0
  281. self.dtype = torch.float16
  282. self.device = torch.cuda.current_device()
  283. self.thresh = {"atol": 1e-3, "rtol": 1e-3}
  284. qlen = [1, 2]
  285. klen = [1, 2, 3, 4, 5, 8, 10, 11, 13, 128, 256, 1200, 1234]
  286. available_cuda_mem = torch.cuda.memory.mem_get_info(self.device)[0] / (1024 ** 3)
  287. if available_cuda_mem > 40:
  288. qlen.extend([1234, 2322, 2348])
  289. klen.extend([2048, 3123, 4096, 4128, 7234, 8192])
  290. self.q_k_lens = itertools.product(qlen, klen)
  291. def tearDown(self) -> None:
  292. torch.cuda.empty_cache()
  293. super().tearDown()
  294. def test_forward(self, allmasked: bool=False):
  295. import generic_scaled_masked_softmax_cuda
  296. for qlen, klen in self.q_k_lens:
  297. inputs = torch.normal(0, 2, (self.batch, self.attn, qlen, klen), dtype=self.dtype, device=self.device)
  298. masks = (
  299. torch.randint(0, 2, (self.batch, 1, qlen, klen), dtype=torch.bool, device=self.device)
  300. if not allmasked else torch.ones((self.batch, 1, qlen, klen), dtype=torch.bool, device=self.device)
  301. )
  302. softmax_results = generic_scaled_masked_softmax_cuda.forward(inputs, masks, self.scale_t)
  303. softmax_results_torch = forward_torch_softmax(inputs, masks, self.scale_t)
  304. self.assertEqual(
  305. softmax_results_torch.to(self.dtype), softmax_results, **self.thresh, msg=f"(q, k) = ({qlen, klen})")
  306. def test_backward(self, allmasked: bool=False):
  307. import generic_scaled_masked_softmax_cuda
  308. prev_thresh = self.thresh
  309. self.thresh = {"atol": 1.5e-1, "rtol": 5e-3}
  310. for qlen, klen in self.q_k_lens:
  311. inputs = torch.normal(0, 2, (self.batch, self.attn, qlen, klen), dtype=self.dtype, device=self.device)
  312. backward = torch.rand_like(inputs, dtype=torch.float16, device=self.device)
  313. masks = (
  314. torch.randint(0, 2, (self.batch, 1, qlen, klen), dtype=torch.bool, device=self.device)
  315. if not allmasked else torch.ones((self.batch, 1, qlen, klen), dtype=torch.bool, device=self.device)
  316. )
  317. softmax_results = generic_scaled_masked_softmax_cuda.forward(inputs, masks, self.scale_t)
  318. back_grad = generic_scaled_masked_softmax_cuda.backward(backward, softmax_results, self.scale_t)
  319. inputs.requires_grad = True
  320. softmax_results_torch = forward_torch_softmax(inputs, masks, self.scale_t)
  321. softmax_results_torch.backward(backward)
  322. self.assertEqual(back_grad, inputs.grad, **self.thresh, msg=f"(q, k) = ({qlen, klen})")
  323. self.thresh = prev_thresh
  324. def test_allmasked(self):
  325. self.test_forward(True)
  326. def test_allmask_backward(self):
  327. self.test_backward(True)
  328. if __name__ == "__main__":
  329. common_utils.run_tests()