test_random.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import logging
  2. import torch
  3. from torch.testing._internal import common_utils
  4. logging.getLogger("torch").setLevel(logging.WARNING)
  5. from apex.transformer import parallel_state
  6. from apex.transformer import tensor_parallel
  7. from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
  8. from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
  9. logging.getLogger("apex").setLevel(logging.WARNING)
  10. class TransformerRandomTestBase:
  11. def test_set_cuda_rng_state(self):
  12. for tensor_model_parallel_world_size in range(1, self.world_size + 1):
  13. if self.world_size % tensor_model_parallel_world_size:
  14. continue
  15. msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}"
  16. parallel_state.initialize_model_parallel(
  17. tensor_model_parallel_size_=tensor_model_parallel_world_size
  18. )
  19. size, seed = 123, 1234
  20. torch.cuda.manual_seed(seed)
  21. tensor = torch.cuda.FloatTensor(size)
  22. rng_state = torch.cuda.get_rng_state()
  23. rng_state_clone = rng_state.clone()
  24. for _ in range(5):
  25. torch.randn(size, out=tensor)
  26. result_1 = tensor.clone()
  27. self.assertEqual(rng_state.sub(rng_state_clone).max(), 0, msg=msg)
  28. self.assertGreater(
  29. torch.cuda.get_rng_state().sub(rng_state_clone).max(), 0,
  30. msg=msg,
  31. )
  32. new_rng_state = torch.cuda.get_rng_state()
  33. self.assertGreater(new_rng_state.sub(rng_state).max(), 0, msg=msg)
  34. tensor_parallel.random._set_cuda_rng_state(rng_state)
  35. for _ in range(5):
  36. torch.randn(size, out=tensor)
  37. tensor_parallel.random._set_cuda_rng_state(rng_state)
  38. for _ in range(5):
  39. torch.randn(size, out=tensor)
  40. result_2 = tensor.clone()
  41. self.assertEqual(result_2, result_1, msg=msg)
  42. self.assertEqual(rng_state.sub(rng_state_clone).max(), 0, msg=msg)
  43. parallel_state.destroy_model_parallel()
  44. def test_cuda_rng_tracker(self):
  45. for tensor_model_parallel_world_size in range(1, self.world_size + 1):
  46. if self.world_size % tensor_model_parallel_world_size:
  47. continue
  48. msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}"
  49. parallel_state.initialize_model_parallel(
  50. tensor_model_parallel_size_=tensor_model_parallel_world_size
  51. )
  52. seed_1, seed_2, size = 1234, 4321, [12, 21]
  53. tensor = torch.cuda.FloatTensor(size)
  54. torch.cuda.manual_seed(seed_1)
  55. torch.randn(size, out=tensor)
  56. target_11 = tensor.clone()
  57. torch.randn(size, out=tensor)
  58. target_12 = tensor.clone()
  59. torch.cuda.manual_seed(seed_2)
  60. torch.randn(size, out=tensor)
  61. targt_21 = tensor.clone()
  62. torch.randn(size, out=tensor)
  63. target_22 = tensor.clone()
  64. torch.cuda.manual_seed(seed_1)
  65. tensor_parallel.random.get_cuda_rng_tracker().add("test", seed_2)
  66. torch.randn(size, out=tensor)
  67. result_11 = tensor.clone()
  68. with tensor_parallel.random.get_cuda_rng_tracker().fork("test"):
  69. torch.randn(size, out=tensor)
  70. result_21 = tensor.clone()
  71. torch.randn(size, out=tensor)
  72. result_12 = tensor.clone()
  73. with tensor_parallel.random.get_cuda_rng_tracker().fork("test"):
  74. torch.randn(size, out=tensor)
  75. result_22 = tensor.clone()
  76. self.assertEqual(target_11, result_11, msg=msg)
  77. self.assertEqual(target_12, result_12, msg=msg)
  78. self.assertEqual(targt_21, result_21, msg=msg)
  79. self.assertEqual(target_22, result_22, msg=msg)
  80. self.assertNotEqual(result_11, result_21, msg=msg)
  81. self.assertNotEqual(result_21, result_22, msg=msg)
  82. tensor_parallel.random.get_cuda_rng_tracker().reset()
  83. parallel_state.destroy_model_parallel()
  84. class NcclTransformerRandomTest(TransformerRandomTestBase, NcclDistributedTestBase): pass
  85. class UccTransformerRandomTest(TransformerRandomTestBase, UccDistributedTestBase): pass
  86. if __name__ == "__main__":
  87. common_utils.run_tests()