test_microbatches.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import logging
  2. from typing import List, Optional
  3. from torch.testing._internal import common_utils
  4. logging.getLogger("torch").setLevel(logging.WARNING)
  5. from apex.transformer import parallel_state
  6. from apex.transformer.pipeline_parallel.utils import (
  7. _reconfigure_microbatch_calculator,
  8. get_micro_batch_size,
  9. get_num_microbatches,
  10. get_current_global_batch_size,
  11. update_num_microbatches,
  12. )
  13. from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
  14. from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
  15. logging.getLogger("apex").setLevel(logging.WARNING)
  16. class MicrobatchCalculatorTestBase:
  17. GLOBAL_BATCH_SIZE: int = 1024
  18. MICRO_BATCH_SIZE: int = 1
  19. def _test(self, rampup_batch_size: Optional[List[int]]) -> None:
  20. for data_parallel_size in range(1, self.world_size + 1):
  21. expected_global_batch_size = self.GLOBAL_BATCH_SIZE
  22. expected_micro_batch_size = self.MICRO_BATCH_SIZE
  23. if rampup_batch_size:
  24. expected_global_batch_size = rampup_batch_size[0]
  25. num_consumed_samples = 0
  26. step_of_global_batch_size = rampup_batch_size[1]
  27. threshold = rampup_batch_size[2]
  28. if data_parallel_size > 1 and data_parallel_size % 2 != 0:
  29. continue
  30. if self.world_size % data_parallel_size != 0:
  31. continue
  32. msg = f"data_parallel_size: {data_parallel_size}"
  33. parallel_state.initialize_model_parallel(
  34. tensor_model_parallel_size_=self.world_size // data_parallel_size,
  35. pipeline_model_parallel_size_=1,
  36. )
  37. self.assertEqual(data_parallel_size, parallel_state.get_data_parallel_world_size(), msg=msg)
  38. _reconfigure_microbatch_calculator(
  39. self.rank,
  40. rampup_batch_size,
  41. self.GLOBAL_BATCH_SIZE,
  42. self.MICRO_BATCH_SIZE,
  43. data_parallel_size,
  44. )
  45. self.assertEqual(get_micro_batch_size(), expected_micro_batch_size, msg=msg)
  46. self.assertEqual(get_num_microbatches(), expected_global_batch_size / expected_micro_batch_size / data_parallel_size, msg=msg)
  47. current_global_batch_size = get_current_global_batch_size()
  48. self.assertEqual(current_global_batch_size, expected_global_batch_size, msg=msg)
  49. # Make sure `global_batch_size` equals to the final global batch size after
  50. # certain number of updates.
  51. if rampup_batch_size:
  52. update_num_microbatches(current_global_batch_size)
  53. for i in range(100):
  54. current_global_batch_size = get_current_global_batch_size()
  55. update_num_microbatches(current_global_batch_size)
  56. current_global_batch_size = get_current_global_batch_size()
  57. self.assertEqual(get_current_global_batch_size(), self.GLOBAL_BATCH_SIZE, msg=msg)
  58. parallel_state.destroy_model_parallel()
  59. def test_constant_microbatch_calculator(self):
  60. self._test(rampup_batch_size=None)
  61. def test_dynamic_microbatch_calculator(self):
  62. self._test(rampup_batch_size=[256, 128, 500])
  63. class NcclMicrobatchCalculatorTest(MicrobatchCalculatorTestBase, NcclDistributedTestBase): pass
  64. class UccMicrobatchCalculatorTest(MicrobatchCalculatorTestBase, UccDistributedTestBase): pass
  65. if __name__ == "__main__":
  66. common_utils.run_tests()