test_mapping.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import logging
  2. import torch
  3. from torch.testing._internal import common_utils
  4. from apex.transformer import parallel_state
  5. from apex.transformer.tensor_parallel import mappings
  6. from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
  7. from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
  8. logging.getLogger("torch").setLevel(logging.WARNING)
  9. logging.getLogger("apex").setLevel(logging.WARNING)
  10. class MappingTestBase:
  11. def test_reduce(self):
  12. for tensor_model_paralell_world_size in range(1, self.world_size + 1):
  13. if self.world_size % tensor_model_paralell_world_size > 0:
  14. continue
  15. parallel_state.initialize_model_parallel(
  16. tensor_model_parallel_size_=tensor_model_paralell_world_size
  17. )
  18. t = torch.full((10, 10, 10, 10), 50, device=f"cuda:{self.rank}")
  19. expected = torch.full(
  20. (10, 10, 10, 10),
  21. 50 * tensor_model_paralell_world_size,
  22. device=f"cuda:{self.rank}",
  23. )
  24. self.assertTrue(
  25. torch.equal(mappings._reduce(t), expected),
  26. msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}",
  27. )
  28. parallel_state.destroy_model_parallel()
  29. def test_split(self):
  30. for tensor_model_paralell_world_size in range(1, self.world_size + 1):
  31. if self.world_size % tensor_model_paralell_world_size > 0:
  32. continue
  33. parallel_state.initialize_model_parallel(
  34. tensor_model_parallel_size_=tensor_model_paralell_world_size
  35. )
  36. tensors = [
  37. torch.randn(10, 1)
  38. for _ in range(tensor_model_paralell_world_size)
  39. ]
  40. x = torch.cat(tensors, 1)
  41. out = mappings._split_along_last_dim(x)
  42. self.assertTrue(
  43. torch.equal(
  44. out, tensors[parallel_state.get_tensor_model_parallel_rank()]
  45. ),
  46. msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}"
  47. )
  48. parallel_state.destroy_model_parallel()
  49. def test_gather(self):
  50. for tensor_model_paralell_world_size in range(1, self.world_size + 1):
  51. if self.world_size % tensor_model_paralell_world_size > 0:
  52. continue
  53. parallel_state.initialize_model_parallel(
  54. tensor_model_parallel_size_=tensor_model_paralell_world_size
  55. )
  56. device = f"cuda:{self.rank}"
  57. gathered = mappings._gather_along_last_dim(
  58. torch.tensor(
  59. [parallel_state.get_tensor_model_parallel_rank()], device=device
  60. )
  61. )
  62. expected = torch.tensor(
  63. [rank for rank in range(tensor_model_paralell_world_size)],
  64. device=device,
  65. )
  66. self.assertTrue(
  67. torch.equal(gathered, expected),
  68. msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}",
  69. )
  70. parallel_state.destroy_model_parallel()
  71. class NcclMappingTest(MappingTestBase, NcclDistributedTestBase): pass
  72. class UccMappingTest(MappingTestBase, UccDistributedTestBase): pass
  73. if __name__ == "__main__":
  74. common_utils.run_tests()