12345678910111213141516171819202122232425262728293031323334353637383940 |
- import logging
- import torch
- from torch.testing._internal import common_utils
- logging.getLogger("torch").setLevel(logging.WARNING)
- from apex.transformer import parallel_state
- from apex.transformer.tensor_parallel import utils
- from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
- logging.getLogger("apex").setLevel(logging.WARNING)
- class TransformerUtilsTest(NcclDistributedTestBase):
- def test_split_tensor_along_last_dim(self):
- for tensor_model_paralell_world_size in range(1, self.world_size + 1):
- if self.world_size % tensor_model_paralell_world_size > 0:
- continue
- parallel_state.initialize_model_parallel(
- tensor_model_parallel_size_=tensor_model_paralell_world_size
- )
- device = "cpu"
- input_tensor = torch.randn((100, 100, 100), device=device)
- splits = utils.split_tensor_along_last_dim(input_tensor, 10)
- last_dim_shapes = torch.tensor(
- [int(split.size()[-1]) for split in splits]
- )
- self.assertTrue(
- torch.equal(last_dim_shapes, torch.full((10,), 10),),
- msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}",
- )
- parallel_state.destroy_model_parallel()
- if __name__ == "__main__":
- common_utils.run_tests()
|