test_cross_entropy.py 3.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import logging
  2. from typing import Tuple
  3. import torch
  4. import torch.nn.functional as F
  5. from torch.testing._internal import common_utils
  6. logging.getLogger("torch").setLevel(logging.WARNING)
  7. from apex.transformer import parallel_state
  8. from apex.transformer import tensor_parallel
  9. from apex.transformer.tensor_parallel import cross_entropy
  10. from apex.transformer.testing.commons import set_random_seed, IdentityLayer
  11. from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
  12. from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
  13. logging.getLogger("apex").setLevel(logging.WARNING)
  14. def torch_cross_entropy(
  15. batch_size: int, seq_length: int, vocab_size: int, logits_scale: float, seed: int, label_smoothing: float = 0.0
  16. ) -> Tuple[torch.Tensor, torch.Tensor]:
  17. set_random_seed(seed)
  18. identity = IdentityLayer(
  19. (batch_size, seq_length, vocab_size), scale=logits_scale
  20. ).cuda()
  21. logits = identity()
  22. target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size)
  23. loss = (
  24. F.cross_entropy(
  25. logits.view(-1, logits.size()[-1]), target.view(-1), reduction="none", label_smoothing=label_smoothing
  26. )
  27. .view_as(target)
  28. .mean()
  29. )
  30. loss.backward()
  31. return loss, identity.weight.grad
  32. def tensor_sharded_cross_entropy(
  33. batch_size, seq_length, vocab_size, logits_scale, seed, label_smoothing=0.0
  34. ):
  35. set_random_seed(seed)
  36. identity = IdentityLayer(
  37. (batch_size, seq_length, vocab_size), scale=logits_scale
  38. ).cuda()
  39. logits = identity()
  40. logits_parallel = tensor_parallel.scatter_to_tensor_model_parallel_region(logits)
  41. target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size)
  42. logits_parallel_ = logits_parallel.clone().detach()
  43. loss = cross_entropy.vocab_parallel_cross_entropy(logits_parallel, target, label_smoothing=label_smoothing).mean()
  44. loss.backward()
  45. # check for mutation
  46. assert torch.equal(logits_parallel_, logits_parallel)
  47. return loss, identity.weight.grad
  48. class VocabParallelCrossEntropyTestBase:
  49. def test_cross_entropy(self):
  50. batch_size, sequence_length, vocab_size_per_partition = 13, 17, 11
  51. logits_scale = 1000.0
  52. seed = 1234
  53. for tensor_model_parallel_world_size in range(1, self.world_size + 1):
  54. if self.world_size % tensor_model_parallel_world_size:
  55. continue
  56. parallel_state.initialize_model_parallel(
  57. tensor_model_parallel_size_=tensor_model_parallel_world_size,
  58. )
  59. vocab_size = vocab_size_per_partition * tensor_model_parallel_world_size
  60. loss_torch, grad_torch = torch_cross_entropy(
  61. batch_size, sequence_length, vocab_size, logits_scale, seed
  62. )
  63. (
  64. loss_tensor_parallel,
  65. grad_tensor_parallel,
  66. ) = tensor_sharded_cross_entropy(
  67. batch_size, sequence_length, vocab_size, logits_scale, seed
  68. )
  69. self.assertEqual(
  70. loss_torch, loss_tensor_parallel,
  71. msg=f"tensor_model_parallel_size: {tensor_model_parallel_world_size}",
  72. )
  73. self.assertEqual(
  74. grad_torch, grad_tensor_parallel,
  75. msg=f"tensor_model_parallel_size: {tensor_model_parallel_world_size}",
  76. )
  77. parallel_state.destroy_model_parallel()
  78. class NcclVocabParallelCrossEntropyTest(VocabParallelCrossEntropyTestBase, NcclDistributedTestBase): pass
  79. class UccVocabParallelCrossEntropyTest(VocabParallelCrossEntropyTestBase, UccDistributedTestBase): pass
  80. if __name__ == "__main__":
  81. common_utils.run_tests()