test_data.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import logging
  2. import torch.testing
  3. from torch.testing._internal import common_utils
  4. logging.getLogger("torch").setLevel(logging.WARNING)
  5. from apex.transformer import parallel_state
  6. from apex.transformer.tensor_parallel import data as data_utils
  7. from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
  8. from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
  9. logging.getLogger("torch").setLevel(logging.WARNING)
  10. class BroadcastDataTestBase:
  11. def test_broadcast_data(self):
  12. tensor_model_parallel_world_size: int = self.world_size // (
  13. 1 + self.world_size > 1
  14. )
  15. parallel_state.initialize_model_parallel(
  16. tensor_model_parallel_size_=tensor_model_parallel_world_size
  17. )
  18. target_key_size = {
  19. "key1": [7, 11],
  20. "key2": [8, 2, 1],
  21. "key3": [13],
  22. "key4": [5, 1, 2],
  23. "key5": [5, 12],
  24. }
  25. keys = [k for k in target_key_size]
  26. data = {}
  27. data_t = {}
  28. with torch.no_grad():
  29. for key in target_key_size:
  30. data[key] = torch.randint(0, 1000, size=target_key_size[key])
  31. data_t[key] = data[key].clone()
  32. # "key_x" is supposed to be ignored.
  33. data["key_x"] = torch.rand(5)
  34. data_t["key_x"] = data["key_x"].clone()
  35. if parallel_state.get_tensor_model_parallel_rank() != 0:
  36. data = None
  37. data_utils._check_data_types(keys, data_t, torch.int64)
  38. key_size, _, _ = data_utils._build_key_size_numel_dictionaries(keys, data)
  39. for key in keys:
  40. self.assertEqual(target_key_size[key], key_size[key])
  41. broadcasted_data = data_utils.broadcast_data(keys, data, torch.int64)
  42. for key in keys:
  43. self.assertEqual(broadcasted_data[key], data_t[key].cuda())
  44. parallel_state.destroy_model_parallel()
  45. class NcclBroadcastDataTest(BroadcastDataTestBase, NcclDistributedTestBase): pass
  46. class UccBroadcastDataTest(BroadcastDataTestBase, UccDistributedTestBase): pass
  47. if __name__ == "__main__":
  48. common_utils.run_tests()