test_p2p_comm.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import logging
  2. import unittest
  3. import torch
  4. from torch.testing._internal import common_utils
  5. logging.getLogger("torch").setLevel(logging.WARNING)
  6. from apex.transformer import parallel_state
  7. from apex.transformer.pipeline_parallel import p2p_communication
  8. from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
  9. from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
  10. logging.getLogger("apex").setLevel(logging.DEBUG)
  11. # [P2P Ops Involved in Pipeline Model Parallel forward/backward]
  12. # **forward_backward_pipelining_without_interleaving**
  13. # - send_forward / recv_forward
  14. # - send_backward / recv_backward
  15. # - send_forward_recv_backward
  16. # - send_backward_recv_forward
  17. # **forward_backward_pipelining_with_interleaving**
  18. # - send_backward_recv_backward
  19. # - recv_backward
  20. # - recv_forward
  21. # - send_forward_backward_recv_forward_backward
  22. # - send_forward_recv_forward
  23. class P2PCommTestBase:
  24. numel = 4
  25. shape = (2, 2)
  26. dtype = torch.float32
  27. @property
  28. def world_size(self):
  29. return min(2, torch.cuda.device_count())
  30. def _init_model_parallel(self):
  31. parallel_state.initialize_model_parallel(
  32. tensor_model_parallel_size_=1,
  33. pipeline_model_parallel_size_=self.world_size,
  34. virtual_pipeline_model_parallel_size_=None,
  35. )
  36. def create_tensor(self, value: int = None):
  37. return torch.tensor(
  38. [value] * self.numel).view(self.shape).to(device="cuda", dtype=self.dtype)
  39. # Brief: Simulate warm-up.
  40. # Brief: test `recv_forward` & `send_forward`.
  41. def test_no_interleaving_warmup(self):
  42. self.assertEqual(self.world_size, 2)
  43. self._init_model_parallel()
  44. input_tensor = None
  45. if parallel_state.is_pipeline_first_stage():
  46. tensor = self.create_tensor(self.rank)
  47. print(tensor)
  48. p2p_communication.send_forward(output_tensor=tensor, tensor_shape=self.shape, dtype=self.dtype)
  49. else:
  50. input_tensor = p2p_communication.recv_forward(tensor_shape=self.shape, dtype=self.dtype)
  51. if parallel_state.is_pipeline_first_stage():
  52. self.assertIsNone(input_tensor)
  53. else:
  54. expected_input_tensor = self.create_tensor(self.rank - 1)
  55. self.assertEqual(input_tensor, expected_input_tensor)
  56. # Brief: test `send_forward`, `send_forward_recv_forward`, and `recv_forward`.
  57. def test_send_forward_recv_forward(self):
  58. self._init_model_parallel()
  59. prev_tensor = None
  60. tensor = self.create_tensor(self.rank)
  61. if parallel_state.is_pipeline_first_stage():
  62. p2p_communication.send_forward(output_tensor=tensor, tensor_shape=self.shape, dtype=self.dtype)
  63. elif parallel_state.is_pipeline_last_stage():
  64. prev_tensor = p2p_communication.recv_forward(tensor_shape=self.shape, dtype=self.dtype)
  65. else:
  66. prev_tensor = p2p_communication.send_forward_recv_forward(
  67. output_tensor=tensor,
  68. recv_prev=True,
  69. tensor_shape=self.shape,
  70. dtype=self.dtype,
  71. )
  72. if parallel_state.is_pipeline_first_stage():
  73. self.assertIsNone(prev_tensor)
  74. else:
  75. expected_prev_tensor = self.create_tensor(self.rank - 1)
  76. self.assertEqual(prev_tensor, expected_prev_tensor)
  77. # Brief: test `send_backward`, `send_backward_recv_backward`, and `recv_backward`.
  78. def test_send_backward_recv_backward(self):
  79. self._init_model_parallel()
  80. tensor = self.create_tensor(self.rank)
  81. next_tensor = None
  82. if parallel_state.is_pipeline_first_stage():
  83. next_tensor = p2p_communication.recv_backward(tensor_shape=self.shape, dtype=self.dtype)
  84. elif parallel_state.is_pipeline_last_stage():
  85. p2p_communication.send_backward(input_tensor_grad=tensor, tensor_shape=self.shape, dtype=self.dtype)
  86. else:
  87. next_tensor = p2p_communication.send_backward_recv_backward(
  88. input_tensor_grad=tensor,
  89. recv_next=True,
  90. tensor_shape=self.shape,
  91. dtype=self.dtype,
  92. )
  93. if parallel_state.is_pipeline_last_stage():
  94. self.assertIsNone(next_tensor)
  95. else:
  96. expected_next_tensor = self.create_tensor(self.rank + 1)
  97. self.assertEqual(next_tensor, expected_next_tensor)
  98. # n.b.(mkozuki): Intentionally skip NCCL backend tests as I trust pytorch/pytorch repo.
  99. @unittest.skipIf(torch.cuda.device_count() < 2, "Requires >= 2 GPUs")
  100. class UccP2PCommTest(P2PCommTestBase, UccDistributedTestBase): pass
  101. if __name__ == "__main__":
  102. common_utils.run_tests()