123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253 |
- import torch
- import unittest
- from apex.transformer.testing import global_vars
- from apex.transformer.testing.standalone_bert import bert_model_provider
- from apex.transformer.pipeline_parallel.schedules.common import (
- _get_params_for_weight_decay_optimization, build_model
- )
- from apex.transformer.pipeline_parallel.schedules import get_forward_backward_func
- from apex.transformer.pipeline_parallel.utils import (
- average_losses_across_data_parallel_group, unwrap_model, setup_microbatch_calculator
- )
- from apex.transformer.log_util import set_logging_level
- from apex.transformer import tensor_parallel, parallel_state
- from apex.transformer.enums import ModelType
- from apex.transformer._ucc_util import HAS_UCC
- from apex.transformer.testing.distributed_test_base import UccDistributedTestBase, NcclDistributedTestBase
- import logging
- from torch.testing._internal import common_utils
- logging.getLogger("torch").setLevel(logging.WARNING)
- logging.getLogger("apex").setLevel(logging.WARNING)
- set_logging_level("WARNING")
- class BertTestBase:
- def _download_fancy_data(self):
- text = """
- An original sentence not subject to any license restrictions, copyright, or royalty payments. Nothing to see here. Commercial or non-commercial use. Research or non-research purposes. The quick brown fox jumps over the lazy dog. Lorem ipsum.
- """
- text = text * 1024
- encoded = text.encode("ascii", "replace")
- ints = [int(encoded[i]) for i in range(len(encoded))]
- return torch.tensor(ints)
- # build a batch given sequence_len and batch size
- def _generate_fancy_data_labels(self, sequence_len, batch_size):
- temps = []
- for i in range(batch_size):
- if self.inds is None or self.data_idx >= len(self.inds):
- # hack as use of RNG will fall out of sync due to pipelines being different
- torch.manual_seed(self.MANUAL_SEED)
- self.inds = torch.randperm(
- self.effective_length, device="cuda")
- self.masks = (
- torch.rand(
- len(self.inds) // batch_size + 1, batch_size, sequence_len, device="cuda"
- )
- >= self.MASK_PROB
- ).long()
- self.MANUAL_SEED += 1
- self.data_idx = 0
- if self.rank == 0:
- print("new epoch", len(self.inds))
- print("my start", self.inds[0:5])
- print("masks_checksum:", torch.sum(self.masks))
- if self.EASY_MODE:
- data_idx_ = self.data_idx % self.EASY_MODE_SIZ
- else:
- data_idx_ = self.data_idx
- offset = self.inds[data_idx_] # * SEQUENCE_LEN
- self.data_idx += 1
- curr = self.fancy_data[offset: offset +
- sequence_len].clone().detach()
- temps.append(curr)
- temp = torch.stack(temps, dim=0).cuda()
- mask = self.masks[self.data_idx // batch_size]
- mask_not = torch.logical_not(mask).long()
- data = mask * temp + mask_not * 124
- label = temp
- if parallel_state.get_tensor_model_parallel_rank() == 0:
- data_dict = {"text": data, "label": label, "mask_not": mask_not}
- else:
- data_dict = None
- keys = ["text", "label", "mask_not"]
- broadcasted_data = tensor_parallel.broadcast_data(
- keys, data_dict, torch.long)
- return (
- broadcasted_data["text"].long(),
- broadcasted_data["label"].long(),
- broadcasted_data["mask_not"],
- )
- def _fwd_step_func(self, batch, model):
- data, label, loss_mask = batch
- y = model(data, torch.ones_like(data), lm_labels=label)
- def loss_func(output_tensor):
- output_tensor, _ = output_tensor
- lm_loss_ = output_tensor.float()
- lm_loss = torch.sum(lm_loss_.view(-1) *
- loss_mask.reshape(-1)) / loss_mask.sum()
- averaged_loss = average_losses_across_data_parallel_group([
- lm_loss])
- if self.data_idx >= 1536:
- # NOTE (patwang): Loss cutoff might be excessively high but roughly one in five
- # unlucky random seeds do cause loss to spike to just under 8.0
- self.assertLess(averaged_loss, 8.0)
- return lm_loss, {"avg": averaged_loss}
- return y, loss_func
- def _train(
- self, model, optim, virtual_pipeline_model_parallel_size, pipeline_model_parallel_size, async_comm
- ):
- args = global_vars.get_args()
- sequence_len = args.seq_length
- micro_batch_size = args.micro_batch_size
- hidden_size = args.hidden_size
- global_batch_size = args.global_batch_size
- forward_backward_func = get_forward_backward_func(
- virtual_pipeline_model_parallel_size, pipeline_model_parallel_size
- )
- tensor_shape = (sequence_len, micro_batch_size, hidden_size)
- for _ in range(16):
- batch = self._generate_fancy_data_labels(
- sequence_len, global_batch_size)
- optim.zero_grad()
- forward_backward_func(
- self._fwd_step_func,
- batch,
- model,
- forward_only=False,
- tensor_shape=tensor_shape,
- async_comm=async_comm,
- sequence_parallel_enabled=args.sequence_parallel,
- )
- # All-reduce layernorm parameters across model parallel nodes
- # when sequence parallelism is used
- if parallel_state.get_tensor_model_parallel_world_size() > 1 and args.sequence_parallel:
- for model_module in model:
- unwrapped_model = unwrap_model(model_module)
- for param in unwrapped_model.parameters():
- if getattr(param, 'sequence_parallel_enabled', False):
- grad = param.grad
- torch.distributed.all_reduce(
- grad, group=parallel_state.get_tensor_model_parallel_group())
- optim.step()
- @unittest.skipUnless(torch.cuda.device_count() > 2, "requires at least 3 gpus")
- def test_bert_without_interleaving(self):
- self._test_bert(virtual_pipeline_model_parallel_size=None)
- @unittest.skipUnless(torch.cuda.device_count() > 2, "requires at least 3 gpus")
- def test_bert_with_interleaving(self):
- if self.DISTRIBUTED_BACKEND == 'ucc':
- self.skipTest('skip interleaving with ucc')
- self._test_bert(virtual_pipeline_model_parallel_size=2)
- def _test_bert(self, virtual_pipeline_model_parallel_size):
- self.MANUAL_SEED = 42
- self.inds = None
- self.masks = None
- self.data_idx = 0
- self.MASK_PROB = 0.1
- self.EASY_MODE = False
- self.EASY_MODE_SIZ = 32
- tensor_model_parallel_size = 2 if self.world_size % 2 == 0 and self.world_size > 4 else 1
- pipeline_model_parallel_size = self.world_size // tensor_model_parallel_size
- override_args = {
- "micro_batch_size": 2,
- "num_layers": 16,
- "hidden_size": 256,
- "num_attention_heads": 8,
- "max_position_embeddings": 512,
- "seq_length": 512,
- "global_batch_size": 128,
- "pipeline_model_parallel_size": pipeline_model_parallel_size,
- "tensor_model_parallel_size": tensor_model_parallel_size,
- "bert_binary_head": False,
- "world_size": self.world_size,
- "rank": self.rank,
- }
- global_vars.set_global_variables(override_args=override_args, ignore_unknown_args=True)
- args = global_vars.get_args()
- self.fancy_data = self._download_fancy_data()
- self.effective_length = self.fancy_data.size(0) // args.seq_length
- self.effective_length = self.fancy_data.size(0) - args.seq_length
- if self.rank == 0:
- print(
- f'testing backend: {self.DISTRIBUTED_BACKEND} with virtual_pipeline_model_parallel_size: {virtual_pipeline_model_parallel_size}')
- async_comm = not args.sequence_parallel and virtual_pipeline_model_parallel_size is None
- self.data_idx = 0
- args.padded_vocab_size = 128 # needed in standalone gpt
- args.model_type = ModelType.encoder_or_decoder
- setup_microbatch_calculator(
- args.rank,
- args.rampup_batch_size,
- args.global_batch_size,
- args.micro_batch_size,
- args.data_parallel_size,
- )
- parallel_state.initialize_model_parallel(
- args.tensor_model_parallel_size,
- args.pipeline_model_parallel_size,
- virtual_pipeline_model_parallel_size,
- default_backend="nccl",
- p2p_backend=self.DISTRIBUTED_BACKEND,
- )
- tensor_parallel.random.model_parallel_cuda_manual_seed(0)
- model = build_model(
- bert_model_provider,
- wrap_with_ddp=parallel_state.get_data_parallel_world_size() > 1,
- virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
- cpu_offload=args.cpu_offload,
- )
- assert isinstance(model, list)
- assert len(model) == (
- 1
- if virtual_pipeline_model_parallel_size is None
- else virtual_pipeline_model_parallel_size
- )
- _param_groups = _get_params_for_weight_decay_optimization(model)
- optim = torch.optim.Adam(_param_groups)
- self._train(
- model,
- optim,
- virtual_pipeline_model_parallel_size,
- args.pipeline_model_parallel_size,
- async_comm,
- )
- torch.cuda.synchronize()
- class NcclBertTest(BertTestBase, NcclDistributedTestBase):
- @property
- def world_size(self) -> int:
- return min(torch.cuda.device_count(), 8)
- @unittest.skipUnless(HAS_UCC, "requires pytorch to be built with native ucc")
- class UccBertTest(BertTestBase, UccDistributedTestBase):
- @property
- def world_size(self) -> int:
- return min(torch.cuda.device_count(), 8)
- if __name__ == "__main__":
- common_utils.run_tests()
|