import logging from typing import Tuple import torch import torch.nn.functional as F from torch.testing._internal import common_utils logging.getLogger("torch").setLevel(logging.WARNING) from apex.transformer import parallel_state from apex.transformer import tensor_parallel from apex.transformer.tensor_parallel import cross_entropy from apex.transformer.testing.commons import set_random_seed, IdentityLayer from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase from apex.transformer.testing.distributed_test_base import UccDistributedTestBase logging.getLogger("apex").setLevel(logging.WARNING) def torch_cross_entropy( batch_size: int, seq_length: int, vocab_size: int, logits_scale: float, seed: int, label_smoothing: float = 0.0 ) -> Tuple[torch.Tensor, torch.Tensor]: set_random_seed(seed) identity = IdentityLayer( (batch_size, seq_length, vocab_size), scale=logits_scale ).cuda() logits = identity() target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size) loss = ( F.cross_entropy( logits.view(-1, logits.size()[-1]), target.view(-1), reduction="none", label_smoothing=label_smoothing ) .view_as(target) .mean() ) loss.backward() return loss, identity.weight.grad def tensor_sharded_cross_entropy( batch_size, seq_length, vocab_size, logits_scale, seed, label_smoothing=0.0 ): set_random_seed(seed) identity = IdentityLayer( (batch_size, seq_length, vocab_size), scale=logits_scale ).cuda() logits = identity() logits_parallel = tensor_parallel.scatter_to_tensor_model_parallel_region(logits) target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size) logits_parallel_ = logits_parallel.clone().detach() loss = cross_entropy.vocab_parallel_cross_entropy(logits_parallel, target, label_smoothing=label_smoothing).mean() loss.backward() # check for mutation assert torch.equal(logits_parallel_, logits_parallel) return loss, identity.weight.grad class VocabParallelCrossEntropyTestBase: def test_cross_entropy(self): batch_size, sequence_length, vocab_size_per_partition = 13, 17, 11 logits_scale = 1000.0 seed = 1234 for tensor_model_parallel_world_size in range(1, self.world_size + 1): if self.world_size % tensor_model_parallel_world_size: continue parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size, ) vocab_size = vocab_size_per_partition * tensor_model_parallel_world_size loss_torch, grad_torch = torch_cross_entropy( batch_size, sequence_length, vocab_size, logits_scale, seed ) ( loss_tensor_parallel, grad_tensor_parallel, ) = tensor_sharded_cross_entropy( batch_size, sequence_length, vocab_size, logits_scale, seed ) self.assertEqual( loss_torch, loss_tensor_parallel, msg=f"tensor_model_parallel_size: {tensor_model_parallel_world_size}", ) self.assertEqual( grad_torch, grad_tensor_parallel, msg=f"tensor_model_parallel_size: {tensor_model_parallel_world_size}", ) parallel_state.destroy_model_parallel() class NcclVocabParallelCrossEntropyTest(VocabParallelCrossEntropyTestBase, NcclDistributedTestBase): pass class UccVocabParallelCrossEntropyTest(VocabParallelCrossEntropyTestBase, UccDistributedTestBase): pass if __name__ == "__main__": common_utils.run_tests()