import logging import torch.testing from torch.testing._internal import common_utils logging.getLogger("torch").setLevel(logging.WARNING) from apex.transformer import parallel_state from apex.transformer.tensor_parallel import data as data_utils from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase from apex.transformer.testing.distributed_test_base import UccDistributedTestBase logging.getLogger("torch").setLevel(logging.WARNING) class BroadcastDataTestBase: def test_broadcast_data(self): tensor_model_parallel_world_size: int = self.world_size // ( 1 + self.world_size > 1 ) parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size ) target_key_size = { "key1": [7, 11], "key2": [8, 2, 1], "key3": [13], "key4": [5, 1, 2], "key5": [5, 12], } keys = [k for k in target_key_size] data = {} data_t = {} with torch.no_grad(): for key in target_key_size: data[key] = torch.randint(0, 1000, size=target_key_size[key]) data_t[key] = data[key].clone() # "key_x" is supposed to be ignored. data["key_x"] = torch.rand(5) data_t["key_x"] = data["key_x"].clone() if parallel_state.get_tensor_model_parallel_rank() != 0: data = None data_utils._check_data_types(keys, data_t, torch.int64) key_size, _, _ = data_utils._build_key_size_numel_dictionaries(keys, data) for key in keys: self.assertEqual(target_key_size[key], key_size[key]) broadcasted_data = data_utils.broadcast_data(keys, data, torch.int64) for key in keys: self.assertEqual(broadcasted_data[key], data_t[key].cuda()) parallel_state.destroy_model_parallel() class NcclBroadcastDataTest(BroadcastDataTestBase, NcclDistributedTestBase): pass class UccBroadcastDataTest(BroadcastDataTestBase, UccDistributedTestBase): pass if __name__ == "__main__": common_utils.run_tests()