import logging import torch from torch.testing._internal import common_utils from apex.transformer import parallel_state from apex.transformer.tensor_parallel import mappings from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase from apex.transformer.testing.distributed_test_base import UccDistributedTestBase logging.getLogger("torch").setLevel(logging.WARNING) logging.getLogger("apex").setLevel(logging.WARNING) class MappingTestBase: def test_reduce(self): for tensor_model_paralell_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_paralell_world_size > 0: continue parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_paralell_world_size ) t = torch.full((10, 10, 10, 10), 50, device=f"cuda:{self.rank}") expected = torch.full( (10, 10, 10, 10), 50 * tensor_model_paralell_world_size, device=f"cuda:{self.rank}", ) self.assertTrue( torch.equal(mappings._reduce(t), expected), msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}", ) parallel_state.destroy_model_parallel() def test_split(self): for tensor_model_paralell_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_paralell_world_size > 0: continue parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_paralell_world_size ) tensors = [ torch.randn(10, 1) for _ in range(tensor_model_paralell_world_size) ] x = torch.cat(tensors, 1) out = mappings._split_along_last_dim(x) self.assertTrue( torch.equal( out, tensors[parallel_state.get_tensor_model_parallel_rank()] ), msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}" ) parallel_state.destroy_model_parallel() def test_gather(self): for tensor_model_paralell_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_paralell_world_size > 0: continue parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_paralell_world_size ) device = f"cuda:{self.rank}" gathered = mappings._gather_along_last_dim( torch.tensor( [parallel_state.get_tensor_model_parallel_rank()], device=device ) ) expected = torch.tensor( [rank for rank in range(tensor_model_paralell_world_size)], device=device, ) self.assertTrue( torch.equal(gathered, expected), msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}", ) parallel_state.destroy_model_parallel() class NcclMappingTest(MappingTestBase, NcclDistributedTestBase): pass class UccMappingTest(MappingTestBase, UccDistributedTestBase): pass if __name__ == "__main__": common_utils.run_tests()