123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- import logging
- import torch
- from torch.testing._internal import common_utils
- logging.getLogger("torch").setLevel(logging.WARNING)
- from apex.transformer import parallel_state
- from apex.transformer import tensor_parallel
- from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
- from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
- logging.getLogger("apex").setLevel(logging.WARNING)
- class TransformerRandomTestBase:
- def test_set_cuda_rng_state(self):
- for tensor_model_parallel_world_size in range(1, self.world_size + 1):
- if self.world_size % tensor_model_parallel_world_size:
- continue
- msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}"
- parallel_state.initialize_model_parallel(
- tensor_model_parallel_size_=tensor_model_parallel_world_size
- )
- size, seed = 123, 1234
- torch.cuda.manual_seed(seed)
- tensor = torch.cuda.FloatTensor(size)
- rng_state = torch.cuda.get_rng_state()
- rng_state_clone = rng_state.clone()
- for _ in range(5):
- torch.randn(size, out=tensor)
- result_1 = tensor.clone()
- self.assertEqual(rng_state.sub(rng_state_clone).max(), 0, msg=msg)
- self.assertGreater(
- torch.cuda.get_rng_state().sub(rng_state_clone).max(), 0,
- msg=msg,
- )
- new_rng_state = torch.cuda.get_rng_state()
- self.assertGreater(new_rng_state.sub(rng_state).max(), 0, msg=msg)
- tensor_parallel.random._set_cuda_rng_state(rng_state)
- for _ in range(5):
- torch.randn(size, out=tensor)
- tensor_parallel.random._set_cuda_rng_state(rng_state)
- for _ in range(5):
- torch.randn(size, out=tensor)
- result_2 = tensor.clone()
- self.assertEqual(result_2, result_1, msg=msg)
- self.assertEqual(rng_state.sub(rng_state_clone).max(), 0, msg=msg)
- parallel_state.destroy_model_parallel()
- def test_cuda_rng_tracker(self):
- for tensor_model_parallel_world_size in range(1, self.world_size + 1):
- if self.world_size % tensor_model_parallel_world_size:
- continue
- msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}"
- parallel_state.initialize_model_parallel(
- tensor_model_parallel_size_=tensor_model_parallel_world_size
- )
- seed_1, seed_2, size = 1234, 4321, [12, 21]
- tensor = torch.cuda.FloatTensor(size)
- torch.cuda.manual_seed(seed_1)
- torch.randn(size, out=tensor)
- target_11 = tensor.clone()
- torch.randn(size, out=tensor)
- target_12 = tensor.clone()
- torch.cuda.manual_seed(seed_2)
- torch.randn(size, out=tensor)
- targt_21 = tensor.clone()
- torch.randn(size, out=tensor)
- target_22 = tensor.clone()
- torch.cuda.manual_seed(seed_1)
- tensor_parallel.random.get_cuda_rng_tracker().add("test", seed_2)
- torch.randn(size, out=tensor)
- result_11 = tensor.clone()
- with tensor_parallel.random.get_cuda_rng_tracker().fork("test"):
- torch.randn(size, out=tensor)
- result_21 = tensor.clone()
- torch.randn(size, out=tensor)
- result_12 = tensor.clone()
- with tensor_parallel.random.get_cuda_rng_tracker().fork("test"):
- torch.randn(size, out=tensor)
- result_22 = tensor.clone()
- self.assertEqual(target_11, result_11, msg=msg)
- self.assertEqual(target_12, result_12, msg=msg)
- self.assertEqual(targt_21, result_21, msg=msg)
- self.assertEqual(target_22, result_22, msg=msg)
- self.assertNotEqual(result_11, result_21, msg=msg)
- self.assertNotEqual(result_21, result_22, msg=msg)
- tensor_parallel.random.get_cuda_rng_tracker().reset()
- parallel_state.destroy_model_parallel()
- class NcclTransformerRandomTest(TransformerRandomTestBase, NcclDistributedTestBase): pass
- class UccTransformerRandomTest(TransformerRandomTestBase, UccDistributedTestBase): pass
- if __name__ == "__main__":
- common_utils.run_tests()
|