123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235 |
- from functools import partial
- from typing import List
- import time
- import torch
- import unittest
- from apex.transformer._ucc_util import HAS_UCC
- from apex.transformer import parallel_state
- from apex.transformer.enums import ModelType
- from apex.transformer.tensor_parallel import model_parallel_cuda_manual_seed
- from apex.transformer.pipeline_parallel.utils import (
- average_losses_across_data_parallel_group, unwrap_model, setup_microbatch_calculator,
- get_ltor_masks_and_position_ids
- )
- from apex.transformer.pipeline_parallel.schedules.common import (
- _get_params_for_weight_decay_optimization, build_model
- )
- from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
- forward_backward_pipelining_without_interleaving,
- )
- from apex.transformer.testing.standalone_gpt import gpt_model_provider
- from apex.transformer.testing import global_vars
- from apex.transformer.testing.distributed_test_base import UccDistributedTestBase, NcclDistributedTestBase
- from torch.testing._internal import common_utils
- from torch.testing._internal.common_device_type import instantiate_device_type_tests
- class GptTestBase:
- def _download_fancy_data(self):
- text = """
- An original sentence not subject to any license restrictions, copyright, or royalty payments. Nothing to see here. Commercial or non-commercial use. Research or non-research purposes. The quick brown fox jumps over the lazy dog. Lorem ipsum.
- """
- text = text * 1024
- encoded = text.encode("ascii", "replace")
- ints = [int(encoded[i]) for i in range(len(encoded))]
- return torch.tensor(ints)
- # build a batch given sequence_len and batch size
- def _generate_fancy_data_labels(self, sequence_len, batch_size):
- temps = list()
- for i in range(batch_size):
- if self.inds is None or self.data_idx >= len(self.inds):
- # hack as use of RNG will fall out of sync due to pipelines being different
- model_parallel_cuda_manual_seed(self.MANUAL_SEED)
- self.inds = torch.randperm(effective_length, device="cuda")
- self.MANUAL_SEED += 1
- self.data_idx = 0
- data_idx_ = self.data_idx
- offset = self.inds[data_idx_]
- self.data_idx += 1
- curr = fancy_data[offset: offset +
- sequence_len + 1].clone().detach()
- temps.append(curr)
- temp = torch.stack(temps, dim=0).cuda()
- return temp
- def _get_batch(self, int_tensors: List[torch.Tensor]):
- data = int_tensors[0]
- # Unpack.
- tokens_ = data.long()
- labels = tokens_[:, 1:].contiguous()
- tokens = tokens_[:, :-1].contiguous()
- # Get the masks and position ids.
- attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
- tokens,
- self.N_VOCAB, # tokenizer.eod,
- False, # args.reset_position_ids,
- False, # args.reset_attention_mask,
- False, # args.eod_mask_loss,
- )
- return tokens, labels, loss_mask, attention_mask, position_ids
- # Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L75
- def _loss_func(self, loss_mask, output_tensor):
- losses = output_tensor.float()
- loss_mask = loss_mask.view(-1).float()
- loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
- # Reduce loss for logging.
- averaged_loss = average_losses_across_data_parallel_group([loss])
- return loss, {"lm loss": averaged_loss[0]}
- # Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L86
- def _fwd_step_func(self, batch, model):
- """Forward step."""
- tokens, labels, loss_mask, attention_mask, position_ids = self._get_batch(
- batch)
- output_tensor = model(tokens, position_ids,
- attention_mask, labels=labels)
- return output_tensor, partial(self._loss_func, loss_mask)
- def _train(self, model, optim, pipeline_model_parallel_size, async_comm):
- args = global_vars.get_args()
- fwd_bwd_func = forward_backward_pipelining_without_interleaving
- tensor_shape = (args.seq_length, args.micro_batch_size,
- args.hidden_size)
- runtime = 0
- # training loop
- for i in range(3):
- since = time.time()
- if torch.distributed.get_rank() == 0:
- print("begin iter", i)
- batch = [
- self._generate_fancy_data_labels(
- args.seq_length, args.global_batch_size)
- for _ in range(pipeline_model_parallel_size)
- ]
- if torch.distributed.get_rank() == 0:
- print("finished making batch...")
- optim.zero_grad()
- fwd_bwd_func(
- self._fwd_step_func,
- batch,
- model,
- forward_only=False,
- tensor_shape=tensor_shape,
- async_comm=async_comm,
- sequence_parallel_enabled=args.sequence_parallel,
- )
- if torch.distributed.get_rank() == 0:
- print("finished forward step")
- # All-reduce layernorm parameters across model parallel nodes
- # when sequence parallelism is used
- if parallel_state.get_tensor_model_parallel_world_size() > 1 and global_vars.get_args().sequence_parallel:
- for model_module in model:
- unwrapped_model = unwrap_model(model_module)
- for param in unwrapped_model.parameters():
- if getattr(param, 'sequence_parallel_enabled', False):
- grad = param.grad
- torch.distributed.all_reduce(
- grad, group=parallel_state.get_tensor_model_parallel_group())
- optim.step()
- if torch.distributed.get_rank() == 0:
- print("finished iter", i)
- runtime += time.time() - since
- return runtime / 3.0
- @unittest.skipUnless(torch.cuda.device_count() > 2, "requires at least 3 gpus")
- def test_gpt(self):
- self.MANUAL_SEED = 42
- self.inds = None
- self.data_idx = 0
- self.N_VOCAB = 128
- init = True
- tensor_model_parallel_size = 2 if self.world_size % 2 == 0 and self.world_size >= 4 else 1
- pipeline_model_parallel_size = self.world_size // tensor_model_parallel_size
- override_args = {
- "micro_batch_size": 2,
- "num_layers": 16,
- "hidden_size": 256,
- "num_attention_heads": 8,
- "max_position_embeddings": 512,
- "seq_length": 512,
- "global_batch_size": 128,
- "pipeline_model_parallel_size": pipeline_model_parallel_size,
- "tensor_model_parallel_size": tensor_model_parallel_size,
- "world_size": self.world_size,
- "rank": self.rank,
- }
- global_vars.set_global_variables(override_args=override_args, ignore_unknown_args=True)
- args = global_vars.get_args()
- for async_comm in (False,) if args.sequence_parallel else (False, True):
- global fancy_data
- global effective_length
- if init:
- init = False
- fancy_data = self._download_fancy_data()
- args = global_vars.get_args()
- args.model_type = ModelType.encoder_or_decoder
- effective_length = fancy_data.size(0) // args.seq_length
- effective_length = fancy_data.size(0) - args.seq_length
- args.padded_vocab_size = 128
- setup_microbatch_calculator(
- args.rank,
- args.rampup_batch_size,
- args.global_batch_size,
- args.micro_batch_size,
- args.data_parallel_size,
- )
- print(args.tensor_model_parallel_size, "MODEL PARALLEL SIZE")
- parallel_state.initialize_model_parallel(
- tensor_model_parallel_size_=args.tensor_model_parallel_size,
- pipeline_model_parallel_size_=args.pipeline_model_parallel_size,
- default_backend="nccl",
- p2p_backend=self.DISTRIBUTED_BACKEND,
- )
- model_parallel_cuda_manual_seed(0)
- model = build_model(
- gpt_model_provider,
- wrap_with_ddp=parallel_state.get_data_parallel_world_size() > 1,
- virtual_pipeline_model_parallel_size=None,
- cpu_offload=args.cpu_offload,
- )
- assert isinstance(model, list), model
- _param_groups = _get_params_for_weight_decay_optimization(model)
- optim = torch.optim.Adam(_param_groups)
- runtime = self._train(
- model, optim, args.pipeline_model_parallel_size, async_comm)
- parallel_state.destroy_model_parallel()
- torch.cuda.synchronize()
- class NcclGptTest(GptTestBase, NcclDistributedTestBase):
- @property
- def world_size(self) -> int:
- return min(torch.cuda.device_count(), 8)
- @unittest.skipUnless(HAS_UCC, "requires pytorch to be built with native ucc")
- class UccGptTest(GptTestBase, UccDistributedTestBase):
- @property
- def world_size(self) -> int:
- return min(torch.cuda.device_count(), 8)
- if __name__ == "__main__":
- common_utils.run_tests()
|