test_bert_minimal.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. import torch
  2. import unittest
  3. from apex.transformer.testing import global_vars
  4. from apex.transformer.testing.standalone_bert import bert_model_provider
  5. from apex.transformer.pipeline_parallel.schedules.common import (
  6. _get_params_for_weight_decay_optimization, build_model
  7. )
  8. from apex.transformer.pipeline_parallel.schedules import get_forward_backward_func
  9. from apex.transformer.pipeline_parallel.utils import (
  10. average_losses_across_data_parallel_group, unwrap_model, setup_microbatch_calculator
  11. )
  12. from apex.transformer.log_util import set_logging_level
  13. from apex.transformer import tensor_parallel, parallel_state
  14. from apex.transformer.enums import ModelType
  15. from apex.transformer._ucc_util import HAS_UCC
  16. from apex.transformer.testing.distributed_test_base import UccDistributedTestBase, NcclDistributedTestBase
  17. import logging
  18. from torch.testing._internal import common_utils
  19. logging.getLogger("torch").setLevel(logging.WARNING)
  20. logging.getLogger("apex").setLevel(logging.WARNING)
  21. set_logging_level("WARNING")
  22. class BertTestBase:
  23. def _download_fancy_data(self):
  24. text = """
  25. An original sentence not subject to any license restrictions, copyright, or royalty payments. Nothing to see here. Commercial or non-commercial use. Research or non-research purposes. The quick brown fox jumps over the lazy dog. Lorem ipsum.
  26. """
  27. text = text * 1024
  28. encoded = text.encode("ascii", "replace")
  29. ints = [int(encoded[i]) for i in range(len(encoded))]
  30. return torch.tensor(ints)
  31. # build a batch given sequence_len and batch size
  32. def _generate_fancy_data_labels(self, sequence_len, batch_size):
  33. temps = []
  34. for i in range(batch_size):
  35. if self.inds is None or self.data_idx >= len(self.inds):
  36. # hack as use of RNG will fall out of sync due to pipelines being different
  37. torch.manual_seed(self.MANUAL_SEED)
  38. self.inds = torch.randperm(
  39. self.effective_length, device="cuda")
  40. self.masks = (
  41. torch.rand(
  42. len(self.inds) // batch_size + 1, batch_size, sequence_len, device="cuda"
  43. )
  44. >= self.MASK_PROB
  45. ).long()
  46. self.MANUAL_SEED += 1
  47. self.data_idx = 0
  48. if self.rank == 0:
  49. print("new epoch", len(self.inds))
  50. print("my start", self.inds[0:5])
  51. print("masks_checksum:", torch.sum(self.masks))
  52. if self.EASY_MODE:
  53. data_idx_ = self.data_idx % self.EASY_MODE_SIZ
  54. else:
  55. data_idx_ = self.data_idx
  56. offset = self.inds[data_idx_] # * SEQUENCE_LEN
  57. self.data_idx += 1
  58. curr = self.fancy_data[offset: offset +
  59. sequence_len].clone().detach()
  60. temps.append(curr)
  61. temp = torch.stack(temps, dim=0).cuda()
  62. mask = self.masks[self.data_idx // batch_size]
  63. mask_not = torch.logical_not(mask).long()
  64. data = mask * temp + mask_not * 124
  65. label = temp
  66. if parallel_state.get_tensor_model_parallel_rank() == 0:
  67. data_dict = {"text": data, "label": label, "mask_not": mask_not}
  68. else:
  69. data_dict = None
  70. keys = ["text", "label", "mask_not"]
  71. broadcasted_data = tensor_parallel.broadcast_data(
  72. keys, data_dict, torch.long)
  73. return (
  74. broadcasted_data["text"].long(),
  75. broadcasted_data["label"].long(),
  76. broadcasted_data["mask_not"],
  77. )
  78. def _fwd_step_func(self, batch, model):
  79. data, label, loss_mask = batch
  80. y = model(data, torch.ones_like(data), lm_labels=label)
  81. def loss_func(output_tensor):
  82. output_tensor, _ = output_tensor
  83. lm_loss_ = output_tensor.float()
  84. lm_loss = torch.sum(lm_loss_.view(-1) *
  85. loss_mask.reshape(-1)) / loss_mask.sum()
  86. averaged_loss = average_losses_across_data_parallel_group([
  87. lm_loss])
  88. if self.data_idx >= 1536:
  89. # NOTE (patwang): Loss cutoff might be excessively high but roughly one in five
  90. # unlucky random seeds do cause loss to spike to just under 8.0
  91. self.assertLess(averaged_loss, 8.0)
  92. return lm_loss, {"avg": averaged_loss}
  93. return y, loss_func
  94. def _train(
  95. self, model, optim, virtual_pipeline_model_parallel_size, pipeline_model_parallel_size, async_comm
  96. ):
  97. args = global_vars.get_args()
  98. sequence_len = args.seq_length
  99. micro_batch_size = args.micro_batch_size
  100. hidden_size = args.hidden_size
  101. global_batch_size = args.global_batch_size
  102. forward_backward_func = get_forward_backward_func(
  103. virtual_pipeline_model_parallel_size, pipeline_model_parallel_size
  104. )
  105. tensor_shape = (sequence_len, micro_batch_size, hidden_size)
  106. for _ in range(16):
  107. batch = self._generate_fancy_data_labels(
  108. sequence_len, global_batch_size)
  109. optim.zero_grad()
  110. forward_backward_func(
  111. self._fwd_step_func,
  112. batch,
  113. model,
  114. forward_only=False,
  115. tensor_shape=tensor_shape,
  116. async_comm=async_comm,
  117. sequence_parallel_enabled=args.sequence_parallel,
  118. )
  119. # All-reduce layernorm parameters across model parallel nodes
  120. # when sequence parallelism is used
  121. if parallel_state.get_tensor_model_parallel_world_size() > 1 and args.sequence_parallel:
  122. for model_module in model:
  123. unwrapped_model = unwrap_model(model_module)
  124. for param in unwrapped_model.parameters():
  125. if getattr(param, 'sequence_parallel_enabled', False):
  126. grad = param.grad
  127. torch.distributed.all_reduce(
  128. grad, group=parallel_state.get_tensor_model_parallel_group())
  129. optim.step()
  130. @unittest.skipUnless(torch.cuda.device_count() > 2, "requires at least 3 gpus")
  131. def test_bert_without_interleaving(self):
  132. self._test_bert(virtual_pipeline_model_parallel_size=None)
  133. @unittest.skipUnless(torch.cuda.device_count() > 2, "requires at least 3 gpus")
  134. def test_bert_with_interleaving(self):
  135. if self.DISTRIBUTED_BACKEND == 'ucc':
  136. self.skipTest('skip interleaving with ucc')
  137. self._test_bert(virtual_pipeline_model_parallel_size=2)
  138. def _test_bert(self, virtual_pipeline_model_parallel_size):
  139. self.MANUAL_SEED = 42
  140. self.inds = None
  141. self.masks = None
  142. self.data_idx = 0
  143. self.MASK_PROB = 0.1
  144. self.EASY_MODE = False
  145. self.EASY_MODE_SIZ = 32
  146. tensor_model_parallel_size = 2 if self.world_size % 2 == 0 and self.world_size > 4 else 1
  147. pipeline_model_parallel_size = self.world_size // tensor_model_parallel_size
  148. override_args = {
  149. "micro_batch_size": 2,
  150. "num_layers": 16,
  151. "hidden_size": 256,
  152. "num_attention_heads": 8,
  153. "max_position_embeddings": 512,
  154. "seq_length": 512,
  155. "global_batch_size": 128,
  156. "pipeline_model_parallel_size": pipeline_model_parallel_size,
  157. "tensor_model_parallel_size": tensor_model_parallel_size,
  158. "bert_binary_head": False,
  159. "world_size": self.world_size,
  160. "rank": self.rank,
  161. }
  162. global_vars.set_global_variables(override_args=override_args, ignore_unknown_args=True)
  163. args = global_vars.get_args()
  164. self.fancy_data = self._download_fancy_data()
  165. self.effective_length = self.fancy_data.size(0) // args.seq_length
  166. self.effective_length = self.fancy_data.size(0) - args.seq_length
  167. if self.rank == 0:
  168. print(
  169. f'testing backend: {self.DISTRIBUTED_BACKEND} with virtual_pipeline_model_parallel_size: {virtual_pipeline_model_parallel_size}')
  170. async_comm = not args.sequence_parallel and virtual_pipeline_model_parallel_size is None
  171. self.data_idx = 0
  172. args.padded_vocab_size = 128 # needed in standalone gpt
  173. args.model_type = ModelType.encoder_or_decoder
  174. setup_microbatch_calculator(
  175. args.rank,
  176. args.rampup_batch_size,
  177. args.global_batch_size,
  178. args.micro_batch_size,
  179. args.data_parallel_size,
  180. )
  181. parallel_state.initialize_model_parallel(
  182. args.tensor_model_parallel_size,
  183. args.pipeline_model_parallel_size,
  184. virtual_pipeline_model_parallel_size,
  185. default_backend="nccl",
  186. p2p_backend=self.DISTRIBUTED_BACKEND,
  187. )
  188. tensor_parallel.random.model_parallel_cuda_manual_seed(0)
  189. model = build_model(
  190. bert_model_provider,
  191. wrap_with_ddp=parallel_state.get_data_parallel_world_size() > 1,
  192. virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
  193. cpu_offload=args.cpu_offload,
  194. )
  195. assert isinstance(model, list)
  196. assert len(model) == (
  197. 1
  198. if virtual_pipeline_model_parallel_size is None
  199. else virtual_pipeline_model_parallel_size
  200. )
  201. _param_groups = _get_params_for_weight_decay_optimization(model)
  202. optim = torch.optim.Adam(_param_groups)
  203. self._train(
  204. model,
  205. optim,
  206. virtual_pipeline_model_parallel_size,
  207. args.pipeline_model_parallel_size,
  208. async_comm,
  209. )
  210. torch.cuda.synchronize()
  211. class NcclBertTest(BertTestBase, NcclDistributedTestBase):
  212. @property
  213. def world_size(self) -> int:
  214. return min(torch.cuda.device_count(), 8)
  215. @unittest.skipUnless(HAS_UCC, "requires pytorch to be built with native ucc")
  216. class UccBertTest(BertTestBase, UccDistributedTestBase):
  217. @property
  218. def world_size(self) -> int:
  219. return min(torch.cuda.device_count(), 8)
  220. if __name__ == "__main__":
  221. common_utils.run_tests()