from functools import partial from typing import List import time import torch import unittest from apex.transformer._ucc_util import HAS_UCC from apex.transformer import parallel_state from apex.transformer.enums import ModelType from apex.transformer.tensor_parallel import model_parallel_cuda_manual_seed from apex.transformer.pipeline_parallel.utils import ( average_losses_across_data_parallel_group, unwrap_model, setup_microbatch_calculator, get_ltor_masks_and_position_ids ) from apex.transformer.pipeline_parallel.schedules.common import ( _get_params_for_weight_decay_optimization, build_model ) from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( forward_backward_pipelining_without_interleaving, ) from apex.transformer.testing.standalone_gpt import gpt_model_provider from apex.transformer.testing import global_vars from apex.transformer.testing.distributed_test_base import UccDistributedTestBase, NcclDistributedTestBase from torch.testing._internal import common_utils from torch.testing._internal.common_device_type import instantiate_device_type_tests class GptTestBase: def _download_fancy_data(self): text = """ An original sentence not subject to any license restrictions, copyright, or royalty payments. Nothing to see here. Commercial or non-commercial use. Research or non-research purposes. The quick brown fox jumps over the lazy dog. Lorem ipsum. """ text = text * 1024 encoded = text.encode("ascii", "replace") ints = [int(encoded[i]) for i in range(len(encoded))] return torch.tensor(ints) # build a batch given sequence_len and batch size def _generate_fancy_data_labels(self, sequence_len, batch_size): temps = list() for i in range(batch_size): if self.inds is None or self.data_idx >= len(self.inds): # hack as use of RNG will fall out of sync due to pipelines being different model_parallel_cuda_manual_seed(self.MANUAL_SEED) self.inds = torch.randperm(effective_length, device="cuda") self.MANUAL_SEED += 1 self.data_idx = 0 data_idx_ = self.data_idx offset = self.inds[data_idx_] self.data_idx += 1 curr = fancy_data[offset: offset + sequence_len + 1].clone().detach() temps.append(curr) temp = torch.stack(temps, dim=0).cuda() return temp def _get_batch(self, int_tensors: List[torch.Tensor]): data = int_tensors[0] # Unpack. tokens_ = data.long() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # Get the masks and position ids. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, self.N_VOCAB, # tokenizer.eod, False, # args.reset_position_ids, False, # args.reset_attention_mask, False, # args.eod_mask_loss, ) return tokens, labels, loss_mask, attention_mask, position_ids # Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L75 def _loss_func(self, loss_mask, output_tensor): losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # Reduce loss for logging. averaged_loss = average_losses_across_data_parallel_group([loss]) return loss, {"lm loss": averaged_loss[0]} # Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L86 def _fwd_step_func(self, batch, model): """Forward step.""" tokens, labels, loss_mask, attention_mask, position_ids = self._get_batch( batch) output_tensor = model(tokens, position_ids, attention_mask, labels=labels) return output_tensor, partial(self._loss_func, loss_mask) def _train(self, model, optim, pipeline_model_parallel_size, async_comm): args = global_vars.get_args() fwd_bwd_func = forward_backward_pipelining_without_interleaving tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) runtime = 0 # training loop for i in range(3): since = time.time() if torch.distributed.get_rank() == 0: print("begin iter", i) batch = [ self._generate_fancy_data_labels( args.seq_length, args.global_batch_size) for _ in range(pipeline_model_parallel_size) ] if torch.distributed.get_rank() == 0: print("finished making batch...") optim.zero_grad() fwd_bwd_func( self._fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape, async_comm=async_comm, sequence_parallel_enabled=args.sequence_parallel, ) if torch.distributed.get_rank() == 0: print("finished forward step") # All-reduce layernorm parameters across model parallel nodes # when sequence parallelism is used if parallel_state.get_tensor_model_parallel_world_size() > 1 and global_vars.get_args().sequence_parallel: for model_module in model: unwrapped_model = unwrap_model(model_module) for param in unwrapped_model.parameters(): if getattr(param, 'sequence_parallel_enabled', False): grad = param.grad torch.distributed.all_reduce( grad, group=parallel_state.get_tensor_model_parallel_group()) optim.step() if torch.distributed.get_rank() == 0: print("finished iter", i) runtime += time.time() - since return runtime / 3.0 @unittest.skipUnless(torch.cuda.device_count() > 2, "requires at least 3 gpus") def test_gpt(self): self.MANUAL_SEED = 42 self.inds = None self.data_idx = 0 self.N_VOCAB = 128 init = True tensor_model_parallel_size = 2 if self.world_size % 2 == 0 and self.world_size >= 4 else 1 pipeline_model_parallel_size = self.world_size // tensor_model_parallel_size override_args = { "micro_batch_size": 2, "num_layers": 16, "hidden_size": 256, "num_attention_heads": 8, "max_position_embeddings": 512, "seq_length": 512, "global_batch_size": 128, "pipeline_model_parallel_size": pipeline_model_parallel_size, "tensor_model_parallel_size": tensor_model_parallel_size, "world_size": self.world_size, "rank": self.rank, } global_vars.set_global_variables(override_args=override_args, ignore_unknown_args=True) args = global_vars.get_args() for async_comm in (False,) if args.sequence_parallel else (False, True): global fancy_data global effective_length if init: init = False fancy_data = self._download_fancy_data() args = global_vars.get_args() args.model_type = ModelType.encoder_or_decoder effective_length = fancy_data.size(0) // args.seq_length effective_length = fancy_data.size(0) - args.seq_length args.padded_vocab_size = 128 setup_microbatch_calculator( args.rank, args.rampup_batch_size, args.global_batch_size, args.micro_batch_size, args.data_parallel_size, ) print(args.tensor_model_parallel_size, "MODEL PARALLEL SIZE") parallel_state.initialize_model_parallel( tensor_model_parallel_size_=args.tensor_model_parallel_size, pipeline_model_parallel_size_=args.pipeline_model_parallel_size, default_backend="nccl", p2p_backend=self.DISTRIBUTED_BACKEND, ) model_parallel_cuda_manual_seed(0) model = build_model( gpt_model_provider, wrap_with_ddp=parallel_state.get_data_parallel_world_size() > 1, virtual_pipeline_model_parallel_size=None, cpu_offload=args.cpu_offload, ) assert isinstance(model, list), model _param_groups = _get_params_for_weight_decay_optimization(model) optim = torch.optim.Adam(_param_groups) runtime = self._train( model, optim, args.pipeline_model_parallel_size, async_comm) parallel_state.destroy_model_parallel() torch.cuda.synchronize() class NcclGptTest(GptTestBase, NcclDistributedTestBase): @property def world_size(self) -> int: return min(torch.cuda.device_count(), 8) @unittest.skipUnless(HAS_UCC, "requires pytorch to be built with native ucc") class UccGptTest(GptTestBase, UccDistributedTestBase): @property def world_size(self) -> int: return min(torch.cuda.device_count(), 8) if __name__ == "__main__": common_utils.run_tests()