test_transformer_utils.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import logging
  2. import torch
  3. from torch.testing._internal import common_utils
  4. logging.getLogger("torch").setLevel(logging.WARNING)
  5. from apex.transformer import parallel_state
  6. from apex.transformer.tensor_parallel import utils
  7. from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
  8. logging.getLogger("apex").setLevel(logging.WARNING)
  9. class TransformerUtilsTest(NcclDistributedTestBase):
  10. def test_split_tensor_along_last_dim(self):
  11. for tensor_model_paralell_world_size in range(1, self.world_size + 1):
  12. if self.world_size % tensor_model_paralell_world_size > 0:
  13. continue
  14. parallel_state.initialize_model_parallel(
  15. tensor_model_parallel_size_=tensor_model_paralell_world_size
  16. )
  17. device = "cpu"
  18. input_tensor = torch.randn((100, 100, 100), device=device)
  19. splits = utils.split_tensor_along_last_dim(input_tensor, 10)
  20. last_dim_shapes = torch.tensor(
  21. [int(split.size()[-1]) for split in splits]
  22. )
  23. self.assertTrue(
  24. torch.equal(last_dim_shapes, torch.full((10,), 10),),
  25. msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}",
  26. )
  27. parallel_state.destroy_model_parallel()
  28. if __name__ == "__main__":
  29. common_utils.run_tests()