test_gpt_minimal.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. from functools import partial
  2. from typing import List
  3. import time
  4. import torch
  5. import unittest
  6. from apex.transformer._ucc_util import HAS_UCC
  7. from apex.transformer import parallel_state
  8. from apex.transformer.enums import ModelType
  9. from apex.transformer.tensor_parallel import model_parallel_cuda_manual_seed
  10. from apex.transformer.pipeline_parallel.utils import (
  11. average_losses_across_data_parallel_group, unwrap_model, setup_microbatch_calculator,
  12. get_ltor_masks_and_position_ids
  13. )
  14. from apex.transformer.pipeline_parallel.schedules.common import (
  15. _get_params_for_weight_decay_optimization, build_model
  16. )
  17. from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
  18. forward_backward_pipelining_without_interleaving,
  19. )
  20. from apex.transformer.testing.standalone_gpt import gpt_model_provider
  21. from apex.transformer.testing import global_vars
  22. from apex.transformer.testing.distributed_test_base import UccDistributedTestBase, NcclDistributedTestBase
  23. from torch.testing._internal import common_utils
  24. from torch.testing._internal.common_device_type import instantiate_device_type_tests
  25. class GptTestBase:
  26. def _download_fancy_data(self):
  27. text = """
  28. An original sentence not subject to any license restrictions, copyright, or royalty payments. Nothing to see here. Commercial or non-commercial use. Research or non-research purposes. The quick brown fox jumps over the lazy dog. Lorem ipsum.
  29. """
  30. text = text * 1024
  31. encoded = text.encode("ascii", "replace")
  32. ints = [int(encoded[i]) for i in range(len(encoded))]
  33. return torch.tensor(ints)
  34. # build a batch given sequence_len and batch size
  35. def _generate_fancy_data_labels(self, sequence_len, batch_size):
  36. temps = list()
  37. for i in range(batch_size):
  38. if self.inds is None or self.data_idx >= len(self.inds):
  39. # hack as use of RNG will fall out of sync due to pipelines being different
  40. model_parallel_cuda_manual_seed(self.MANUAL_SEED)
  41. self.inds = torch.randperm(effective_length, device="cuda")
  42. self.MANUAL_SEED += 1
  43. self.data_idx = 0
  44. data_idx_ = self.data_idx
  45. offset = self.inds[data_idx_]
  46. self.data_idx += 1
  47. curr = fancy_data[offset: offset +
  48. sequence_len + 1].clone().detach()
  49. temps.append(curr)
  50. temp = torch.stack(temps, dim=0).cuda()
  51. return temp
  52. def _get_batch(self, int_tensors: List[torch.Tensor]):
  53. data = int_tensors[0]
  54. # Unpack.
  55. tokens_ = data.long()
  56. labels = tokens_[:, 1:].contiguous()
  57. tokens = tokens_[:, :-1].contiguous()
  58. # Get the masks and position ids.
  59. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
  60. tokens,
  61. self.N_VOCAB, # tokenizer.eod,
  62. False, # args.reset_position_ids,
  63. False, # args.reset_attention_mask,
  64. False, # args.eod_mask_loss,
  65. )
  66. return tokens, labels, loss_mask, attention_mask, position_ids
  67. # Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L75
  68. def _loss_func(self, loss_mask, output_tensor):
  69. losses = output_tensor.float()
  70. loss_mask = loss_mask.view(-1).float()
  71. loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
  72. # Reduce loss for logging.
  73. averaged_loss = average_losses_across_data_parallel_group([loss])
  74. return loss, {"lm loss": averaged_loss[0]}
  75. # Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L86
  76. def _fwd_step_func(self, batch, model):
  77. """Forward step."""
  78. tokens, labels, loss_mask, attention_mask, position_ids = self._get_batch(
  79. batch)
  80. output_tensor = model(tokens, position_ids,
  81. attention_mask, labels=labels)
  82. return output_tensor, partial(self._loss_func, loss_mask)
  83. def _train(self, model, optim, pipeline_model_parallel_size, async_comm):
  84. args = global_vars.get_args()
  85. fwd_bwd_func = forward_backward_pipelining_without_interleaving
  86. tensor_shape = (args.seq_length, args.micro_batch_size,
  87. args.hidden_size)
  88. runtime = 0
  89. # training loop
  90. for i in range(3):
  91. since = time.time()
  92. if torch.distributed.get_rank() == 0:
  93. print("begin iter", i)
  94. batch = [
  95. self._generate_fancy_data_labels(
  96. args.seq_length, args.global_batch_size)
  97. for _ in range(pipeline_model_parallel_size)
  98. ]
  99. if torch.distributed.get_rank() == 0:
  100. print("finished making batch...")
  101. optim.zero_grad()
  102. fwd_bwd_func(
  103. self._fwd_step_func,
  104. batch,
  105. model,
  106. forward_only=False,
  107. tensor_shape=tensor_shape,
  108. async_comm=async_comm,
  109. sequence_parallel_enabled=args.sequence_parallel,
  110. )
  111. if torch.distributed.get_rank() == 0:
  112. print("finished forward step")
  113. # All-reduce layernorm parameters across model parallel nodes
  114. # when sequence parallelism is used
  115. if parallel_state.get_tensor_model_parallel_world_size() > 1 and global_vars.get_args().sequence_parallel:
  116. for model_module in model:
  117. unwrapped_model = unwrap_model(model_module)
  118. for param in unwrapped_model.parameters():
  119. if getattr(param, 'sequence_parallel_enabled', False):
  120. grad = param.grad
  121. torch.distributed.all_reduce(
  122. grad, group=parallel_state.get_tensor_model_parallel_group())
  123. optim.step()
  124. if torch.distributed.get_rank() == 0:
  125. print("finished iter", i)
  126. runtime += time.time() - since
  127. return runtime / 3.0
  128. @unittest.skipUnless(torch.cuda.device_count() > 2, "requires at least 3 gpus")
  129. def test_gpt(self):
  130. self.MANUAL_SEED = 42
  131. self.inds = None
  132. self.data_idx = 0
  133. self.N_VOCAB = 128
  134. init = True
  135. tensor_model_parallel_size = 2 if self.world_size % 2 == 0 and self.world_size >= 4 else 1
  136. pipeline_model_parallel_size = self.world_size // tensor_model_parallel_size
  137. override_args = {
  138. "micro_batch_size": 2,
  139. "num_layers": 16,
  140. "hidden_size": 256,
  141. "num_attention_heads": 8,
  142. "max_position_embeddings": 512,
  143. "seq_length": 512,
  144. "global_batch_size": 128,
  145. "pipeline_model_parallel_size": pipeline_model_parallel_size,
  146. "tensor_model_parallel_size": tensor_model_parallel_size,
  147. "world_size": self.world_size,
  148. "rank": self.rank,
  149. }
  150. global_vars.set_global_variables(override_args=override_args, ignore_unknown_args=True)
  151. args = global_vars.get_args()
  152. for async_comm in (False,) if args.sequence_parallel else (False, True):
  153. global fancy_data
  154. global effective_length
  155. if init:
  156. init = False
  157. fancy_data = self._download_fancy_data()
  158. args = global_vars.get_args()
  159. args.model_type = ModelType.encoder_or_decoder
  160. effective_length = fancy_data.size(0) // args.seq_length
  161. effective_length = fancy_data.size(0) - args.seq_length
  162. args.padded_vocab_size = 128
  163. setup_microbatch_calculator(
  164. args.rank,
  165. args.rampup_batch_size,
  166. args.global_batch_size,
  167. args.micro_batch_size,
  168. args.data_parallel_size,
  169. )
  170. print(args.tensor_model_parallel_size, "MODEL PARALLEL SIZE")
  171. parallel_state.initialize_model_parallel(
  172. tensor_model_parallel_size_=args.tensor_model_parallel_size,
  173. pipeline_model_parallel_size_=args.pipeline_model_parallel_size,
  174. default_backend="nccl",
  175. p2p_backend=self.DISTRIBUTED_BACKEND,
  176. )
  177. model_parallel_cuda_manual_seed(0)
  178. model = build_model(
  179. gpt_model_provider,
  180. wrap_with_ddp=parallel_state.get_data_parallel_world_size() > 1,
  181. virtual_pipeline_model_parallel_size=None,
  182. cpu_offload=args.cpu_offload,
  183. )
  184. assert isinstance(model, list), model
  185. _param_groups = _get_params_for_weight_decay_optimization(model)
  186. optim = torch.optim.Adam(_param_groups)
  187. runtime = self._train(
  188. model, optim, args.pipeline_model_parallel_size, async_comm)
  189. parallel_state.destroy_model_parallel()
  190. torch.cuda.synchronize()
  191. class NcclGptTest(GptTestBase, NcclDistributedTestBase):
  192. @property
  193. def world_size(self) -> int:
  194. return min(torch.cuda.device_count(), 8)
  195. @unittest.skipUnless(HAS_UCC, "requires pytorch to be built with native ucc")
  196. class UccGptTest(GptTestBase, UccDistributedTestBase):
  197. @property
  198. def world_size(self) -> int:
  199. return min(torch.cuda.device_count(), 8)
  200. if __name__ == "__main__":
  201. common_utils.run_tests()