import logging import unittest import typing import torch import torch.nn as nn from torch.testing._internal import common_utils from apex.transformer import parallel_state from apex.transformer.tensor_parallel import layers from apex.transformer.testing.commons import set_random_seed from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase from apex.transformer.testing.distributed_test_base import UccDistributedTestBase logging.getLogger("torch").setLevel(logging.WARNING) logging.getLogger("apex").setLevel(logging.WARNING) # N.B.(mkozuki): Disable TF32 matrix multiply. # Matrices used in this test are so small that TF32 matmul # can be less precise so that `self.assertEqual` raises. torch.backends.cuda.matmul.allow_tf32 = False class TensorParallelLayerTestBase: BATCH_SIZE: int = 8 SEQUENCE_LENGTH: int = 128 VOCAB_SIZE: int = 1024 HIDDEN_SIZE: int = 256 INPUT_SIZE_COEFF: int = 256 OUTPUT_SIZE_COEFF: int = 256 SEED: int = 123456 @property def tensor_shape(self) -> typing.Sequence[int]: return [self.SEQUENCE_LENGTH, self.BATCH_SIZE, self.HIDDEN_SIZE] @torch.no_grad() @unittest.skipIf(torch.cuda.device_count() < 2, "Requires >=2 GPUs") def test_all_gather_parity(self) -> None: if self.DISTRIBUTED_BACKEND == "ucc": self.skipTest("torch_ucc does NOT support `torch.distributed._all_gather_base` as of 2022/06/15") from torch.distributed.distributed_c10d import all_gather, _all_gather_base # NOQA for tensor_model_parallel_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_parallel_world_size: continue parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size, ) tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank() cur_tensor_model_device = torch.device(f"cuda:{tensor_model_parallel_rank}") with torch.no_grad(): tensor = tensor_model_parallel_rank * torch.ones( self.tensor_shape, dtype=torch.float32, device=cur_tensor_model_device) numel = tensor.numel() numel_gathered = tensor_model_parallel_world_size * numel gathered = torch.empty( torch.Size((numel_gathered,)), device=cur_tensor_model_device, dtype=torch.float32, requires_grad=False, ) chunks = [ gathered[i * numel : (i + 1) * numel] for i in range(tensor_model_parallel_world_size) ] all_gather(chunks, tensor, group=parallel_state.get_tensor_model_parallel_group()) gathered_for_base = torch.empty( torch.Size((numel_gathered,)), device=cur_tensor_model_device, dtype=torch.float32, requires_grad=False, ) _all_gather_base( gathered_for_base, tensor, group=parallel_state.get_tensor_model_parallel_group(), ) msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}" self.assertEqual(gathered, gathered_for_base, msg=msg) parallel_state.destroy_model_parallel() @torch.no_grad() @unittest.skipIf(torch.cuda.device_count() < 2, "Requires >=2 GPUs") def test_reduce_scatter_parity(self) -> None: if self.DISTRIBUTED_BACKEND == "ucc": self.skipTest("torch_ucc does NOT support `torch.distributed._reduce_scatter_base` as of 2022/06/15") from torch.distributed.distributed_c10d import reduce_scatter, _reduce_scatter_base # NOQA for tensor_model_parallel_world_size in range(2, self.world_size + 1): if self.world_size % tensor_model_parallel_world_size: continue parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size, ) tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank() cur_tensor_model_device = torch.device(f"cuda:{tensor_model_parallel_rank}") with torch.no_grad(): input = torch.cat([ i * torch.ones(self.tensor_shape, dtype=torch.float32, device=cur_tensor_model_device) for i in range(tensor_model_parallel_world_size) ]) input_list = [t.clone() for t in input.chunk(tensor_model_parallel_world_size)] output = torch.empty( self.tensor_shape, device=cur_tensor_model_device, dtype=torch.float32, requires_grad=False, ) reduce_scatter( output, input_list, group=parallel_state.get_tensor_model_parallel_group(), ) output_for_base = torch.empty( self.tensor_shape, device=cur_tensor_model_device, dtype=torch.float32, requires_grad=False, ) _reduce_scatter_base( output_for_base, input, group=parallel_state.get_tensor_model_parallel_group(), ) msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}" self.assertEqual(output, output_for_base, msg=msg) self.assertEqual(input, torch.cat(input_list), msg=msg) parallel_state.destroy_model_parallel() def test_parallel_embedding(self) -> None: for tensor_model_parallel_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_parallel_world_size: continue parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size, ) set_random_seed(self.SEED + 1) input_tensor = torch.randint( 0, self.VOCAB_SIZE, ( self.BATCH_SIZE, self.SEQUENCE_LENGTH, ), device="cuda", ) loss_weight = torch.randn( ( self.BATCH_SIZE, self.SEQUENCE_LENGTH, self.HIDDEN_SIZE, ), device="cuda", ) set_random_seed(self.SEED) embedding_torch = nn.Embedding( self.VOCAB_SIZE, self.HIDDEN_SIZE, ).cuda() output_torch = embedding_torch(input_tensor) loss_torch = torch.mul(output_torch, loss_weight).sum() loss_torch.backward() # N.B.(mkozuki): With affine weight initialization on GPU, # it's super difficult to keep the consistency with nn.Embedding. # Thus, turning on `use_cpu_initialization`. set_random_seed(self.SEED) embedding_vocab_parallel = layers.VocabParallelEmbedding( self.VOCAB_SIZE, self.HIDDEN_SIZE, init_method=nn.init.normal_, use_cpu_initialization=True, ).cuda() output_vocab_parallel = embedding_vocab_parallel(input_tensor) loss_vocab_parallel = torch.mul( output_vocab_parallel, loss_weight ).sum() loss_vocab_parallel.backward() msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}" self.assertEqual(output_torch, output_vocab_parallel, msg=msg) self.assertEqual(loss_torch, loss_vocab_parallel, msg=msg) splitted_weight_torch = torch.split( embedding_torch.weight.grad, self.VOCAB_SIZE // tensor_model_parallel_world_size, 0, )[parallel_state.get_tensor_model_parallel_rank()] self.assertEqual( splitted_weight_torch, embedding_vocab_parallel.weight.grad, msg=msg, ) parallel_state.destroy_model_parallel() def _affine_weight_init_test_impl( self, init_device: str, is_column_parallel: bool ) -> None: dim = int(not is_column_parallel) for tensor_model_parallel_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_parallel_world_size: continue parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size ) input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size weight_shape = ( (self.OUTPUT_SIZE_COEFF, input_size) if is_column_parallel else (output_size, self.INPUT_SIZE_COEFF) ) weight = torch.empty(weight_shape) set_random_seed(self.SEED) sharding_dim_size = ( self.OUTPUT_SIZE_COEFF if is_column_parallel else self.INPUT_SIZE_COEFF ) if init_device == "cpu": layers._initialize_affine_weight_cpu( weight, output_size, input_size, sharding_dim_size, dim, nn.init.normal_, params_dtype=torch.float32, ) else: layers._initialize_affine_weight_gpu( weight, torch.nn.init.normal_, dim ) # Target set_random_seed(self.SEED) if init_device == "cpu": main_weight = torch.empty(output_size, input_size) nn.init.normal_(main_weight) curr_weight = torch.split(main_weight, sharding_dim_size, dim=dim)[ parallel_state.get_tensor_model_parallel_rank() ] else: curr_weight = torch.empty(*weight_shape) nn.init.normal_(curr_weight) self.assertEqual( curr_weight, weight, msg=f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}") parallel_state.destroy_model_parallel() def test_affine_weight_init_column_parallel_cpu(self) -> None: self._affine_weight_init_test_impl(init_device="cpu", is_column_parallel=True) def test_affine_weight_init_column_parallel_gpu(self) -> None: self._affine_weight_init_test_impl(init_device="gpu", is_column_parallel=True) def test_affine_weight_init_row_parallel_cpu(self) -> None: self._affine_weight_init_test_impl(init_device="cpu", is_column_parallel=False) def test_affine_weight_init_row_parallel_gpu(self) -> None: self._affine_weight_init_test_impl(init_device="gpu", is_column_parallel=False) def test_row_parallel_linear(self) -> None: self._row_parallel_linear_test_impl(False, False, False) def test_row_parallel_linear_gradient_accumulation_fusion(self) -> None: self._row_parallel_linear_test_impl(True, False, False) def test_row_parallel_linear_gradient_accumulation_fusion_in_fp16(self) -> None: self._row_parallel_linear_test_impl(True, True, False) # fails on native ucc and torch ucc: ucc does not support reduce scatter @unittest.skipIf(torch.cuda.device_count() < 2, "Sequence Parallel requires >=2 GPUs") def test_row_parallel_linear_sequence_parallel(self) -> None: self._row_parallel_linear_test_impl(False, False, True) # TODO(mkozuki): Merge this with `_column_parallel_linear_test_impl` # Note that `input_is_parallel` is unique to `RowParallelLinear` which could make the merge complicated. def _row_parallel_linear_test_impl( self, gradient_accumulation_fusion: bool, accumulation_in_fp16: bool, sequence_parallel_enabled: bool, ) -> None: tensor_shape = ( self.SEQUENCE_LENGTH, self.BATCH_SIZE, self.HIDDEN_SIZE, ) for tensor_model_parallel_world_size in range( 1 + int(sequence_parallel_enabled), self.world_size + 1 ): if self.world_size % tensor_model_parallel_world_size: continue parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size, ) set_random_seed(self.SEED) linear = layers.RowParallelLinear( self.HIDDEN_SIZE, self.HIDDEN_SIZE, keep_master_weight_for_test=True, params_dtype=torch.float32, use_cpu_initialization=True, gradient_accumulation_fusion=gradient_accumulation_fusion, accumulation_in_fp16=accumulation_in_fp16, sequence_parallel_enabled=sequence_parallel_enabled, # n.b.(mkozuki): RowParallelLinear is constructed with `input_is_parallel=True` # by default, e.g. https://github.com/NVIDIA/NeMo/blob/782b4e1652aaa43c8be390d9\ # db0dc89544afa080/nemo/collections/nlp/modules/common/megatron/transformer.py#L204 input_is_parallel=True, ).cuda() if accumulation_in_fp16: linear = linear.half() # Simulate the situation where fusion of weight grad calculation and gradient accumulation is enabled. if gradient_accumulation_fusion: with torch.no_grad(): linear.weight.main_grad = torch.zeros_like(linear.weight) msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}" with torch.no_grad(): orig_input_tensor = torch.randn(tensor_shape, requires_grad=True, device="cuda") orig_loss_weight = torch.randn(tensor_shape, device="cuda") input_tensor = orig_input_tensor.chunk( chunks=tensor_model_parallel_world_size, dim=2, )[parallel_state.get_tensor_model_parallel_rank()].contiguous() if sequence_parallel_enabled: loss_weight = orig_loss_weight.chunk( chunks=tensor_model_parallel_world_size, dim=0, )[parallel_state.get_tensor_model_parallel_rank()] else: loss_weight = orig_loss_weight if accumulation_in_fp16: orig_input_tensor = orig_input_tensor.half() input_tensor = input_tensor.half() loss_weight = loss_weight.half() input_tensor.requires_grad_() output, _ = linear(input_tensor) loss = torch.mul(output, loss_weight).sum() loss.backward() self.assertIsNotNone(input_tensor.grad, msg=msg) ref_linear = nn.Linear( in_features=self.HIDDEN_SIZE, out_features=self.HIDDEN_SIZE, bias=False, device="cuda", ) with torch.no_grad(): dldy = orig_loss_weight.clone() x = orig_input_tensor.clone() ref_linear.weight.copy_(linear.master_weight) if accumulation_in_fp16: ref_linear = ref_linear.half() x.requires_grad_() expected_output = ref_linear(x) expected_loss = torch.mul(expected_output, dldy).sum() expected_loss.backward() if not accumulation_in_fp16: if sequence_parallel_enabled: self.assertEqual( x=output, y=expected_output.chunk( chunks=tensor_model_parallel_world_size, dim=0, )[parallel_state.get_tensor_model_parallel_rank()], msg=msg, ) else: self.assertEqual( x=output, y=expected_output, msg=msg, ) grad_attr_name = "main_grad" if gradient_accumulation_fusion else "grad" # NOTE(mkozuki): Numerical errors seems to be enlarged by tensor model parallel. if tensor_model_parallel_world_size == 1: self.assertEqual( x=getattr(linear.weight, grad_attr_name), y=ref_linear.weight.grad.chunk( chunks=tensor_model_parallel_world_size, dim=0, )[parallel_state.get_tensor_model_parallel_rank()], msg=msg, ) parallel_state.destroy_model_parallel() def test_column_parallel_linear(self): self._column_parallel_linear_test_impl(False, False, False, False) def test_column_parallel_linear_async(self): self._column_parallel_linear_test_impl(True, False, False, False) def test_column_parallel_linear_gradient_accumulation_fusion(self): self._column_parallel_linear_test_impl(False, True, False, False) def test_column_parallel_linear_gradient_accumulation_fusion_in_fp16(self): self._column_parallel_linear_test_impl(False, True, True, False) def test_column_parallel_linear_sequence_parallel(self): if self.DISTRIBUTED_BACKEND == "ucc": self.skipTest("Backward's reduce_scatter fails. as of 2022/06/15") self._column_parallel_linear_test_impl(False, False, False, True) @unittest.skipIf(torch.cuda.device_count() < 2, "Sequence Parallel requires >= 2 GPUs") def test_column_parallel_linear_exception(self): with self.assertRaisesRegex( RuntimeError, "`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` cannot be enabled at the same time.", ): self._column_parallel_linear_test_impl(True, False, False, True) def _column_parallel_linear_test_impl( self, async_tensor_model_parallel_allreduce: bool, gradient_accumulation_fusion: bool, accumulation_in_fp16: bool, sequence_parallel_enabled: bool, ): for tensor_model_parallel_world_size in range(1, self.world_size + 1): if async_tensor_model_parallel_allreduce and sequence_parallel_enabled: if tensor_model_parallel_world_size == 1: continue if self.world_size % tensor_model_parallel_world_size: continue msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}" parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size, ) input_tensor_shape = self.tensor_shape expected_output_shape = self.tensor_shape # When sequence parallel, `gather_output` is disabled, i.e., # output of matmul isn't gathered in dimension of feature/hidden (last dim). if sequence_parallel_enabled: expected_output_shape[-1] //= tensor_model_parallel_world_size # tensor's shape is [sequence length, batch size, hidden size] set_random_seed(self.SEED) linear = layers.ColumnParallelLinear( self.HIDDEN_SIZE, self.HIDDEN_SIZE, bias=False, keep_master_weight_for_test=True, params_dtype=torch.float32, use_cpu_initialization=True, gather_output=not sequence_parallel_enabled, no_async_tensor_model_parallel_allreduce=not async_tensor_model_parallel_allreduce, gradient_accumulation_fusion=gradient_accumulation_fusion, accumulation_in_fp16=accumulation_in_fp16, sequence_parallel_enabled=sequence_parallel_enabled, ).cuda() if accumulation_in_fp16: linear = linear.half() # Simulate the situation where fusion of weight grad calculation and gradient accumulation happens. if gradient_accumulation_fusion: with torch.no_grad(): linear.weight.main_grad = torch.zeros_like(linear.weight) orig_input_tensor = torch.randn(input_tensor_shape, device="cuda", requires_grad=True) if accumulation_in_fp16: orig_input_tensor = orig_input_tensor.half() if sequence_parallel_enabled: input_tensor = list( orig_input_tensor.chunk(tensor_model_parallel_world_size, dim=0) )[parallel_state.get_tensor_model_parallel_rank()] else: input_tensor = orig_input_tensor output, _ = linear(input_tensor) # The order of dimension is expected to be (sequence, batch, hidden) self.assertEqual(output.shape, expected_output_shape, msg=msg) orig_loss_weight = torch.randn(input_tensor_shape, device="cuda") if accumulation_in_fp16: orig_loss_weight = orig_loss_weight.half() if sequence_parallel_enabled: loss_weight = orig_loss_weight.chunk( tensor_model_parallel_world_size, dim=2, )[parallel_state.get_tensor_model_parallel_rank()] else: loss_weight = orig_loss_weight loss = torch.mul(output, loss_weight).sum() loss.backward() with torch.no_grad(): dldy = orig_loss_weight.clone() x = orig_input_tensor.clone() ref_linear = nn.Linear( in_features=self.HIDDEN_SIZE, out_features=self.HIDDEN_SIZE, bias=False, device="cuda", ) if accumulation_in_fp16: ref_linear = ref_linear.half() # NOTE(mkozuki): `master_weight` is available because `keep_master_weight_for_test` is set. ref_linear.weight.copy_(linear.master_weight) x.requires_grad_() expected_output = ref_linear(x) if sequence_parallel_enabled: chunk = expected_output.chunk( tensor_model_parallel_world_size, dim=2, )[parallel_state.get_tensor_model_parallel_rank()] self.assertEqual( x=output, y=chunk, msg=msg, ) else: self.assertEqual( x=output, y=expected_output, msg=msg, ) expected_loss = torch.mul(expected_output, dldy).sum() expected_loss.backward() grad_attr_name = "main_grad" if gradient_accumulation_fusion else "grad" # NOTE(mkozuki): Numerical errors seems to be enlarged by tensor model parallel. if tensor_model_parallel_world_size == 1: self.assertEqual( x=getattr(linear.weight, grad_attr_name), y=ref_linear.weight.grad.chunk( chunks=tensor_model_parallel_world_size, dim=0, )[parallel_state.get_tensor_model_parallel_rank()], msg=msg, ) parallel_state.destroy_model_parallel() class NcclTensorParallelLayerTest(TensorParallelLayerTestBase, NcclDistributedTestBase): pass class UccTensorParallelLayerTest(TensorParallelLayerTestBase, UccDistributedTestBase): pass if __name__ == "__main__": common_utils.run_tests()