123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- import logging
- import torch
- from torch.testing._internal import common_utils
- from apex.transformer import parallel_state
- from apex.transformer.tensor_parallel import mappings
- from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
- from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
- logging.getLogger("torch").setLevel(logging.WARNING)
- logging.getLogger("apex").setLevel(logging.WARNING)
- class MappingTestBase:
- def test_reduce(self):
- for tensor_model_paralell_world_size in range(1, self.world_size + 1):
- if self.world_size % tensor_model_paralell_world_size > 0:
- continue
- parallel_state.initialize_model_parallel(
- tensor_model_parallel_size_=tensor_model_paralell_world_size
- )
- t = torch.full((10, 10, 10, 10), 50, device=f"cuda:{self.rank}")
- expected = torch.full(
- (10, 10, 10, 10),
- 50 * tensor_model_paralell_world_size,
- device=f"cuda:{self.rank}",
- )
- self.assertTrue(
- torch.equal(mappings._reduce(t), expected),
- msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}",
- )
- parallel_state.destroy_model_parallel()
- def test_split(self):
- for tensor_model_paralell_world_size in range(1, self.world_size + 1):
- if self.world_size % tensor_model_paralell_world_size > 0:
- continue
- parallel_state.initialize_model_parallel(
- tensor_model_parallel_size_=tensor_model_paralell_world_size
- )
- tensors = [
- torch.randn(10, 1)
- for _ in range(tensor_model_paralell_world_size)
- ]
- x = torch.cat(tensors, 1)
- out = mappings._split_along_last_dim(x)
- self.assertTrue(
- torch.equal(
- out, tensors[parallel_state.get_tensor_model_parallel_rank()]
- ),
- msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}"
- )
- parallel_state.destroy_model_parallel()
- def test_gather(self):
- for tensor_model_paralell_world_size in range(1, self.world_size + 1):
- if self.world_size % tensor_model_paralell_world_size > 0:
- continue
- parallel_state.initialize_model_parallel(
- tensor_model_parallel_size_=tensor_model_paralell_world_size
- )
- device = f"cuda:{self.rank}"
- gathered = mappings._gather_along_last_dim(
- torch.tensor(
- [parallel_state.get_tensor_model_parallel_rank()], device=device
- )
- )
- expected = torch.tensor(
- [rank for rank in range(tensor_model_paralell_world_size)],
- device=device,
- )
- self.assertTrue(
- torch.equal(gathered, expected),
- msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}",
- )
- parallel_state.destroy_model_parallel()
- class NcclMappingTest(MappingTestBase, NcclDistributedTestBase): pass
- class UccMappingTest(MappingTestBase, UccDistributedTestBase): pass
- if __name__ == "__main__":
- common_utils.run_tests()
|