test_parallel_state.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import logging
  2. import os
  3. from torch.testing._internal import common_utils
  4. logging.getLogger("torch").setLevel(logging.WARNING)
  5. from apex.transformer import parallel_state
  6. from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
  7. from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
  8. logging.getLogger("apex").setLevel(logging.WARNING)
  9. os.environ["BACKEND"] = "NCCL"
  10. DATA_PARALLEL_WORLD_SIZE: int = 1
  11. def calc_expected_tensor_model_paralell_rank(
  12. rank: int, tensor_model_parallel_world_size: int,
  13. ) -> int:
  14. return rank % tensor_model_parallel_world_size
  15. class ParallelStateTestBase:
  16. def test_initialize_model_parallel(self) -> None:
  17. self.assertFalse(parallel_state.model_parallel_is_initialized())
  18. for tensor_model_parallel_world_size in range(1, self.world_size + 1):
  19. msg = f"tensor_model_parallel_world_siz: {tensor_model_parallel_world_size}"
  20. if self.world_size % tensor_model_parallel_world_size:
  21. continue
  22. pipeline_model_parallel_world_size = (
  23. self.world_size // tensor_model_parallel_world_size
  24. )
  25. parallel_state.initialize_model_parallel(
  26. tensor_model_parallel_size_=tensor_model_parallel_world_size,
  27. pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
  28. )
  29. self.assertEqual(
  30. tensor_model_parallel_world_size,
  31. parallel_state.get_tensor_model_parallel_world_size(),
  32. msg=msg,
  33. )
  34. expected_tensor_model_parallel_rank = calc_expected_tensor_model_paralell_rank(
  35. self.rank, tensor_model_parallel_world_size
  36. )
  37. self.assertEqual(
  38. expected_tensor_model_parallel_rank,
  39. parallel_state.get_tensor_model_parallel_rank(),
  40. msg=msg,
  41. )
  42. expected_tensor_model_parallel_src_rank = (
  43. self.rank // tensor_model_parallel_world_size
  44. ) * tensor_model_parallel_world_size
  45. self.assertEqual(
  46. expected_tensor_model_parallel_src_rank,
  47. parallel_state.get_tensor_model_parallel_src_rank(),
  48. msg=msg,
  49. )
  50. parallel_state.destroy_model_parallel()
  51. self.assertFalse(parallel_state.model_parallel_is_initialized(), msg=msg)
  52. def test_initialize_model_parallel_with_virtual_and_split(self) -> None:
  53. if self.world_size < 4:
  54. self.skipTest("requires >= 4 GPUs")
  55. self.assertFalse(parallel_state.model_parallel_is_initialized())
  56. tensor_model_parallel_world_size = 1 + int(self.world_size > 4)
  57. pipeline_model_parallel_world_size = (
  58. self.world_size // tensor_model_parallel_world_size
  59. )
  60. virtual_pipeline_model_parallel_world_size = 2
  61. pipeline_model_parallel_split_rank = pipeline_model_parallel_world_size // 2
  62. parallel_state.initialize_model_parallel(
  63. tensor_model_parallel_size_=tensor_model_parallel_world_size,
  64. pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
  65. virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_world_size,
  66. pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank,
  67. )
  68. self.assertEqual(
  69. calc_expected_tensor_model_paralell_rank(
  70. self.rank, tensor_model_parallel_world_size
  71. ),
  72. parallel_state.get_tensor_model_parallel_rank(),
  73. )
  74. self.assertEqual(
  75. pipeline_model_parallel_world_size,
  76. parallel_state.get_pipeline_model_parallel_world_size(),
  77. )
  78. self.assertEqual(
  79. virtual_pipeline_model_parallel_world_size,
  80. parallel_state.get_virtual_pipeline_model_parallel_world_size(),
  81. )
  82. expected_pipeline_rank = (
  83. self.rank - (self.rank % tensor_model_parallel_world_size)
  84. ) % pipeline_model_parallel_world_size
  85. self.assertEqual(
  86. expected_pipeline_rank, parallel_state.get_pipeline_model_parallel_rank(),
  87. )
  88. # virtual pipeline model parallel rank is lazily set, i.e., right after the call of
  89. # `initialize_model_parallel`, it's set to 0.
  90. self.assertEqual(
  91. 0, parallel_state.get_virtual_pipeline_model_parallel_rank(),
  92. )
  93. self.assertEqual(
  94. pipeline_model_parallel_split_rank,
  95. parallel_state.get_pipeline_model_parallel_split_rank(),
  96. )
  97. fake_split_rank = 77
  98. parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank)
  99. self.assertEqual(
  100. fake_split_rank, parallel_state.get_pipeline_model_parallel_split_rank()
  101. )
  102. # relative position embedding groups check
  103. self.assertEqual(
  104. expected_pipeline_rank < pipeline_model_parallel_split_rank,
  105. parallel_state.is_rank_in_encoder_relative_position_embedding_group(),
  106. )
  107. self.assertEqual(
  108. expected_pipeline_rank >= pipeline_model_parallel_split_rank,
  109. parallel_state.is_rank_in_decoder_relative_position_embedding_group(),
  110. )
  111. parallel_state.destroy_model_parallel()
  112. def test_initialize_model_parallel_decoder_only(self) -> None:
  113. """Initialize model parallelism for decoder-only Transformers like GPT-3"""
  114. self.assertFalse(parallel_state.model_parallel_is_initialized())
  115. for tensor_model_parallel_world_size in range(1, self.world_size + 1):
  116. msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}"
  117. if self.world_size % tensor_model_parallel_world_size:
  118. continue
  119. pipeline_model_parallel_world_size = (
  120. self.world_size // tensor_model_parallel_world_size
  121. )
  122. parallel_state.initialize_model_parallel(
  123. tensor_model_parallel_size_=tensor_model_parallel_world_size,
  124. pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
  125. pipeline_model_parallel_split_rank_=0,
  126. )
  127. self.assertEqual(
  128. tensor_model_parallel_world_size,
  129. parallel_state.get_tensor_model_parallel_world_size(),
  130. msg=msg,
  131. )
  132. expected_tensor_model_parallel_rank = calc_expected_tensor_model_paralell_rank(
  133. self.rank, tensor_model_parallel_world_size
  134. )
  135. self.assertEqual(
  136. expected_tensor_model_parallel_rank,
  137. parallel_state.get_tensor_model_parallel_rank(),
  138. msg=msg,
  139. )
  140. expected_tensor_model_parallel_src_rank = (
  141. self.rank // tensor_model_parallel_world_size
  142. ) * tensor_model_parallel_world_size
  143. self.assertEqual(
  144. expected_tensor_model_parallel_src_rank,
  145. parallel_state.get_tensor_model_parallel_src_rank(),
  146. msg=msg,
  147. )
  148. parallel_state.destroy_model_parallel()
  149. self.assertFalse(parallel_state.model_parallel_is_initialized(), msg=msg)
  150. class NcclParallelStateTest(ParallelStateTestBase, NcclDistributedTestBase): pass
  151. class UccParallelStateTest(ParallelStateTestBase, UccDistributedTestBase): pass
  152. if __name__ == "__main__":
  153. common_utils.run_tests()